brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -29,11 +29,12 @@ from jax import lax, core, dtypes
29
29
 
30
30
  from brainstate import environ
31
31
  from brainstate._state import State
32
- from brainstate.compile._error_if import jit_error_if
33
32
  from brainstate.typing import DTypeLike, Size, SeedOrKey
34
- from ._random_for_unit import uniform_for_unit, permutation_for_unit
35
33
 
36
- __all__ = ['RandomState', 'DEFAULT', ]
34
+ __all__ = [
35
+ 'RandomState',
36
+ 'DEFAULT',
37
+ ]
37
38
 
38
39
  use_prng_key = True
39
40
 
@@ -43,7 +44,10 @@ class RandomState(State):
43
44
 
44
45
  # __slots__ = ('_backup', '_value')
45
46
 
46
- def __init__(self, seed_or_key: Optional[SeedOrKey] = None):
47
+ def __init__(
48
+ self,
49
+ seed_or_key: Optional[SeedOrKey] = None
50
+ ):
47
51
  """RandomState constructor.
48
52
 
49
53
  Parameters
@@ -69,10 +73,14 @@ class RandomState(State):
69
73
 
70
74
  self._backup = None
71
75
 
72
- def __repr__(self):
76
+ def __repr__(
77
+ self
78
+ ):
73
79
  return f'{self.__class__.__name__}({self.value})'
74
80
 
75
- def check_if_deleted(self):
81
+ def check_if_deleted(
82
+ self
83
+ ):
76
84
  if not use_prng_key and isinstance(self._value, np.ndarray):
77
85
  self._value = jr.key(np.random.randint(0, 10000))
78
86
 
@@ -104,7 +112,10 @@ class RandomState(State):
104
112
  def set_key(self, key: SeedOrKey):
105
113
  self.value = key
106
114
 
107
- def seed(self, seed_or_key: Optional[SeedOrKey] = None):
115
+ def seed(
116
+ self,
117
+ seed_or_key: Optional[SeedOrKey] = None
118
+ ):
108
119
  """Sets a new random seed.
109
120
 
110
121
  Parameters
@@ -132,7 +143,11 @@ class RandomState(State):
132
143
  raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
133
144
  self.value = key
134
145
 
135
- def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
146
+ def split_key(
147
+ self,
148
+ n: Optional[int] = None,
149
+ backup: bool = False
150
+ ) -> SeedOrKey:
136
151
  """
137
152
  Create a new seed from the current seed.
138
153
 
@@ -152,7 +167,7 @@ class RandomState(State):
152
167
  assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
153
168
 
154
169
  if not isinstance(self.value, jax.Array):
155
- self.value = jnp.asarray(self.value, dtype=jnp.uint32)
170
+ self.value = u.math.asarray(self.value, dtype=jnp.uint32)
156
171
  keys = jr.split(self.value, num=2 if n is None else n + 1)
157
172
  self.value = keys[0]
158
173
  if backup:
@@ -162,7 +177,11 @@ class RandomState(State):
162
177
  else:
163
178
  return keys[1:]
164
179
 
165
- def self_assign_multi_keys(self, n: int, backup: bool = True):
180
+ def self_assign_multi_keys(
181
+ self,
182
+ n: int,
183
+ backup: bool = True
184
+ ):
166
185
  """
167
186
  Self-assign multiple keys to the current random state.
