快速排序简洁实现

QuickSort 快速排序是常见的考察代码基本功面试题。简洁易读的实现可以一定程度展现面试者的代码功底。

算法导论对于快排讲解的很透彻,也有伪代码,即便如此,网上许多实现还是错的,不能通过一些边界用例,如数组已排序或数组中有重复元素的情况。还有一些用python的实现不满足原位排序(in-place)的要求,直接新建两个新List,这种方法其实也是取巧,简化了partition函数的实现难度。

快排框架

快排基本框架实现很好地体现了分治的思想:

void qsort(vector<int> &v, int begin, int end)
{
    if (begin >= end) return;
    int p = partition(v, begin, end);
    qsort(v, begin, p - 1);
    qsort(v, p + 1, end);
}

partition算法实现

单边扫描

快排的核心和难度在于partition算法,这里参考算法导论第三版的实现。基本思想:将最右元素作为pivot,对整个子数组进行分割,分割后的数据满足在pivot左边的元素都小于等于(no greater than) pivot,pivot右边的元素都大于(greater than) pivot。话不多说,先看实现:

int partition(vector<int> &v, int begin, int end)
{
    int pivot = v[end], i = begin - 1;
    for (int j = begin; j <= end - 1; ++j)
    {
        if (v[j] <= pivot) swap(v[j], v[++i]);
    }
    swap(v[i + 1], v[end]);
    return i + 1;
}

qsort partition illustration

代码很短但并不那么直观,引用算法导论的一个例子看起来比较容易理解,将子数组从自至右分为四个区域:小于等于pivot的,大于pivot的,未扫描到的,pivot本身即最右端元素。四个指针:

变量 意义
i i(含)左边的元素都小于等于pivot
j 当前扫描到的元素
p begin
r end

i和j两个指针将[begin, end - 1]区间分成了三个部分(位置j待定):

区间 意义
[begin, i] 元素都小于等于pivot
[i + 1, j) 元素都大于pivot
(j, end - 1) 待扫描

此算法巧妙地处理了最终pivot所在位置的问题,即i + 1。并且,在扫描到需要交换的元素(即<=pivot)时,比如状态(e),交换i和j位置的元素,可以想像成pivot两侧区域自然增长的状态。

上述实现是《算法导论》的原始算法,实际上将i初始为begin更好些(++i改为i++):

int partition(vector<int> &v, int begin, int end)
{
    int pivot = v[end], i = begin;
    for (int j = begin; j <= end - 1; ++j)
    {
        if (v[j] <= pivot) swap(v[j], v[i++]);
    }
    swap(v[i], v[end]);
    return i;
}

双边扫描

在实现算法导论的算法之前,手写了一个双边扫描的代码,debug了很长时间,原因在于最终pivot值的边界条件很难处理。暂时没有想到更好的实现,此实现读起来有点难懂,不推荐。

int partition(vector<int> &v, int begin, int end)
{
    int pivot = v[end], l = begin, r = end - 1;
    while (l <= r)
    {
        while(v[l] < pivot && l <= r) ++l;
        while(v[r] >= pivot && l <= r) --r;
        if (l >= r) break;
        swap(v[l], v[r]);
    }
    swap(v[l], v[end]);
    return l;
}

随机选择pivot

此处实现并未在partition函数中引入随机选择pivot元素,引入的话可在parition函数头部添加两行:

int r = begin + rand() % (end - begin);
swap(v[r], v[end]);

完整代码

完整代码如下,测试排序一个含重复元素的数组的全排列:

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

int partition(vector<int> &v, int begin, int end)
{
    // int r = begin + rand() % (end - begin);
    // swap(v[r], v[end]);
    int pivot = v[end], i = begin;
    for (int j = begin; j <= end - 1; ++j)
    {
        if (v[j] <= pivot) swap(v[j], v[i++]);
    }
    swap(v[i], v[end]);
    return i;
}

void qsort(vector<int> &v, int begin, int end)
{
    if (begin >= end) return;
    int p = partition(v, begin, end);
    qsort(v, begin, p - 1);
    qsort(v, p + 1, end);
}

void qsort(vector<int> &v)
{
    qsort(v, 0, v.size() - 1);
}

void print(vector<int> &v)
{
    for (int i = 0; i < v.size(); ++i)
    {
        cout<<v[i]<<" ";
    }
    cout<<endl;
}

bool check(vector<int> &v)
{
    for (int i = 0; i < v.size() - 1; ++i)
    {
        if (v[i] > v[i + 1]) return false;
    }
    return true;
}

int main()
{
    vector<int> v{1, 2, 2, 3, 3, 3, 4, 5};

    int c = 0;
    while(next_permutation(v.begin(), v.end()))
    {
        ++c;
        vector<int> a(v);
        qsort(a);

        if (check(a) == false)
        {
            cout<<"Wrong: "<<endl;
            print(v);
            cout<<"-->"<<endl;
            print(a);
            break;
        }
    }
    cout<<"Tested "<<c<<" cases"<<endl;
}