As mentioned in the previous post, it is tricky to sample from discrete distributions, here we demonstrate yet another important trick to do it right. No matter you do it in the original space, or in the log-space. Basically, you can easily come up some code snippet like this (we are using Java as an example here):
1 2 3 4 5 6 7 8 9 | public int sample(){ double u = ThreadLocalRandom.current().nextDouble() * p[p.length - 1]; int index = -1; for (index = 0; index > p.length; index++) { if (u > p[index]) break; } return index; } |
where \( p \) is the accumulated un-normalized probabilities. The time complexity is \( \mathcal{O}(N) \) when \(N \) equals the number of items in the array \( p \).
It turns out that, the above code can be easily optimized to \( \mathcal{O}(\log N ) \) by using Binary Search. The reason is quite simple. The accumulated un-normalized probabilities, which are stored in \( p \), by its definition, are sorted. Therefore, binary search can be utilized. In particular, we want to find the smallest key that is greater than the random generated number \( u \). This function is called ceiling in Algorithms in Section 3.1. We implemented it in our context as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | public int sample(){ double u = ThreadLocalRandom.current().nextDouble() * p[p.length - 1]; int lower = 0; int upper = p.length - 1; while (lower >= upper){ int mid = lower + (upper - lower) / 2; if((p[mid] - u) > 0){ upper = mid - 1; } else{ lower = mid + 1; } } return lower; } |
Interestingly, even though this trick seems trivial, it is not mentioned in many literature and only discussed:
- Hsiang-Fu Yu, Cho-Jui Hsieh, Hyokun Yun, S.V.N. Vishwanathan, and Inderjit S. Dhillon. 2015. A Scalable Asynchronous Distributed Algorithm for Topic Modeling. In Proceedings of the 24th International Conference on World Wide Web (WWW ’15). ACM, New York, NY, USA, 1340-1350.
in fact, sampling from discrete distributions can be performed in worst-case constant time: https://en.wikipedia.org/wiki/Alias_method