brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,672 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from functools import partial
17
+
18
+ import brainunit as u
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import jax.random as jr
22
+ import numpy as np
23
+ from jax import jit, vmap
24
+ from jax import lax, dtypes
25
+ from jax.scipy import special as jsp
26
+
27
+ from brainstate import environ
28
+
29
+
30
+ def _categorical(key, p, shape):
31
+ # this implementation is fast when event shape is small, and slow otherwise
32
+ # Ref: https://stackoverflow.com/a/34190035
33
+ shape = shape or p.shape[:-1]
34
+ s = jnp.cumsum(p, axis=-1)
35
+ r = jr.uniform(key, shape=shape + (1,))
36
+ return jnp.sum(s < r, axis=-1)
37
+
38
+
39
+ @partial(jit, static_argnames=('n_max', 'shape'))
40
+ def multinomial(key, p, n, *, n_max, shape=()):
41
+ if u.math.shape(n) != u.math.shape(p)[:-1]:
42
+ broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
43
+ n = jnp.broadcast_to(n, broadcast_shape)
44
+ p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
45
+ shape = shape or p.shape[:-1]
46
+ if n_max == 0:
47
+ return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
48
+ # get indices from categorical distribution then gather the result
49
+ indices = _categorical(key, p, (n_max,) + shape)
50
+ # mask out values when counts is heterogeneous
51
+ if jnp.ndim(n) > 0:
52
+ mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
53
+ mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
54
+ excess = jnp.concatenate(
55
+ [jnp.expand_dims(n_max - n, -1),
56
+ jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
57
+ -1
58
+ )
59
+ else:
60
+ mask = 1
61
+ excess = 0
62
+ # NB: we transpose to move batch shape to the front
63
+ indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
64
+ samples_2D = vmap(_scatter_add_one)(
65
+ jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
66
+ jnp.expand_dims(indices_2D, axis=-1),
67
+ jnp.ones(indices_2D.shape, dtype=indices.dtype)
68
+ )
69
+ return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
70
+
71
+
72
+ @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
73
+ def von_mises_centered(
74
+ key,
75
+ concentration,
76
+ shape,
77
+ dtype=None
78
+ ):
79
+ """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
80
+
81
+ Returns
82
+ -------
83
+ out: array_like
84
+ centered samples from von Mises
85
+
86
+ References
87
+ ----------
88
+ .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
89
+ Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
90
+
91
+ """
92
+ shape = shape or u.math.shape(concentration)
93
+ dtype = dtype or environ.dftype()
94
+ concentration = lax.convert_element_type(concentration, dtype)
95
+ concentration = jnp.broadcast_to(concentration, shape)
96
+
97
+ if dtype == jnp.float16:
98
+ s_cutoff = 1.8e-1
99
+ elif dtype == jnp.float32:
100
+ s_cutoff = 2e-2
101
+ elif dtype == jnp.float64:
102
+ s_cutoff = 1.2e-4
103
+ else:
104
+ raise ValueError(f"Unsupported dtype: {dtype}")
105
+
106
+ r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
107
+ rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
108
+ s_exact = (1.0 + rho ** 2) / (2.0 * rho)
109
+
110
+ s_approximate = 1.0 / concentration
111
+
112
+ s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
113
+
114
+ def cond_fn(*args):
115
+ """check if all are done or reached max number of iterations"""
116
+ i, _, done, _, _ = args[0]
117
+ return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
118
+
119
+ def body_fn(*args):
120
+ i, key, done, _, w = args[0]
121
+ uni_ukey, uni_vkey, key = jr.split(key, 3)
122
+ u_ = jr.uniform(
123
+ key=uni_ukey,
124
+ shape=shape,
125
+ dtype=concentration.dtype,
126
+ minval=-1.0,
127
+ maxval=1.0,
128
+ )
129
+ z = jnp.cos(jnp.pi * u_)
130
+ w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
131
+ y = concentration * (s - w)
132
+ v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
133
+ accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
134
+ return i + 1, key, accept | done, u_, w
135
+
136
+ init_done = jnp.zeros(shape, dtype=bool)
137
+ init_u = jnp.zeros(shape)
138
+ init_w = jnp.zeros(shape)
139
+
140
+ _, _, done, uu, w = lax.while_loop(
141
+ cond_fun=cond_fn,
142
+ body_fun=body_fn,
143
+ init_val=(jnp.array(0), key, init_done, init_u, init_w),
144
+ )
145
+
146
+ return jnp.sign(uu) * jnp.arccos(w)
147
+
148
+
149
+ def _scatter_add_one(operand, indices, updates):
150
+ return lax.scatter_add(
151
+ operand,
152
+ indices,
153
+ updates,
154
+ lax.ScatterDimensionNumbers(
155
+ update_window_dims=(),
156
+ inserted_window_dims=(0,),
157
+ scatter_dims_to_operand_dims=(0,),
158
+ ),
159
+ )
160
+
161
+
162
+ def _reshape(x, shape):
163
+ if isinstance(x, (int, float, np.ndarray, np.generic)):
164
+ return np.reshape(x, shape)
165
+ else:
166
+ return jnp.reshape(x, shape)
167
+
168
+
169
+ def _promote_shapes(*args, shape=()):
170
+ # adapted from lax.lax_numpy
171
+ if len(args) < 2 and not shape:
172
+ return args
173
+ else:
174
+ shapes = [u.math.shape(arg) for arg in args]
175
+ num_dims = len(lax.broadcast_shapes(shape, *shapes))
176
+ return [
177
+ _reshape(arg, (1,) * (num_dims - len(s)) + s)
178
+ if len(s) < num_dims else arg
179
+ for arg, s in zip(args, shapes)
180
+ ]
181
+
182
+
183
+ python_scalar_dtypes = {
184
+ bool: np.dtype('bool'),
185
+ int: np.dtype('int64'),
186
+ float: np.dtype('float64'),
187
+ complex: np.dtype('complex128'),
188
+ }
189
+
190
+
191
+ def _dtype(x, *, canonicalize: bool = False):
192
+ """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
193
+ if x is None:
194
+ raise ValueError(f"Invalid argument to dtype: {x}.")
195
+ elif isinstance(x, type) and x in python_scalar_dtypes:
196
+ dt = python_scalar_dtypes[x]
197
+ elif type(x) in python_scalar_dtypes:
198
+ dt = python_scalar_dtypes[type(x)]
199
+ elif hasattr(x, 'dtype'):
200
+ dt = x.dtype
201
+ else:
202
+ dt = np.result_type(x)
203
+ return dtypes.canonicalize_dtype(dt) if canonicalize else dt
204
+
205
+
206
+ def _is_python_scalar(x):
207
+ if hasattr(x, 'aval'):
208
+ return x.aval.weak_type
209
+ elif np.ndim(x) == 0:
210
+ return True
211
+ elif isinstance(x, (bool, int, float, complex)):
212
+ return True
213
+ else:
214
+ return False
215
+
216
+
217
+ def const(example, val):
218
+ if _is_python_scalar(example):
219
+ dtype = dtypes.canonicalize_dtype(type(example))
220
+ val = dtypes.scalar_type_of(example)(val)
221
+ return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
222
+ else:
223
+ dtype = dtypes.canonicalize_dtype(example.dtype)
224
+ return np.array(val, dtype)
225
+
226
+
227
+ # ---------------------------------------------------------------------------------------------------------------
228
+
229
+
230
+ def formalize_key(key, use_prng_key=True):
231
+ if isinstance(key, int):
232
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
233
+ elif isinstance(key, (jax.Array, np.ndarray)):
234
+ if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
235
+ return key
236
+ if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
237
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
238
+
239
+ if key.dtype != jnp.uint32:
240
+ raise TypeError('key must be a int or an array with two uint32.')
241
+ if key.size != 2:
242
+ raise TypeError('key must be a int or an array with two uint32.')
243
+ return u.math.asarray(key, dtype=jnp.uint32)
244
+ else:
245
+ raise TypeError('key must be a int or an array with two uint32.')
246
+
247
+
248
+ def _size2shape(size):
249
+ if size is None:
250
+ return ()
251
+ elif isinstance(size, (tuple, list)):
252
+ return tuple(size)
253
+ else:
254
+ return (size,)
255
+
256
+
257
+ def _check_shape(name, shape, *param_shapes):
258
+ if param_shapes:
259
+ shape_ = lax.broadcast_shapes(shape, *param_shapes)
260
+ if shape != shape_:
261
+ msg = ("{} parameter shapes must be broadcast-compatible with shape "
262
+ "argument, and the result of broadcasting the shapes must equal "
263
+ "the shape argument, but got result {} for shape argument {}.")
264
+ raise ValueError(msg.format(name, shape_, shape))
265
+
266
+
267
+ def _loc_scale(
268
+ loc,
269
+ scale,
270
+ value
271
+ ):
272
+ if loc is None:
273
+ if scale is None:
274
+ return value
275
+ else:
276
+ return value * scale
277
+ else:
278
+ if scale is None:
279
+ return value + loc
280
+ else:
281
+ return value * scale + loc
282
+
283
+
284
+ def _check_py_seq(seq):
285
+ return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq
286
+
287
+
288
+ @partial(jit, static_argnames=['shape', 'dtype'])
289
+ def f(
290
+ key,
291
+ dfnum,
292
+ dfden,
293
+ *,
294
+ shape,
295
+ dtype=None
296
+ ):
297
+ """Draw samples from the central F distribution."""
298
+ dtype = dtype or environ.dftype()
299
+ dfnum = lax.convert_element_type(dfnum, dtype)
300
+ dfden = lax.convert_element_type(dfden, dtype)
301
+
302
+ if shape is None:
303
+ shape = lax.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
304
+ elif isinstance(shape, int):
305
+ shape = (shape,)
306
+ else:
307
+ shape = tuple(shape)
308
+
309
+ dfnum = jnp.broadcast_to(dfnum, shape)
310
+ dfden = jnp.broadcast_to(dfden, shape)
311
+
312
+ size = int(np.prod(shape)) if shape else 1
313
+ if size == 0:
314
+ return jnp.empty(shape, dtype=dtype)
315
+
316
+ key_num, key_den = jr.split(key)
317
+ chi2_num = 2.0 * jr.gamma(key_num, 0.5 * dfnum, shape=shape, dtype=dtype)
318
+ chi2_den = 2.0 * jr.gamma(key_den, 0.5 * dfden, shape=shape, dtype=dtype)
319
+
320
+ return (chi2_num / dfnum) / (chi2_den / dfden)
321
+
322
+
323
+ @partial(jit, static_argnames=['shape', 'dtype'])
324
+ def noncentral_f(
325
+ key,
326
+ dfnum,
327
+ dfden,
328
+ nonc,
329
+ *,
330
+ shape,
331
+ dtype=None
332
+ ):
333
+ """
334
+ Draw samples from the noncentral F distribution.
335
+
336
+ The noncentral F distribution is a generalization of the F distribution.
337
+ It is parameterized by dfnum (degrees of freedom of the numerator),
338
+ dfden (degrees of freedom of the denominator), and nonc (noncentrality parameter).
339
+
340
+ The implementation uses the relationship:
341
+ If X ~ noncentral_chisquare(dfnum, nonc) and Y ~ chisquare(dfden), then
342
+ F = (X / dfnum) / (Y / dfden) ~ noncentral_f(dfnum, dfden, nonc)
343
+
344
+ Parameters
345
+ ----------
346
+ key : jax.random.PRNGKey
347
+ Random key
348
+ dfnum : float or array_like
349
+ Degrees of freedom of the numerator, must be > 0
350
+ dfden : float or array_like
351
+ Degrees of freedom of the denominator, must be > 0
352
+ nonc : float or array_like
353
+ Noncentrality parameter, must be >= 0
354
+ shape : tuple
355
+ Output shape
356
+ dtype : dtype, optional
357
+ Data type of the output
358
+
359
+ Returns
360
+ -------
361
+ out : array_like
362
+ Samples from the noncentral F distribution
363
+ """
364
+ dtype = dtype or environ.dftype()
365
+ dfnum = lax.convert_element_type(dfnum, dtype)
366
+ dfden = lax.convert_element_type(dfden, dtype)
367
+ nonc = lax.convert_element_type(nonc, dtype)
368
+
369
+ # Split key for two random samples
370
+ key1, key2 = jr.split(key)
371
+
372
+ # Generate noncentral chi-square for numerator
373
+ # noncentral_chisquare(df, nonc) = chi-square(df - 1) + (normal(0,1) + sqrt(nonc))^2
374
+ # when df > 1, else chi-square(df + 2*poisson(nonc/2))
375
+ keys_numer = jr.split(key1, 3)
376
+ i = jr.poisson(keys_numer[0], 0.5 * nonc, shape=shape, dtype=environ.ditype())
377
+ n = jr.normal(keys_numer[1], shape=shape, dtype=dtype) + jnp.sqrt(nonc)
378
+ cond = jnp.greater(dfnum, 1.0)
379
+ df_numerator = jnp.where(cond, dfnum - 1.0, dfnum + 2.0 * i)
380
+ chi2_numerator = 2.0 * jr.gamma(keys_numer[2], 0.5 * df_numerator, shape=shape, dtype=dtype)
381
+ numerator = jnp.where(cond, chi2_numerator + n * n, chi2_numerator)
382
+
383
+ # Generate central chi-square for denominator
384
+ # chi-square(df) = 2 * gamma(df/2, 1)
385
+ chi2_denominator = 2.0 * jr.gamma(key2, 0.5 * dfden, shape=shape, dtype=dtype)
386
+
387
+ # Compute F statistic: (numerator / dfnum) / (denominator / dfden)
388
+ f_stat = (numerator / dfnum) / (chi2_denominator / dfden)
389
+
390
+ return f_stat
391
+
392
+
393
+ @partial(jit, static_argnames=['shape', 'dtype'])
394
+ def logseries(
395
+ key,
396
+ p,
397
+ *,
398
+ shape,
399
+ dtype=None
400
+ ):
401
+ """Draw samples from the logarithmic series distribution."""
402
+ dtype = dtype or environ.ditype()
403
+ float_dtype = dtypes.canonicalize_dtype(environ.dftype())
404
+ calc_dtype = dtypes.canonicalize_dtype(jnp.promote_types(float_dtype, jnp.float64))
405
+
406
+ p = lax.convert_element_type(p, float_dtype)
407
+
408
+ if shape is None:
409
+ shape = u.math.shape(p)
410
+ elif isinstance(shape, int):
411
+ shape = (shape,)
412
+ else:
413
+ shape = tuple(shape)
414
+
415
+ p = jnp.broadcast_to(p, shape)
416
+
417
+ size = int(np.prod(shape)) if shape else 1
418
+ if size == 0:
419
+ return jnp.empty(shape, dtype=dtype)
420
+
421
+ p_flat = jnp.reshape(lax.convert_element_type(p, calc_dtype), (size,))
422
+ keys = jr.split(key, size)
423
+
424
+ tiny = jnp.array(np.finfo(calc_dtype).tiny, dtype=calc_dtype)
425
+ one_minus_eps = jnp.nextafter(jnp.array(1.0, dtype=calc_dtype), jnp.array(0.0, dtype=calc_dtype))
426
+
427
+ def _sample_one(single_key, p_scalar):
428
+ p_scalar = lax.convert_element_type(p_scalar, calc_dtype)
429
+ operand = (single_key, p_scalar)
430
+
431
+ def _limit_case(_):
432
+ return jnp.array(1.0, dtype=calc_dtype)
433
+
434
+ def _positive_case(args):
435
+ key_i, p_val = args
436
+ p_val = jnp.clip(p_val, tiny, one_minus_eps)
437
+ log_p = jnp.log(p_val)
438
+ log_norm = jnp.log(-jnp.log1p(-p_val))
439
+ log_prob = log_p - log_norm
440
+ log_cdf = log_prob
441
+ log_u = jnp.log(jr.uniform(key_i, shape=(), dtype=calc_dtype, minval=tiny, maxval=one_minus_eps))
442
+
443
+ init_state = (jnp.array(1.0, dtype=calc_dtype), log_prob, log_cdf, log_u)
444
+
445
+ def cond_fn(state):
446
+ _, _, log_cdf_val, log_u_val = state
447
+ return log_u_val > log_cdf_val
448
+
449
+ def body_fn(state):
450
+ k_val, log_prob_val, log_cdf_val, log_u_val = state
451
+ k_next = k_val + 1.0
452
+ log_prob_next = log_prob_val + log_p + jnp.log(k_val) - jnp.log(k_next)
453
+ log_cdf_next = jnp.logaddexp(log_cdf_val, log_prob_next)
454
+ return k_next, log_prob_next, log_cdf_next, log_u_val
455
+
456
+ k_val, _, _, _ = lax.while_loop(cond_fn, body_fn, init_state)
457
+ return k_val
458
+
459
+ return lax.cond(p_scalar <= 0.0, _limit_case, _positive_case, operand)
460
+
461
+ samples = vmap(_sample_one)(keys, p_flat)
462
+ samples = lax.convert_element_type(samples, dtype)
463
+ return jnp.reshape(samples, shape)
464
+
465
+
466
+ @partial(jit, static_argnames=['shape', 'dtype'])
467
+ def zipf(
468
+ key,
469
+ a,
470
+ *,
471
+ shape,
472
+ dtype=None
473
+ ):
474
+ """Draw samples from the Zipf (zeta) distribution."""
475
+ dtype = dtype or environ.ditype()
476
+ float_dtype = dtypes.canonicalize_dtype(environ.dftype())
477
+ calc_dtype = dtypes.canonicalize_dtype(jnp.promote_types(float_dtype, jnp.float64))
478
+
479
+ a = lax.convert_element_type(a, calc_dtype)
480
+
481
+ if shape is None:
482
+ shape = u.math.shape(a)
483
+ elif isinstance(shape, int):
484
+ shape = (shape,)
485
+ else:
486
+ shape = tuple(shape)
487
+
488
+ a = jnp.broadcast_to(a, shape)
489
+
490
+ size = int(np.prod(shape)) if shape else 1
491
+ if size == 0:
492
+ return jnp.empty(shape, dtype=dtype)
493
+
494
+ u_ = jr.uniform(
495
+ key,
496
+ shape=shape,
497
+ dtype=calc_dtype,
498
+ minval=jnp.finfo(calc_dtype).tiny,
499
+ maxval=jnp.array(1.0, dtype=calc_dtype)
500
+ )
501
+
502
+ a_flat = jnp.reshape(a, (size,))
503
+ u_flat = jnp.reshape(u_, (size,))
504
+
505
+ max_iters = jnp.array(1000000, dtype=jnp.int32)
506
+
507
+ def _sample_one(a_scalar, u_scalar):
508
+ norm = jsp.zeta(a_scalar, jnp.array(1.0, dtype=calc_dtype))
509
+
510
+ def cdf(k_val):
511
+ return (
512
+ jnp.array(1.0, dtype=calc_dtype) -
513
+ jsp.zeta(a_scalar, k_val + jnp.array(1.0, dtype=calc_dtype)) / norm
514
+ )
515
+
516
+ initial = jnp.array(1.0, dtype=calc_dtype)
517
+ cdf_prev = jnp.array(0.0, dtype=calc_dtype)
518
+ cdf_curr = cdf(initial)
519
+
520
+ state = (
521
+ initial,
522
+ cdf_prev,
523
+ cdf_curr,
524
+ jnp.array(0, dtype=jnp.int32)
525
+ )
526
+
527
+ def cond_fn(state):
528
+ _, c_prev, c_curr, it = state
529
+ not_ok = jnp.logical_or(u_scalar > c_curr, u_scalar <= c_prev)
530
+ return jnp.logical_and(not_ok, it < max_iters)
531
+
532
+ def body_fn(state):
533
+ k_val, c_prev, c_curr, it = state
534
+ need_increase = u_scalar > c_curr
535
+
536
+ def inc(_):
537
+ k_next = k_val + jnp.array(1.0, dtype=calc_dtype)
538
+ c_prev_next = jnp.array(1.0, dtype=calc_dtype) - jsp.zeta(a_scalar, k_next) / norm
539
+ c_curr_next = cdf(k_next)
540
+ return k_next, c_prev_next, c_curr_next, it + 1
541
+
542
+ def dec(_):
543
+ k_next = jnp.maximum(jnp.array(1.0, dtype=calc_dtype), k_val - jnp.array(1.0, dtype=calc_dtype))
544
+ c_prev_next = jnp.array(1.0, dtype=calc_dtype) - jsp.zeta(a_scalar, k_next) / norm
545
+ c_curr_next = cdf(k_next)
546
+ return k_next, c_prev_next, c_curr_next, it + 1
547
+
548
+ return lax.cond(need_increase, inc, dec, operand=None)
549
+
550
+ k_final, _, _, _ = lax.while_loop(cond_fn, body_fn, state)
551
+ return lax.convert_element_type(k_final, dtype)
552
+
553
+ samples_flat = jax.vmap(_sample_one)(a_flat, u_flat)
554
+ samples = jnp.reshape(samples_flat, shape)
555
+ return samples
556
+
557
+
558
+ @partial(jit, static_argnames=['shape', 'dtype'])
559
+ def power(
560
+ key,
561
+ a,
562
+ *,
563
+ shape,
564
+ dtype=None
565
+ ):
566
+ """Draw samples from the power distribution."""
567
+ dtype = dtype or environ.dftype()
568
+ float_dtype = dtypes.canonicalize_dtype(dtype)
569
+
570
+ a = lax.convert_element_type(a, float_dtype)
571
+
572
+ if shape is None:
573
+ shape = u.math.shape(a)
574
+ elif isinstance(shape, int):
575
+ shape = (shape,)
576
+ else:
577
+ shape = tuple(shape)
578
+
579
+ a = jnp.broadcast_to(a, shape)
580
+
581
+ size = int(np.prod(shape)) if shape else 1
582
+ if size == 0:
583
+ return jnp.empty(shape, dtype=float_dtype)
584
+
585
+ eps = jnp.array(np.finfo(float_dtype).tiny, dtype=float_dtype)
586
+ a_safe = jnp.maximum(a, eps)
587
+
588
+ u_ = jr.uniform(key, shape=shape, dtype=float_dtype, minval=eps, maxval=1.0)
589
+ samples = jnp.power(u_, jnp.reciprocal(a_safe))
590
+
591
+ return lax.convert_element_type(samples, dtype)
592
+
593
+
594
+ @partial(jit, static_argnames=['shape', 'dtype'])
595
+ def hypergeometric(
596
+ key,
597
+ ngood,
598
+ nbad,
599
+ nsample,
600
+ *,
601
+ shape,
602
+ dtype=None
603
+ ):
604
+ """Draw samples from the hypergeometric distribution."""
605
+ dtype = dtype or environ.ditype()
606
+ out_dtype = dtypes.canonicalize_dtype(dtype)
607
+ float_dtype = dtypes.canonicalize_dtype(environ.dftype())
608
+ calc_dtype = dtypes.canonicalize_dtype(jnp.promote_types(float_dtype, jnp.float64))
609
+
610
+ ngood = lax.convert_element_type(ngood, out_dtype)
611
+ nbad = lax.convert_element_type(nbad, out_dtype)
612
+ nsample = lax.convert_element_type(nsample, out_dtype)
613
+
614
+ if shape is None:
615
+ shape = lax.broadcast_shapes(u.math.shape(ngood), u.math.shape(nbad), u.math.shape(nsample))
616
+ elif isinstance(shape, int):
617
+ shape = (shape,)
618
+ else:
619
+ shape = tuple(shape)
620
+
621
+ ngood = jnp.broadcast_to(ngood, shape)
622
+ nbad = jnp.broadcast_to(nbad, shape)
623
+ nsample = jnp.broadcast_to(nsample, shape)
624
+
625
+ size = int(np.prod(shape)) if shape else 1
626
+ if size == 0:
627
+ return jnp.empty(shape, dtype=out_dtype)
628
+
629
+ flat_ngood = jnp.reshape(ngood, (size,))
630
+ flat_nbad = jnp.reshape(nbad, (size,))
631
+ flat_nsample = jnp.reshape(nsample, (size,))
632
+ sample_keys = jr.split(key, size + 1)[1:]
633
+
634
+ one = jnp.array(1, dtype=out_dtype)
635
+ zero = jnp.array(0, dtype=out_dtype)
636
+
637
+ def _sample_one(sample_key, good, bad, draws):
638
+ good = jnp.maximum(good, zero)
639
+ bad = jnp.maximum(bad, zero)
640
+ draws = jnp.maximum(draws, zero)
641
+ total = good + bad
642
+ draws = jnp.minimum(draws, total)
643
+
644
+ init_state = (zero, sample_key, good, bad, zero, draws)
645
+
646
+ def cond_fn(state):
647
+ i, _, good_i, bad_i, _, draws_i = state
648
+ total_i = good_i + bad_i
649
+ return jnp.logical_and(i < draws_i, total_i > zero)
650
+
651
+ def body_fn(state):
652
+ i, key_i, good_i, bad_i, succ_i, draws_i = state
653
+ key_i, subkey = jr.split(key_i)
654
+ total_i = good_i + bad_i
655
+ prob = jnp.where(
656
+ total_i > zero,
657
+ lax.convert_element_type(good_i, calc_dtype) / lax.convert_element_type(total_i, calc_dtype),
658
+ jnp.array(0.0, dtype=calc_dtype),
659
+ )
660
+ u = jr.uniform(subkey, shape=(), dtype=calc_dtype)
661
+ success = (u < prob).astype(out_dtype)
662
+ good_i = good_i - success
663
+ bad_i = bad_i - jnp.where(total_i > zero, one - success, zero)
664
+ succ_i = succ_i + success
665
+ return (i + one, key_i, good_i, bad_i, succ_i, draws_i)
666
+
667
+ _, _, _, _, successes, _ = lax.while_loop(cond_fn, body_fn, init_state)
668
+ return successes
669
+
670
+ samples = jax.vmap(_sample_one)(sample_keys, flat_ngood, flat_nbad, flat_nsample)
671
+ samples = lax.convert_element_type(samples, out_dtype)
672
+ return jnp.reshape(samples, shape)