统计与概率相关算法

常用算法与数据结构模板系列(五)

Posted by John Mactavish on September 12, 2020

用 Rand7() 实现 Rand10()

如果我们有 Rand10N() 的话可以通过对 10 除取余很容易地实现 Rand10(), 所以可以自然地想到把 Rand7() 向更大的 rand 映射。一个朴素的想法是进行两次 Rand7() 然后 把结果相加或相乘得到 rand of [2,14] 或者 rand of [1,49],但是仔细想想这样得到的数 并不在区间内均匀出现。然后我们可以想到一个一一映射操作:元组构建,即 a op b -> (a,b)。 元组又可以映射回整数(hash 码):hash of (a,b) is a*weight+b,其中 weight 是权值, 等于元组中该位置所允许出现的数字的总数目。所以进行操作 (Rand7() - 1) * 7 + Rand7() - 1 即可 得到 rand of [0,48],减一是为了让随机数从 0 开始并连续分布。但 [0,48] 不是 [0,9] 的整数倍, 但是没关系,我们可以消去比整数倍多出的部分,如果生成的随机数大于 39,直接拒绝采样,重新生成

int rand10() {
    int n = (rand7() - 1) * 7 + rand7() - 1;
    if (n >= 40) {
        return rand10();
    }
    return n % 10 + 1;
}

进一步地,我们可以进行一些优化。利用范围外的数字构建新的 rand,以减少丢弃的值, 提高命中率从而提高随机数生成效率。

int rand10() {
    while(true) {
        int a = rand7();
        int b = rand7();
        int num = (a-1)*7 + b; // rand 49
        if(num <= 40) return num % 10 + 1; 
        
        a = num - 40; // rand 9
        b = rand7();
        num = (a-1)*7 + b; // rand 63
        if(num <= 60) return num % 10 + 1;
        
        a = num - 60; // rand 3
        b = rand7();
        num = (a-1)*7 + b; // rand 21
        if(num <= 20) return num % 10 + 1;
    }
}

蓄水池抽样

蓄水池抽样解决了一个这样的问题:给定一个数据流,数据流长度 N 很大(不能一次性存入内存), 且 N 直到处理完所有数据之前都不可知(比如链表形式), 请问如何在只遍历一遍数据的情况下,能够随机选取出 m 个不重复的数据(每个数被选中的概率为 m/N)。

证明可以看这里,代码模板如下:

int[] reservoir = new int[m];

// init
for (int i = 0; i < reservoir.length; i++)
{
    reservoir[i] = dataStream[i];
}

for (int i = m; i < dataStream.length; i++)
{
    // 随机获得一个[0, i]内的随机整数
    int d = rand.nextInt(i + 1);
    // 如果随机整数落在[0, m-1]范围内,则替换蓄水池中的元素
    if (d < m)
    {
        reservoir[d] = dataStream[i];
    }
}

模板中遍历 dataStream 时是知道长度的,实际上这样是可以直接用 rand.nextInt(dataStream.length) 选数的。 真实例子中 dataStream 要么是链表,要么是下面这种形式:

// 给定一个可能含有重复元素的整数数组,要求随机输出给定的数字的索引。 您可以假设给定的数字一定存在于数组中。

// 注意:
// 数组大小可能非常大。 使用太多额外空间的解决方案将不会通过测试。

// 示例:

// int[] nums = new int[] {1,2,3,3,3};
// Solution solution = new Solution(nums);

// // pick(3) 应该返回索引 2,3 或者 4。每个索引的返回概率应该相等。
// solution.pick(3);

// // pick(1) 应该返回 0。因为只有nums[0]等于1。
// solution.pick(1);

import java.util.Random;

class Solution { 
    private final int[] nums;
    private final Random rand = new Random();

    public Solution(int[] nums) {
        this.nums = nums;
    }

