215. Kth Largest Element in an Array
尋找第 k 大的數字,注意是排序後的第 k 個最大元素,而不是第 k 個不同的元素。解題上是經典的使用 priority_queue 的題目,但對於這個問題,還可以使用 quick select ,也是個蠻好的解法。
思路
假設 n 代表 array 的長度,而 k 就是要取第 k 個最大元素:
直接排序(Sorting)
時間複雜度是: O(nlogn) 此為最簡單的 brute force 解,就是先將陣列排序,再取出第 k 個元素。由於在同一個 array 排序,故空間複雜度為 O(1)。
Heap 資料結構
Max Heap
O(klogn)- 先把原 array 使用
heapify建成 Heap,這個操作的時間複雜度是O(n)。由於預設是 min-heap,所以可用負號技巧轉成 max-heap - 再取出 k 次堆頂元素,因為每次取出最大值都要做一次
O(logn)的調整,所以總共是O(klogn)如果不是用 heapify,而是把 n 個元素一個一個 push 進 heap,則每次 push:
O(logn), n 次就是O(nlogn) - 因為要另外空間儲存原始 array 資料,故空間複雜度為
O(n)
- 先把原 array 使用
Min Heap
O(nlogk)- 維護一個大小為 k 的 min-heap,故空間複雜度為
O(k) - 如果新元素大於或等於堆頂元素,就加入min-heap,數亮超過時要調整 heap 的大小
- 最後堆頂元素就是第 k 大
因為我們每次加入堆中的,都是較大的元素,所以最後 Min Heap 的 top,雖然是 Heap 裡面最小的元素,但它已經大於 array 中全部其他沒留在 heap 裡的元素:換句話說,它就是整個 array 中的第 k 大元素。
- 維護一個大小為 k 的 min-heap,故空間複雜度為
解答
先給 Heap 的解法:
Max Heap
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
nums = [-x for x in nums]
heapq.heapify(nums)
ans = 0
for _ in range(k):
ans = -heapq.heappop(nums)
return ans
heapq 預設是 min heap,技巧上可以先把 nums 的值都加上負號,這樣概念上就是轉為 max heap 了,最後答案記得再把負號轉回去。
Min Heap
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
heap = []
for x in nums:
if len(heap) < k:
heapq.heappush(heap, x)
elif x > heap[0]:
heapq.heapreplace(heap, x)
return heap[0]
當 heap 還沒滿 k 個時,都要把 nums[i] 放入,等到 heap 滿了之後再做比較。這裏就是使用 heapreplace 的好時機。
Quick Select Algorithm
比較特別的是,這題主要是考察 Quick Select ,可以把時間複雜度平均下降到 O(n),當然最壞的情況會是 O(n^2),會用到了快速排序 Quick Sort 的思想。
核心是每次都要先找一個「中樞點 Pivot」,然後遍歷其他所有的數字,像這道題需要從大往小排:
- 把大於中樞點的數字放到左半邊
- 把小於中樞點的放在右半邊
這樣中樞點是整個數組中第幾大的數字就確定了。如果:
- 正好是
k-1,那麼直接返回該位置上的數字 - 大於
k-1,說明要求的數字在左半部分,更新右邊界,再求新的中樞點位置 - 反之,則更新右半部分,求中樞點的位置;
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
n = len(nums)
target = n - k
left_index = 0
right_index = n - 1
while left_index <= right_index:
index = self.partition(nums, left_index, right_index)
if index == target:
return nums[target]
elif index > target:
right_index = index - 1
else:
left_index = index + 1
def partition(self, nums: List[int], left_index: int, right_index: int) -> int:
pivot_index = random.randint(left_index, right_index)
pivot = nums[pivot_index]
self.swap(nums, pivot_index, right_index)
i = left_index
for j in range(left_index, right_index):
if nums[j] <= pivot:
self.swap(nums, i, j)
i += 1
self.swap(nums, i, right_index)
return i
def swap(self, nums: List[int], i:int, j:int):
nums[i], nums[j] = nums[j], nums[i]
target = n - k 為什麼呢 ? 可以思考一下,nums 是小到大排序,那麼:
- 第 1 大 index 是 :
n - 1 - 第 2 大是 index 是 :
n - 2
以此類推,第 k 大 index 就是 : n - k。接下來思考一下, index > target 代表什麼意思,可以先想成是下圖:
0 1 2 3 4 5 6
|----|----|----|----|----|----|
target index
因為 index > target,如圖 index 一定落在 target 的右邊,代表答案不可能在右側,答案一定在左半邊,所以把 right_index 縮小成 index - 1
雖然 quick selection 解法很不錯,但是現在這種一般寫法是會 Time Limit Exceeded 的,原因是有 test_case 是 [1,2,3,4,5,1,....(too many 1),-5,-4,-3,-2,-1 ], array 裡有很多和 pivot 一樣的值時,會導致 partition 每次只縮小一點點,時間退化到 O(n^2)。
這時候就要稍微進階一點使用 3-way partition 技巧,partition function 修改成會回傳一個區間 [lt, gt],意義上其區間內值都是一樣的,這樣遇到很多重複值時,可以一次全部排除:
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
n = len(nums)
target = n - k
left_index = 0
right_index = n - 1
while left_index <= right_index:
lt, gt = self.partition(nums, left_index, right_index)
if target < lt:
right_index = lt - 1
elif target > gt:
left_index = gt + 1
else:
return nums[target]
def partition(self, nums: List[int], left_index: int, right_index: int) -> tuple[int, int]:
pivot_index = random.randint(left_index, right_index)
pivot = nums[pivot_index]
lt = left_index
i = left_index
gt = right_index
while i <= gt:
if nums[i] < pivot:
self.swap(nums, lt, i)
lt += 1
i += 1
elif nums[i] > pivot:
self.swap(nums, i, gt)
gt -= 1
else:
i += 1
return lt, gt
def swap(self, nums: List[int], i: int, j: int) -> None:
nums[i], nums[j] = nums[j], nums[i]
修正的重點是:
- 原本 partition 只回傳一個 index,現在改成回傳 tuple
(lt, gt) nums[lt:gt+1]這一段內容都是 pivot- 如果 target 落在這段裡,就直接找到答案
時間空間複雜度
假設 nums 總共有 n 個資料元素
時間複雜度: O(n)
quick select 跟 quick sort 最大差別是 quick sort 每次 partition 後,左邊要排且右邊也要排 ; 但 quick select 每次 partition 後,只會選則一邊做,所以平均為: O(n)
quick select 和 quick sort,兩個是不一樣的
空間複雜度:O(1)
若用遞迴,則平均空間 O(logn),最差是 O(n)
