递归的 lambda 函数

发布于 2022-10-30  180 次阅读


Motivation: 有时候可能在 LeetCode 上遇见一些需要 DFS 的题目,为了做到相应的递归,要么就要把许多变量存成类的成员变量,要么就要在递归函数签名上加上一大堆的引用参数,让递归的一些操作能把结果传递出来。无论是哪种操作,都是很麻烦的。那么有没有更简洁的办法呢?答案是肯定的。下文以典型的斐波那契数列为例,介绍一些基于 lambda 函数构造递归的技术。

需要说明的是,因为斐波那契数列的函数本身很简单,本文的许多技术看起来似乎就是多此一举的。这是因为暴力递归的斐波那契函数不需要保存任何状态。因此,考虑如何将以下各种实现改造成记忆化递归的版本(需要保存状态)。对于普通的递归函数,如果需要保存状态,需要借助于全局变量(如果该函数定义在顶层)或类的成员变量(如果该函数定义为某个成员函数);如果对于全局变量/成员变量的随意增删有洁癖,一个可能更好的写法就是把状态以引用的方式写在函数参数里(作为输出参数),然后在调用时传递给函数。而如果使用 lambda 函数,则可以直接将相应的状态定义在任何位置(最方便的位置就是调用栈上),然后利用引用捕获,几乎不需要任何修改。

不使用 Lambda 函数的递归

int64_t FibNormal(int n) {
  if (n == 0 || n == 1) {
    return 1;
  }
  return FibNormal(n - 1) + FibNormal(n - 2);
}

一个直接但错误的 lambda 递归

那么一个直观的想法就是:直接像写递归函数一样写 lambda 不就好了?基于这种想法,可能会写出这种东西来:

// Won't compile.
auto NaiveWrongLambda = [](int n) {
  if (n == 0 || n == 1) {
    return 1;
  }
  return NaiveWrongLambda(n - 1) + NaiveWrongLambda(n - 2);
};

这段代码是非法的,并不能通过编译,这是因为等号右边那一大块 lambda 函数定义完成之后才会绑定给 NaiveWrongLambda,因此在该函数内调用 NaiveWrongLambda 等于引用了一个未定义的函数。

What If C++ 23

如果我们有 C++23,利用 C++23 的 deducing this 特性,可以让编译器推到目前我们正在定义的这个函数。

// Define lambda
auto FibCxx23 = [](this auto && self, int n) {
   if(n == 0 || n == 1){
     return 1;
   }
   return self(n - 1) + self(n - 2);
}

// Use lambda
std::cout<< FibCxx23(5) << std::endl;

这段代码是可以在支持 C++23 的编译器上成功编译运行的。当然,如果 GCC 或 Clang 版本过低,这个代码是会报错的。最简单的办法是直接打开(宇宙第一 IDE)Visual Studio,打开 C++23 选项,即可成功编译。

不过,要等到 C++23 的支持成为主流可能需要等到猴年马月了,上述的黑魔法根本不能在实际的工作中玩起来。我们还是需要基于 C++14/C++17 的标准来完成这件事。

下文中涉及的与 auto 相关的推导特性基本都是来源于 C++17,如果使用 C++14 标准甚至于更低版本的 C++11,许多地方需要换成相应的更冗长的 decltype(auto) 或者相应的具体的类型。

引入冗余的参数

其实,参照我们经常在 Python 里面做的事情,可以给函数引入一个冗余的参数,表示它自身。于是,我们可以写出这种东西:

auto SelfRefFib = [](auto &&self, int n) {
  if (n == 0 || n == 1) {
    return 1;
  }
  return self(self, n - 1) + self(self, n - 2);
};

// Use lambda
std::cout<< SelfRefFib(SelfRefFib, 5) << std::endl;

这玩意当然是可以正常递归的了,只不过使用它的时候始终需要额外带一个参数。当然,我们可以再用一个 lambda 来把它包装一层:

auto FibWrap = [](int n) { return SelfRefFib(SelfRefFib, n); };

// Use lambda
std::cout<< FibWrap(5) << std::endl;

我们也可以不用分成两部分,可以一步到位地写成:

auto FibWrapV2 = [](int n) {
  auto FibImpl = [](auto &&self, int n) {
    if (n == 0 || n == 1) {
      return 1;
    }
    return self(self, n - 1) + self(self, n - 2);
  };
  return FibImpl(FibImpl, n);
};

(怎么样,是不是有 JS 那味了🤣?)

基于 Y 组合子的构造