168
187
  """
@@ -178,10 +197,15 @@ class RandomState(State):
178
197
  # random functions #
179
198
  # ---------------- #
180
199
 
181
- def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
200
+ def rand(
201
+ self,
202
+ *dn,
203
+ key: Optional[SeedOrKey] = None,
204
+ dtype: DTypeLike = None
205
+ ):
182
206
  key = self.split_key() if key is None else _formalize_key(key)
183
207
  dtype = dtype or environ.dftype()
184
- r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
208
+ r = jr.uniform(key, dn, dtype)
185
209
  return r
186
210
 
187
211
  def randint(
@@ -198,8 +222,8 @@ class RandomState(State):
198
222
  high = _check_py_seq(high)
199
223
  low = _check_py_seq(low)
200
224
  if size is None:
201
- size = lax.broadcast_shapes(jnp.shape(low),
202
- jnp.shape(high))
225
+ size = lax.broadcast_shapes(u.math.shape(low),
226
+ u.math.shape(high))
203
227
  key = self.split_key() if key is None else _formalize_key(key)
204
228
  dtype = dtype or environ.ditype()
205
229
  r = jr.randint(key,
@@ -213,7 +237,7 @@ class RandomState(State):
213
237
  high=None,
214
238
  size: Optional[Size] = None,
215
239
  key: Optional[SeedOrKey] = None,
216
- dtype: DTypeLike = None,
240
+ dtype: DTypeLike = None
217
241
  ):
218
242
  low = _check_py_seq(low)
219
243
  high = _check_py_seq(high)
@@ -222,7 +246,7 @@ class RandomState(State):
222
246
  low = 1
223
247
  high += 1
224
248
  if size is None:
225
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
249
+ size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
226
250
  key = self.split_key() if key is None else _formalize_key(key)
227
251
  dtype = dtype or environ.ditype()
228
252
  r = jr.randint(key,
@@ -232,112 +256,137 @@ class RandomState(State):
232
256
  dtype=dtype)
233
257
  return r
234
258
 
235
- def randn(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
259
+ def randn(
260
+ self,
261
+ *dn,
262
+ key: Optional[SeedOrKey] = None,
263
+ dtype: DTypeLike = None
264
+ ):
236
265
  key = self.split_key() if key is None else _formalize_key(key)
237
266
  dtype = dtype or environ.dftype()
238
267
  r = jr.normal(key, shape=dn, dtype=dtype)
239
268
  return r
240
269
 
241
- def random(self,
242
- size: Optional[Size] = None,
243
- key: Optional[SeedOrKey] = None,
244
- dtype: DTypeLike = None):
270
+ def random(
271
+ self,
272
+ size: Optional[Size] = None,
273
+ key: Optional[SeedOrKey] = None,
274
+ dtype: DTypeLike = None
275
+ ):
245
276
  dtype = dtype or environ.dftype()
246
277
  key = self.split_key() if key is None else _formalize_key(key)
247
- r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
278
+ r = jr.uniform(key, _size2shape(size), dtype)
248
279
  return r
249
280
 
250
- def random_sample(self,
251
- size: Optional[Size] = None,
252
- key: Optional[SeedOrKey] = None,
253
- dtype: DTypeLike = None):
281
+ def random_sample(
282
+ self,
283
+ size: Optional[Size] = None,
284
+ key: Optional[SeedOrKey] = None,
285
+ dtype: DTypeLike = None
286
+ ):
254
287
  r = self.random(size=size, key=key, dtype=dtype)
255
288
  return r
256
289
 
257
- def ranf(self,
258
- size: Optional[Size] = None,
259
- key: Optional[SeedOrKey] = None,
260
- dtype: DTypeLike = None):
290
+ def ranf(
291
+ self,
292
+ size: Optional[Size] = None,
293
+ key: Optional[SeedOrKey] = None,
294
+ dtype: DTypeLike = None
295
+ ):
261
296
  r = self.random(size=size, key=key, dtype=dtype)
262
297
  return r
263
298
 
264
- def sample(self,
265
- size: Optional[Size] = None,
266
- key: Optional[SeedOrKey] = None,
267
- dtype: DTypeLike = None):
299
+ def sample(
300
+ self,
301
+ size: Optional[Size] = None,
302
+ key: Optional[SeedOrKey] = None,
303
+ dtype: DTypeLike = None
304
+ ):
268
305
  r = self.random(size=size, key=key, dtype=dtype)
269
306
  return r
270
307
 
271
- def choice(self,
272
- a,
273
- size: Optional[Size] = None,
274
- replace=True,
275
- p=None,
276
- key: Optional[SeedOrKey] = None):
308
+ def choice(
309
+ self,
310
+ a,
311
+ size: Optional[Size] = None,
312
+ replace=True,
313
+ p=None,
314
+ key: Optional[SeedOrKey] = None
315
+ ):
277
316
  a = _check_py_seq(a)
317
+ a, unit = u.split_mantissa_unit(a)
278
318
  p = _check_py_seq(p)
279
319
  key = self.split_key() if key is None else _formalize_key(key)
280
320
  r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
281
- return r
321
+ return u.maybe_decimal(r * unit)
282
322
 
283
- def permutation(self,
284
- x,
285
- axis: int = 0,
286
- independent: bool = False,
287
- key: Optional[SeedOrKey] = None):
323
+ def permutation(
324
+ self,
325
+ x,
326
+ axis: int = 0,
327
+ independent: bool = False,
328
+ key: Optional[SeedOrKey] = None
329
+ ):
288
330
  x = _check_py_seq(x)
331
+ x, unit = u.split_mantissa_unit(x)
289
332
  key = self.split_key() if key is None else _formalize_key(key)
290
- r = permutation_for_unit(key, x, axis=axis, independent=independent)
291
- return r
333
+ r = jr.permutation(key, x, axis, independent=independent)
334
+ return u.maybe_decimal(r * unit)
292
335
 
293
- def shuffle(self,
294
- x,
295
- axis=0,
296
- key: Optional[SeedOrKey] = None):
297
- key = self.split_key() if key is None else _formalize_key(key)
298
- x = permutation_for_unit(key, x, axis=axis)
299
- return x
300
-
301
- def beta(self,
302
- a,
303
- b,
304
- size: Optional[Size] = None,
305
- key: Optional[SeedOrKey] = None,
306
- dtype: DTypeLike = None):
336
+ def shuffle(
337
+ self,
338
+ x,
339
+ axis=0,
340
+ key: Optional[SeedOrKey] = None
341
+ ):
342
+ return self.permutation(x, axis=axis, key=key, independent=False)
343
+
344
+ def beta(
345
+ self,
346
+ a,
347
+ b,
348
+ size: Optional[Size] = None,
349
+ key: Optional[SeedOrKey] = None,
350
+ dtype: DTypeLike = None
351
+ ):
307
352
  a = _check_py_seq(a)
308
353
  b = _check_py_seq(b)
309
354
  if size is None:
310
- size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b))
355
+ size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
311
356
  key = self.split_key() if key is None else _formalize_key(key)
312
357
  dtype = dtype or environ.dftype()
313
358
  r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
314
359
  return r
315
360
 
316
- def exponential(self,
317
- scale=None,
318
- size: Optional[Size] = None,
319
- key: Optional[SeedOrKey] = None,
320
- dtype: DTypeLike = None):
361
+ def exponential(
362
+ self,
363
+ scale=None,
364
+ size: Optional[Size] = None,
365
+ key: Optional[SeedOrKey] = None,
366
+ dtype: DTypeLike = None
367
+ ):
321
368
  if size is None:
322
- size = jnp.shape(scale)
369
+ size = u.math.shape(scale)
323
370
  key = self.split_key() if key is None else _formalize_key(key)
324
371
  dtype = dtype or environ.dftype()
325
- scale = jnp.asarray(scale, dtype=dtype)
326
372
  r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
327
373
  if scale is not None:
374
+ scale = u.math.asarray(scale, dtype=dtype)
328
375
  r = r / scale
329
376
  return r
330
377
 
331
- def gamma(self,
332
- shape,
333
- scale=None,
334
- size: Optional[Size] = None,
335
- key: Optional[SeedOrKey] = None,
336
- dtype: DTypeLike = None):
378
+ def gamma(
379
+ self,
380
+ shape,
381
+ scale=None,
382
+ size: Optional[Size] = None,
383
+ key: Optional[SeedOrKey] = None,
384
+ dtype: DTypeLike = None
385
+ ):
337
386
  shape = _check_py_seq(shape)
338
387
  scale = _check_py_seq(scale)
339
388
  if size is None:
340
- size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale))
389
+ size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
341
390
  key = self.split_key() if key is None else _formalize_key(key)
342
391
  dtype = dtype or environ.dftype()
343
392
  r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
@@ -345,166 +394,196 @@ class RandomState(State):
345
394
  r = r * scale
346
395
  return r
347
396
 
348
- def gumbel(self,
349
- loc=None,
350
- scale=None,
351
- size: Optional[Size] = None,
352
- key: Optional[SeedOrKey] = None,
353
- dtype: DTypeLike = None):
397
+ def gumbel(
398
+ self,
399
+ loc=None,
400
+ scale=None,
401
+ size: Optional[Size] = None,
402
+ key: Optional[SeedOrKey] = None,
403
+ dtype: DTypeLike = None
404
+ ):
354
405
  loc = _check_py_seq(loc)
355
406
  scale = _check_py_seq(scale)
356
407
  if size is None:
357
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
408
+ size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
358
409
  key = self.split_key() if key is None else _formalize_key(key)
359
410
  dtype = dtype or environ.dftype()
360
411
  r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
361
412
  return r
362
413
 
363
- def laplace(self,
364
- loc=None,
365
- scale=None,
366
- size: Optional[Size] = None,
367
- key: Optional[SeedOrKey] = None,
368
- dtype: DTypeLike = None):
414
+ def laplace(
415
+ self,
416
+ loc=None,
417
+ scale=None,
418
+ size: Optional[Size] = None,
419
+ key: Optional[SeedOrKey] = None,
420
+ dtype: DTypeLike = None
421
+ ):
369
422
  loc = _check_py_seq(loc)
370
423
  scale = _check_py_seq(scale)
371
424
  if size is None:
372
- size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
425
+ size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
373
426
  key = self.split_key() if key is None else _formalize_key(key)
374
427
  dtype = dtype or environ.dftype()
375
428
  r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
376
429
  return r
377
430
 
378
- def logistic(self,
379
- loc=None,
380
- scale=None,
381
- size: Optional[Size] = None,
382
- key: Optional[SeedOrKey] = None,
383
- dtype: DTypeLike = None):
431
+ def logistic(
432
+ self,
433
+ loc=None,
434
+ scale=None,
435
+ size: Optional[Size] = None,
436
+ key: Optional[SeedOrKey] = None,
437
+ dtype: DTypeLike = None
438
+ ):
384
439
  loc = _check_py_seq(loc)
385
440
  scale = _check_py_seq(scale)
386
441
  if size is None:
387
442
  size = lax.broadcast_shapes(
388
- jnp.shape(loc) if loc is not None else (),
389
- jnp.shape(scale) if scale is not None else ()
443
+ u.math.shape(loc) if loc is not None else (),
444
+ u.math.shape(scale) if scale is not None else ()
390
445
  )
391
446
  key = self.split_key() if key is None else _formalize_key(key)
392
447
  dtype = dtype or environ.dftype()
393
448
  r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
394
449
  return r
395
450
 
396
- def normal(self,
397
- loc=None,
398
- scale=None,
399
- size: Optional[Size] = None,
400
- key: Optional[SeedOrKey] = None,
401
- dtype: DTypeLike = None):
451
+ def normal(
452
+ self,
453
+ loc=None,
454
+ scale=None,
455
+ size: Optional[Size] = None,
456
+ key: Optional[SeedOrKey] = None,
457
+ dtype: DTypeLike = None
458
+ ):
402
459
  loc = _check_py_seq(loc)
403
460
  scale = _check_py_seq(scale)
404
461
  if size is None:
405
462
  size = lax.broadcast_shapes(
406
- jnp.shape(scale) if scale is not None else (),
407
- jnp.shape(loc) if loc is not None else ()
463
+ u.math.shape(scale) if scale is not None else (),
464
+ u.math.shape(loc) if loc is not None else ()
408
465
  )
409
466
  key = self.split_key() if key is None else _formalize_key(key)
410
467
  dtype = dtype or environ.dftype()
411
468
  r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
412
469
  return r
413
470
 
414
- def pareto(self,
415
- a,
416
- size: Optional[Size] = None,
417
- key: Optional[SeedOrKey] = None,
418
- dtype: DTypeLike = None):
471
+ def pareto(
472
+ self,
473
+ a,
474
+ size: Optional[Size] = None,
475
+ key: Optional[SeedOrKey] = None,
476
+ dtype: DTypeLike = None
477
+ ):
419
478
  if size is None:
420
- size = jnp.shape(a)
479
+ size = u.math.shape(a)
421
480
  key = self.split_key() if key is None else _formalize_key(key)
422
481
  dtype = dtype or environ.dftype()
423
- a = jnp.asarray(a, dtype=dtype)
482
+ a = u.math.asarray(a, dtype=dtype)
424
483
  r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
425
484
  return r
426
485
 
427
- def poisson(self,
428
- lam=1.0,
429
- size: Optional[Size] = None,
430
- key: Optional[SeedOrKey] = None,
431
- dtype: DTypeLike = None):
486
+ def poisson(
487
+ self,
488
+ lam=1.0,
489
+ size: Optional[Size] = None,
490
+ key: Optional[SeedOrKey] = None,
491
+ dtype: DTypeLike = None
492
+ ):
432
493
  lam = _check_py_seq(lam)
433
494
  if size is None:
434
- size = jnp.shape(lam)
495
+ size = u.math.shape(lam)
435
496
  key = self.split_key() if key is None else _formalize_key(key)
436
497
  dtype = dtype or environ.ditype()
437
498
  r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
438
499
  return r
439
500
 
440
- def standard_cauchy(self,
441
- size: Optional[Size] = None,
442
- key: Optional[SeedOrKey] = None,
443
- dtype: DTypeLike = None):
501
+ def standard_cauchy(
502
+ self,
503
+ size: Optional[Size] = None,
504
+ key: Optional[SeedOrKey] = None,
505
+ dtype: DTypeLike = None
506
+ ):
444
507
  key = self.split_key() if key is None else _formalize_key(key)
445
508
  dtype = dtype or environ.dftype()
446
509
  r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
447
510
  return r
448
511
 
449
- def standard_exponential(self,
450
- size: Optional[Size] = None,
451
- key: Optional[SeedOrKey] = None,
452
- dtype: DTypeLike = None):
512
+ def standard_exponential(
513
+ self,
514
+ size: Optional[Size] = None,
515
+ key: Optional[SeedOrKey] = None,
516
+ dtype: DTypeLike = None
517
+ ):
453
518
  key = self.split_key() if key is None else _formalize_key(key)
454
519
  dtype = dtype or environ.dftype()
455
520
  r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
456
521
  return r
457
522
 
458
- def standard_gamma(self,
459
- shape,
460
- size: Optional[Size] = None,
461
- key: Optional[SeedOrKey] = None,
462
- dtype: DTypeLike = None):
523
+ def standard_gamma(
524
+ self,
525
+ shape,
526
+ size: Optional[Size] = None,
527
+ key: Optional[SeedOrKey] = None,
528
+ dtype: DTypeLike = None
529
+ ):
463
530
  shape = _check_py_seq(shape)
464
531
  if size is None:
465
- size = jnp.shape(shape) if shape is not None else ()
532
+ size = u.math.shape(shape) if shape is not None else ()
466
533
  key = self.split_key() if key is None else _formalize_key(key)
467
534
  dtype = dtype or environ.dftype()
468
535
  r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
469
536
  return r
470
537
 
471
- def standard_normal(self,
472
- size: Optional[Size] = None,
473
- key: Optional[SeedOrKey] = None,
474
- dtype: DTypeLike = None):
538
+ def standard_normal(
539
+ self,
540
+ size: Optional[Size] = None,
541
+ key: Optional[SeedOrKey] = None,
542
+ dtype: DTypeLike = None
543
+ ):
475
544
  key = self.split_key() if key is None else _formalize_key(key)
476
545
  dtype = dtype or environ.dftype()
477
546
  r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
478
547
  return r
479
548
 
480
- def standard_t(self, df,
481
- size: Optional[Size] = None,
482
- key: Optional[SeedOrKey] = None,
483
- dtype: DTypeLike = None):
549
+ def standard_t(
550
+ self,
551
+ df,
552
+ size: Optional[Size] = None,
553
+ key: Optional[SeedOrKey] = None,
554
+ dtype: DTypeLike = None
555
+ ):
484
556
  df = _check_py_seq(df)
485
557
  if size is None:
486
- size = jnp.shape(size) if size is not None else ()
558
+ size = u.math.shape(size) if size is not None else ()
487
559
  key = self.split_key() if key is None else _formalize_key(key)
488
560
  dtype = dtype or environ.dftype()
489
561
  r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
490
562
  return r
491
563
 
492
- def uniform(self,
493
- low=0.0,
494
- high=1.0,
495
- size: Optional[Size] = None,
496
- key: Optional[SeedOrKey] = None,
497
- dtype: DTypeLike = None):
498
- low = _check_py_seq(low)
499
- high = _check_py_seq(high)
564
+ def uniform(
565
+ self,
566
+ low=0.0,
567
+ high=1.0,
568
+ size: Optional[Size] = None,
569
+ key: Optional[SeedOrKey] = None,
570
+ dtype: DTypeLike = None
571
+ ):
572
+ low, unit = u.split_mantissa_unit(_check_py_seq(low))
573
+ high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
500
574
  if size is None:
501
- size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
575
+ size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
502
576
  key = self.split_key() if key is None else _formalize_key(key)
503
577
  dtype = dtype or environ.dftype()
504
- r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
505
- return r
578
+ r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
579
+ return u.maybe_decimal(r * unit)
506
580
 
507
- def __norm_cdf(self, x, sqrt2, dtype):
581
+ def __norm_cdf(
582
+ self,
583
+ x,
584
+ sqrt2,
585
+ dtype
586
+ ):
508
587
  # Computes standard normal cumulative distribution function
509
588
  return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
510
589
 
@@ -513,10 +592,11 @@ class RandomState(State):
513
592
  lower,
514
593
  upper,
515
594
  size: Optional[Size] = None,
516
- loc=0.,
517
- scale=1.,
595
+ loc=0.0,
596
+ scale=1.0,
518
597
  key: Optional[SeedOrKey] = None,
519
- dtype: DTypeLike = None
598
+ dtype: DTypeLike = None,
599
+ check_valid: bool = True
520
600
  ):
521
601
  lower = _check_py_seq(lower)
522
602
  upper = _check_py_seq(upper)
@@ -524,29 +604,31 @@ class RandomState(State):
524
604
  scale = _check_py_seq(scale)
525
605
  dtype = dtype or environ.dftype()
526
606
 
527
- lower = u.math.asarray(lower, dtype=dtype)
607
+ lower, unit = u.split_mantissa_unit(u.math.asarray(lower, dtype=dtype))
528
608
  upper = u.math.asarray(upper, dtype=dtype)
529
609
  loc = u.math.asarray(loc, dtype=dtype)
530
610
  scale = u.math.asarray(scale, dtype=dtype)
531
- unit = u.get_unit(lower)
532
- lower, upper, loc, scale = (
533
- lower.mantissa if isinstance(lower, u.Quantity) else lower,
611
+ upper, loc, scale = (
534
612
  u.Quantity(upper).in_unit(unit).mantissa,
535
613
  u.Quantity(loc).in_unit(unit).mantissa,
536
614
  u.Quantity(scale).in_unit(unit).mantissa
537
615
  )
538
616
 
539
- jit_error_if(
540
- u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
541
- "mean is more than 2 std from [lower, upper] in truncated_normal. "
542
- "The distribution of values may be incorrect."
543
- )
617
+ if check_valid:
618
+ from brainstate.transform._error_if import jit_error_if
619
+ jit_error_if(
620
+ u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
621
+ "mean is more than 2 std from [lower, upper] in truncated_normal. "
622
+ "The distribution of values may be incorrect."
623
+ )
544
624
 
545
625
  if size is None:
546
- size = u.math.broadcast_shapes(jnp.shape(lower),
547
- jnp.shape(upper),
548
- jnp.shape(loc),
549
- jnp.shape(scale))
626
+ size = u.math.broadcast_shapes(
627
+ u.math.shape(lower),
628
+ u.math.shape(upper),
629
+ u.math.shape(loc),
630
+ u.math.shape(scale)
631
+ )
550
632
 
551
633
  # Values are generated by using a truncated uniform distribution and
552
634
  # then using the inverse CDF for the normal distribution.
@@ -558,7 +640,7 @@ class RandomState(State):
558
640
  # Uniformly fill tensor with values from [l, u], then translate to
559
641
  # [2l-1, 2u-1].
560
642
  key = self.split_key() if key is None else _formalize_key(key)
561
- out = uniform_for_unit(
643
+ out = jr.uniform(
562
644
  key, size, dtype,
563
645
  minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
564
646
  maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
@@ -577,19 +659,24 @@ class RandomState(State):
577
659
  lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
578
660
  lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
579
661
  )
580
- return out if unit.is_unitless else u.Quantity(out, unit=unit)
662
+ return u.maybe_decimal(out * unit)
581
663
 
582
664
  def _check_p(self, *args, **kwargs):
583
665
  raise ValueError('Parameter p should be within [0, 1], but we got {p}')
584
666
 
585
- def bernoulli(self,
586
- p,
587
- size: Optional[Size] = None,
588
- key: Optional[SeedOrKey] = None):
667
+ def bernoulli(
668
+ self,
669
+ p,
670
+ size: Optional[Size] = None,
671
+ key: Optional[SeedOrKey] = None,
672
+ check_valid: bool = True
673
+ ):
589
674
  p = _check_py_seq(p)
590
- jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
675
+ if check_valid:
676
+ from brainstate.transform._error_if import jit_error_if
677
+ jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
591
678
  if size is None:
592
- size = jnp.shape(p)
679
+ size = u.math.shape(p)
593
680
  key = self.split_key() if key is None else _formalize_key(key)
594
681
  r = jr.bernoulli(key, p=p, shape=_size2shape(size))
595
682
  return r
@@ -606,21 +693,21 @@ class RandomState(State):
606
693
  sigma = _check_py_seq(sigma)
607
694
  mean = u.math.asarray(mean, dtype=dtype)
608
695
  sigma = u.math.asarray(sigma, dtype=dtype)
609
- unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
696
+ unit = mean.unit if isinstance(mean, u.Quantity) else u.UNITLESS
610
697
  mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
611
698
  sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
612
699
 
613
700
  if size is None:
614
701
  size = jnp.broadcast_shapes(
615
- jnp.shape(mean) if mean is not None else (),
616
- jnp.shape(sigma) if sigma is not None else ()
702
+ u.math.shape(mean) if mean is not None else (),
703
+ u.math.shape(sigma) if sigma is not None else ()
617
704
  )
618
705
  key = self.split_key() if key is None else _formalize_key(key)
619
706
  dtype = dtype or environ.dftype()
620
707
  samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
621
708
  samples = _loc_scale(mean, sigma, samples)
622
709
  samples = jnp.exp(samples)
623
- return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
710
+ return u.maybe_decimal(samples * unit)
624
711
 
625
712
  def binomial(
626
713
  self,
@@ -629,28 +716,31 @@ class RandomState(State):
629
716
  size: Optional[Size] = None,
630
717
  key: Optional[SeedOrKey] = None,
631
718
  dtype: DTypeLike = None,
632
- check_valid: bool = True,
719
+ check_valid: bool = True
633
720
  ):
634
721
  n = _check_py_seq(n)
635
722
  p = _check_py_seq(p)
636
723
  if check_valid:
724
+ from brainstate.transform._error_if import jit_error_if
637
725
  jit_error_if(
638
726
  jnp.any(jnp.logical_or(p < 0, p > 1)),
639
727
  'Parameter p should be within [0, 1], but we got {p}',
640
728
  p=p
641
729
  )
642
730
  if size is None:
643
- size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
731
+ size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
644
732
  key = self.split_key() if key is None else _formalize_key(key)
645
733
  r = jr.binomial(key, n, p, shape=_size2shape(size))
646
734
  dtype = dtype or environ.ditype()
647
- return jnp.asarray(r, dtype=dtype)
735
+ return u.math.asarray(r, dtype=dtype)
648
736
 
649
- def chisquare(self,
650
- df,
651
- size: Optional[Size] = None,
652
- key: Optional[SeedOrKey] = None,
653
- dtype: DTypeLike = None):
737
+ def chisquare(
738
+ self,
739
+ df,
740
+ size: Optional[Size] = None,
741
+ key: Optional[SeedOrKey] = None,
742
+ dtype: DTypeLike = None
743
+ ):
654
744
  df = _check_py_seq(df)
655
745
  key = self.split_key() if key is None else _formalize_key(key)
656
746
  dtype = dtype or environ.dftype()
@@ -665,52 +755,61 @@ class RandomState(State):
665
755
  dist = dist.sum(axis=0)
666
756
  return dist
667
757
 
668
- def dirichlet(self,
669
- alpha,
670
- size: Optional[Size] = None,
671
- key: Optional[SeedOrKey] = None,
672
- dtype: DTypeLike = None):
758
+ def dirichlet(
759
+ self,
760
+ alpha,
761
+ size: Optional[Size] = None,
762
+ key: Optional[SeedOrKey] = None,
763
+ dtype: DTypeLike = None
764
+ ):
673
765
  key = self.split_key() if key is None else _formalize_key(key)
674
766
  alpha = _check_py_seq(alpha)
675
767
  dtype = dtype or environ.dftype()
676
768
  r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
677
769
  return r
678
770
 
679
- def geometric(self,
680
- p,
681
- size: Optional[Size] = None,
682
- key: Optional[SeedOrKey] = None,
683
- dtype: DTypeLike = None):
771
+ def geometric(
772
+ self,
773
+ p,
774
+ size: Optional[Size] = None,
775
+ key: Optional[SeedOrKey] = None,
776
+ dtype: DTypeLike = None
777
+ ):
684
778
  p = _check_py_seq(p)
685
779
  if size is None:
686
- size = jnp.shape(p)
780
+ size = u.math.shape(p)
687
781
  key = self.split_key() if key is None else _formalize_key(key)
688
782
  dtype = dtype or environ.dftype()
689
- u_ = uniform_for_unit(key, size, dtype=dtype)
783
+ u_ = jr.uniform(key, size, dtype)
690
784
  r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
691
785
  return r
692
786
 
693
787
  def _check_p2(self, p):
694
788
  raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
695
789
 
696
- def multinomial(self,
697
- n,
698
- pvals,
699
- size: Optional[Size] = None,
700
- key: Optional[SeedOrKey] = None,
701
- dtype: DTypeLike = None):
790
+ def multinomial(
791
+ self,
792
+ n,
793
+ pvals,
794
+ size: Optional[Size] = None,
795
+ key: Optional[SeedOrKey] = None,
796
+ dtype: DTypeLike = None,
797
+ check_valid: bool = True
798
+ ):
702
799
  key = self.split_key() if key is None else _formalize_key(key)
703
800
  n = _check_py_seq(n)
704
801
  pvals = _check_py_seq(pvals)
705
- jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
802
+ if check_valid:
803
+ from brainstate.transform._error_if import jit_error_if
804
+ jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
706
805
  if isinstance(n, jax.core.Tracer):
707
806
  raise ValueError("The total count parameter `n` should not be a jax abstract array.")
708
807
  size = _size2shape(size)
709
808
  n_max = int(np.max(jax.device_get(n)))
710
- batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n))
809
+ batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
711
810
  r = _multinomial(key, pvals, n, n_max, batch_shape + size)
712
811
  dtype = dtype or environ.ditype()
713
- return jnp.asarray(r, dtype=dtype)
812
+ return u.math.asarray(r, dtype=dtype)
714
813
 
715
814
  def multivariate_normal(
716
815
  self,
@@ -739,9 +838,9 @@ class RandomState(State):
739
838
  if not jnp.ndim(cov) >= 2:
740
839
  raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
741
840
  n = mean.shape[-1]
742
- if jnp.shape(cov)[-2:] != (n, n):
841
+ if u.math.shape(cov)[-2:] != (n, n):
743
842
  raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
744
- f"but got cov.shape == {jnp.shape(cov)}.")
843
+ f"but got cov.shape == {u.math.shape(cov)}.")
745
844
  if size is None:
746
845
  size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
747
846
  else:
@@ -758,92 +857,104 @@ class RandomState(State):
758
857
  factor = jnp.linalg.cholesky(cov)
759
858
  normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
760
859
  r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
761
- return r if unit.is_unitless else u.Quantity(r, unit=unit)
860
+ return u.maybe_decimal(r * unit)
762
861
 
763
- def rayleigh(self,
764
- scale=1.0,
765
- size: Optional[Size] = None,
766
- key: Optional[SeedOrKey] = None,
767
- dtype: DTypeLike = None):
862
+ def rayleigh(
863
+ self,
864
+ scale=1.0,
865
+ size: Optional[Size] = None,
866
+ key: Optional[SeedOrKey] = None,
867
+ dtype: DTypeLike = None
868
+ ):
768
869
  scale = _check_py_seq(scale)
769
870
  if size is None:
770
- size = jnp.shape(scale)
871
+ size = u.math.shape(scale)
771
872
  key = self.split_key() if key is None else _formalize_key(key)
772
873
  dtype = dtype or environ.dftype()
773
- x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
874
+ x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
774
875
  r = x * scale
775
876
  return r
776
877
 
777
- def triangular(self,
778
- size: Optional[Size] = None,
779
- key: Optional[SeedOrKey] = None):
878
+ def triangular(
879
+ self,
880
+ size: Optional[Size] = None,
881
+ key: Optional[SeedOrKey] = None
882
+ ):
780
883
  key = self.split_key() if key is None else _formalize_key(key)
781
884
  bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
782
885
  r = 2 * bernoulli_samples - 1
783
886
  return r
784
887
 
785
- def vonmises(self,
786
- mu,
787
- kappa,
788
- size: Optional[Size] = None,
789
- key: Optional[SeedOrKey] = None,
790
- dtype: DTypeLike = None):
888
+ def vonmises(
889
+ self,
890
+ mu,
891
+ kappa,
892
+ size: Optional[Size] = None,
893
+ key: Optional[SeedOrKey] = None,
894
+ dtype: DTypeLike = None
895
+ ):
791
896
  key = self.split_key() if key is None else _formalize_key(key)
792
897
  dtype = dtype or environ.dftype()
793
- mu = jnp.asarray(_check_py_seq(mu), dtype=dtype)
794
- kappa = jnp.asarray(_check_py_seq(kappa), dtype=dtype)
898
+ mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
899
+ kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
795
900
  if size is None:
796
- size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa))
901
+ size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
797
902
  size = _size2shape(size)
798
903
  samples = _von_mises_centered(key, kappa, size, dtype=dtype)
799
904
  samples = samples + mu
800
905
  samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
801
906
  return samples
802
907
 
803
- def weibull(self,
804
- a,
805
- size: Optional[Size] = None,
806
- key: Optional[SeedOrKey] = None,
807
- dtype: DTypeLike = None):
908
+ def weibull(
909
+ self,
910
+ a,
911
+ size: Optional[Size] = None,
912
+ key: Optional[SeedOrKey] = None,
913
+ dtype: DTypeLike = None
914
+ ):
808
915
  key = self.split_key() if key is None else _formalize_key(key)
809
916
  a = _check_py_seq(a)
810
917
  if size is None:
811
- size = jnp.shape(a)
918
+ size = u.math.shape(a)
812
919
  else:
813
920
  if jnp.size(a) > 1:
814
921
  raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
815
922
  size = _size2shape(size)
816
923
  dtype = dtype or environ.dftype()
817
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
924
+ random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
818
925
  r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
819
926
  return r
820
927
 
821
- def weibull_min(self,
822
- a,
823
- scale=None,
824
- size: Optional[Size] = None,
825
- key: Optional[SeedOrKey] = None,
826
- dtype: DTypeLike = None):
928
+ def weibull_min(
929
+ self,
930
+ a,
931
+ scale=None,
932
+ size: Optional[Size] = None,
933
+ key: Optional[SeedOrKey] = None,
934
+ dtype: DTypeLike = None
935
+ ):
827
936
  key = self.split_key() if key is None else _formalize_key(key)
828
937
  a = _check_py_seq(a)
829
938
  scale = _check_py_seq(scale)
830
939
  if size is None:
831
- size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
940
+ size = jnp.broadcast_shapes(u.math.shape(a), u.math.shape(scale) if scale is not None else ())
832
941
  else:
833
942
  if jnp.size(a) > 1:
834
943
  raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
835
944
  size = _size2shape(size)
836
945
  dtype = dtype or environ.dftype()
837
- random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
946
+ random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
838
947
  r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
839
948
  if scale is not None:
840
949
  r /= scale
841
950
  return r
842
951
 
843
- def maxwell(self,
844
- size: Optional[Size] = None,
845
- key: Optional[SeedOrKey] = None,
846
- dtype: DTypeLike = None):
952
+ def maxwell(
953
+ self,
954
+ size: Optional[Size] = None,
955
+ key: Optional[SeedOrKey] = None,
956
+ dtype: DTypeLike = None
957
+ ):
847
958
  key = self.split_key() if key is None else _formalize_key(key)
848
959
  shape = _size2shape(size) + (3,)
849
960
  dtype = dtype or environ.dftype()
@@ -851,16 +962,18 @@ class RandomState(State):
851
962
  r = jnp.linalg.norm(norm_rvs, axis=-1)
852
963
  return r
853
964
 
854
- def negative_binomial(self,
855
- n,
856
- p,
857
- size: Optional[Size] = None,
858
- key: Optional[SeedOrKey] = None,
859
- dtype: DTypeLike = None):
965
+ def negative_binomial(
966
+ self,
967
+ n,
968
+ p,
969
+ size: Optional[Size] = None,
970
+ key: Optional[SeedOrKey] = None,
971
+ dtype: DTypeLike = None
972
+ ):
860
973
  n = _check_py_seq(n)
861
974
  p = _check_py_seq(p)
862
975
  if size is None:
863
- size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p))
976
+ size = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p))
864
977
  size = _size2shape(size)
865
978
  logits = jnp.log(p) - jnp.log1p(-p)
866
979
  if key is None:
@@ -871,18 +984,20 @@ class RandomState(State):
871
984
  r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
872
985
  return r
873
986
 
874
- def wald(self,
875
- mean,
876
- scale,
877
- size: Optional[Size] = None,
878
- key: Optional[SeedOrKey] = None,
879
- dtype: DTypeLike = None):
987
+ def wald(
988
+ self,
989
+ mean,
990
+ scale,
991
+ size: Optional[Size] = None,
992
+ key: Optional[SeedOrKey] = None,
993
+ dtype: DTypeLike = None
994
+ ):
880
995
  dtype = dtype or environ.dftype()
881
996
  key = self.split_key() if key is None else _formalize_key(key)
882
- mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
883
- scale = jnp.asarray(_check_py_seq(scale), dtype=dtype)
997
+ mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
998
+ scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
884
999
  if size is None:
885
- size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale))
1000
+ size = lax.broadcast_shapes(u.math.shape(mean), u.math.shape(scale))
886
1001
  size = _size2shape(size)
887
1002
  sampled_chi2 = jnp.square(self.randn(*size))
888
1003
  sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
@@ -917,13 +1032,15 @@ class RandomState(State):
917
1032
  jnp.square(mean) / sampled)
918
1033
  return res
919
1034
 
920
- def t(self,
921
- df,
922
- size: Optional[Size] = None,
923
- key: Optional[SeedOrKey] = None,
924
- dtype: DTypeLike = None):
1035
+ def t(
1036
+ self,
1037
+ df,
1038
+ size: Optional[Size] = None,
1039
+ key: Optional[SeedOrKey] = None,
1040
+ dtype: DTypeLike = None
1041
+ ):
925
1042
  dtype = dtype or environ.dftype()
926
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
1043
+ df = u.math.asarray(_check_py_seq(df), dtype=dtype)
927
1044
  if size is None:
928
1045
  size = np.shape(df)
929
1046
  else:
@@ -940,11 +1057,13 @@ class RandomState(State):
940
1057
  r = n * jnp.sqrt(half_df / g)
941
1058
  return r
942
1059
 
943
- def orthogonal(self,
944
- n: int,
945
- size: Optional[Size] = None,
946
- key: Optional[SeedOrKey] = None,
947
- dtype: DTypeLike = None):
1060
+ def orthogonal(
1061
+ self,
1062
+ n: int,
1063
+ size: Optional[Size] = None,
1064
+ key: Optional[SeedOrKey] = None,
1065
+ dtype: DTypeLike = None
1066
+ ):
948
1067
  dtype = dtype or environ.dftype()
949
1068
  key = self.split_key() if key is None else _formalize_key(key)
950
1069
  size = _size2shape(size)
@@ -956,17 +1075,19 @@ class RandomState(State):
956
1075
  r = q * jnp.expand_dims(d / abs(d), -2)
957
1076
  return r
958
1077
 
959
- def noncentral_chisquare(self,
960
- df,
961
- nonc,
962
- size: Optional[Size] = None,
963
- key: Optional[SeedOrKey] = None,
964
- dtype: DTypeLike = None):
1078
+ def noncentral_chisquare(
1079
+ self,
1080
+ df,
1081
+ nonc,
1082
+ size: Optional[Size] = None,
1083
+ key: Optional[SeedOrKey] = None,
1084
+ dtype: DTypeLike = None
1085
+ ):
965
1086
  dtype = dtype or environ.dftype()
966
- df = jnp.asarray(_check_py_seq(df), dtype=dtype)
967
- nonc = jnp.asarray(_check_py_seq(nonc), dtype=dtype)
1087
+ df = u.math.asarray(_check_py_seq(df), dtype=dtype)
1088
+ nonc = u.math.asarray(_check_py_seq(nonc), dtype=dtype)
968
1089
  if size is None:
969
- size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc))
1090
+ size = lax.broadcast_shapes(u.math.shape(df), u.math.shape(nonc))
970
1091
  size = _size2shape(size)
971
1092
  if key is None:
972
1093
  keys = self.split_key(3)
@@ -980,54 +1101,62 @@ class RandomState(State):
980
1101
  r = jnp.where(cond, chi2 + n * n, chi2)
981
1102
  return r
982
1103
 
983
- def loggamma(self,
984
- a,
985
- size: Optional[Size] = None,
986
- key: Optional[SeedOrKey] = None,
987
- dtype: DTypeLike = None):
1104
+ def loggamma(
1105
+ self,
1106
+ a,
1107
+ size: Optional[Size] = None,
1108
+ key: Optional[SeedOrKey] = None,
1109
+ dtype: DTypeLike = None
1110
+ ):
988
1111
  dtype = dtype or environ.dftype()
989
1112
  key = self.split_key() if key is None else _formalize_key(key)
990
1113
  a = _check_py_seq(a)
991
1114
  if size is None:
992
- size = jnp.shape(a)
1115
+ size = u.math.shape(a)
993
1116
  r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
994
1117
  return r
995
1118
 
996
- def categorical(self,
997
- logits,
998
- axis: int = -1,
999
- size: Optional[Size] = None,
1000
- key: Optional[SeedOrKey] = None):
1119
+ def categorical(
1120
+ self,
1121
+ logits,
1122
+ axis: int = -1,
1123
+ size: Optional[Size] = None,
1124
+ key: Optional[SeedOrKey] = None
1125
+ ):
1001
1126
  key = self.split_key() if key is None else _formalize_key(key)
1002
1127
  logits = _check_py_seq(logits)
1003
1128
  if size is None:
1004
- size = list(jnp.shape(logits))
1129
+ size = list(u.math.shape(logits))
1005
1130
  size.pop(axis)
1006
1131
  r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
1007
1132
  return r
1008
1133
 
1009
- def zipf(self,
1010
- a,
1011
- size: Optional[Size] = None,
1012
- key: Optional[SeedOrKey] = None,
1013
- dtype: DTypeLike = None):
1134
+ def zipf(
1135
+ self,
1136
+ a,
1137
+ size: Optional[Size] = None,
1138
+ key: Optional[SeedOrKey] = None,
1139
+ dtype: DTypeLike = None
1140
+ ):
1014
1141
  a = _check_py_seq(a)
1015
1142
  if size is None:
1016
- size = jnp.shape(a)
1143
+ size = u.math.shape(a)
1017
1144
  dtype = dtype or environ.ditype()
1018
1145
  r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
1019
1146
  jax.ShapeDtypeStruct(size, dtype),
1020
1147
  a)
1021
1148
  return r
1022
1149
 
1023
- def power(self,
1024
- a,
1025
- size: Optional[Size] = None,
1026
- key: Optional[SeedOrKey] = None,
1027
- dtype: DTypeLike = None):
1150
+ def power(
1151
+ self,
1152
+ a,
1153
+ size: Optional[Size] = None,
1154
+ key: Optional[SeedOrKey] = None,
1155
+ dtype: DTypeLike = None
1156
+ ):
1028
1157
  a = _check_py_seq(a)
1029
1158
  if size is None:
1030
- size = jnp.shape(a)
1159
+ size = u.math.shape(a)
1031
1160
  size = _size2shape(size)
1032
1161
  dtype = dtype or environ.dftype()
1033
1162
  r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
@@ -1035,24 +1164,28 @@ class RandomState(State):
1035
1164
  a)
1036
1165
  return r
1037
1166
 
1038
- def f(self,
1039
- dfnum,
1040
- dfden,
1041
- size: Optional[Size] = None,
1042
- key: Optional[SeedOrKey] = None,
1043
- dtype: DTypeLike = None):
1167
+ def f(
1168
+ self,
1169
+ dfnum,
1170
+ dfden,
1171
+ size: Optional[Size] = None,
1172
+ key: Optional[SeedOrKey] = None,
1173
+ dtype: DTypeLike = None
1174
+ ):
1044
1175
  dfnum = _check_py_seq(dfnum)
1045
1176
  dfden = _check_py_seq(dfden)
1046
1177
  if size is None:
1047
- size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden))
1178
+ size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
1048
1179
  size = _size2shape(size)
1049
1180
  d = {'dfnum': dfnum, 'dfden': dfden}
1050
1181
  dtype = dtype or environ.dftype()
1051
- r = jax.pure_callback(lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1052
- dfden=dfden_,
1053
- size=size).astype(dtype),
1054
- jax.ShapeDtypeStruct(size, dtype),
1055
- dfnum, dfden)
1182
+ r = jax.pure_callback(
1183
+ lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
1184
+ dfden=dfden_,
1185
+ size=size).astype(dtype),
1186
+ jax.ShapeDtypeStruct(size, dtype),
1187
+ dfnum, dfden
1188
+ )
1056
1189
  return r
1057
1190
 
1058
1191
  def hypergeometric(
@@ -1069,64 +1202,82 @@ class RandomState(State):
1069
1202
  nsample = _check_py_seq(nsample)
1070
1203
 
1071
1204
  if size is None:
1072
- size = lax.broadcast_shapes(jnp.shape(ngood),
1073
- jnp.shape(nbad),
1074
- jnp.shape(nsample))
1205
+ size = lax.broadcast_shapes(u.math.shape(ngood),
1206
+ u.math.shape(nbad),
1207
+ u.math.shape(nsample))
1075
1208
  size = _size2shape(size)
1076
1209
  dtype = dtype or environ.ditype()
1077
1210
  d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
1078
- r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
1079
- nbad=d['nbad'],
1080
- nsample=d['nsample'],
1081
- size=size).astype(dtype),
1082
- jax.ShapeDtypeStruct(size, dtype),
1083
- d)
1211
+ r = jax.pure_callback(
1212
+ lambda d: np.random.hypergeometric(
1213
+ ngood=d['ngood'],
1214
+ nbad=d['nbad'],
1215
+ nsample=d['nsample'],
1216
+ size=size
1217
+ ).astype(dtype),
1218
+ jax.ShapeDtypeStruct(size, dtype),
1219
+ d
1220
+ )
1084
1221
  return r
1085
1222
 
1086
- def logseries(self,
1087
- p,
1088
- size: Optional[Size] = None,
1089
- key: Optional[SeedOrKey] = None,
1090
- dtype: DTypeLike = None):
1223
+ def logseries(
1224
+ self,
1225
+ p,
1226
+ size: Optional[Size] = None,
1227
+ key: Optional[SeedOrKey] = None,
1228
+ dtype: DTypeLike = None
1229
+ ):
1091
1230
  p = _check_py_seq(p)
1092
1231
  if size is None:
1093
- size = jnp.shape(p)
1232
+ size = u.math.shape(p)
1094
1233
  size = _size2shape(size)
1095
1234
  dtype = dtype or environ.ditype()
1096
- r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1097
- jax.ShapeDtypeStruct(size, dtype),
1098
- p)
1235
+ r = jax.pure_callback(
1236
+ lambda p: np.random.logseries(p=p, size=size).astype(dtype),
1237
+ jax.ShapeDtypeStruct(size, dtype),
1238
+ p
1239
+ )
1099
1240
  return r
1100
1241
 
1101
- def noncentral_f(self,
1102
- dfnum,
1103
- dfden,
1104
- nonc,
1105
- size: Optional[Size] = None,
1106
- key: Optional[SeedOrKey] = None,
1107
- dtype: DTypeLike = None):
1242
+ def noncentral_f(
1243
+ self,
1244
+ dfnum,
1245
+ dfden,
1246
+ nonc,
1247
+ size: Optional[Size] = None,
1248
+ key: Optional[SeedOrKey] = None,
1249
+ dtype: DTypeLike = None
1250
+ ):
1108
1251
  dfnum = _check_py_seq(dfnum)
1109
1252
  dfden = _check_py_seq(dfden)
1110
1253
  nonc = _check_py_seq(nonc)
1111
1254
  if size is None:
1112
- size = lax.broadcast_shapes(jnp.shape(dfnum),
1113
- jnp.shape(dfden),
1114
- jnp.shape(nonc))
1255
+ size = lax.broadcast_shapes(u.math.shape(dfnum),
1256
+ u.math.shape(dfden),
1257
+ u.math.shape(nonc))
1115
1258
  size = _size2shape(size)
1116
1259
  d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
1117
1260
  dtype = dtype or environ.dftype()
1118
- r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1119
- dfden=x['dfden'],
1120
- nonc=x['nonc'],
1121
- size=size).astype(dtype),
1122
- jax.ShapeDtypeStruct(size, dtype),
1123
- d)
1261
+ r = jax.pure_callback(
1262
+ lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
1263
+ dfden=x['dfden'],
1264
+ nonc=x['nonc'],
1265
+ size=size).astype(dtype),
1266
+ jax.ShapeDtypeStruct(size, dtype),
1267
+ d
1268
+ )
1124
1269
  return r
1125
1270
 
1126
1271
  # PyTorch compatibility #
1127
1272
  # --------------------- #
1128
1273
 
1129
- def rand_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1274
+ def rand_like(
1275
+ self,
1276
+ input,
1277
+ *,
1278
+ dtype=None,
1279
+ key: Optional[SeedOrKey] = None
1280
+ ):
1130
1281
  """Returns a tensor with the same size as input that is filled with random
