Find K Pairs with Smallest Sums
Among all pairs (a, b) from two sorted arrays, return the k with the smallest sum — heap over the implicit grid.
7 min read
k-way-merge heaps
Problem#
Two sorted ascending arrays nums1, nums2. Return the k pairs (nums1[i], nums2[j]) with smallest sums.
Examples#
nums1 = [1,7,11], nums2 = [2,4,6], k = 3→[[1,2],[1,4],[1,6]]
Constraints#
1 <= nums1.length, nums2.length <= 10^4.
Hints#
Hint 1
Treat the grid of all pairs as an implicit graph. Start at (0,0); each popped (i,j) unlocks (i+1,j) and (i,j+1). Avoid duplicates with a visited set.
Approach#
Min-heap over (sum, i, j). Push (0, 0) initially. Each pop generates the next-smallest pair; push (i+1, j) and (i, j+1) if not visited.
Solution#
class Solution {public: vector<vector<int>> kSmallestPairs(vector<int>& a, vector<int>& b, int k) { priority_queue<tuple<int,int,int>, vector<tuple<int,int,int>>, greater<>> pq; set<pair<int,int>> seen; pq.emplace(a[0] + b[0], 0, 0); seen.insert({0, 0}); vector<vector<int>> ans; while (k-- > 0 && !pq.empty()) { auto [s, i, j] = pq.top(); pq.pop(); ans.push_back({a[i], b[j]}); if (i + 1 < (int)a.size() && !seen.count({i + 1, j})) { pq.emplace(a[i + 1] + b[j], i + 1, j); seen.insert({i + 1, j}); } if (j + 1 < (int)b.size() && !seen.count({i, j + 1})) { pq.emplace(a[i] + b[j + 1], i, j + 1); seen.insert({i, j + 1}); } } return ans; }};import "container/heap"
type pairItem struct{ sum, i, j int }type pairHeap []pairItem
func (h pairHeap) Len() int { return len(h) }func (h pairHeap) Less(i, j int) bool { return h[i].sum < h[j].sum }func (h pairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }func (h *pairHeap) Push(x interface{}) { *h = append(*h, x.(pairItem)) }func (h *pairHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] *h = old[:n-1] return x}
func kSmallestPairs(a []int, b []int, k int) [][]int { pq := &pairHeap{} seen := map[[2]int]bool{} heap.Push(pq, pairItem{a[0] + b[0], 0, 0}) seen[[2]int{0, 0}] = true ans := [][]int{} for k > 0 && pq.Len() > 0 { x := heap.Pop(pq).(pairItem) ans = append(ans, []int{a[x.i], b[x.j]}) if x.i+1 < len(a) && !seen[[2]int{x.i + 1, x.j}] { heap.Push(pq, pairItem{a[x.i+1] + b[x.j], x.i + 1, x.j}) seen[[2]int{x.i + 1, x.j}] = true } if x.j+1 < len(b) && !seen[[2]int{x.i, x.j + 1}] { heap.Push(pq, pairItem{a[x.i] + b[x.j+1], x.i, x.j + 1}) seen[[2]int{x.i, x.j + 1}] = true } k-- } return ans}import heapqfrom typing import List
class Solution: def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]: pq = [(nums1[0] + nums2[0], 0, 0)] seen = {(0, 0)} ans: List[List[int]] = [] while k > 0 and pq: s, i, j = heapq.heappop(pq) ans.append([nums1[i], nums2[j]]) if i + 1 < len(nums1) and (i + 1, j) not in seen: heapq.heappush(pq, (nums1[i + 1] + nums2[j], i + 1, j)) seen.add((i + 1, j)) if j + 1 < len(nums2) and (i, j + 1) not in seen: heapq.heappush(pq, (nums1[i] + nums2[j + 1], i, j + 1)) seen.add((i, j + 1)) k -= 1 return ansclass MinHeap { constructor() { this.h = []; } size() { return this.h.length; } push(x) { this.h.push(x); let i = this.h.length - 1; while (i > 0) { const p = (i - 1) >> 1; if (this.h[i][0] < this.h[p][0]) { [this.h[i], this.h[p]] = [this.h[p], this.h[i]]; i = p; } else break; } } pop() { const top = this.h[0]; const last = this.h.pop(); if (this.h.length > 0) { this.h[0] = last; let i = 0; const n = this.h.length; while (true) { const l = i * 2 + 1, r = i * 2 + 2; let best = i; if (l < n && this.h[l][0] < this.h[best][0]) best = l; if (r < n && this.h[r][0] < this.h[best][0]) best = r; if (best === i) break; [this.h[i], this.h[best]] = [this.h[best], this.h[i]]; i = best; } } return top; }}
function kSmallestPairs(nums1, nums2, k) { const pq = new MinHeap(); const seen = new Set(); pq.push([nums1[0] + nums2[0], 0, 0]); seen.add('0,0'); const ans = []; while (k > 0 && pq.size() > 0) { const [s, i, j] = pq.pop(); ans.push([nums1[i], nums2[j]]); if (i + 1 < nums1.length) { const key = `${i + 1},${j}`; if (!seen.has(key)) { pq.push([nums1[i + 1] + nums2[j], i + 1, j]); seen.add(key); } } if (j + 1 < nums2.length) { const key = `${i},${j + 1}`; if (!seen.has(key)) { pq.push([nums1[i] + nums2[j + 1], i, j + 1]); seen.add(key); } } k--; } return ans;}class Solution { public List<List<Integer>> kSmallestPairs(int[] nums1, int[] nums2, int k) { PriorityQueue<int[]> pq = new PriorityQueue<>((x, y) -> x[0] - y[0]); Set<Long> seen = new HashSet<>(); pq.offer(new int[]{nums1[0] + nums2[0], 0, 0}); seen.add(0L); List<List<Integer>> ans = new ArrayList<>(); while (k > 0 && !pq.isEmpty()) { int[] cur = pq.poll(); int i = cur[1], j = cur[2]; ans.add(Arrays.asList(nums1[i], nums2[j])); if (i + 1 < nums1.length) { long key = (long)(i + 1) * 100000L + j; if (seen.add(key)) { pq.offer(new int[]{nums1[i + 1] + nums2[j], i + 1, j}); } } if (j + 1 < nums2.length) { long key = (long)i * 100000L + (j + 1); if (seen.add(key)) { pq.offer(new int[]{nums1[i] + nums2[j + 1], i, j + 1}); } } k--; } return ans; }}class MinHeap<T extends [number, ...unknown[]]> { private h: T[] = []; size(): number { return this.h.length; } push(x: T): void { this.h.push(x); let i = this.h.length - 1; while (i > 0) { const p = (i - 1) >> 1; if (this.h[i][0] < this.h[p][0]) { [this.h[i], this.h[p]] = [this.h[p], this.h[i]]; i = p; } else break; } } pop(): T { const top = this.h[0]; const last = this.h.pop()!; if (this.h.length > 0) { this.h[0] = last; let i = 0; const n = this.h.length; while (true) { const l = i * 2 + 1, r = i * 2 + 2; let best = i; if (l < n && this.h[l][0] < this.h[best][0]) best = l; if (r < n && this.h[r][0] < this.h[best][0]) best = r; if (best === i) break; [this.h[i], this.h[best]] = [this.h[best], this.h[i]]; i = best; } } return top; }}
function kSmallestPairs(nums1: number[], nums2: number[], k: number): number[][] { const pq = new MinHeap<[number, number, number]>(); const seen = new Set<string>(); pq.push([nums1[0] + nums2[0], 0, 0]); seen.add('0,0'); const ans: number[][] = []; while (k > 0 && pq.size() > 0) { const [, i, j] = pq.pop(); ans.push([nums1[i], nums2[j]]); if (i + 1 < nums1.length) { const key = `${i + 1},${j}`; if (!seen.has(key)) { pq.push([nums1[i + 1] + nums2[j], i + 1, j]); seen.add(key); } } if (j + 1 < nums2.length) { const key = `${i},${j + 1}`; if (!seen.has(key)) { pq.push([nums1[i] + nums2[j + 1], i, j + 1]); seen.add(key); } } k--; } return ans;}Editorial#
The k-pair grid problem is equivalent to a k-way merge where each “list” is a row of pair sums (row i is a[i] + b[0..n-1] ascending). Starting from (0,0) and expanding cardinally explores the grid in non-decreasing order via the heap; the visited set prevents duplicate pushes.
Complexity#
- Time: O(k log k).
- Space: O(k).
Concept revision#
Related#