分治算法精讲:从归并排序到快速选择

分治算法的本质
🎯 分治算法核心三步骤
1
分解 Divide
将复杂问题分解为多个规模较小的相同子问题
2
解决 Conquer
递归地解决各个子问题,基线条件直接求解
3
合并 Combine
将子问题的解合并构建原问题的解
分治(Divide and Conquer)是算法设计中最重要的思想之一。它将复杂问题分解成若干个规模较小的相同问题,递归地解决这些子问题,然后将子问题的解合并得到原问题的解。这种”分而治之”的策略,在许多经典算法中都有体现。
分治算法的三个步骤
步骤 | 英文 | 说明 | 关键点 |
---|---|---|---|
分解 | Divide | 将问题分解成若干子问题 | 子问题应该是原问题的较小实例 |
解决 | Conquer | 递归地解决子问题 | 子问题足够小时直接求解 |
合并 | Combine | 将子问题的解合并 | 合并过程的效率很关键 |
经典分治算法详解
1. 归并排序(Merge Sort)
归并排序是分治思想的典型应用,时间复杂度稳定O(nlogn)。
// 合并两个有序数组
void merge(vector<int>& arr, int left, int mid, int right) {
vector<int> temp(right - left + 1);
int i = left, j = mid + 1, k = 0;
// 合并两个有序部分
while (i <= mid && j <= right) {
if (arr[i] <= arr[j]) {
temp[k++] = arr[i++];
} else {
temp[k++] = arr[j++];
}
}
// 处理剩余元素
while (i <= mid) temp[k++] = arr[i++];
while (j <= right) temp[k++] = arr[j++];
// 复制回原数组
for (int i = 0; i < temp.size(); i++) {
arr[left + i] = temp[i];
}
}
void mergeSort(vector<int>& arr, int left, int right) {
if (left >= right) return;
int mid = left + (right - left) / 2;
// 分治
mergeSort(arr, left, mid);
mergeSort(arr, mid + 1, right);
// 合并
merge(arr, left, mid, right);
}
// 优化版本:小数组使用插入排序
void mergeSortOptimized(vector<int>& arr, int left, int right) {
if (right - left <= 15) {
// 小数组使用插入排序
for (int i = left + 1; i <= right; i++) {
int key = arr[i];
int j = i - 1;
while (j >= left && arr[j] > key) {
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
return;
}
int mid = left + (right - left) / 2;
mergeSortOptimized(arr, left, mid);
mergeSortOptimized(arr, mid + 1, right);
// 如果已经有序,跳过合并
if (arr[mid] <= arr[mid + 1]) return;
merge(arr, left, mid, right);
}
2. 快速排序(Quick Sort)
快速排序平均时间复杂度O(nlogn),最坏O(n²),但实际表现优秀。
// 三路快排(处理重复元素)
void quickSort3Way(vector<int>& arr, int left, int right) {
if (left >= right) return;
// 随机选择pivot
int pivotIdx = left + rand() % (right - left + 1);
swap(arr[left], arr[pivotIdx]);
int pivot = arr[left];
int lt = left; // arr[left+1...lt] < pivot
int gt = right + 1; // arr[gt...right] > pivot
int i = left + 1; // arr[lt+1...i) == pivot
while (i < gt) {
if (arr[i] < pivot) {
swap(arr[++lt], arr[i++]);
} else if (arr[i] > pivot) {
swap(arr[i], arr[--gt]);
} else {
i++;
}
}
swap(arr[left], arr[lt]);
quickSort3Way(arr, left, lt - 1);
quickSort3Way(arr, gt, right);
}
// 快速选择(找第k小元素)
int quickSelect(vector<int>& arr, int left, int right, int k) {
if (left == right) return arr[left];
// 随机选择pivot
int pivotIdx = left + rand() % (right - left + 1);
swap(arr[pivotIdx], arr[right]);
// 分区
int i = left;
for (int j = left; j < right; j++) {
if (arr[j] <= arr[right]) {
swap(arr[i++], arr[j]);
}
}
swap(arr[i], arr[right]);
// 判断位置
int count = i - left + 1;
if (count == k) {
return arr[i];
} else if (count > k) {
return quickSelect(arr, left, i - 1, k);
} else {
return quickSelect(arr, i + 1, right, k - count);
}
}
3. 最近点对问题
平面上n个点,找出距离最近的两个点,时间复杂度O(nlogn)。
struct Point {
double x, y;
};
double distance(const Point& a, const Point& b) {
return sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y));
}
// 暴力法:用于小规模问题
double bruteForce(vector<Point>& points, int left, int right) {
double minDist = DBL_MAX;
for (int i = left; i <= right; i++) {
for (int j = i + 1; j <= right; j++) {
minDist = min(minDist, distance(points[i], points[j]));
}
}
return minDist;
}
// 合并步骤:处理跨越中线的点对
double stripClosest(vector<Point>& strip, double d) {
double minDist = d;
// 按y坐标排序
sort(strip.begin(), strip.end(), [](const Point& a, const Point& b) {
return a.y < b.y;
});
// 只需要检查y坐标差小于d的点对
for (int i = 0; i < strip.size(); i++) {
for (int j = i + 1; j < strip.size() &&
(strip[j].y - strip[i].y) < minDist; j++) {
minDist = min(minDist, distance(strip[i], strip[j]));
}
}
return minDist;
}
double closestPairRecursive(vector<Point>& points, int left, int right) {
// 小规模问题直接暴力
if (right - left <= 3) {
return bruteForce(points, left, right);
}
int mid = left + (right - left) / 2;
Point midPoint = points[mid];
// 分治
double dl = closestPairRecursive(points, left, mid);
double dr = closestPairRecursive(points, mid + 1, right);
double d = min(dl, dr);
// 找出中线附近的点
vector<Point> strip;
for (int i = left; i <= right; i++) {
if (abs(points[i].x - midPoint.x) < d) {
strip.push_back(points[i]);
}
}
// 合并
return min(d, stripClosest(strip, d));
}
double closestPair(vector<Point>& points) {
// 按x坐标排序
sort(points.begin(), points.end(), [](const Point& a, const Point& b) {
return a.x < b.x;
});
return closestPairRecursive(points, 0, points.size() - 1);
}
4. 大整数乘法(Karatsuba算法)
将O(n²)的乘法优化到O(n^1.58)。
string addStrings(string num1, string num2) {
string result = "";
int carry = 0;
int i = num1.length() - 1;
int j = num2.length() - 1;
while (i >= 0 || j >= 0 || carry) {
int sum = carry;
if (i >= 0) sum += num1[i--] - 0;
if (j >= 0) sum += num2[j--] - 0;
result = char(sum % 10 + 0) + result;
carry = sum / 10;
}
return result;
}
string subtractStrings(string num1, string num2) {
// 假设num1 >= num2
string result = "";
int borrow = 0;
int i = num1.length() - 1;
int j = num2.length() - 1;
while (i >= 0) {
int diff = (num1[i] - 0) - borrow;
if (j >= 0) diff -= (num2[j--] - 0);
if (diff < 0) {
diff += 10;
borrow = 1;
} else {
borrow = 0;
}
result = char(diff + 0) + result;
i--;
}
// 去除前导零
size_t pos = result.find_first_not_of(0);
if (pos == string::npos) return "0";
return result.substr(pos);
}
string karatsuba(string num1, string num2) {
int n = max(num1.length(), num2.length());
// 补齐长度
while (num1.length() < n) num1 = "0" + num1;
while (num2.length() < n) num2 = "0" + num2;
// 基础情况
if (n == 1) {
int product = (num1[0] - 0) * (num2[0] - 0);
return to_string(product);
}
int mid = n / 2;
// 分割
string a = num1.substr(0, n - mid);
string b = num1.substr(n - mid);
string c = num2.substr(0, n - mid);
string d = num2.substr(n - mid);
// 三次递归乘法
string ac = karatsuba(a, c);
string bd = karatsuba(b, d);
string abcd = karatsuba(addStrings(a, b), addStrings(c, d));
// 计算ad + bc = (a+b)(c+d) - ac - bd
string adbc = subtractStrings(subtractStrings(abcd, ac), bd);
// 结果 = ac * 10^(2*mid) + (ad+bc) * 10^mid + bd
for (int i = 0; i < 2 * mid; i++) ac += "0";
for (int i = 0; i < mid; i++) adbc += "0";
return addStrings(addStrings(ac, adbc), bd);
}
分治算法的应用技巧
1. 逆序对计数
long long mergeCount(vector<int>& arr, vector<int>& temp, int left, int mid, int right) {
int i = left, j = mid + 1, k = left;
long long invCount = 0;
while (i <= mid && j <= right) {
if (arr[i] <= arr[j]) {
temp[k++] = arr[i++];
} else {
temp[k++] = arr[j++];
invCount += (mid - i + 1); // 关键:统计逆序对
}
}
while (i <= mid) temp[k++] = arr[i++];
while (j <= right) temp[k++] = arr[j++];
for (int i = left; i <= right; i++) {
arr[i] = temp[i];
}
return invCount;
}
long long countInversions(vector<int>& arr, vector<int>& temp, int left, int right) {
long long invCount = 0;
if (left < right) {
int mid = left + (right - left) / 2;
invCount += countInversions(arr, temp, left, mid);
invCount += countInversions(arr, temp, mid + 1, right);
invCount += mergeCount(arr, temp, left, mid, right);
}
return invCount;
}
2. 矩阵快速幂
vector<vector<long long>> matrixMultiply(
const vector<vector<long long>>& A,
const vector<vector<long long>>& B,
long long mod) {
int n = A.size();
vector<vector<long long>> C(n, vector<long long>(n, 0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod;
}
}
}
return C;
}
vector<vector<long long>> matrixPower(
vector<vector<long long>> base,
long long exp,
long long mod) {
int n = base.size();
vector<vector<long long>> result(n, vector<long long>(n, 0));
// 初始化为单位矩阵
for (int i = 0; i < n; i++) {
result[i][i] = 1;
}
while (exp > 0) {
if (exp & 1) {
result = matrixMultiply(result, base, mod);
}
base = matrixMultiply(base, base, mod);
exp >>= 1;
}
return result;
}
// 应用:斐波那契数列的第n项
long long fibonacci(long long n, long long mod) {
if (n <= 1) return n;
vector<vector<long long>> base = {{1, 1}, {1, 0}};
vector<vector<long long>> result = matrixPower(base, n - 1, mod);
return result[0][0];
}
3. 分治FFT(快速傅里叶变换)
const double PI = acos(-1);
struct Complex {
double real, imag;
Complex(double r = 0, double i = 0) : real(r), imag(i) {}
Complex operator+(const Complex& other) const {
return Complex(real + other.real, imag + other.imag);
}
Complex operator-(const Complex& other) const {
return Complex(real - other.real, imag - other.imag);
}
Complex operator*(const Complex& other) const {
return Complex(real * other.real - imag * other.imag,
real * other.imag + imag * other.real);
}
};
void FFT(vector<Complex>& a, bool inverse) {
int n = a.size();
if (n == 1) return;
// 分治
vector<Complex> even(n / 2), odd(n / 2);
for (int i = 0; i < n / 2; i++) {
even[i] = a[2 * i];
odd[i] = a[2 * i + 1];
}
FFT(even, inverse);
FFT(odd, inverse);
// 合并
double angle = 2 * PI / n * (inverse ? -1 : 1);
Complex w(1), wn(cos(angle), sin(angle));
for (int i = 0; i < n / 2; i++) {
Complex t = w * odd[i];
a[i] = even[i] + t;
a[i + n / 2] = even[i] - t;
if (inverse) {
a[i].real /= 2;
a[i].imag /= 2;
a[i + n / 2].real /= 2;
a[i + n / 2].imag /= 2;
}
w = w * wn;
}
}
// 多项式乘法
vector<int> polynomialMultiply(vector<int>& A, vector<int>& B) {
int n = 1;
while (n < A.size() + B.size() - 1) n <<= 1;
vector<Complex> fa(n), fb(n);
for (int i = 0; i < A.size(); i++) fa[i] = Complex(A[i]);
for (int i = 0; i < B.size(); i++) fb[i] = Complex(B[i]);
FFT(fa, false);
FFT(fb, false);
for (int i = 0; i < n; i++) {
fa[i] = fa[i] * fb[i];
}
FFT(fa, true);
vector<int> result(A.size() + B.size() - 1);
for (int i = 0; i < result.size(); i++) {
result[i] = round(fa[i].real);
}
return result;
}
分治算法的优化技巧
优化策略对比
优化技巧 | 适用场景 | 效果 |
---|---|---|
小规模直接求解 | 递归深度大时 | 减少递归开销 |
尾递归优化 | 单边递归 | 节省栈空间 |
记忆化 | 有重叠子问题 | 避免重复计算 |
并行化 | 子问题独立 | 利用多核性能 |
缓存优化 | 数据局部性好 | 提高缓存命中率 |
分治算法练习题
基础题目
题目 | 算法 | 难度 |
---|---|---|
数组中的第K大元素 | 快速选择 | ★★☆ |
逆序对计数 | 归并排序 | ★★☆ |
最大子数组和 | 分治法 | ★★☆ |
汉诺塔 | 递归分治 | ★☆☆ |
进阶题目
题目 | 算法 | 难度 |
---|---|---|
平面最近点对 | 分治+剪枝 | ★★★ |
线段树 | 分治维护 | ★★★ |
CDQ分治 | 三维偏序 | ★★★★ |
FFT多项式乘法 | 分治FFT | ★★★★ |
总结
分治算法是算法设计的基础范式,掌握分治思想需要:
- 理解本质:将大问题分解为小问题
- 掌握模式:分解-解决-合并三步骤
- 注意效率:合并步骤往往是关键
- 灵活应用:识别可以分治的问题特征
- 优化技巧:小问题直接求解,避免过度递归
分治不仅是一种算法,更是一种思维方式。通过将复杂问题简化,我们能够优雅地解决看似困难的问题。掌握分治算法,你将拥有解决复杂问题的利器!
Views: 0
答复