1131
1282
  numbers from a uniform distribution on the interval ``[0, 1)``.
1132
1283
 
@@ -1138,9 +1289,15 @@ class RandomState(State):
1138
1289
  Returns:
1139
1290
  The random data.
1140
1291
  """
1141
- return self.random(jnp.shape(input), key=key).astype(dtype)
1292
+ return self.random(u.math.shape(input), key=key).astype(dtype)
1142
1293
 
1143
- def randn_like(self, input, *, dtype=None, key: Optional[SeedOrKey] = None):
1294
+ def randn_like(
1295
+ self,
1296
+ input,
1297
+ *,
1298
+ dtype=None,
1299
+ key: Optional[SeedOrKey] = None
1300
+ ):
1144
1301
  """Returns a tensor with the same size as ``input`` that is filled with
1145
1302
  random numbers from a normal distribution with mean 0 and variance 1.
1146
1303
 
@@ -1152,12 +1309,20 @@ class RandomState(State):
1152
1309
  Returns:
1153
1310
  The random data.
1154
1311
  """
1155
- return self.randn(*jnp.shape(input), key=key).astype(dtype)
1312
+ return self.randn(*u.math.shape(input), key=key).astype(dtype)
1156
1313
 
1157
- def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[SeedOrKey] = None):
1314
+ def randint_like(
1315
+ self,
1316
+ input,
1317
+ low=0,
1318
+ high=None,
1319
+ *,
1320
+ dtype=None,
1321
+ key: Optional[SeedOrKey] = None
1322
+ ):
1158
1323
  if high is None:
1159
1324
  high = max(input)
1160
- return self.randint(low, high=high, size=jnp.shape(input), dtype=dtype, key=key)
1325
+ return self.randint(low, high=high, size=u.math.shape(input), dtype=dtype, key=key)
1161
1326
 
1162
1327
 
1163
1328
  # default random generator
@@ -1180,7 +1345,7 @@ def _formalize_key(key):
1180
1345
  raise TypeError('key must be a int or an array with two uint32.')
1181
1346
  if key.size != 2:
1182
1347
  raise TypeError('key must be a int or an array with two uint32.')
1183
- return jnp.asarray(key, dtype=jnp.uint32)
1348
+ return u.math.asarray(key, dtype=jnp.uint32)
1184
1349
  else:
1185
1350
  raise TypeError('key must be a int or an array with two uint32.')
1186
1351
 
@@ -1194,7 +1359,11 @@ def _size2shape(size):
1194
1359
  return (size,)
1195
1360
 
1196
1361
 
1197
- def _check_shape(name, shape, *param_shapes):
1362
+ def _check_shape(
1363
+ name,
1364
+ shape,
1365
+ *param_shapes
1366
+ ):
1198
1367
  if param_shapes:
1199
1368
  shape_ = lax.broadcast_shapes(shape, *param_shapes)
1200
1369
  if shape != shape_:
@@ -1223,7 +1392,11 @@ python_scalar_dtypes = {
1223
1392
  }
1224
1393
 
1225
1394
 
1226
- def _dtype(x, *, canonicalize: bool = False):
1395
+ def _dtype(
1396
+ x,
1397
+ *,
1398
+ canonicalize: bool = False
1399
+ ):
1227
1400
  """Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