实际上,在上一节我们最后使用的那个,把一个带冗余参数的函数封装为一个普通的函数的过程,是一个通用的过程,各种各样的递归函数都可以通过这种方式构造出相应的 lambda 函数:

  1. 写一个带冗余参数的 lambda,该冗余参数表示正在定义的函数本身。
  2. 用一个包装函数,隐藏这个冗余参数。

实际上,这就是 lambda 演算中大名鼎鼎的 Y 组合子。我们真正想要的递归函数,就是那个所谓的“带冗余参数的函数”的不动点。我们可以使用模板类,把上面的 FibWrapFibWrapV2 这种包装过程给统一实现出来(ref):

template <typename F>
class YCombinator {
 public:
  explicit YCombinator(F &&f) : f_(std::forward<F>(f)) {}

  template <typename... Args>
  auto operator()(Args &&...args) const {
    return f_(*this, std::forward<Args>(args)...);
  }

 private:
  std::decay_t<F> f_;
};

template <typename F>
auto MakeYCombinator(F &&f) {
  return YCombinator(std::forward<F>(f));
}

此时,利用这个 Y 组合子,就可以直接简单地构造递归 lambda:

auto FibYComb = MakeYCombinator([](auto &&self, int n) {
  if (n == 0 || n == 1) {
    return 1;
  }
  return self(n - 1) + self(n - 2);
});

// Use lambda:
std::cout<< FibYComb(5) << std::endl;

这里有个小问题,lambda 函数的返回值在这种情况下有时候可能推导不出来,需要手动指定:[]() -> T {}

实战演练

我们随便抓一道 LeetCode 上通过 DFS 来实现的题目 2458,用我们的组合拳一套搞定:

#include <bits/stdc++.h>

template <typename F>
class YCombinator {
 public:
  explicit YCombinator(F &&f) : f_(std::forward<F>(f)) {}

  template <typename... Args>
  auto operator()(Args &&...args) const {
    return f_(*this, std::forward<Args>(args)...);
  }

 private:
  std::decay_t<F> f_;
};

template <typename F>
auto MakeYCombinator(F &&f) {
  return YCombinator(std::forward<F>(f));
}

// struct TreeNode {
//   int val;
//   TreeNode *left;
//   TreeNode *right;
//   TreeNode() : val(0), left(nullptr), right(nullptr) {}
//   TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
//   TreeNode(int x, TreeNode *left, TreeNode *right)
//       : val(x), left(left), right(right) {}
// };

class Solution {
 public:
  std::vector<int> treeQueries(TreeNode *root, std::vector<int> &queries) {
    hmap<TreeNode *, int> height;
    hmap<int, TreeNode *> node_by_id;
    hmap<TreeNode *, int> level;
    hmap<int, std::map<int, int>> level_dist;

    auto Dfs =
        MakeYCombinator([&](auto &&self, TreeNode *root, int cur_level) -> int {
          if (root == nullptr) {
            return -1;
          }
          node_by_id[root->val] = root;
          level[root] = cur_level;
          int l_height = self(root->left, cur_level + 1),
              r_height = self(root->right, cur_level + 1);
          int h = height[root] = 1 + std::max(l_height, r_height);

          if (level_dist.count(cur_level) == 0) {
            level_dist[cur_level] = {};
          }

          auto &dist = level_dist[cur_level];
          if (dist.count(h) == 0) {
            dist[h] = 0;
          }
          dist[h] += 1;
          return h;
        });

    Dfs(root, 0);

    int h = height[root];
    std::vector<int> answer;
    for (int query : queries) {
      auto cur = node_by_id[query];
      int cur_h = -1;
      int cur_level = level[cur];
      auto &dist = level_dist[cur_level];
      if (dist.rbegin()->first > height[cur] ||
          (dist.rbegin()->first == height[cur] && dist[height[cur]] > 1)) {
        cur_h = h;
      } else if (dist.size() == 1) {
        cur_h = h - height[cur] - 1;
      } else {
        cur_h = h - height[cur] + std::next(dist.rbegin())->first;
      }
      answer.push_back(cur_h);
    }
    return answer;
  }

 private:
  template <typename K, typename V>
  using hmap = std::unordered_map<K, V>;
};

结语

我们做这一切的出发点是:利用闭包(lambda)的引用捕获来化简递归函数的实现。通过探索一些手动的封装,最终基于 Y 组合子给出了一个通用的递归 lambda 函数构造。不禁感慨:现代 C++ 还挺好玩的。


终有一日, 仰望星空