删除排序数组中的重复项

原文链接

给定一个排序数组,你需要在原地删除重复出现的元素,使得每个元素只出现一次,返回移除后数组的新长度。

不要使用额外的数组空间,你必须在原地修改输入数组并在使用 O(1) 额外空间的条件下完成。

示例1:

给定数组 nums = [1,1,2], 

函数应该返回新的长度 2, 并且原数组 nums 的前两个元素被修改为 1, 2。 

你不需要考虑数组中超出新长度后面的元素。

示例 2:

给定 nums = [0,0,1,1,1,2,2,3,3,4],

函数应该返回新的长度 5, 并且原数组 nums 的前五个元素被修改为 0, 1, 2, 3, 4。

你不需要考虑数组中超出新长度后面的元素。

说明:

为什么返回数值是整数,但输出的答案是数组呢?

请注意,输入数组是以“引用”方式传递的,这意味着在函数里修改输入数组对于调用者是可见的。

你可以想象内部操作如下:

// nums 是以“引用”方式传递的。也就是说,不对实参做任何拷贝
int len = removeDuplicates(nums);

// 在函数里修改输入数组对于调用者是可见的。
// 根据你的函数返回的长度, 它会打印出数组中该长度范围内的所有元素。
for (int i = 0; i < len; i++) {
    print(nums[i]);
}

初步解答

从题目来看。首先题目给定条件为排序数组。其次,只需要返回最后排序后的长度。

题目中的已经给了相应的提示,就是说在不使用额外空间的情况下使用一次遍历。

那么给我的思路首先就是想到,在一次遍历的时候得到重复值与新排序值的一个交换位置。这样就能达到在遍历完成之后得到前N位为去除重复的排序 数组。

根据这个题目,首先写出第一轮的代码:

In [1]:
class Solution:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        length = len(nums)
        if length == 0:
            return 0
        i = 1
        max_id = 0
        max_val = nums[max_id]
        while i < length:
            if max_val == nums[i]:
                i += 1
                continue
            elif nums[i] > max_val:
                max_val = nums[i]
                max_id += 1
                nums[max_id], nums[i] = nums[i], nums[max_id]
            i += 1
        return max_id + 1

这里的思路是,首先考虑边界条件(一个完整的算法应当考虑到边界条件)。在数组长度为0的情况下应当返回0,在长度为1的情况下应当返回1。

然后第一步,首先假设取到第一个值为最大值 max_val,其下标为 max_id

如果下一个值和max_val相等,则我们认为max_id + 1为我们需要替换判断的目标, 继续遍历;

直到下一个值大于max_val。则我们将max_id + 1 位置的元素与该值交换,此时max_id = max_id + 1

最后遍历完成,max_id 为排序后的长度,则我们需要返回 max_id + 1 取到正确的排序后的数组。

自定义几个测试用例:

In [2]:
a = [0, 0, 1, 1, 2, 2, 3, 4, 5]
b = []
c = [1]
d = [1, 1, 1]
In [3]:
a[:Solution().removeDuplicates(a)]
Out[3]:
[0, 1, 2, 3, 4, 5]
In [4]:
b[:Solution().removeDuplicates(b)]
Out[4]:
[]
In [5]:
c[:Solution().removeDuplicates(c)]
Out[5]:
[1]
In [6]:
d[:Solution().removeDuplicates(d)]
Out[6]:
[1]

看起来没什么问题的样子,把代码提交到 LeetCode 上也验证通过了。查看了一下效率,大概在100ms 左右, 在python提交答案里大概是中等水平。

我们自己测一下效率:

In [7]:
%%timeit 
a[:Solution().removeDuplicates(a)]
2.58 µs ± 13.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

再来一个针对大数组的调优, 先构建一个百万级的大数组。

In [8]:
import random

big_list = []
for i in range(0, 1000000):
    j = random.randint(0, 10)
    big_list += [i] * j
In [9]:
len(big_list)
Out[9]:
5003009

为了避免排序后的数组导致对算法排序的影响,这里需要copy一份完全一样的数组。

In [10]:
import copy

big_list_one = copy.copy(big_list)
In [11]:
%%timeit
big_list_one[:Solution().removeDuplicates(big_list_one)]
995 ms ± 8.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看到对一个百万级的数组排序大概需要1秒左右。作为一个基准,接下来对算法进行调优。

算法调优

在调优思路上,首先遍历循环上一次循环,没毛病。那么可以调优的地方,可以从赋值、变量、判断这几个思路出发。

首先第一个思路就是,交换赋值的必要性。因为这里我们无需保留原始数据的所有数组,所以我们可以直接把 max_id + 1 的位置替换为num[i], 从而少了一个赋值操作。

对上文算法第20行进行修改。修改为 nums[max_id] = nums[i]

修改过后的代码如下:

In [12]:
class Solution:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        length = len(nums)
        if length == 0:
            return 0
        i = 1
        max_id = 0
        max_val = nums[max_id]
        while i < length:
            if max_val == nums[i]:
                i += 1
                continue
            elif nums[i] > max_val:
                max_val = nums[i]
                max_id += 1
                nums[max_id] = nums[i]
            i += 1
        return max_id + 1