1228
1401
  if x is None:
1229
1402
  raise ValueError(f"Invalid argument to dtype: {x}.")
@@ -1238,7 +1411,10 @@ def _dtype(x, *, canonicalize: bool = False):
1238
1411
  return dtypes.canonicalize_dtype(dt) if canonicalize else dt
1239
1412
 
1240
1413
 
1241
- def _const(example, val):
1414
+ def _const(
1415
+ example,
1416
+ val
1417
+ ):
1242
1418
  if _is_python_scalar(example):
1243
1419
  dtype = dtypes.canonicalize_dtype(type(example))
1244
1420
  val = dtypes.scalar_type_of(example)(val)
@@ -1249,7 +1425,11 @@ def _const(example, val):
1249
1425
 
1250
1426
 
1251
1427
  @partial(jit, static_argnums=(2,))
1252
- def _categorical(key, p, shape):
1428
+ def _categorical(
1429
+ key,
1430
+ p,
1431
+ shape
1432
+ ):
1253
1433
  # this implementation is fast when event shape is small, and slow otherwise
1254
1434
  # Ref: https://stackoverflow.com/a/34190035
1255
1435
  shape = shape or p.shape[:-1]
@@ -1258,7 +1438,11 @@ def _categorical(key, p, shape):
1258
1438
  return jnp.sum(s < r, axis=-1)
