Skip to content

Commit

Permalink
Merge pull request numpy#24220 from WarrenWeckesser/dirichlet-zeros
Browse files Browse the repository at this point in the history
BUG: random: Fix generation of nan by dirichlet.
  • Loading branch information
charris committed Aug 3, 2023
2 parents 4237991 + c27678e commit 09dac65
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
29 changes: 18 additions & 11 deletions numpy/random/_generator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4408,6 +4408,7 @@ cdef class Generator:
np.NPY_ARRAY_ALIGNED | np.NPY_ARRAY_C_CONTIGUOUS)
if np.any(np.less(alpha_arr, 0)):
raise ValueError('alpha < 0')

alpha_data = <double*>np.PyArray_DATA(alpha_arr)

if size is None:
Expand Down Expand Up @@ -4467,17 +4468,23 @@ cdef class Generator:
csum += alpha_data[j]
alpha_csum_data[j] = csum

with self.lock, nogil:
while i < totsize:
acc = 1.
for j in range(k - 1):
v = random_beta(&self._bitgen, alpha_data[j],
alpha_csum_data[j + 1])
val_data[i + j] = acc * v
acc *= (1. - v)
val_data[i + k - 1] = acc
i = i + k

# If csum == 0, then all the values in alpha are 0, and there is
# nothing to do, because diric was created with np.zeros().
if csum > 0:
with self.lock, nogil:
while i < totsize:
acc = 1.
for j in range(k - 1):
v = random_beta(&self._bitgen, alpha_data[j],
alpha_csum_data[j + 1])
val_data[i + j] = acc * v
acc *= (1. - v)
if alpha_csum_data[j + 1] == 0:
# v must be 1, so acc is now 0. All
# remaining elements will be left at 0.
break
val_data[i + k - 1] = acc
i = i + k
else:
# Standard case: Unit normalisation of a vector of gamma random
# variates
Expand Down
24 changes: 22 additions & 2 deletions numpy/random/tests/test_generator_mt19937.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
},
]


@pytest.fixture(scope='module', params=[True, False])
def endpoint(request):
return request.param
Expand Down Expand Up @@ -145,6 +146,7 @@ def test_multinomial_pvals_float32(self):
with pytest.raises(ValueError, match=match):
random.multinomial(1, pvals)


class TestMultivariateHypergeometric:

def setup_method(self):
Expand Down Expand Up @@ -1238,6 +1240,25 @@ def test_dirichlet_moderately_small_alpha(self):
sample_mean = sample.mean(axis=0)
assert_allclose(sample_mean, exact_mean, rtol=1e-3)

# This set of parameters includes inputs with alpha.max() >= 0.1 and
# alpha.max() < 0.1 to exercise both generation methods within the
# dirichlet code.
@pytest.mark.parametrize(
'alpha',
[[5, 9, 0, 8],
[0.5, 0, 0, 0],
[1, 5, 0, 0, 1.5, 0, 0, 0],
[0.01, 0.03, 0, 0.005],
[1e-5, 0, 0, 0],
[0.002, 0.015, 0, 0, 0.04, 0, 0, 0],
[0.0],
[0, 0, 0]],
)
def test_dirichlet_multiple_zeros_in_alpha(self, alpha):
alpha = np.array(alpha)
y = random.dirichlet(alpha)
assert_equal(y[alpha == 0], 0.0)

def test_exponential(self):
random = Generator(MT19937(self.seed))
actual = random.exponential(1.1234, size=(3, 2))
Expand Down Expand Up @@ -1467,7 +1488,7 @@ def test_multivariate_normal(self, method):
mu, np.empty((3, 2)))
assert_raises(ValueError, random.multivariate_normal,
mu, np.eye(3))

@pytest.mark.parametrize('mean, cov', [([0], [[1+1j]]), ([0j], [[1]])])
def test_multivariate_normal_disallow_complex(self, mean, cov):
random = Generator(MT19937(self.seed))
Expand Down Expand Up @@ -1847,7 +1868,6 @@ class TestBroadcast:
def setup_method(self):
self.seed = 123456789


def test_uniform(self):
random = Generator(MT19937(self.seed))
low = [0]
Expand Down
5 changes: 1 addition & 4 deletions numpy/random/tests/test_randomstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,6 @@ def test_dirichlet_bad_alpha(self):
alpha = np.array([5.4e-01, -1.0e-16])
assert_raises(ValueError, random.dirichlet, alpha)

def test_dirichlet_zero_alpha(self):
y = random.default_rng().dirichlet([5, 9, 0, 8])
assert_equal(y[2], 0)

def test_dirichlet_alpha_non_contiguous(self):
a = np.array([51.72840233779265162, -1.0, 39.74494232180943953])
alpha = a[::2]
Expand Down Expand Up @@ -2061,6 +2057,7 @@ def test_randomstate_ctor_old_style_pickle():
assert_equal(state_a['has_gauss'], state_b['has_gauss'])
assert_equal(state_a['gauss'], state_b['gauss'])


def test_hot_swap(restore_singleton_bitgen):
# GH 21808
def_bg = np.random.default_rng(0)
Expand Down

0 comments on commit 09dac65

Please sign in to comment.