big_list_two = copy.copy(big_list)

测试一下同样对大数组的执行效率:

In [13]:
%%timeit
big_list_two[:Solution().removeDuplicates(big_list_two)]
949 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看到这次改动带来了差不多5%的效率提高。继续再来调优一下算法。

避免过度的优化

其次,因为条件已知是排序数组, 所以第17行的判断也是没有必要的。修改为 else:

最后,可以看到 max_val 的作用只是为了比较最大值,为了节省一个赋值操作。可以将 max_val 修改为 num[max_id]

直觉上这会给我们带来性能提升,但是很可惜,并不是这样子的。:(

In [14]:
class Solution:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        length = len(nums)
        if length == 0:
            return 0
        i = 1
        max_id = 0
        max_val = nums[max_id]
        while i < length:
            if max_val == nums[i]:
                i += 1
                continue
            else: # 修改为 else
                max_val = nums[i]
                max_id += 1
                nums[max_id] = nums[i]
            i += 1
        return max_id + 1
    
big_list_three = copy.copy(big_list)
In [15]:
%%timeit
big_list_three[:Solution().removeDuplicates(big_list_three)]
1.05 s ± 136 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [16]:
class Solution:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        length = len(nums)
        if length == 0:
            return 0
        i = 1
        max_id = 0
        while i < length:
            if nums[max_id] == nums[i]: # 修改为直接取值
                i += 1
                continue
            elif nums[i] > nums[max_id]: # 修改为直接取值
                max_id += 1
                nums[max_id] = nums[i]
            i += 1
        return max_id + 1

big_list_four = copy.copy(big_list)
In [17]:
%%timeit
big_list_four[:Solution().removeDuplicates(big_list_four)]
1.11 s ± 3.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看到。两种改动分别给我们带来 10% 和 16% 的性能下降。

再来看看我们的两个改动,这里第一个不能理解为什么会变慢,难道else更消耗时间?

第二个在优化了一个赋值的同时,却让我们增加了更多的内存寻址时间(从nums取到max_id值)。

但是,这两个地方都没有达到我们想要的效果,并且还造成了效率降低。

这充分说明了一个问题,就是说在思考算法优化的时候,必须保证我们的优化是针对实际问题而优化的,不要自己瞎下定论。

利用 line_profiler 进行分析优化

现在我们利用工具来分析一下以下这段代码的运行。

In [18]:
!cat solution.py
import random

big_list = []
for i in range(0, 1000000):
    j = random.randint(0, 10)
    big_list += [i] * j

print("len big_list ", len(big_list))


class Solution:
    @profile
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        length = len(nums)
        if length == 0:
            return 0
        i = 1
        max_id = 0
        max_val = nums[max_id]
        while i < length:
            if max_val == nums[i]:
                i += 1
                continue
            elif nums[i] > max_val:
                max_val = nums[i]
                max_id += 1
                nums[max_id] = nums[i]
            i += 1
        return max_id + 1


print(len(big_list[:Solution().removeDuplicates(big_list)]))
In [19]:
!kernprof -l -v solution.py
len big_list  4993573
908353
Wrote profile results to solution.py.lprof
Timer unit: 1e-06 s

Total time: 9.55295 s
File: solution.py
Function: removeDuplicates at line 12

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    12                                               @profile
    13                                               def removeDuplicates(self, nums):
    14                                                   """
    15                                                   :type nums: List[int]
    16                                                   :rtype: int
    17                                                   """
    18         1          9.0      9.0      0.0          length = len(nums)
    19         1          1.0      1.0      0.0          if length == 0:
    20                                                       return 0
    21         1          0.0      0.0      0.0          i = 1
    22         1          0.0      0.0      0.0          max_id = 0
    23         1          0.0      0.0      0.0          max_val = nums[max_id]
    24   4993573    2036572.0      0.4     21.3          while i < length:
    25   4993572    2254330.0      0.5     23.6              if max_val == nums[i]:
    26   4085220    1736952.0      0.4     18.2                  i += 1
    27   4085220    1479735.0      0.4     15.5                  continue
    28    908352     429401.0      0.5      4.5              elif nums[i] > max_val:
    29    908352     380742.0      0.4      4.0                  max_val = nums[i]
    30    908352     398103.0      0.4      4.2                  max_id += 1
    31    908352     437387.0      0.5      4.6                  nums[max_id] = nums[i]
    32    908352     399715.0      0.4      4.2              i += 1
    33         1          1.0      1.0      0.0          return max_id + 1

这里看了下,发现好像有一句本来就不是必要的 continue 被执行了 4085220 次。而这个是可以去掉的,一般情况下,较少的指令操作会带来更快的操作速度(一般情况下,根据指令的复杂程度决定执行效率)。

另外,这里 i+=1if ... else 的两个分支里都会执行到。所以我们可以尝试修改一下。改成如下样子:

while i < length:
    if nums[i] > max_val:
        max_id += 1
        max_val = nums[i]
        nums[max_id] = nums[i]
    i += 1

再次执行一下分析。

In [20]:
!kernprof -l -v solution.1.py
len big_list  4995163
908617
Wrote profile results to solution.1.py.lprof
Timer unit: 1e-06 s

Total time: 7.73485 s
File: solution.1.py
Function: removeDuplicates at line 12

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    12                                               @profile
    13                                               def removeDuplicates(self, nums):
    14                                                   """
    15                                                   :type nums: List[int]
    16                                                   :rtype: int
    17                                                   """
    18         1          8.0      8.0      0.0          length = len(nums)
    19         1          1.0      1.0      0.0          if length == 0:
    20                                                       return 0
    21         1          0.0      0.0      0.0          i = 1
    22         1          0.0      0.0      0.0          max_id = 0
    23         1          1.0      1.0      0.0          max_val = nums[max_id]
    24   4995163    2098343.0      0.4     27.1          while i < length:
    25   4995162    2301579.0      0.5     29.8              if nums[i] > max_val:
    26    908616     393738.0      0.4      5.1                  max_id += 1
    27    908616     375304.0      0.4      4.9                  max_val = nums[i]
    28    908616     427793.0      0.5      5.5                  nums[max_id] = nums[i]
    29   4995162    2138087.0      0.4     27.6              i += 1
    30         1          1.0      1.0      0.0          return max_id + 1

可以看到,现在我们只要必要的判断处执行了 908616 次。相比之前少了 81%。

现在我们尝试一下计算上面的耗时, 看看是否有达到我们的预期效果。

In [21]:
class Solution:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        length = len(nums)
        if length == 0:
            return 0
        i = 1
        max_id = 0
        max_val = nums[max_id]
        while i < length:
            if nums[i] > max_val:
                max_id += 1
                max_val = nums[i]
                nums[max_id] = nums[i]
            i += 1
        return max_id + 1

big_list_four = copy.copy(big_list)
In [22]:
%%timeit
big_list_four[:Solution().removeDuplicates(big_list_four)]
707 ms ± 2.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看到这次优化,和原来的算法相比带来了 30% 的性能提升。

最后把优化过的算法提交到 LeetCode 上。可以看到大概通过测试用例是 84ms 的样子。比之前有提升,不过好像 LeetCode 好像同一个算法的执行时间会有所不同,所以这个时间看看就好。:)