1259
1439
 
1260
1440
 
1261
- def _scatter_add_one(operand, indices, updates):
1441
+ def _scatter_add_one(
1442
+ operand,
1443
+ indices,
1444
+ updates
1445
+ ):
1262
1446
  return lax.scatter_add(
1263
1447
  operand,
1264
1448
  indices,
@@ -1278,12 +1462,15 @@ def _reshape(x, shape):
1278
1462
  return jnp.reshape(x, shape)
1279
1463
 
1280
1464
 
1281
- def _promote_shapes(*args, shape=()):
1465
+ def _promote_shapes(
1466
+ *args,
1467
+ shape=()
1468
+ ):
1282
1469
  # adapted from lax.lax_numpy
1283
1470
  if len(args) < 2 and not shape:
1284
1471
  return args
1285
1472
  else:
1286
- shapes = [jnp.shape(arg) for arg in args]
1473
+ shapes = [u.math.shape(arg) for arg in args]
1287
1474
  num_dims = len(lax.broadcast_shapes(shape, *shapes))
1288
1475
  return [
1289
1476
  _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
@@ -1292,11 +1479,17 @@ def _promote_shapes(*args, shape=()):
1292
1479
 
1293
1480
 
1294
1481
  @partial(jit, static_argnums=(3, 4))
1295
- def _multinomial(key, p, n, n_max, shape=()):
1296
- if jnp.shape(n) != jnp.shape(p)[:-1]:
1297
- broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
1482
+ def _multinomial(
1483
+ key,
1484
+ p,
1485
+ n,
1486
+ n_max,
1487
+ shape=()
1488
+ ):
1489
+ if u.math.shape(n) != u.math.shape(p)[:-1]:
1490
+ broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
1298
1491
  n = jnp.broadcast_to(n, broadcast_shape)
1299
- p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
1492
+ p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
1300
1493
  shape = shape or p.shape[:-1]
1301
1494
  if n_max == 0:
1302
1495
  return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
@@ -1307,21 +1500,28 @@ def _multinomial(key, p, n, n_max, shape=()):
1307
1500
  mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
1308
1501
  mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
1309
1502
  excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
1310
- jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))],
1503
+ jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
1311
1504
  -1)
