组合问题指的是从给定的 N 个数中找出任意 K 个数的所有组合。例如,从 [1, 2, 3, 4] 这 4 个数中找到任意 2 个数的所有组合,它们是:
1 2
1 3
1 4
2 3
2 4
3 4
注意,组合内的元素是没有顺序的,这就意味着 [1, 2] 和 [2, 1] 是同一个组合。因此从 [1, 2, 3, 4] 中找到任意 2 个数的组合,一共有 6 组。
再比如,从 [1,2,3,4] 这 4 个数中找到任意 3 个数的所有组合,它们是:
1 2 3
1 2 4
1 3 4
2 3 4
编程解决组合问题,多数读者最先想到的思路是用嵌套的循环结构。比如从 [1, 2, 3, 4] 中找到任意 2 个数的组合,对应的 C 语言实现代码为:
- #include <stdio.h>
- int main()
- {
- int nums[] = { 1,2,3,4 };
- // i 的值作为选择的第一个数
- for (int i = 0; i < 4; i++) {
- // j 的值作为选择的第二个数
- for (int j = i + 1; j < 4; j++) {
- printf("%d %d\n", nums[i], nums[j]);
- }
- }
- return 0;
- }
类似的,查找任意 3 个数的组合,需要用 3 个嵌套的循环结构实现。也就是说,从 N 个数中找到任意 K 个数的所有组合,用循环结构实现的话,需要用 K 个嵌套的循环结构。
嵌套的循环结构解决组合问题,只适合 K 值较小的情况,比如 K 是 2 或者 3。当 K 值较大时,比如 10、20 甚至更大,循环嵌套的层数非常多,代码变得复杂、难以维护,算法的时间复杂度也会呈指数级增长,导致效率极低,这种情况下更推荐用回溯算法解决组合问题。
回溯算法解决组合问题
假设在 {0, 1, 2, 3} 这 4 个数中找到任意 3 个数的所有组合,下面是回溯算法的整个实现思路:
从 {0,1,2,3} 里,选择元素 0;
从 {1,2,3} 里,选择元素 1;
从 {2,3} 里,选择元素 2,找到了 {0,1,2} 组合;
回溯到 {2,3},选择元素 3,找到了 {0,1,3} 组合;
回溯到 {2,3},没有元素可以选择,再次回溯;
回溯到 {1,2,3},选择元素 2;
从 {3} 里选择元素 3,找到了 {0,2,3} 组合;
回溯到 {3},没有元素可以选择,再次回溯;
回溯到 {1,2,3},只剩 3 可以选择,无法构成一个新的组合,直接回溯;
回溯到 {0,1,2,3},选择元素 1;
从 {2,3} 里,选择元素 2;
从 {3} 里选择元素 3,找到了 {1,2,3} 组合;
回溯到 {2,3},只剩 3 可以选择,无法构成一个新的组合,直接回溯;
回溯到 {0,1,2,3},只剩 2 和 3 可以选择,无法构成一个新的组合,算法执行结束。
仔细观察整个过程不难发现,每选择一个元素 e,下次选择是在 e 之后的元素里进行。比如选择了元素 1 和 2,下次选择就是在 {3} 中进行,而不是 {0, 3}。这样做的好处是,可以确保每种组合只找到一次,不会重复。
实现回溯算法,通常会采用递归的方式,下面是用回溯算法解决组合问题的 C 语言代码:
- #include <stdio.h>
- #include <stdlib.h>
- #define N 4 // 定义数组的大小
- #define K 2 // 定义组合的大小
- // 为实现回溯算法做前期准备的函数
- void printCombination(int arr[], int n, int k);
- // 实现用回溯算法解决组合问题的递归函数
- void combinationUtil(int arr[], int data[], int start, int end, int index, int k);
- int main() {
- int arr[N]; // 声明一个大小为 N 的数组
- // 初始化数组为 {1, 2, 3, 4}
- for (int i = 0; i < N; i++) {
- arr[i] = i + 1;
- }
- // 调用 printCombination 函数来打印所有可能的组合
- printCombination(arr, N, K);
- return 0;
- }
- // 用于初始化临时数组并开始组合过程
- void printCombination(int arr[], int n, int k) {
- // 动态创建一个数组,用于存储选择了的元素
- int* data = (int*)malloc(k * sizeof(int));
- // 使用回溯生成所有组合
- combinationUtil(arr, data, 0, n, 0, k);
- free(data); // 释放之前分配的内存
- }
- // 在 arr 数组中的[start,end)下标范围内,找到 k 个可选择的元素
- // data[] 数组用于存储被选择了的元素
- // index 参数用来统计被选择的元素个数
- void combinationUtil(int arr[], int data[], int start, int end, int index, int k) {
- // 如果已找到一个完整的组合
- if (index == k) {
- // 打印当前组合
- for (int j = 0; j < k; j++) {
- printf("%d ", data[j]);
- }
- printf("\n");
- return;
- }
- // 递归生成所有可能的组合
- for (int i = start; (i < end) && (end - i + 1 >= k - index); i++) {
- data[index] = arr[i]; // 将当前元素加入到当前组合中
- combinationUtil(arr, data, i + 1, end, index + 1, k); // 递归调用以选择下一个元素
- }
- }
程序中自定义了两个函数,真正实现“回溯算法解决组合问题”的是 combinationUtil() 函数,而 printCombination() 函数存在的意义是创建一个数组,用来存储回溯过程被选择的元素,以便后续输出符合要求的组合。
虽然 combinationUtil() 函数的参数数量比较多,但非常容易理解:
- [start, end):start 和 end 用来表示 arr 数组中的下标范围,也是可选择的元素范围。递归过程中,end 的值不变,但 start 的值是变化的;
- arr 和 k:arr 数组用于存储所有的元素;k 用于指定要选择的元素数量。整个递归过程中,arr 和 k 都是不变的。
- data 和 index:data 数组用来存储被选择了的元素;index 用来记录已经选择了的元素数量。
程序中,需要重点解释的是第 49 行的循环条件(i < end) && (end - i + 1 >= k - index),其中 (end - i + 1 >= k - index) 的含义是:
end - i + 1:计算的是从当前元素(包括当前元素)到数组末尾的元素数量,也就是从当前元素开始,数组中还剩下多少个元素可以被选择。k - index:表示为了达到指定的组合大小,还需要选择多少个元素。
如果可选元素的数量能满足需求,即end - i + 1 >= k - index成立,可以继续筛选;反之,如果可选元素的数量无法满足需求,说明当前情况下已经无法找到符合条件的组合,直接回溯即可。
通常情况下,多数读者只会想到i < end作为循环条件,程序也是能正常运行的。添加end - i + 1 >= k - index作为循环条件,可以过滤掉一些无效的、没必要存在的递归过程,进一步优化程序的运行效率。
下面是回溯算法解决组合问题的 Python 程序:
- def print_combination(arr, n, k):
- """
- 用于初始化临时列表并开始组合过程
- """
- # 创建一个列表,用于存储选择的元素
- data = [0] * k
- # 使用回溯生成所有组合
- combination_util(arr, data, 0, n, 0, k)
- def combination_util(arr, data, start, end, index, k):
- """
- 在 arr 列表中 [start, end) 下标范围内,找到 k 个可选择的元素
- data 列表用于存储被选择的元素
- index 参数用来统计被选择的元素个数
- """
- # 如果已找到一个完整的组合
- if index == k:
- # 打印当前组合
- for i in data:
- print(i, end=' ')
- print()
- return
- # 递归生成所有可能的组合
- for i in range(start, end):
- if end - i >= k - index:
- data[index] = arr[i] # 将当前元素加入到当前组合中
- combination_util(arr, data, i + 1, end, index + 1, k) # 递归调用以选择下一个元素
- # 定义列表的大小和组合的大小
- N = 4
- K = 2
- if __name__ == "__main__":
- # 声明一个大小为 N 的列表并初始化为 {1, 2, 3, 4}
- arr = [i + 1 for i in range(N)]
- # 调用 print_combination 函数来打印所有可能的组合
- print_combination(arr, N, K)
下面是回溯算法解决组合问题的 Java 程序:
- public class Combination {
- public static void main(String[] args) {
- final int N = 4; // 定义数组的大小
- final int K = 2; // 定义组合的大小
- int[] arr = new int[N];
- // 初始化数组为 {1, 2, 3, 4}
- for (int i = 0; i < N; i++) {
- arr[i] = i + 1;
- }
- // 调用 printCombination 函数来打印所有可能的组合
- printCombination(arr, N, K);
- }
- // 用于初始化临时数组并开始组合过程
- public static void printCombination(int[] arr, int n, int k) {
- int[] data = new int[k];
- // 使用回溯生成所有组合
- combinationUtil(arr, data, 0, n, 0, k);
- }
- // 在 arr 数组中的[start, end)下标范围内,找到 k 个可选择的元素
- // data 数组用于存储被选择了的元素
- // index 参数用来统计被选择的元素个数
- public static void combinationUtil(int[] arr, int[] data, int start, int end, int index, int k) {
- // 如果已找到一个完整的组合
- if (index == k) {
- // 打印当前组合
- for (int i = 0; i < k; i++) {
- System.out.print(data[i] + " ");
- }
- System.out.println();
- return;
- }
- // 递归生成所有可能的组合
- for (int i = start; (i < end) && (end - i + 1 >= k - index); i++) {
- data[index] = arr[i]; // 将当前元素加入到当前组合中
- combinationUtil(arr, data, i + 1, end, index + 1, k); // 递归调用以选择下一个元素
- }
- }
- }
运行结果,结果为:
1 2
1 3
1 4
2 3
2 4
3 4