    // nums 长度已知,但是 target 的数目未知
    // N = nums.filter(x => x==target).size
    public int pick(int target) {
        int cnt = 0, ret = -1;
        for (int i = 0; i < nums.length; i++) {
            if (nums[i] != target) { 
                continue;
            }

            cnt++;
            if (rand.nextInt(cnt) == 0) { // 抽样的 m = 1
                ret = i;
            }
        }
        return ret;
    }
}

/**
 * Your Solution object will be instantiated and called as such:
 * Solution obj = new Solution(nums);
 * int param_1 = obj.pick(target);
 */

洗牌算法

对于洗牌问题:从一个有限集合生成一个随机排列(数组随机排序),一般存在着两种算法。

Fisher-Yates Shuffle 是比较通俗的算法:

  1. 初始化原始数组和新数组,原始数组长度为 n(已知);
  2. 从还没处理的数组(假如还剩 k 个)中,随机产生一个 [0, k) 之间的数字 p(假设数组从 0 开始);
  3. 从剩下的 k 个数中把第 p 个数取出添加到新数组最后;
  4. 重复步骤 2 和 3 直到数字全部取完;
  5. 从步骤 3 取出的数字序列便是一个打乱了的数列。

问题是时间复杂度为 O(n^2),因为 list.remove(p) 操作是线性时间的,总共发生 n 次。

改进算法 Knuth-Durstenfeld Shuffle 在原始数组(或拷贝数组)上对数字进行交换, 时间复杂度优化到 O(n),是最推荐的洗牌算法。模板如下:

int[] shuffle(int[] nums) {
    int n = nums.length;
    var res = Arrays.copyOf(nums, n); // 如果可以修改原数组就不用拷贝
    Random rand = new Random();
    for (int i = n - 1; i >= 0; i--) { // 也可以从前向后扫描,但是要修改 rand
        int tmp = res[i];
        // 每次都从闭区间 [0, i] 中随机选取元素进行交换
        // 注意包括 i,即它可能与自己交换
        int id = rand.nextInt(i + 1); 
        res[i] = res[id];
        res[id] = tmp;
    }
    return res;
}

分析洗牌算法正确性的必要条件:产生的结果必须有 n! 种可能,否则就是错误的。 因为一个长度为 n 的数组的全排列就有 n! 种,也就是说打乱结果总共有 n! 种。 算法必须能够反映这个事实,才是正确的。依此可以说明下面的写法是错误的, 因为这种写法得到的所有可能结果有 n^n 种,而不是 n! 种,而 n^n 不是 n! 的整数倍。

void shuffle(int[] arr) {
    int n = arr.length();
    for (int i = 0 ; i < n; i++) {
        // 每次都从闭区间 [0, n-1] 中随机选取元素进行交换
        int rand = randInt(0, n - 1);
        swap(arr[i], arr[rand]);
    }
}

蒙特卡罗

蒙特卡罗方法(Monte Carlo method),也称统计模拟方法,是 1940 年代中期由于科学技术的发展和电子计算机的发明, 而提出的一种以概率统计理论为指导的数值计算方法。是指使用随机数(或更常见的伪随机数)来解决很多计算问题的方法。 通常可以用来验证概率相关算法的正确性。比如高中有道数学题:往一个正方形里面随机打点,这个正方形里紧贴着一个圆, 告诉你打点的总数和落在圆里的点的数量,让你计算圆周率。用蒙特卡罗验证上面的洗牌算法的正确性如下:

void shuffle(int[] arr);

int N = 1000000; // 重复实验 1000000 次
HashMap count; // 作为直方图
for (i = 0; i < N; i++) {
    int[] arr = {1,2,3};
    shuffle(arr);
    // 此时 arr 已被打乱
    count[arr] += 1;
}
for (int feq : count.values()) 
    print(feq / N + " "); // 频率应当相近

参考:

labuladong 的 LeetCode 解答

kkbill 的 LeetCode 解答

最后附上 GitHub:https://github.com/gonearewe