1312
1505
  else:
1313
1506
  mask = 1
1314
1507
  excess = 0
1315
1508
  # NB: we transpose to move batch shape to the front
1316
1509
  indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
1317
- samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1318
- jnp.expand_dims(indices_2D, axis=-1),
1319
- jnp.ones(indices_2D.shape, dtype=indices.dtype))
1510
+ samples_2D = vmap(_scatter_add_one)(
1511
+ jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
1512
+ jnp.expand_dims(indices_2D, axis=-1),
1513
+ jnp.ones(indices_2D.shape, dtype=indices.dtype)
1514
+ )
1320
1515
  return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
1321
1516
 
1322
1517
 
1323
1518
  @partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
1324
- def _von_mises_centered(key, concentration, shape, dtype=None):
1519
+ def _von_mises_centered(
1520
+ key,
1521
+ concentration,
1522
+ shape,
1523
+ dtype=None
1524
+ ):
1325
1525
  """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
1326
1526
 
1327
1527
  Returns
@@ -1335,7 +1535,7 @@ def _von_mises_centered(key, concentration, shape, dtype=None):
1335
1535
  Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
1336
1536
 
1337
1537
  """
1338
- shape = shape or jnp.shape(concentration)
1538
+ shape = shape or u.math.shape(concentration)
1339
1539
  dtype = dtype or environ.dftype()
