Skip to content

Commit

Permalink
UPDATED: Multinomial distribution rand does not support Float32 proba…
Browse files Browse the repository at this point in the history
…bility vectors (#1738)

* Allow probability vectors other than Float64 vector for Multinomial distributions.
Tests still pass.

* Remove explicit concrete types from multinom_rand

* Added test to make sure rand works for real types other than Float64

* Removed extra test

* updated tests for type stability

* same test but not using internal function

* repeat tests under different types instead

* type stability within multinomial. may want to extend into Binomial

* Apply suggestions from code review

Type conversions implicitly at runtime covers more systems better

Co-authored-by: David Widmann <[email protected]>

* delete leftover type variable

* Simplify tests

---------

Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
4 people committed Jun 30, 2023
1 parent b9d3093 commit 55d4d6e
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 198 deletions.
7 changes: 4 additions & 3 deletions src/samplers/multinomial.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
function multinom_rand!(rng::AbstractRNG, n::Int, p::AbstractVector{Float64},
function multinom_rand!(rng::AbstractRNG, n::Int, p::AbstractVector{<:Real},
x::AbstractVector{<:Real})
k = length(p)
length(x) == k || throw(DimensionMismatch("Invalid argument dimension."))

rp = 1.0 # remaining total probability
z = zero(eltype(p))
rp = oftype(z + z, 1) # remaining total probability (widens type if needed)
i = 0
km1 = k - 1

while i < km1 && n > 0
i += 1
@inbounds pi = p[i]
if pi < rp
xi = rand(rng, Binomial(n, pi / rp))
xi = rand(rng, Binomial(n, Float64(pi / rp)))
@inbounds x[i] = xi
n -= xi
rp -= pi
Expand Down
Loading

0 comments on commit 55d4d6e

Please sign in to comment.