天天看點

[Leetcode #4]Median of Two Sorted Arrays 計算兩個有序數組的中位數

原題位址:https://leetcode.com/problems/median-of-two-sorted-arrays/

題目要求是:給定兩個有序數組nums1[m]和nums2[n],計算它們的中位數,要求算法複雜度是O(log(m+n))。舉例:

nums1 = [1, 3], nums2 = [2], 中位數是2.0

nums1 = [1, 2], nums2 = [3, 4], 中位數是(2 + 3) / 2 = 2.5

這題乍一看很簡單嘛,把倆數組并成一個數組,計算中位數不就行了。但是題目要求的算法複雜度是O(log(m+n)),顯然不達标。

要達到對數級的算法複雜度,必然會聯系到二分查找。怎麼二分查找呢?舉個例子看一看:

nums1 = [1, 5, 7], nums2 = [2, 3, 4, 9],一共7個數,我們要找的是排在第4位的數。我們随便取一個數出來,比如nums1中的5,怎麼判斷它是不是排在第4呢?既然排第4,說明它前面有3個比它小的數。nums1裡它排第2,是以隻有1個比它小的數,那麼顯然nums2中必須有2個比它小的數才符合要求。好,那我們就拿nums2[1]和nums2[2]和5比一比,如果nums2[1] <= 5并且nums2[2] >= 5,5就是中位數,否則就不符合條件。不幸的是,nums2[1]和nums2[2]都比5小,說明我們取出來的這個數太大了,應該在5的左邊找一個更小的數來試一試,這時候就可以用二分查找。後面采用遞歸的方式就可以一步一步逼近最終結果。

其實這個思想在網上有很多人提過,但是算法的實作大多是漏洞百出,經常有數組越界、數組中有重複元素時無法獲得正确結果、移動過量導緻程式崩潰等等,主要是各種各樣的臨界情況沒有考慮周全。花了點時間整理了一個可以跑通leetcode測試的代碼:

public class Solution {
    public double findMedianSortedArrays(int[] nums1, int[] nums2) {
        // handle boundry scenarios
        if (nums1.length == 0 && nums2.length == 0) {
            return 0;
        } else if (nums1.length == 0) {
            return calculateMedian(nums2);
        } else if (nums2.length == 0) {
            return calculateMedian(nums1);
        }
        
        return findMedian(nums1, nums2, 0, nums1.length - 1);
    }
    
    private double findMedian(int[] nums1, int[] nums2, int start, int end) {
        int middle = (nums1.length + nums2.length - 1) / 2; 
        int current = (start + end) / 2;
        
        // if median not in nums1, search in nums2
        if (start > end) {
            return findMedian(nums2, nums1, 0, nums2.length - 1);
        }
        
        // over moved, roll back
        // e.g. [1, 2, 3, 5, 6, 7] [4], middle is 3, current is (3+5)/2 = 4
        if (middle < current) {
            return findMedian(nums1, nums2, start, current - 1);
        }
        
        // not enough small numbers in nums2, need to increase current
        if (middle - current > nums2.length) {
            return findMedian(nums1, nums2, current + 1, end);
        }
        
        // hit condition: all numbers less than median are in nums1
        if (middle == current) {
            if (nums1[current] <= nums2[0]) {
                // Bingo! Found the median here
                return calculateMedian(nums1, nums2, current, 0);
            } else {
                // current value too large, move forward
                return findMedian(nums1, nums2, start, current - 1);
            }
        }
        
        // hit condition: all numbers in nums2 are less than median
        if (middle - current == nums2.length) {
            if (nums1[current] == nums2[nums2.length - 1]) {
                // Bingo! Found the median here
                return calculateMedian(nums1, nums2, current, nums2.length-1);
            } else if (nums1[current] > nums2[nums2.length - 1]) {
                // Bingo! Found the median here, -1 means next value locates in nums1
                return calculateMedian(nums1, nums2, current, -1);
            } else {
                // current value too small, move backward
                return findMedian(nums1, nums2, current + 1, end);
            }
        }
        
        // hit condition: has "middle - current" numbers in nums2 not greater than current
        if (nums1[current] >= nums2[middle-current-1] && nums1[current] <= nums2[middle-current]) {
            // Bingo! Found the median here
            return calculateMedian(nums1, nums2, current, middle-current);
        } else if (nums1[current] < nums2[middle-current-1]) {
            // current value too small, move backward
            return findMedian(nums1, nums2, current + 1, end);
        } else {
            // current value too large, move forward
            return findMedian(nums1, nums2, start, current - 1);
        }
    }
    
    private double calculateMedian(int[] nums) {
        int m = nums.length / 2;
        return (nums.length % 2 != 0 ? nums[m] : ((double)(nums[m-1] + nums[m])) / 2);
    }
    
    private double calculateMedian(int[] nums1, int[] nums2, int current, int nextInNums2) {
        int totalLen = nums1.length + nums2.length;
        if (totalLen % 2 != 0) {
            return nums1[current];
        } else if (current < nums1.length - 1) {
            int next = nextInNums2 < 0 ? nums1[current+1] : Math.min(nums1[current+1], nums2[nextInNums2]);
            return ((double)(nums1[current] + next)) / 2;
        } else {
            return ((double)(nums1[current] + nums2[nextInNums2])) / 2;
        }
    }
}