快速排序简洁实现
QuickSort 快速排序是常见的考察代码基本功面试题。简洁易读的实现可以一定程度展现面试者的代码功底。
算法导论对于快排讲解的很透彻,也有伪代码,即便如此,网上许多实现还是错的,不能通过一些边界用例,如数组已排序或数组中有重复元素的情况。还有一些用python的实现不满足原位排序(in-place)的要求,直接新建两个新List,这种方法其实也是取巧,简化了partition函数的实现难度。
快排框架
快排基本框架实现很好地体现了分治的思想:
void qsort(vector<int> &v, int begin, int end)
{
if (begin >= end) return;
int p = partition(v, begin, end);
qsort(v, begin, p - 1);
qsort(v, p + 1, end);
}
partition算法实现
单边扫描
快排的核心和难度在于partition算法,这里参考算法导论第三版的实现。基本思想:将最右元素作为pivot,对整个子数组进行分割,分割后的数据满足在pivot左边的元素都小于等于(no greater than) pivot,pivot右边的元素都大于(greater than) pivot。话不多说,先看实现:
int partition(vector<int> &v, int begin, int end)
{
int pivot = v[end], i = begin - 1;
for (int j = begin; j <= end - 1; ++j)
{
if (v[j] <= pivot) swap(v[j], v[++i]);
}
swap(v[i + 1], v[end]);
return i + 1;
}
代码很短但并不那么直观,引用算法导论的一个例子看起来比较容易理解,将子数组从自至右分为四个区域:小于等于pivot的,大于pivot的,未扫描到的,pivot本身即最右端元素。四个指针:
变量 | 意义 |
---|---|
i | i(含)左边的元素都小于等于pivot |
j | 当前扫描到的元素 |
p | begin |
r | end |
i和j两个指针将[begin, end - 1]区间分成了三个部分(位置j待定):
区间 | 意义 |
---|---|
[begin, i] | 元素都小于等于pivot |
[i + 1, j) | 元素都大于pivot |
(j, end - 1) | 待扫描 |
此算法巧妙地处理了最终pivot所在位置的问题,即i + 1。并且,在扫描到需要交换的元素(即<=pivot)时,比如状态(e),交换i和j位置的元素,可以想像成pivot两侧区域自然增长的状态。
上述实现是《算法导论》的原始算法,实际上将i初始为begin更好些(++i改为i++):
int partition(vector<int> &v, int begin, int end)
{
int pivot = v[end], i = begin;
for (int j = begin; j <= end - 1; ++j)
{
if (v[j] <= pivot) swap(v[j], v[i++]);
}
swap(v[i], v[end]);
return i;
}
双边扫描
在实现算法导论的算法之前,手写了一个双边扫描的代码,debug了很长时间,原因在于最终pivot值的边界条件很难处理。暂时没有想到更好的实现,此实现读起来有点难懂,不推荐。
int partition(vector<int> &v, int begin, int end)
{
int pivot = v[end], l = begin, r = end - 1;
while (l <= r)
{
while(v[l] < pivot && l <= r) ++l;
while(v[r] >= pivot && l <= r) --r;
if (l >= r) break;
swap(v[l], v[r]);
}
swap(v[l], v[end]);
return l;
}
随机选择pivot
此处实现并未在partition函数中引入随机选择pivot元素,引入的话可在parition函数头部添加两行:
int r = begin + rand() % (end - begin);
swap(v[r], v[end]);
完整代码
完整代码如下,测试排序一个含重复元素的数组的全排列:
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
int partition(vector<int> &v, int begin, int end)
{
// int r = begin + rand() % (end - begin);
// swap(v[r], v[end]);
int pivot = v[end], i = begin;
for (int j = begin; j <= end - 1; ++j)
{
if (v[j] <= pivot) swap(v[j], v[i++]);
}
swap(v[i], v[end]);
return i;
}
void qsort(vector<int> &v, int begin, int end)
{
if (begin >= end) return;
int p = partition(v, begin, end);
qsort(v, begin, p - 1);
qsort(v, p + 1, end);
}
void qsort(vector<int> &v)
{
qsort(v, 0, v.size() - 1);
}
void print(vector<int> &v)
{
for (int i = 0; i < v.size(); ++i)
{
cout<<v[i]<<" ";
}
cout<<endl;
}
bool check(vector<int> &v)
{
for (int i = 0; i < v.size() - 1; ++i)
{
if (v[i] > v[i + 1]) return false;
}
return true;
}
int main()
{
vector<int> v{1, 2, 2, 3, 3, 3, 4, 5};
int c = 0;
while(next_permutation(v.begin(), v.end()))
{
++c;
vector<int> a(v);
qsort(a);
if (check(a) == false)
{
cout<<"Wrong: "<<endl;
print(v);
cout<<"-->"<<endl;
print(a);
break;
}
}
cout<<"Tested "<<c<<" cases"<<endl;
}