其他解

这里毕竟只是一个解。接下来分析一下 在 LeetCode 上的执行耗时最短的几个解。

解一

这个执行时间是64ms 。见如下代码:

In [24]:
class SolutionSecond:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        if not nums:
            return 0
        length = len(nums)
        if length == 1:
            return 1
        i = 0
        j = 1
        while j < length:
            if(nums[i] == nums[j]):
                j += 1
            else:
                i += 1
                nums[i] = nums[j]
                j += 1
        return i+1

big_list_five = copy.copy(big_list)

这个思路大概和我的差不多,不同的是多了几个操作。这里尝试执行一下, 看看效率如何:

In [25]:
%%timeit
big_list_five[:SolutionSecond().removeDuplicates(big_list_five)]
1.06 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看出来,在针对大数组的执行上,这个算法的效率有所降低,和上面的算法相比,多了几步操作。具体可以参看上文解释。

这里和我们原本的第一次的算法比较,差距并不是很大。可见赋值和判断等我们直观上会带来开销的操作,反而是无足轻重的。

解二

这个解是 LeetCode 执行时间最短的,只有60ms。我们来看看他的代码:

In [26]:
class SolutionFirst:
    def removeDuplicates(self, nums):
        """
        :type nums: List[int]
        :rtype: int
        """
        nums[:] = sorted(set(nums))
        return len(nums)

big_list_six = copy.copy(big_list)

emmmmm,这里代码很简单,让我们进行第一印象的分析。

首先这里通过一个数据结构转为了字典进行去重,要知道数组转为字典的复杂度可是 O(n), 可以预计到这个算法在大数组的前提下性能可能不是很好了;

其次这里使用了 sorted 的内置函数进行排序,尽管python内部的排序也是一个混合了各种排序算法的高性能排序函数,但是这本身已经增加了算法的复杂度了。

执行一下看看同样大小的大数组,它的执行时间:

In [27]:
%%timeit
big_list_six[:SolutionFirst().removeDuplicates(big_list_six)]
74.4 ms ± 6.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

what ??? 这里有点不敢执行,居然差了一个数量级。难道是我的数据有问题?重新检测一下。

In [28]:
len(big_list_six) # 这里打印出这么小是因为 原数组在函数中被改变大小了
Out[28]:
909265
In [29]:
len(big_list) # 但原数据是没有问题的
Out[29]:
5003009

这和我们的预期不太一样,理论上来说,这个算法的复杂度应该要大于我们上面的解法。但事实上内置函数的执行效率上比我们快了一个数量级。

还有几点存疑?

  • 和转为set相比,这个算法的性能消耗点在哪里?
  • 为什么将list转置为set的时候,效率执行很快?什么原因导致的效率这么高?
  • sorted 函数的排序效率?

这里暂且先记下来这两个问题,等待后续解答吧。如果你有更好的解释,也可以在下面留言告诉我。

Comments !