1340
1540
  concentration = lax.convert_element_type(concentration, dtype)
1341
1541
  concentration = jnp.broadcast_to(concentration, shape)
@@ -1357,42 +1557,50 @@ def _von_mises_centered(key, concentration, shape, dtype=None):
1357
1557
 
1358
1558
  s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
1359
1559
 
1360
- def cond_fn(*args):
1560
+ def cond_fn(
1561
+ *args
1562
+ ):
1361
1563
  """check if all are done or reached max number of iterations"""
1362
1564
  i, _, done, _, _ = args[0]
1363
1565
  return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
1364
1566
 
1365
- def body_fn(*args):
1567
+ def body_fn(
1568
+ *args
1569
+ ):
1366
1570
  i, key, done, _, w = args[0]
1367
1571
  uni_ukey, uni_vkey, key = jr.split(key, 3)
1368
- u = jr.uniform(
1572
+ u_ = jr.uniform(
1369
1573
  key=uni_ukey,
1370
1574
  shape=shape,
1371
1575
  dtype=concentration.dtype,
1372
1576
  minval=-1.0,
1373
1577
  maxval=1.0,
1374
1578
  )
1375
- z = jnp.cos(jnp.pi * u)
1579
+ z = jnp.cos(jnp.pi * u_)
1376
1580
  w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
1377
1581
  y = concentration * (s - w)
1378
1582
  v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
1379
1583
  accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
1380
- return i + 1, key, accept | done, u, w
1584
+ return i + 1, key, accept | done, u_, w
1381
1585
 
1382
1586
  init_done = jnp.zeros(shape, dtype=bool)
1383
1587
  init_u = jnp.zeros(shape)
1384
1588
  init_w = jnp.zeros(shape)
1385
1589
 
1386
- _, _, done, u, w = lax.while_loop(
1590
+ _, _, done, uu, w = lax.while_loop(
1387
1591
  cond_fun=cond_fn,
1388
1592
  body_fun=body_fn,
1389
1593
  init_val=(jnp.array(0), key, init_done, init_u, init_w),
1390
1594
  )
1391
1595
 
1392
- return jnp.sign(u) * jnp.arccos(w)
1596
+ return jnp.sign(uu) * jnp.arccos(w)
1393
1597
 
1394
1598
 
1395
- def _loc_scale(loc, scale, value):
1599
+ def _loc_scale(
1600
+ loc,
1601
+ scale,
1602
+ value
1603
+ ):
1396
1604
  if loc is None:
1397
1605
  if scale is None:
1398
1606
  return value
@@ -1406,4 +1614,4 @@ def _loc_scale(loc, scale, value):
1406
1614
 
1407
1615
 
1408
1616
  def _check_py_seq(seq):
1409
- return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq
1617
+ return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq