`
yangdong
  • 浏览: 66885 次
  • 性别: Icon_minigender_1
  • 来自: 杭州
社区版块
存档分类
最新评论

TimSort 中的核心过程

 
阅读更多
    TimSort 是 Python 中 list.sort 的默认实现。Java 7 也将非原始类型列表的排序实现替换成了 TimSort。网上关于 TimSort 是什么,性能特点分析的文章不少,但是介绍它的具体实现步骤的文章很少。这里有一篇:Understanding timsort, Part 1: Adaptive Mergesort,用 C 作为示例代码。

基于这个文章的介绍,我用 python 实现一遍 TimSort,并说一下其中的关键步骤。因为原文只讲解了 TimSort 中最基本最重要的部分,所以本文也没有超过这个范围。本文不是对 TimSort 的分析,只是介绍一下其基本实现。

TimSort 概览
    TimSort 是一个归并排序做了大量优化的版本。对归并排序排在已经反向排好序的输入时表现O(n^2)的特点做了特别优化。对已经正向排好序的输入减少回溯。对两种情况混合(一会升序,一会降序)的输入处理比较好。

TimSort 核心过程
    假定,我们的 TimSort 是进行升序排序。TimSort 为了减少对升序部分的回溯和对降序部分的性能倒退,将输入按其升序和降序特点进行了分区。排序的输入的单位不是一个个单独的数字了,而一个个的分区。其中每一个分区我们叫一个“run“。针对这个 run 序列,每次我们拿一个 run 出来进行归并。每次归并会将两个 runs 合并成一个 run。归并的结果保存到 "run_stack" 上。如果我们觉得有必要归并了,那么进行归并,直到消耗掉所有的 runs。这时将 run_stack 上剩余的 runs 归并到只剩一个 run 为止。这时这个仅剩的 run 即为我们需要的排好序的结果。

def timsort(arr):
    arr = arr or []
    if len(arr) <= 0: return []
    runs = _partition_to_runs(arr)
    run_stack = []
    for run in runs:
        run_stack.append(run)
        while _should_merge(run_stack):
            _merge_stack(run_stack)
    while len(run_stack) > 1:
        _merge_stack(run_stack)
    return run_stack[0]


这里“觉得有必要”这句话很模糊,到底什么时候有必要后面会给出定义。

如何分区
    为了在已经按升序排好序的输入面前减少回溯,我们把输入当中已经有序的这些段分组,使得它们成为一个基本单元,这样我们就不必在这个基本单元内部浪费时间进行回溯了。比如[1, 2, 3, 2] 进行分区后就变成了 [[1, 2, 3], [2]]。

为了在已经按降序排好序的输入面前避免归并排序倒退成 O(n^2),我们把输入当中降序的部分翻转成升序,也作为一个单元。比如 [3, 2, 1, 3] 进行分区后就变成了 [[1, 2, 3], [3]]。

def _partition_to_runs(arr):
    partitioned_up_to = 0
    while partitioned_up_to < len(arr):
        if not len(arr) - partitioned_up_to:
            return
        if len(arr) - partitioned_up_to == 1:
            part = list(arr[-1:])
            partitioned_up_to += 1
            yield part
        else:
            if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: # 这里必须是严格降序
                next_pos = _find_desc_boundary(arr, partitioned_up_to)
                _reverse(arr, partitioned_up_to, next_pos)
            else:
                next_pos = _find_asc_boundary(arr, partitioned_up_to)

            part = arr[partitioned_up_to:next_pos]
            partitioned_up_to = next_pos
            yield part

def _find_desc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] > arr[start+1]: # 这里必须是严格降序
        return _find_desc_boundary(arr, start + 1)
    else:
        return start + 1

def _reverse(arr, start=0, end=None):
    # 正常的翻转函数,实现省略

def _find_asc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] <= arr[start+1]:
        return _find_asc_boundary(arr, start + 1)
    else:
        return start + 1


这里注意降序的部分必须是“严格”降序才能进行翻转。因为 TimSort 的一个重要目标是保持稳定性(stability)。如果在 >= 的情况下进行翻转这个算法就不再是 stable sorting algorithm 了。

逆向分解
    传统的归并排序是通过递归,用函数栈把每次 "divide" 的结果保存下来的。divide 的最终结果是一个个的基本单元-单个数字。但是我们看到 TimSort 把这个过程反过来了。我们经过一次分区,已经拿到了了基本单元列表,只不过这次基本单元是一串数字。所以我们只能自己手工将将基本单元列表进行合并。

如何合并
    那么何时进行合并?合并的策略是要在 "run_stack" 上维护一个不变式。当这个不变式被打破时即进行合并。传统的归并排序通过二分法可以保证函数栈的深度为 log(n)。我们也模拟这个策略,也让 run_stack 的长度不超过 log(n)。假如 runN 先入栈,runN+1 紧随其后入栈。那么就要求 runN 的长度要是 runN+1 长度的 2 倍。所以归并的条件是:如果 runN 的长度 < (runN+1 的长度 * 2) 即进行归并。

# 因为我们每次新添 run 进入 run_stack 时都判断是否需要归并,
# 并且在每次归并之后还要进一步确保 run_stack 是满足不变式的,
# 所以这里只判断栈头的两个 run 就够了。
def _should_merge(run_stack):
    if len(run_stack) < 2:
        return False
    return len(run_stack[-2]) < 2*len(run_stack[-1])

def _merge(ls1, ls2):
    # 正常的归并函数,实现省略

def _merge_stack(run_stack):
    head = run_stack.pop()
    next = run_stack.pop()
    new_run = _merge(next, head)
    run_stack.append(new_run)


跟分区的情况类似,这里在归并的时候也要用 stable merge。

插入排序优化
    到上面的步骤为止,程序已经可以正确地排序了。但是我们知道插入排序在输入元素数小于一个阀值的时候相比其它排序会更快,所以很多排序算法在 divide 这一步进行到只剩不到这个阀值个数的元素的时候会改用插入排序(比如 JDK6 的快排,参考这里),所以我们也要做这个优化。

在分区的时候,如果我们观察到新产生出来的 run 的长度小于适用于插入排序的阀值,我们就用插入排序把这个 run 的长度扩充到这个阀值。

def _partition_to_runs(arr):
    partitioned_up_to = 0
    while partitioned_up_to < len(arr):
        if not len(arr) - partitioned_up_to:
            return
        if len(arr) - partitioned_up_to == 1:
            part = list(arr[-1:])
            partitioned_up_to += 1
            yield part
        else:
            if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
                next_pos = _find_desc_boundary(arr, partitioned_up_to)
                _reverse(arr, partitioned_up_to, next_pos)
            else:
                next_pos = _find_asc_boundary(arr, partitioned_up_to)

            # 只加了这一句话
            next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)

            part = arr[partitioned_up_to:next_pos]
            partitioned_up_to = next_pos
            yield part

def _insertion_sort(arr, start, end):
    # 标准插入排序实现

def _do_insertion_sort_optimization(arr, start, end):
    length = end - start
    if length < INSERTION_SORT_THRESHOLD:
        end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
        _insertion_sort(arr, start, end)
    return end


这里我们只加一句话就够了。剩余的就是标准的插入排序实现。

与原文代码的差异
    TimSort 最多使用 O(n) 临时内存空间。由于原文是 C 的代码,为了减少 malloc 的次数而一次性分配了 O(n) 的数组空间。我们这里因为是用 python,也这么做会显得很怪异。所以内存是在每次归并的时候一点点分配的。

TimSort 的实现逻辑上可以看成分区和归并两部分。但由于 C 不支持协程,而 python 通过 generator 部分支持协程。所以为了提高可读性,分区的部分我是用 generator 的方式做的。在代码上与归并的部分完全分离。而原文为了达到 lazy 的目的,是一边分区一边归并的。

完整的实现和测试代码
# -*- coding: utf-8 -*-
import functools
from unittest import TestCase

INSERTION_SORT_THRESHOLD = 6

def _find_desc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] > arr[start+1]:
        return _find_desc_boundary(arr, start + 1)
    else:
        return start + 1

def _reverse(arr, start=0, end=None):
    if end is None:
        end = len(arr)
    for i in range(start, start + (end-start)//2):
        opposite = end - i - 1
        arr[i], arr[opposite] = arr[opposite], arr[i]

def _find_asc_boundary(arr, start):
    if start >= len(arr) - 1:
        return start + 1
    if arr[start] <= arr[start+1]:
        return _find_asc_boundary(arr, start + 1)
    else:
        return start + 1

def _insertion_sort(arr, start, end):
    if end - start <= 1:
        return
    for i in range(start, end):
        v = arr[i]
        j = i - 1
        while j>=0 and arr[j] > v:
            arr[j+1] = arr[j]
            j -= 1
        arr[j+1] = v

def _do_insertion_sort_optimization(arr, start, end):
    length = end - start
    if length < INSERTION_SORT_THRESHOLD:
        end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
        _insertion_sort(arr, start, end)
    return end

def _partition_to_runs(arr):
    partitioned_up_to = 0
    while partitioned_up_to < len(arr):
        if not len(arr) - partitioned_up_to:
            return
        if len(arr) - partitioned_up_to == 1:
            part = list(arr[-1:])
            partitioned_up_to += 1
            yield part
        else:
            if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
                next_pos = _find_desc_boundary(arr, partitioned_up_to)
                _reverse(arr, partitioned_up_to, next_pos)
            else:
                next_pos = _find_asc_boundary(arr, partitioned_up_to)

            next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)

            part = arr[partitioned_up_to:next_pos]
            partitioned_up_to = next_pos
            yield part

def _should_merge(run_stack):
    if len(run_stack) < 2:
        return False
    return len(run_stack[-2]) < 2*len(run_stack[-1])

def _merge(ls1, ls2, merge_storage=None):
    ret = merge_storage or []
    i1 = 0
    i2 = 0
    while i1 < len(ls1) and i2 < len(ls2):
        a = ls1[i1]
        b = ls2[i2]
        if a <= b:
            ret.append(a)
            i1 += 1
        else:
            ret.append(b)
            i2 += 1
    ret += ls1[i1:]
    ret += ls2[i2:]
    return ret

def _merge_stack(run_stack, merge_storage=None):
    head = run_stack.pop()
    next = run_stack.pop()
    new_run = _merge(next, head, merge_storage=merge_storage)
    run_stack.append(new_run)

def timsort(arr):
    arr = arr or []
    if len(arr) <= 0: return []
    runs = _partition_to_runs(arr)
    run_stack = []
    for run in runs:
        run_stack.append(run)
        while _should_merge(run_stack):
            _merge_stack(run_stack)
    while len(run_stack) > 1:
        _merge_stack(run_stack)
    return run_stack[0]

class Test(TestCase):
    class Elem:
        seq_no = 0
        def __init__(self, n):
            Elem = Test.Elem
            self.n = n
            self.seq_no = Elem.seq_no
            Elem.seq_no += 1

        def __lt__(self, other):
            return self.n < other.n

        def __str__(self):
            return "E" + str(self.n) + "S" + str(self.seq_no)
    Elem = functools.total_ordering(Elem)

    def setUp(self):
        Test.Elem.seq_no = 0

    def test_reverse(self):
        arr = [3, 2, 1, 4, 7, 5, 6]
        _reverse(arr)
        self.assertEquals(arr, [6, 5, 7, 4, 1, 2, 3])

        arr = [3, 2, 1]
        _reverse(arr)
        self.assertEquals(arr, [1, 2, 3])

    def test_find_asc_boundary(self):
        arr = [1, 2, 3, 3, 2]
        self.assertEqual(_find_asc_boundary(arr, 0), 4)

        arr = [1, 2, 3, 3]
        self.assertEqual(_find_asc_boundary(arr, 0), 4)

    def test_find_desc_boundary(self):
        arr = [3, 2, 1]
        self.assertEqual(_find_desc_boundary(arr, 0), 3)

        arr = [3, 2, 1, 1]
        self.assertEqual(_find_desc_boundary(arr, 0), 3)

    def test_merge_stack(self):
        arr1 = [1, 2, 3]
        arr2 = [2, 3, 4]
        stack = [arr1, arr2]
        _merge_stack(stack)
        self.assertEqual(stack, [[1, 2, 2, 3, 3, 4]])

    def test_merge_stability(self):
        Elem = Test.Elem
        arr1 = map(lambda e: Elem(e), [1, 2, 3])
        arr2 = map(lambda e: Elem(e), [2, 3, 4])
        stack = [arr1, arr2]
        _merge_stack(stack)
        self.assertEqual(map(lambda lst: map(str, lst), stack), [['E1S0', 'E2S1', 'E2S3', 'E3S2', 'E3S4', 'E4S5']])

    def test_timsort(self):
        Elem = Test.Elem
        arr = map(lambda e: Elem(e), [3, 1, 2, 2, 7, 5])
        ret = timsort(arr)
        self.assertEquals(map(str, ret), ['E1S1', 'E2S2', 'E2S3', 'E3S0', 'E5S5', 'E7S4'])

        self.assertEqual(timsort([]), [])
        self.assertEqual(timsort(None), [])
分享到:
评论

相关推荐

    cpp-TimSort:timsort的C ++实现

    4. **最小堆(MinHeap)**:在合并过程中,为了高效地找到最小元素,C++实现可能使用最小堆数据结构。这使得每次从堆顶取出最小元素变得高效。 5. **内存效率**:尽管归并排序通常需要额外的空间,但TimSort通过...

    java-timsort-bug:如何破坏 TimSort 以及如何修复它

    TimSort的核心思想是将输入数组分为多个“运行”(runs),这些运行是已排序或接近排序的子序列。然后,它通过归并这些运行来完成排序。TimSort的时间复杂度为O(n log n),并且在最佳情况下(输入已经部分排序)可以...

    深入探究TimSort对归并排序算法的优化及Java实现

    3. **交换数组角色**:在归并过程中,为了避免频繁复制元素到辅助数组,可以交换原始数组和辅助数组的角色,这样每次只需要将未处理的部分复制到辅助数组,简化了合并操作。 4. **反向序列处理**:对于反向序列,...

    Java Arrays.sort和Collections.sort排序实现原理解析

    4. **归并**:在找到所有运行段后,`ComparableTimSort`类会负责整个排序过程,包括运行段的合并,这是TimSort的核心部分。 5. **最小运行长度**:`minRunLength()`计算出最小的运行段长度,以决定何时直接使用插入...

    排序算法_java

    在编程领域,排序算法是计算机科学中的核心概念,尤其是在数据处理和算法效率分析方面。Java作为广泛应用的编程语言,提供了实现各种排序算法的平台。本文将深入探讨标题和描述中提及的五种经典排序算法——插入排序...

    数据结构课件

    - 动态查找表:数据可以在查找过程中动态添加或删除,如二分查找树和B树。 这些课件涵盖了数据结构的关键概念,从基本的线性结构(如栈和队列)到复杂的非线性结构(如树和图),再到高级的排序和查找算法。通过...

    Java常用排序算法

    插入排序是一种简单直观的排序算法,它的工作原理类似于我们日常生活中整理扑克牌的过程。算法分为两个阶段:遍历待排序的数组,将每个元素插入到已排序部分的正确位置。在Java中,可以使用两层循环实现插入排序,...

    Java排序算法详解.rar

    在编程领域,排序算法是计算机科学中的核心概念,尤其是在Java这样的高级编程语言...在深入学习这些算法的过程中,还可以探索更多关于算法分析、数据结构和复杂度理论的知识,这对于任何Java开发者来说都是宝贵的财富。

    C#排序算法详解.rar

    在编程领域,排序算法是计算机科学中的核心概念,尤其是在C#这样的高级编程语言中。排序算法是用来组织数据,使其按照特定顺序排列的程序设计技术。本资料“C#排序算法详解”聚焦于如何在C#中实现各种排序算法,帮助...

    python五大排序-五大排序算法(Python),算法数据结构

    在计算机科学中,排序算法是处理数据序列,使...在实际编程中,Python内置的`sorted()`函数和列表的`sort()`方法已经实现了高效的Timsort算法,它是基于归并排序和插入排序的一种优化算法,可以适应各种数据输入模式。

    Java面向对象思想的排序方法

    在Java编程语言中,面向对象思想是核心设计原则之一,它包括封装、继承和多态三个主要特性。本文将深入探讨如何运用面向对象思想来...在开发过程中,理解并熟练运用这些面向对象的概念,将有助于提高代码质量和效率。

    java 版数据结构教程

    时间复杂度描述了算法执行时间与输入规模的关系,而空间复杂度则表示算法运行过程中所需的内存。了解这些可以帮助我们选择合适的数据结构和优化算法。 例如,排序算法有冒泡排序、选择排序、插入排序、快速排序、...

    Python语言程序设计课教程 中英双语课件 Python中的1ADS算法-6-排序算法 共118页.pptx

    排序算法是计算机科学中的核心概念,特别是在编程语言如Python中,它们对于数据处理和分析至关重要。1ADS算法(可能是指"1st Approach to Data Structures",即数据结构的第一种方法)在本教程中可能涵盖了对排序...

    第二章 ArrayList源码解析1

    transient表示该字段不会在序列化过程中被保存。尽管如此,ArrayList还是实现了Serializable接口,因此它具有序列化能力。ArrayList通过自定义的`writeObject`和`readObject`方法来进行序列化和反序列化,而不是依赖...

    常用排序算法介绍_示例程序|排序算法_程序.rar

    6. **希尔排序**:改进的插入排序,通过比较距离较远的元素来减少排序过程中元素的移动次数。 7. **堆排序**:构造一个大顶堆或小顶堆,然后将堆顶元素与末尾元素交换,调整堆,重复此过程。 8. **计数排序**:...

    Android例子源码非第三方实现根据字母排序的城市列表

    - 开发过程中,单元测试和集成测试是必不可少的,确保排序逻辑的正确性。 - 使用Android Studio的模拟器或真实设备进行调试,检查列表的显示和交互是否符合预期。 通过这个例子,开发者不仅可以学习到如何在...

    c#各种排序算法动态图形演示-数据结构经典算法动态演示

    在IT领域,排序算法是计算机科学中的核心概念,特别是在数据结构和算法分析中。C#是一种广泛用于开发桌面、Web和移动应用的编程语言,它提供了丰富的库支持各种排序算法的实现。本资源"c#各种排序算法动态图形演示-...

    java中常见排序的所有demo,包含冒泡,选择,插入,快速,堆,希尔,二叉树等

    插入排序在实现上,通常采用in-place排序(即只需用到O(1)的额外空间的排序),因而在从后向前扫描过程中,需要反复把已排序元素逐步向后挪位,为最新元素提供插入空间。 4. **快速排序(Quick Sort)**:快速排序...

    java四大排序算法总结.zip

    在编程领域,排序算法是计算机科学中的核心概念,尤其是在数据处理和算法分析中。Java作为广泛应用的编程语言,其对排序算法的支持使得开发者能够高效地处理大量数据。本篇文章将详细探讨Java中冒泡排序、选择排序、...

    google collection

    - **TimSort**:Java 7中使用的排序算法将采用Tim Peters在Python中实现的TimSort算法。 - **特点**:稳定、自适应、迭代式合并排序,特别适合处理部分已排序的数据集。 - **优势**:在处理部分已排序数据时比...

Global site tag (gtag.js) - Google Analytics