brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/random/_rand_state.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -29,11 +29,12 @@ from jax import lax, core, dtypes
|
|
29
29
|
|
30
30
|
from brainstate import environ
|
31
31
|
from brainstate._state import State
|
32
|
-
from brainstate.compile._error_if import jit_error_if
|
33
32
|
from brainstate.typing import DTypeLike, Size, SeedOrKey
|
34
|
-
from ._random_for_unit import uniform_for_unit, permutation_for_unit
|
35
33
|
|
36
|
-
__all__ = [
|
34
|
+
__all__ = [
|
35
|
+
'RandomState',
|
36
|
+
'DEFAULT',
|
37
|
+
]
|
37
38
|
|
38
39
|
use_prng_key = True
|
39
40
|
|
@@ -43,7 +44,10 @@ class RandomState(State):
|
|
43
44
|
|
44
45
|
# __slots__ = ('_backup', '_value')
|
45
46
|
|
46
|
-
def __init__(
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
seed_or_key: Optional[SeedOrKey] = None
|
50
|
+
):
|
47
51
|
"""RandomState constructor.
|
48
52
|
|
49
53
|
Parameters
|
@@ -69,10 +73,14 @@ class RandomState(State):
|
|
69
73
|
|
70
74
|
self._backup = None
|
71
75
|
|
72
|
-
def __repr__(
|
76
|
+
def __repr__(
|
77
|
+
self
|
78
|
+
):
|
73
79
|
return f'{self.__class__.__name__}({self.value})'
|
74
80
|
|
75
|
-
def check_if_deleted(
|
81
|
+
def check_if_deleted(
|
82
|
+
self
|
83
|
+
):
|
76
84
|
if not use_prng_key and isinstance(self._value, np.ndarray):
|
77
85
|
self._value = jr.key(np.random.randint(0, 10000))
|
78
86
|
|
@@ -104,7 +112,10 @@ class RandomState(State):
|
|
104
112
|
def set_key(self, key: SeedOrKey):
|
105
113
|
self.value = key
|
106
114
|
|
107
|
-
def seed(
|
115
|
+
def seed(
|
116
|
+
self,
|
117
|
+
seed_or_key: Optional[SeedOrKey] = None
|
118
|
+
):
|
108
119
|
"""Sets a new random seed.
|
109
120
|
|
110
121
|
Parameters
|
@@ -132,7 +143,11 @@ class RandomState(State):
|
|
132
143
|
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
133
144
|
self.value = key
|
134
145
|
|
135
|
-
def split_key(
|
146
|
+
def split_key(
|
147
|
+
self,
|
148
|
+
n: Optional[int] = None,
|
149
|
+
backup: bool = False
|
150
|
+
) -> SeedOrKey:
|
136
151
|
"""
|
137
152
|
Create a new seed from the current seed.
|
138
153
|
|
@@ -152,7 +167,7 @@ class RandomState(State):
|
|
152
167
|
assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
|
153
168
|
|
154
169
|
if not isinstance(self.value, jax.Array):
|
155
|
-
self.value =
|
170
|
+
self.value = u.math.asarray(self.value, dtype=jnp.uint32)
|
156
171
|
keys = jr.split(self.value, num=2 if n is None else n + 1)
|
157
172
|
self.value = keys[0]
|
158
173
|
if backup:
|
@@ -162,7 +177,11 @@ class RandomState(State):
|
|
162
177
|
else:
|
163
178
|
return keys[1:]
|
164
179
|
|
165
|
-
def self_assign_multi_keys(
|
180
|
+
def self_assign_multi_keys(
|
181
|
+
self,
|
182
|
+
n: int,
|
183
|
+
backup: bool = True
|
184
|
+
):
|
166
185
|
"""
|
167
186
|
Self-assign multiple keys to the current random state.
|
168
187
|
"""
|
@@ -178,10 +197,15 @@ class RandomState(State):
|
|
178
197
|
# random functions #
|
179
198
|
# ---------------- #
|
180
199
|
|
181
|
-
def rand(
|
200
|
+
def rand(
|
201
|
+
self,
|
202
|
+
*dn,
|
203
|
+
key: Optional[SeedOrKey] = None,
|
204
|
+
dtype: DTypeLike = None
|
205
|
+
):
|
182
206
|
key = self.split_key() if key is None else _formalize_key(key)
|
183
207
|
dtype = dtype or environ.dftype()
|
184
|
-
r =
|
208
|
+
r = jr.uniform(key, dn, dtype)
|
185
209
|
return r
|
186
210
|
|
187
211
|
def randint(
|
@@ -198,8 +222,8 @@ class RandomState(State):
|
|
198
222
|
high = _check_py_seq(high)
|
199
223
|
low = _check_py_seq(low)
|
200
224
|
if size is None:
|
201
|
-
size = lax.broadcast_shapes(
|
202
|
-
|
225
|
+
size = lax.broadcast_shapes(u.math.shape(low),
|
226
|
+
u.math.shape(high))
|
203
227
|
key = self.split_key() if key is None else _formalize_key(key)
|
204
228
|
dtype = dtype or environ.ditype()
|
205
229
|
r = jr.randint(key,
|
@@ -213,7 +237,7 @@ class RandomState(State):
|
|
213
237
|
high=None,
|
214
238
|
size: Optional[Size] = None,
|
215
239
|
key: Optional[SeedOrKey] = None,
|
216
|
-
dtype: DTypeLike = None
|
240
|
+
dtype: DTypeLike = None
|
217
241
|
):
|
218
242
|
low = _check_py_seq(low)
|
219
243
|
high = _check_py_seq(high)
|
@@ -222,7 +246,7 @@ class RandomState(State):
|
|
222
246
|
low = 1
|
223
247
|
high += 1
|
224
248
|
if size is None:
|
225
|
-
size = lax.broadcast_shapes(
|
249
|
+
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
226
250
|
key = self.split_key() if key is None else _formalize_key(key)
|
227
251
|
dtype = dtype or environ.ditype()
|
228
252
|
r = jr.randint(key,
|
@@ -232,112 +256,137 @@ class RandomState(State):
|
|
232
256
|
dtype=dtype)
|
233
257
|
return r
|
234
258
|
|
235
|
-
def randn(
|
259
|
+
def randn(
|
260
|
+
self,
|
261
|
+
*dn,
|
262
|
+
key: Optional[SeedOrKey] = None,
|
263
|
+
dtype: DTypeLike = None
|
264
|
+
):
|
236
265
|
key = self.split_key() if key is None else _formalize_key(key)
|
237
266
|
dtype = dtype or environ.dftype()
|
238
267
|
r = jr.normal(key, shape=dn, dtype=dtype)
|
239
268
|
return r
|
240
269
|
|
241
|
-
def random(
|
242
|
-
|
243
|
-
|
244
|
-
|
270
|
+
def random(
|
271
|
+
self,
|
272
|
+
size: Optional[Size] = None,
|
273
|
+
key: Optional[SeedOrKey] = None,
|
274
|
+
dtype: DTypeLike = None
|
275
|
+
):
|
245
276
|
dtype = dtype or environ.dftype()
|
246
277
|
key = self.split_key() if key is None else _formalize_key(key)
|
247
|
-
r =
|
278
|
+
r = jr.uniform(key, _size2shape(size), dtype)
|
248
279
|
return r
|
249
280
|
|
250
|
-
def random_sample(
|
251
|
-
|
252
|
-
|
253
|
-
|
281
|
+
def random_sample(
|
282
|
+
self,
|
283
|
+
size: Optional[Size] = None,
|
284
|
+
key: Optional[SeedOrKey] = None,
|
285
|
+
dtype: DTypeLike = None
|
286
|
+
):
|
254
287
|
r = self.random(size=size, key=key, dtype=dtype)
|
255
288
|
return r
|
256
289
|
|
257
|
-
def ranf(
|
258
|
-
|
259
|
-
|
260
|
-
|
290
|
+
def ranf(
|
291
|
+
self,
|
292
|
+
size: Optional[Size] = None,
|
293
|
+
key: Optional[SeedOrKey] = None,
|
294
|
+
dtype: DTypeLike = None
|
295
|
+
):
|
261
296
|
r = self.random(size=size, key=key, dtype=dtype)
|
262
297
|
return r
|
263
298
|
|
264
|
-
def sample(
|
265
|
-
|
266
|
-
|
267
|
-
|
299
|
+
def sample(
|
300
|
+
self,
|
301
|
+
size: Optional[Size] = None,
|
302
|
+
key: Optional[SeedOrKey] = None,
|
303
|
+
dtype: DTypeLike = None
|
304
|
+
):
|
268
305
|
r = self.random(size=size, key=key, dtype=dtype)
|
269
306
|
return r
|
270
307
|
|
271
|
-
def choice(
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
308
|
+
def choice(
|
309
|
+
self,
|
310
|
+
a,
|
311
|
+
size: Optional[Size] = None,
|
312
|
+
replace=True,
|
313
|
+
p=None,
|
314
|
+
key: Optional[SeedOrKey] = None
|
315
|
+
):
|
277
316
|
a = _check_py_seq(a)
|
317
|
+
a, unit = u.split_mantissa_unit(a)
|
278
318
|
p = _check_py_seq(p)
|
279
319
|
key = self.split_key() if key is None else _formalize_key(key)
|
280
320
|
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
|
281
|
-
return r
|
321
|
+
return u.maybe_decimal(r * unit)
|
282
322
|
|
283
|
-
def permutation(
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
323
|
+
def permutation(
|
324
|
+
self,
|
325
|
+
x,
|
326
|
+
axis: int = 0,
|
327
|
+
independent: bool = False,
|
328
|
+
key: Optional[SeedOrKey] = None
|
329
|
+
):
|
288
330
|
x = _check_py_seq(x)
|
331
|
+
x, unit = u.split_mantissa_unit(x)
|
289
332
|
key = self.split_key() if key is None else _formalize_key(key)
|
290
|
-
r =
|
291
|
-
return r
|
333
|
+
r = jr.permutation(key, x, axis, independent=independent)
|
334
|
+
return u.maybe_decimal(r * unit)
|
292
335
|
|
293
|
-
def shuffle(
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
key =
|
298
|
-
|
299
|
-
return x
|
300
|
-
|
301
|
-
def beta(
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
336
|
+
def shuffle(
|
337
|
+
self,
|
338
|
+
x,
|
339
|
+
axis=0,
|
340
|
+
key: Optional[SeedOrKey] = None
|
341
|
+
):
|
342
|
+
return self.permutation(x, axis=axis, key=key, independent=False)
|
343
|
+
|
344
|
+
def beta(
|
345
|
+
self,
|
346
|
+
a,
|
347
|
+
b,
|
348
|
+
size: Optional[Size] = None,
|
349
|
+
key: Optional[SeedOrKey] = None,
|
350
|
+
dtype: DTypeLike = None
|
351
|
+
):
|
307
352
|
a = _check_py_seq(a)
|
308
353
|
b = _check_py_seq(b)
|
309
354
|
if size is None:
|
310
|
-
size = lax.broadcast_shapes(
|
355
|
+
size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
|
311
356
|
key = self.split_key() if key is None else _formalize_key(key)
|
312
357
|
dtype = dtype or environ.dftype()
|
313
358
|
r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype)
|
314
359
|
return r
|
315
360
|
|
316
|
-
def exponential(
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
361
|
+
def exponential(
|
362
|
+
self,
|
363
|
+
scale=None,
|
364
|
+
size: Optional[Size] = None,
|
365
|
+
key: Optional[SeedOrKey] = None,
|
366
|
+
dtype: DTypeLike = None
|
367
|
+
):
|
321
368
|
if size is None:
|
322
|
-
size =
|
369
|
+
size = u.math.shape(scale)
|
323
370
|
key = self.split_key() if key is None else _formalize_key(key)
|
324
371
|
dtype = dtype or environ.dftype()
|
325
|
-
scale = jnp.asarray(scale, dtype=dtype)
|
326
372
|
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
327
373
|
if scale is not None:
|
374
|
+
scale = u.math.asarray(scale, dtype=dtype)
|
328
375
|
r = r / scale
|
329
376
|
return r
|
330
377
|
|
331
|
-
def gamma(
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
378
|
+
def gamma(
|
379
|
+
self,
|
380
|
+
shape,
|
381
|
+
scale=None,
|
382
|
+
size: Optional[Size] = None,
|
383
|
+
key: Optional[SeedOrKey] = None,
|
384
|
+
dtype: DTypeLike = None
|
385
|
+
):
|
337
386
|
shape = _check_py_seq(shape)
|
338
387
|
scale = _check_py_seq(scale)
|
339
388
|
if size is None:
|
340
|
-
size = lax.broadcast_shapes(
|
389
|
+
size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
|
341
390
|
key = self.split_key() if key is None else _formalize_key(key)
|
342
391
|
dtype = dtype or environ.dftype()
|
343
392
|
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
@@ -345,166 +394,196 @@ class RandomState(State):
|
|
345
394
|
r = r * scale
|
346
395
|
return r
|
347
396
|
|
348
|
-
def gumbel(
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
397
|
+
def gumbel(
|
398
|
+
self,
|
399
|
+
loc=None,
|
400
|
+
scale=None,
|
401
|
+
size: Optional[Size] = None,
|
402
|
+
key: Optional[SeedOrKey] = None,
|
403
|
+
dtype: DTypeLike = None
|
404
|
+
):
|
354
405
|
loc = _check_py_seq(loc)
|
355
406
|
scale = _check_py_seq(scale)
|
356
407
|
if size is None:
|
357
|
-
size = lax.broadcast_shapes(
|
408
|
+
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
358
409
|
key = self.split_key() if key is None else _formalize_key(key)
|
359
410
|
dtype = dtype or environ.dftype()
|
360
411
|
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
|
361
412
|
return r
|
362
413
|
|
363
|
-
def laplace(
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
414
|
+
def laplace(
|
415
|
+
self,
|
416
|
+
loc=None,
|
417
|
+
scale=None,
|
418
|
+
size: Optional[Size] = None,
|
419
|
+
key: Optional[SeedOrKey] = None,
|
420
|
+
dtype: DTypeLike = None
|
421
|
+
):
|
369
422
|
loc = _check_py_seq(loc)
|
370
423
|
scale = _check_py_seq(scale)
|
371
424
|
if size is None:
|
372
|
-
size = lax.broadcast_shapes(
|
425
|
+
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
373
426
|
key = self.split_key() if key is None else _formalize_key(key)
|
374
427
|
dtype = dtype or environ.dftype()
|
375
428
|
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype))
|
376
429
|
return r
|
377
430
|
|
378
|
-
def logistic(
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
431
|
+
def logistic(
|
432
|
+
self,
|
433
|
+
loc=None,
|
434
|
+
scale=None,
|
435
|
+
size: Optional[Size] = None,
|
436
|
+
key: Optional[SeedOrKey] = None,
|
437
|
+
dtype: DTypeLike = None
|
438
|
+
):
|
384
439
|
loc = _check_py_seq(loc)
|
385
440
|
scale = _check_py_seq(scale)
|
386
441
|
if size is None:
|
387
442
|
size = lax.broadcast_shapes(
|
388
|
-
|
389
|
-
|
443
|
+
u.math.shape(loc) if loc is not None else (),
|
444
|
+
u.math.shape(scale) if scale is not None else ()
|
390
445
|
)
|
391
446
|
key = self.split_key() if key is None else _formalize_key(key)
|
392
447
|
dtype = dtype or environ.dftype()
|
393
448
|
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
|
394
449
|
return r
|
395
450
|
|
396
|
-
def normal(
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
451
|
+
def normal(
|
452
|
+
self,
|
453
|
+
loc=None,
|
454
|
+
scale=None,
|
455
|
+
size: Optional[Size] = None,
|
456
|
+
key: Optional[SeedOrKey] = None,
|
457
|
+
dtype: DTypeLike = None
|
458
|
+
):
|
402
459
|
loc = _check_py_seq(loc)
|
403
460
|
scale = _check_py_seq(scale)
|
404
461
|
if size is None:
|
405
462
|
size = lax.broadcast_shapes(
|
406
|
-
|
407
|
-
|
463
|
+
u.math.shape(scale) if scale is not None else (),
|
464
|
+
u.math.shape(loc) if loc is not None else ()
|
408
465
|
)
|
409
466
|
key = self.split_key() if key is None else _formalize_key(key)
|
410
467
|
dtype = dtype or environ.dftype()
|
411
468
|
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
412
469
|
return r
|
413
470
|
|
414
|
-
def pareto(
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
471
|
+
def pareto(
|
472
|
+
self,
|
473
|
+
a,
|
474
|
+
size: Optional[Size] = None,
|
475
|
+
key: Optional[SeedOrKey] = None,
|
476
|
+
dtype: DTypeLike = None
|
477
|
+
):
|
419
478
|
if size is None:
|
420
|
-
size =
|
479
|
+
size = u.math.shape(a)
|
421
480
|
key = self.split_key() if key is None else _formalize_key(key)
|
422
481
|
dtype = dtype or environ.dftype()
|
423
|
-
a =
|
482
|
+
a = u.math.asarray(a, dtype=dtype)
|
424
483
|
r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
|
425
484
|
return r
|
426
485
|
|
427
|
-
def poisson(
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
486
|
+
def poisson(
|
487
|
+
self,
|
488
|
+
lam=1.0,
|
489
|
+
size: Optional[Size] = None,
|
490
|
+
key: Optional[SeedOrKey] = None,
|
491
|
+
dtype: DTypeLike = None
|
492
|
+
):
|
432
493
|
lam = _check_py_seq(lam)
|
433
494
|
if size is None:
|
434
|
-
size =
|
495
|
+
size = u.math.shape(lam)
|
435
496
|
key = self.split_key() if key is None else _formalize_key(key)
|
436
497
|
dtype = dtype or environ.ditype()
|
437
498
|
r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
|
438
499
|
return r
|
439
500
|
|
440
|
-
def standard_cauchy(
|
441
|
-
|
442
|
-
|
443
|
-
|
501
|
+
def standard_cauchy(
|
502
|
+
self,
|
503
|
+
size: Optional[Size] = None,
|
504
|
+
key: Optional[SeedOrKey] = None,
|
505
|
+
dtype: DTypeLike = None
|
506
|
+
):
|
444
507
|
key = self.split_key() if key is None else _formalize_key(key)
|
445
508
|
dtype = dtype or environ.dftype()
|
446
509
|
r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
|
447
510
|
return r
|
448
511
|
|
449
|
-
def standard_exponential(
|
450
|
-
|
451
|
-
|
452
|
-
|
512
|
+
def standard_exponential(
|
513
|
+
self,
|
514
|
+
size: Optional[Size] = None,
|
515
|
+
key: Optional[SeedOrKey] = None,
|
516
|
+
dtype: DTypeLike = None
|
517
|
+
):
|
453
518
|
key = self.split_key() if key is None else _formalize_key(key)
|
454
519
|
dtype = dtype or environ.dftype()
|
455
520
|
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
456
521
|
return r
|
457
522
|
|
458
|
-
def standard_gamma(
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
523
|
+
def standard_gamma(
|
524
|
+
self,
|
525
|
+
shape,
|
526
|
+
size: Optional[Size] = None,
|
527
|
+
key: Optional[SeedOrKey] = None,
|
528
|
+
dtype: DTypeLike = None
|
529
|
+
):
|
463
530
|
shape = _check_py_seq(shape)
|
464
531
|
if size is None:
|
465
|
-
size =
|
532
|
+
size = u.math.shape(shape) if shape is not None else ()
|
466
533
|
key = self.split_key() if key is None else _formalize_key(key)
|
467
534
|
dtype = dtype or environ.dftype()
|
468
535
|
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
469
536
|
return r
|
470
537
|
|
471
|
-
def standard_normal(
|
472
|
-
|
473
|
-
|
474
|
-
|
538
|
+
def standard_normal(
|
539
|
+
self,
|
540
|
+
size: Optional[Size] = None,
|
541
|
+
key: Optional[SeedOrKey] = None,
|
542
|
+
dtype: DTypeLike = None
|
543
|
+
):
|
475
544
|
key = self.split_key() if key is None else _formalize_key(key)
|
476
545
|
dtype = dtype or environ.dftype()
|
477
546
|
r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
478
547
|
return r
|
479
548
|
|
480
|
-
def standard_t(
|
481
|
-
|
482
|
-
|
483
|
-
|
549
|
+
def standard_t(
|
550
|
+
self,
|
551
|
+
df,
|
552
|
+
size: Optional[Size] = None,
|
553
|
+
key: Optional[SeedOrKey] = None,
|
554
|
+
dtype: DTypeLike = None
|
555
|
+
):
|
484
556
|
df = _check_py_seq(df)
|
485
557
|
if size is None:
|
486
|
-
size =
|
558
|
+
size = u.math.shape(size) if size is not None else ()
|
487
559
|
key = self.split_key() if key is None else _formalize_key(key)
|
488
560
|
dtype = dtype or environ.dftype()
|
489
561
|
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
490
562
|
return r
|
491
563
|
|
492
|
-
def uniform(
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
564
|
+
def uniform(
|
565
|
+
self,
|
566
|
+
low=0.0,
|
567
|
+
high=1.0,
|
568
|
+
size: Optional[Size] = None,
|
569
|
+
key: Optional[SeedOrKey] = None,
|
570
|
+
dtype: DTypeLike = None
|
571
|
+
):
|
572
|
+
low, unit = u.split_mantissa_unit(_check_py_seq(low))
|
573
|
+
high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
|
500
574
|
if size is None:
|
501
|
-
size = lax.broadcast_shapes(
|
575
|
+
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
502
576
|
key = self.split_key() if key is None else _formalize_key(key)
|
503
577
|
dtype = dtype or environ.dftype()
|
504
|
-
r =
|
505
|
-
return r
|
578
|
+
r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
|
579
|
+
return u.maybe_decimal(r * unit)
|
506
580
|
|
507
|
-
def __norm_cdf(
|
581
|
+
def __norm_cdf(
|
582
|
+
self,
|
583
|
+
x,
|
584
|
+
sqrt2,
|
585
|
+
dtype
|
586
|
+
):
|
508
587
|
# Computes standard normal cumulative distribution function
|
509
588
|
return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
|
510
589
|
|
@@ -513,10 +592,11 @@ class RandomState(State):
|
|
513
592
|
lower,
|
514
593
|
upper,
|
515
594
|
size: Optional[Size] = None,
|
516
|
-
loc=0
|
517
|
-
scale=1
|
595
|
+
loc=0.0,
|
596
|
+
scale=1.0,
|
518
597
|
key: Optional[SeedOrKey] = None,
|
519
|
-
dtype: DTypeLike = None
|
598
|
+
dtype: DTypeLike = None,
|
599
|
+
check_valid: bool = True
|
520
600
|
):
|
521
601
|
lower = _check_py_seq(lower)
|
522
602
|
upper = _check_py_seq(upper)
|
@@ -524,29 +604,31 @@ class RandomState(State):
|
|
524
604
|
scale = _check_py_seq(scale)
|
525
605
|
dtype = dtype or environ.dftype()
|
526
606
|
|
527
|
-
lower = u.math.asarray(lower, dtype=dtype)
|
607
|
+
lower, unit = u.split_mantissa_unit(u.math.asarray(lower, dtype=dtype))
|
528
608
|
upper = u.math.asarray(upper, dtype=dtype)
|
529
609
|
loc = u.math.asarray(loc, dtype=dtype)
|
530
610
|
scale = u.math.asarray(scale, dtype=dtype)
|
531
|
-
|
532
|
-
lower, upper, loc, scale = (
|
533
|
-
lower.mantissa if isinstance(lower, u.Quantity) else lower,
|
611
|
+
upper, loc, scale = (
|
534
612
|
u.Quantity(upper).in_unit(unit).mantissa,
|
535
613
|
u.Quantity(loc).in_unit(unit).mantissa,
|
536
614
|
u.Quantity(scale).in_unit(unit).mantissa
|
537
615
|
)
|
538
616
|
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
617
|
+
if check_valid:
|
618
|
+
from brainstate.transform._error_if import jit_error_if
|
619
|
+
jit_error_if(
|
620
|
+
u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
621
|
+
"mean is more than 2 std from [lower, upper] in truncated_normal. "
|
622
|
+
"The distribution of values may be incorrect."
|
623
|
+
)
|
544
624
|
|
545
625
|
if size is None:
|
546
|
-
size = u.math.broadcast_shapes(
|
547
|
-
|
548
|
-
|
549
|
-
|
626
|
+
size = u.math.broadcast_shapes(
|
627
|
+
u.math.shape(lower),
|
628
|
+
u.math.shape(upper),
|
629
|
+
u.math.shape(loc),
|
630
|
+
u.math.shape(scale)
|
631
|
+
)
|
550
632
|
|
551
633
|
# Values are generated by using a truncated uniform distribution and
|
552
634
|
# then using the inverse CDF for the normal distribution.
|
@@ -558,7 +640,7 @@ class RandomState(State):
|
|
558
640
|
# Uniformly fill tensor with values from [l, u], then translate to
|
559
641
|
# [2l-1, 2u-1].
|
560
642
|
key = self.split_key() if key is None else _formalize_key(key)
|
561
|
-
out =
|
643
|
+
out = jr.uniform(
|
562
644
|
key, size, dtype,
|
563
645
|
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
564
646
|
maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
|
@@ -577,19 +659,24 @@ class RandomState(State):
|
|
577
659
|
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
578
660
|
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
|
579
661
|
)
|
580
|
-
return
|
662
|
+
return u.maybe_decimal(out * unit)
|
581
663
|
|
582
664
|
def _check_p(self, *args, **kwargs):
|
583
665
|
raise ValueError('Parameter p should be within [0, 1], but we got {p}')
|
584
666
|
|
585
|
-
def bernoulli(
|
586
|
-
|
587
|
-
|
588
|
-
|
667
|
+
def bernoulli(
|
668
|
+
self,
|
669
|
+
p,
|
670
|
+
size: Optional[Size] = None,
|
671
|
+
key: Optional[SeedOrKey] = None,
|
672
|
+
check_valid: bool = True
|
673
|
+
):
|
589
674
|
p = _check_py_seq(p)
|
590
|
-
|
675
|
+
if check_valid:
|
676
|
+
from brainstate.transform._error_if import jit_error_if
|
677
|
+
jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
|
591
678
|
if size is None:
|
592
|
-
size =
|
679
|
+
size = u.math.shape(p)
|
593
680
|
key = self.split_key() if key is None else _formalize_key(key)
|
594
681
|
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
595
682
|
return r
|
@@ -606,21 +693,21 @@ class RandomState(State):
|
|
606
693
|
sigma = _check_py_seq(sigma)
|
607
694
|
mean = u.math.asarray(mean, dtype=dtype)
|
608
695
|
sigma = u.math.asarray(sigma, dtype=dtype)
|
609
|
-
unit = mean.unit if isinstance(mean, u.Quantity) else u.
|
696
|
+
unit = mean.unit if isinstance(mean, u.Quantity) else u.UNITLESS
|
610
697
|
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
611
698
|
sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
|
612
699
|
|
613
700
|
if size is None:
|
614
701
|
size = jnp.broadcast_shapes(
|
615
|
-
|
616
|
-
|
702
|
+
u.math.shape(mean) if mean is not None else (),
|
703
|
+
u.math.shape(sigma) if sigma is not None else ()
|
617
704
|
)
|
618
705
|
key = self.split_key() if key is None else _formalize_key(key)
|
619
706
|
dtype = dtype or environ.dftype()
|
620
707
|
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
621
708
|
samples = _loc_scale(mean, sigma, samples)
|
622
709
|
samples = jnp.exp(samples)
|
623
|
-
return
|
710
|
+
return u.maybe_decimal(samples * unit)
|
624
711
|
|
625
712
|
def binomial(
|
626
713
|
self,
|
@@ -629,28 +716,31 @@ class RandomState(State):
|
|
629
716
|
size: Optional[Size] = None,
|
630
717
|
key: Optional[SeedOrKey] = None,
|
631
718
|
dtype: DTypeLike = None,
|
632
|
-
check_valid: bool = True
|
719
|
+
check_valid: bool = True
|
633
720
|
):
|
634
721
|
n = _check_py_seq(n)
|
635
722
|
p = _check_py_seq(p)
|
636
723
|
if check_valid:
|
724
|
+
from brainstate.transform._error_if import jit_error_if
|
637
725
|
jit_error_if(
|
638
726
|
jnp.any(jnp.logical_or(p < 0, p > 1)),
|
639
727
|
'Parameter p should be within [0, 1], but we got {p}',
|
640
728
|
p=p
|
641
729
|
)
|
642
730
|
if size is None:
|
643
|
-
size = jnp.broadcast_shapes(
|
731
|
+
size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
|
644
732
|
key = self.split_key() if key is None else _formalize_key(key)
|
645
733
|
r = jr.binomial(key, n, p, shape=_size2shape(size))
|
646
734
|
dtype = dtype or environ.ditype()
|
647
|
-
return
|
735
|
+
return u.math.asarray(r, dtype=dtype)
|
648
736
|
|
649
|
-
def chisquare(
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
737
|
+
def chisquare(
|
738
|
+
self,
|
739
|
+
df,
|
740
|
+
size: Optional[Size] = None,
|
741
|
+
key: Optional[SeedOrKey] = None,
|
742
|
+
dtype: DTypeLike = None
|
743
|
+
):
|
654
744
|
df = _check_py_seq(df)
|
655
745
|
key = self.split_key() if key is None else _formalize_key(key)
|
656
746
|
dtype = dtype or environ.dftype()
|
@@ -665,52 +755,61 @@ class RandomState(State):
|
|
665
755
|
dist = dist.sum(axis=0)
|
666
756
|
return dist
|
667
757
|
|
668
|
-
def dirichlet(
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
758
|
+
def dirichlet(
|
759
|
+
self,
|
760
|
+
alpha,
|
761
|
+
size: Optional[Size] = None,
|
762
|
+
key: Optional[SeedOrKey] = None,
|
763
|
+
dtype: DTypeLike = None
|
764
|
+
):
|
673
765
|
key = self.split_key() if key is None else _formalize_key(key)
|
674
766
|
alpha = _check_py_seq(alpha)
|
675
767
|
dtype = dtype or environ.dftype()
|
676
768
|
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
|
677
769
|
return r
|
678
770
|
|
679
|
-
def geometric(
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
771
|
+
def geometric(
|
772
|
+
self,
|
773
|
+
p,
|
774
|
+
size: Optional[Size] = None,
|
775
|
+
key: Optional[SeedOrKey] = None,
|
776
|
+
dtype: DTypeLike = None
|
777
|
+
):
|
684
778
|
p = _check_py_seq(p)
|
685
779
|
if size is None:
|
686
|
-
size =
|
780
|
+
size = u.math.shape(p)
|
687
781
|
key = self.split_key() if key is None else _formalize_key(key)
|
688
782
|
dtype = dtype or environ.dftype()
|
689
|
-
u_ =
|
783
|
+
u_ = jr.uniform(key, size, dtype)
|
690
784
|
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
|
691
785
|
return r
|
692
786
|
|
693
787
|
def _check_p2(self, p):
|
694
788
|
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
|
695
789
|
|
696
|
-
def multinomial(
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
790
|
+
def multinomial(
|
791
|
+
self,
|
792
|
+
n,
|
793
|
+
pvals,
|
794
|
+
size: Optional[Size] = None,
|
795
|
+
key: Optional[SeedOrKey] = None,
|
796
|
+
dtype: DTypeLike = None,
|
797
|
+
check_valid: bool = True
|
798
|
+
):
|
702
799
|
key = self.split_key() if key is None else _formalize_key(key)
|
703
800
|
n = _check_py_seq(n)
|
704
801
|
pvals = _check_py_seq(pvals)
|
705
|
-
|
802
|
+
if check_valid:
|
803
|
+
from brainstate.transform._error_if import jit_error_if
|
804
|
+
jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
|
706
805
|
if isinstance(n, jax.core.Tracer):
|
707
806
|
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
|
708
807
|
size = _size2shape(size)
|
709
808
|
n_max = int(np.max(jax.device_get(n)))
|
710
|
-
batch_shape = lax.broadcast_shapes(
|
809
|
+
batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
|
711
810
|
r = _multinomial(key, pvals, n, n_max, batch_shape + size)
|
712
811
|
dtype = dtype or environ.ditype()
|
713
|
-
return
|
812
|
+
return u.math.asarray(r, dtype=dtype)
|
714
813
|
|
715
814
|
def multivariate_normal(
|
716
815
|
self,
|
@@ -739,9 +838,9 @@ class RandomState(State):
|
|
739
838
|
if not jnp.ndim(cov) >= 2:
|
740
839
|
raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
|
741
840
|
n = mean.shape[-1]
|
742
|
-
if
|
841
|
+
if u.math.shape(cov)[-2:] != (n, n):
|
743
842
|
raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
|
744
|
-
f"but got cov.shape == {
|
843
|
+
f"but got cov.shape == {u.math.shape(cov)}.")
|
745
844
|
if size is None:
|
746
845
|
size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
|
747
846
|
else:
|
@@ -758,92 +857,104 @@ class RandomState(State):
|
|
758
857
|
factor = jnp.linalg.cholesky(cov)
|
759
858
|
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
760
859
|
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
761
|
-
return
|
860
|
+
return u.maybe_decimal(r * unit)
|
762
861
|
|
763
|
-
def rayleigh(
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
862
|
+
def rayleigh(
|
863
|
+
self,
|
864
|
+
scale=1.0,
|
865
|
+
size: Optional[Size] = None,
|
866
|
+
key: Optional[SeedOrKey] = None,
|
867
|
+
dtype: DTypeLike = None
|
868
|
+
):
|
768
869
|
scale = _check_py_seq(scale)
|
769
870
|
if size is None:
|
770
|
-
size =
|
871
|
+
size = u.math.shape(scale)
|
771
872
|
key = self.split_key() if key is None else _formalize_key(key)
|
772
873
|
dtype = dtype or environ.dftype()
|
773
|
-
x = jnp.sqrt(-2. * jnp.log(
|
874
|
+
x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
|
774
875
|
r = x * scale
|
775
876
|
return r
|
776
877
|
|
777
|
-
def triangular(
|
778
|
-
|
779
|
-
|
878
|
+
def triangular(
|
879
|
+
self,
|
880
|
+
size: Optional[Size] = None,
|
881
|
+
key: Optional[SeedOrKey] = None
|
882
|
+
):
|
780
883
|
key = self.split_key() if key is None else _formalize_key(key)
|
781
884
|
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
|
782
885
|
r = 2 * bernoulli_samples - 1
|
783
886
|
return r
|
784
887
|
|
785
|
-
def vonmises(
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
888
|
+
def vonmises(
|
889
|
+
self,
|
890
|
+
mu,
|
891
|
+
kappa,
|
892
|
+
size: Optional[Size] = None,
|
893
|
+
key: Optional[SeedOrKey] = None,
|
894
|
+
dtype: DTypeLike = None
|
895
|
+
):
|
791
896
|
key = self.split_key() if key is None else _formalize_key(key)
|
792
897
|
dtype = dtype or environ.dftype()
|
793
|
-
mu =
|
794
|
-
kappa =
|
898
|
+
mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
|
899
|
+
kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
|
795
900
|
if size is None:
|
796
|
-
size = lax.broadcast_shapes(
|
901
|
+
size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
|
797
902
|
size = _size2shape(size)
|
798
903
|
samples = _von_mises_centered(key, kappa, size, dtype=dtype)
|
799
904
|
samples = samples + mu
|
800
905
|
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
|
801
906
|
return samples
|
802
907
|
|
803
|
-
def weibull(
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
908
|
+
def weibull(
|
909
|
+
self,
|
910
|
+
a,
|
911
|
+
size: Optional[Size] = None,
|
912
|
+
key: Optional[SeedOrKey] = None,
|
913
|
+
dtype: DTypeLike = None
|
914
|
+
):
|
808
915
|
key = self.split_key() if key is None else _formalize_key(key)
|
809
916
|
a = _check_py_seq(a)
|
810
917
|
if size is None:
|
811
|
-
size =
|
918
|
+
size = u.math.shape(a)
|
812
919
|
else:
|
813
920
|
if jnp.size(a) > 1:
|
814
921
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
815
922
|
size = _size2shape(size)
|
816
923
|
dtype = dtype or environ.dftype()
|
817
|
-
random_uniform =
|
924
|
+
random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
|
818
925
|
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
819
926
|
return r
|
820
927
|
|
821
|
-
def weibull_min(
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
928
|
+
def weibull_min(
|
929
|
+
self,
|
930
|
+
a,
|
931
|
+
scale=None,
|
932
|
+
size: Optional[Size] = None,
|
933
|
+
key: Optional[SeedOrKey] = None,
|
934
|
+
dtype: DTypeLike = None
|
935
|
+
):
|
827
936
|
key = self.split_key() if key is None else _formalize_key(key)
|
828
937
|
a = _check_py_seq(a)
|
829
938
|
scale = _check_py_seq(scale)
|
830
939
|
if size is None:
|
831
|
-
size = jnp.broadcast_shapes(
|
940
|
+
size = jnp.broadcast_shapes(u.math.shape(a), u.math.shape(scale) if scale is not None else ())
|
832
941
|
else:
|
833
942
|
if jnp.size(a) > 1:
|
834
943
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
835
944
|
size = _size2shape(size)
|
836
945
|
dtype = dtype or environ.dftype()
|
837
|
-
random_uniform =
|
946
|
+
random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
|
838
947
|
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
839
948
|
if scale is not None:
|
840
949
|
r /= scale
|
841
950
|
return r
|
842
951
|
|
843
|
-
def maxwell(
|
844
|
-
|
845
|
-
|
846
|
-
|
952
|
+
def maxwell(
|
953
|
+
self,
|
954
|
+
size: Optional[Size] = None,
|
955
|
+
key: Optional[SeedOrKey] = None,
|
956
|
+
dtype: DTypeLike = None
|
957
|
+
):
|
847
958
|
key = self.split_key() if key is None else _formalize_key(key)
|
848
959
|
shape = _size2shape(size) + (3,)
|
849
960
|
dtype = dtype or environ.dftype()
|
@@ -851,16 +962,18 @@ class RandomState(State):
|
|
851
962
|
r = jnp.linalg.norm(norm_rvs, axis=-1)
|
852
963
|
return r
|
853
964
|
|
854
|
-
def negative_binomial(
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
965
|
+
def negative_binomial(
|
966
|
+
self,
|
967
|
+
n,
|
968
|
+
p,
|
969
|
+
size: Optional[Size] = None,
|
970
|
+
key: Optional[SeedOrKey] = None,
|
971
|
+
dtype: DTypeLike = None
|
972
|
+
):
|
860
973
|
n = _check_py_seq(n)
|
861
974
|
p = _check_py_seq(p)
|
862
975
|
if size is None:
|
863
|
-
size = lax.broadcast_shapes(
|
976
|
+
size = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p))
|
864
977
|
size = _size2shape(size)
|
865
978
|
logits = jnp.log(p) - jnp.log1p(-p)
|
866
979
|
if key is None:
|
@@ -871,18 +984,20 @@ class RandomState(State):
|
|
871
984
|
r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
|
872
985
|
return r
|
873
986
|
|
874
|
-
def wald(
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
987
|
+
def wald(
|
988
|
+
self,
|
989
|
+
mean,
|
990
|
+
scale,
|
991
|
+
size: Optional[Size] = None,
|
992
|
+
key: Optional[SeedOrKey] = None,
|
993
|
+
dtype: DTypeLike = None
|
994
|
+
):
|
880
995
|
dtype = dtype or environ.dftype()
|
881
996
|
key = self.split_key() if key is None else _formalize_key(key)
|
882
|
-
mean =
|
883
|
-
scale =
|
997
|
+
mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
|
998
|
+
scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
|
884
999
|
if size is None:
|
885
|
-
size = lax.broadcast_shapes(
|
1000
|
+
size = lax.broadcast_shapes(u.math.shape(mean), u.math.shape(scale))
|
886
1001
|
size = _size2shape(size)
|
887
1002
|
sampled_chi2 = jnp.square(self.randn(*size))
|
888
1003
|
sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
|
@@ -917,13 +1032,15 @@ class RandomState(State):
|
|
917
1032
|
jnp.square(mean) / sampled)
|
918
1033
|
return res
|
919
1034
|
|
920
|
-
def t(
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
1035
|
+
def t(
|
1036
|
+
self,
|
1037
|
+
df,
|
1038
|
+
size: Optional[Size] = None,
|
1039
|
+
key: Optional[SeedOrKey] = None,
|
1040
|
+
dtype: DTypeLike = None
|
1041
|
+
):
|
925
1042
|
dtype = dtype or environ.dftype()
|
926
|
-
df =
|
1043
|
+
df = u.math.asarray(_check_py_seq(df), dtype=dtype)
|
927
1044
|
if size is None:
|
928
1045
|
size = np.shape(df)
|
929
1046
|
else:
|
@@ -940,11 +1057,13 @@ class RandomState(State):
|
|
940
1057
|
r = n * jnp.sqrt(half_df / g)
|
941
1058
|
return r
|
942
1059
|
|
943
|
-
def orthogonal(
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
1060
|
+
def orthogonal(
|
1061
|
+
self,
|
1062
|
+
n: int,
|
1063
|
+
size: Optional[Size] = None,
|
1064
|
+
key: Optional[SeedOrKey] = None,
|
1065
|
+
dtype: DTypeLike = None
|
1066
|
+
):
|
948
1067
|
dtype = dtype or environ.dftype()
|
949
1068
|
key = self.split_key() if key is None else _formalize_key(key)
|
950
1069
|
size = _size2shape(size)
|
@@ -956,17 +1075,19 @@ class RandomState(State):
|
|
956
1075
|
r = q * jnp.expand_dims(d / abs(d), -2)
|
957
1076
|
return r
|
958
1077
|
|
959
|
-
def noncentral_chisquare(
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
1078
|
+
def noncentral_chisquare(
|
1079
|
+
self,
|
1080
|
+
df,
|
1081
|
+
nonc,
|
1082
|
+
size: Optional[Size] = None,
|
1083
|
+
key: Optional[SeedOrKey] = None,
|
1084
|
+
dtype: DTypeLike = None
|
1085
|
+
):
|
965
1086
|
dtype = dtype or environ.dftype()
|
966
|
-
df =
|
967
|
-
nonc =
|
1087
|
+
df = u.math.asarray(_check_py_seq(df), dtype=dtype)
|
1088
|
+
nonc = u.math.asarray(_check_py_seq(nonc), dtype=dtype)
|
968
1089
|
if size is None:
|
969
|
-
size = lax.broadcast_shapes(
|
1090
|
+
size = lax.broadcast_shapes(u.math.shape(df), u.math.shape(nonc))
|
970
1091
|
size = _size2shape(size)
|
971
1092
|
if key is None:
|
972
1093
|
keys = self.split_key(3)
|
@@ -980,54 +1101,62 @@ class RandomState(State):
|
|
980
1101
|
r = jnp.where(cond, chi2 + n * n, chi2)
|
981
1102
|
return r
|
982
1103
|
|
983
|
-
def loggamma(
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
1104
|
+
def loggamma(
|
1105
|
+
self,
|
1106
|
+
a,
|
1107
|
+
size: Optional[Size] = None,
|
1108
|
+
key: Optional[SeedOrKey] = None,
|
1109
|
+
dtype: DTypeLike = None
|
1110
|
+
):
|
988
1111
|
dtype = dtype or environ.dftype()
|
989
1112
|
key = self.split_key() if key is None else _formalize_key(key)
|
990
1113
|
a = _check_py_seq(a)
|
991
1114
|
if size is None:
|
992
|
-
size =
|
1115
|
+
size = u.math.shape(a)
|
993
1116
|
r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
|
994
1117
|
return r
|
995
1118
|
|
996
|
-
def categorical(
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1119
|
+
def categorical(
|
1120
|
+
self,
|
1121
|
+
logits,
|
1122
|
+
axis: int = -1,
|
1123
|
+
size: Optional[Size] = None,
|
1124
|
+
key: Optional[SeedOrKey] = None
|
1125
|
+
):
|
1001
1126
|
key = self.split_key() if key is None else _formalize_key(key)
|
1002
1127
|
logits = _check_py_seq(logits)
|
1003
1128
|
if size is None:
|
1004
|
-
size = list(
|
1129
|
+
size = list(u.math.shape(logits))
|
1005
1130
|
size.pop(axis)
|
1006
1131
|
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
|
1007
1132
|
return r
|
1008
1133
|
|
1009
|
-
def zipf(
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1134
|
+
def zipf(
|
1135
|
+
self,
|
1136
|
+
a,
|
1137
|
+
size: Optional[Size] = None,
|
1138
|
+
key: Optional[SeedOrKey] = None,
|
1139
|
+
dtype: DTypeLike = None
|
1140
|
+
):
|
1014
1141
|
a = _check_py_seq(a)
|
1015
1142
|
if size is None:
|
1016
|
-
size =
|
1143
|
+
size = u.math.shape(a)
|
1017
1144
|
dtype = dtype or environ.ditype()
|
1018
1145
|
r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
|
1019
1146
|
jax.ShapeDtypeStruct(size, dtype),
|
1020
1147
|
a)
|
1021
1148
|
return r
|
1022
1149
|
|
1023
|
-
def power(
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1150
|
+
def power(
|
1151
|
+
self,
|
1152
|
+
a,
|
1153
|
+
size: Optional[Size] = None,
|
1154
|
+
key: Optional[SeedOrKey] = None,
|
1155
|
+
dtype: DTypeLike = None
|
1156
|
+
):
|
1028
1157
|
a = _check_py_seq(a)
|
1029
1158
|
if size is None:
|
1030
|
-
size =
|
1159
|
+
size = u.math.shape(a)
|
1031
1160
|
size = _size2shape(size)
|
1032
1161
|
dtype = dtype or environ.dftype()
|
1033
1162
|
r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
|
@@ -1035,24 +1164,28 @@ class RandomState(State):
|
|
1035
1164
|
a)
|
1036
1165
|
return r
|
1037
1166
|
|
1038
|
-
def f(
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1167
|
+
def f(
|
1168
|
+
self,
|
1169
|
+
dfnum,
|
1170
|
+
dfden,
|
1171
|
+
size: Optional[Size] = None,
|
1172
|
+
key: Optional[SeedOrKey] = None,
|
1173
|
+
dtype: DTypeLike = None
|
1174
|
+
):
|
1044
1175
|
dfnum = _check_py_seq(dfnum)
|
1045
1176
|
dfden = _check_py_seq(dfden)
|
1046
1177
|
if size is None:
|
1047
|
-
size = jnp.broadcast_shapes(
|
1178
|
+
size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
|
1048
1179
|
size = _size2shape(size)
|
1049
1180
|
d = {'dfnum': dfnum, 'dfden': dfden}
|
1050
1181
|
dtype = dtype or environ.dftype()
|
1051
|
-
r = jax.pure_callback(
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1182
|
+
r = jax.pure_callback(
|
1183
|
+
lambda dfnum_, dfden_: np.random.f(dfnum=dfnum_,
|
1184
|
+
dfden=dfden_,
|
1185
|
+
size=size).astype(dtype),
|
1186
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1187
|
+
dfnum, dfden
|
1188
|
+
)
|
1056
1189
|
return r
|
1057
1190
|
|
1058
1191
|
def hypergeometric(
|
@@ -1069,64 +1202,82 @@ class RandomState(State):
|
|
1069
1202
|
nsample = _check_py_seq(nsample)
|
1070
1203
|
|
1071
1204
|
if size is None:
|
1072
|
-
size = lax.broadcast_shapes(
|
1073
|
-
|
1074
|
-
|
1205
|
+
size = lax.broadcast_shapes(u.math.shape(ngood),
|
1206
|
+
u.math.shape(nbad),
|
1207
|
+
u.math.shape(nsample))
|
1075
1208
|
size = _size2shape(size)
|
1076
1209
|
dtype = dtype or environ.ditype()
|
1077
1210
|
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
|
1078
|
-
r = jax.pure_callback(
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1211
|
+
r = jax.pure_callback(
|
1212
|
+
lambda d: np.random.hypergeometric(
|
1213
|
+
ngood=d['ngood'],
|
1214
|
+
nbad=d['nbad'],
|
1215
|
+
nsample=d['nsample'],
|
1216
|
+
size=size
|
1217
|
+
).astype(dtype),
|
1218
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1219
|
+
d
|
1220
|
+
)
|
1084
1221
|
return r
|
1085
1222
|
|
1086
|
-
def logseries(
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1223
|
+
def logseries(
|
1224
|
+
self,
|
1225
|
+
p,
|
1226
|
+
size: Optional[Size] = None,
|
1227
|
+
key: Optional[SeedOrKey] = None,
|
1228
|
+
dtype: DTypeLike = None
|
1229
|
+
):
|
1091
1230
|
p = _check_py_seq(p)
|
1092
1231
|
if size is None:
|
1093
|
-
size =
|
1232
|
+
size = u.math.shape(p)
|
1094
1233
|
size = _size2shape(size)
|
1095
1234
|
dtype = dtype or environ.ditype()
|
1096
|
-
r = jax.pure_callback(
|
1097
|
-
|
1098
|
-
|
1235
|
+
r = jax.pure_callback(
|
1236
|
+
lambda p: np.random.logseries(p=p, size=size).astype(dtype),
|
1237
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1238
|
+
p
|
1239
|
+
)
|
1099
1240
|
return r
|
1100
1241
|
|
1101
|
-
def noncentral_f(
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1242
|
+
def noncentral_f(
|
1243
|
+
self,
|
1244
|
+
dfnum,
|
1245
|
+
dfden,
|
1246
|
+
nonc,
|
1247
|
+
size: Optional[Size] = None,
|
1248
|
+
key: Optional[SeedOrKey] = None,
|
1249
|
+
dtype: DTypeLike = None
|
1250
|
+
):
|
1108
1251
|
dfnum = _check_py_seq(dfnum)
|
1109
1252
|
dfden = _check_py_seq(dfden)
|
1110
1253
|
nonc = _check_py_seq(nonc)
|
1111
1254
|
if size is None:
|
1112
|
-
size = lax.broadcast_shapes(
|
1113
|
-
|
1114
|
-
|
1255
|
+
size = lax.broadcast_shapes(u.math.shape(dfnum),
|
1256
|
+
u.math.shape(dfden),
|
1257
|
+
u.math.shape(nonc))
|
1115
1258
|
size = _size2shape(size)
|
1116
1259
|
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
|
1117
1260
|
dtype = dtype or environ.dftype()
|
1118
|
-
r = jax.pure_callback(
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1261
|
+
r = jax.pure_callback(
|
1262
|
+
lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
|
1263
|
+
dfden=x['dfden'],
|
1264
|
+
nonc=x['nonc'],
|
1265
|
+
size=size).astype(dtype),
|
1266
|
+
jax.ShapeDtypeStruct(size, dtype),
|
1267
|
+
d
|
1268
|
+
)
|
1124
1269
|
return r
|
1125
1270
|
|
1126
1271
|
# PyTorch compatibility #
|
1127
1272
|
# --------------------- #
|
1128
1273
|
|
1129
|
-
def rand_like(
|
1274
|
+
def rand_like(
|
1275
|
+
self,
|
1276
|
+
input,
|
1277
|
+
*,
|
1278
|
+
dtype=None,
|
1279
|
+
key: Optional[SeedOrKey] = None
|
1280
|
+
):
|
1130
1281
|
"""Returns a tensor with the same size as input that is filled with random
|
1131
1282
|
numbers from a uniform distribution on the interval ``[0, 1)``.
|
1132
1283
|
|
@@ -1138,9 +1289,15 @@ class RandomState(State):
|
|
1138
1289
|
Returns:
|
1139
1290
|
The random data.
|
1140
1291
|
"""
|
1141
|
-
return self.random(
|
1292
|
+
return self.random(u.math.shape(input), key=key).astype(dtype)
|
1142
1293
|
|
1143
|
-
def randn_like(
|
1294
|
+
def randn_like(
|
1295
|
+
self,
|
1296
|
+
input,
|
1297
|
+
*,
|
1298
|
+
dtype=None,
|
1299
|
+
key: Optional[SeedOrKey] = None
|
1300
|
+
):
|
1144
1301
|
"""Returns a tensor with the same size as ``input`` that is filled with
|
1145
1302
|
random numbers from a normal distribution with mean 0 and variance 1.
|
1146
1303
|
|
@@ -1152,12 +1309,20 @@ class RandomState(State):
|
|
1152
1309
|
Returns:
|
1153
1310
|
The random data.
|
1154
1311
|
"""
|
1155
|
-
return self.randn(*
|
1312
|
+
return self.randn(*u.math.shape(input), key=key).astype(dtype)
|
1156
1313
|
|
1157
|
-
def randint_like(
|
1314
|
+
def randint_like(
|
1315
|
+
self,
|
1316
|
+
input,
|
1317
|
+
low=0,
|
1318
|
+
high=None,
|
1319
|
+
*,
|
1320
|
+
dtype=None,
|
1321
|
+
key: Optional[SeedOrKey] = None
|
1322
|
+
):
|
1158
1323
|
if high is None:
|
1159
1324
|
high = max(input)
|
1160
|
-
return self.randint(low, high=high, size=
|
1325
|
+
return self.randint(low, high=high, size=u.math.shape(input), dtype=dtype, key=key)
|
1161
1326
|
|
1162
1327
|
|
1163
1328
|
# default random generator
|
@@ -1180,7 +1345,7 @@ def _formalize_key(key):
|
|
1180
1345
|
raise TypeError('key must be a int or an array with two uint32.')
|
1181
1346
|
if key.size != 2:
|
1182
1347
|
raise TypeError('key must be a int or an array with two uint32.')
|
1183
|
-
return
|
1348
|
+
return u.math.asarray(key, dtype=jnp.uint32)
|
1184
1349
|
else:
|
1185
1350
|
raise TypeError('key must be a int or an array with two uint32.')
|
1186
1351
|
|
@@ -1194,7 +1359,11 @@ def _size2shape(size):
|
|
1194
1359
|
return (size,)
|
1195
1360
|
|
1196
1361
|
|
1197
|
-
def _check_shape(
|
1362
|
+
def _check_shape(
|
1363
|
+
name,
|
1364
|
+
shape,
|
1365
|
+
*param_shapes
|
1366
|
+
):
|
1198
1367
|
if param_shapes:
|
1199
1368
|
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
1200
1369
|
if shape != shape_:
|
@@ -1223,7 +1392,11 @@ python_scalar_dtypes = {
|
|
1223
1392
|
}
|
1224
1393
|
|
1225
1394
|
|
1226
|
-
def _dtype(
|
1395
|
+
def _dtype(
|
1396
|
+
x,
|
1397
|
+
*,
|
1398
|
+
canonicalize: bool = False
|
1399
|
+
):
|
1227
1400
|
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
|
1228
1401
|
if x is None:
|
1229
1402
|
raise ValueError(f"Invalid argument to dtype: {x}.")
|
@@ -1238,7 +1411,10 @@ def _dtype(x, *, canonicalize: bool = False):
|
|
1238
1411
|
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
|
1239
1412
|
|
1240
1413
|
|
1241
|
-
def _const(
|
1414
|
+
def _const(
|
1415
|
+
example,
|
1416
|
+
val
|
1417
|
+
):
|
1242
1418
|
if _is_python_scalar(example):
|
1243
1419
|
dtype = dtypes.canonicalize_dtype(type(example))
|
1244
1420
|
val = dtypes.scalar_type_of(example)(val)
|
@@ -1249,7 +1425,11 @@ def _const(example, val):
|
|
1249
1425
|
|
1250
1426
|
|
1251
1427
|
@partial(jit, static_argnums=(2,))
|
1252
|
-
def _categorical(
|
1428
|
+
def _categorical(
|
1429
|
+
key,
|
1430
|
+
p,
|
1431
|
+
shape
|
1432
|
+
):
|
1253
1433
|
# this implementation is fast when event shape is small, and slow otherwise
|
1254
1434
|
# Ref: https://stackoverflow.com/a/34190035
|
1255
1435
|
shape = shape or p.shape[:-1]
|
@@ -1258,7 +1438,11 @@ def _categorical(key, p, shape):
|
|
1258
1438
|
return jnp.sum(s < r, axis=-1)
|
1259
1439
|
|
1260
1440
|
|
1261
|
-
def _scatter_add_one(
|
1441
|
+
def _scatter_add_one(
|
1442
|
+
operand,
|
1443
|
+
indices,
|
1444
|
+
updates
|
1445
|
+
):
|
1262
1446
|
return lax.scatter_add(
|
1263
1447
|
operand,
|
1264
1448
|
indices,
|
@@ -1278,12 +1462,15 @@ def _reshape(x, shape):
|
|
1278
1462
|
return jnp.reshape(x, shape)
|
1279
1463
|
|
1280
1464
|
|
1281
|
-
def _promote_shapes(
|
1465
|
+
def _promote_shapes(
|
1466
|
+
*args,
|
1467
|
+
shape=()
|
1468
|
+
):
|
1282
1469
|
# adapted from lax.lax_numpy
|
1283
1470
|
if len(args) < 2 and not shape:
|
1284
1471
|
return args
|
1285
1472
|
else:
|
1286
|
-
shapes = [
|
1473
|
+
shapes = [u.math.shape(arg) for arg in args]
|
1287
1474
|
num_dims = len(lax.broadcast_shapes(shape, *shapes))
|
1288
1475
|
return [
|
1289
1476
|
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
|
@@ -1292,11 +1479,17 @@ def _promote_shapes(*args, shape=()):
|
|
1292
1479
|
|
1293
1480
|
|
1294
1481
|
@partial(jit, static_argnums=(3, 4))
|
1295
|
-
def _multinomial(
|
1296
|
-
|
1297
|
-
|
1482
|
+
def _multinomial(
|
1483
|
+
key,
|
1484
|
+
p,
|
1485
|
+
n,
|
1486
|
+
n_max,
|
1487
|
+
shape=()
|
1488
|
+
):
|
1489
|
+
if u.math.shape(n) != u.math.shape(p)[:-1]:
|
1490
|
+
broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
|
1298
1491
|
n = jnp.broadcast_to(n, broadcast_shape)
|
1299
|
-
p = jnp.broadcast_to(p, broadcast_shape +
|
1492
|
+
p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
|
1300
1493
|
shape = shape or p.shape[:-1]
|
1301
1494
|
if n_max == 0:
|
1302
1495
|
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
|
@@ -1307,21 +1500,28 @@ def _multinomial(key, p, n, n_max, shape=()):
|
|
1307
1500
|
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
|
1308
1501
|
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
|
1309
1502
|
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
|
1310
|
-
jnp.zeros(
|
1503
|
+
jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
|
1311
1504
|
-1)
|
1312
1505
|
else:
|
1313
1506
|
mask = 1
|
1314
1507
|
excess = 0
|
1315
1508
|
# NB: we transpose to move batch shape to the front
|
1316
1509
|
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
|
1317
|
-
samples_2D = vmap(_scatter_add_one)(
|
1318
|
-
|
1319
|
-
|
1510
|
+
samples_2D = vmap(_scatter_add_one)(
|
1511
|
+
jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
|
1512
|
+
jnp.expand_dims(indices_2D, axis=-1),
|
1513
|
+
jnp.ones(indices_2D.shape, dtype=indices.dtype)
|
1514
|
+
)
|
1320
1515
|
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
|
1321
1516
|
|
1322
1517
|
|
1323
1518
|
@partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
|
1324
|
-
def _von_mises_centered(
|
1519
|
+
def _von_mises_centered(
|
1520
|
+
key,
|
1521
|
+
concentration,
|
1522
|
+
shape,
|
1523
|
+
dtype=None
|
1524
|
+
):
|
1325
1525
|
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
|
1326
1526
|
|
1327
1527
|
Returns
|
@@ -1335,7 +1535,7 @@ def _von_mises_centered(key, concentration, shape, dtype=None):
|
|
1335
1535
|
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
|
1336
1536
|
|
1337
1537
|
"""
|
1338
|
-
shape = shape or
|
1538
|
+
shape = shape or u.math.shape(concentration)
|
1339
1539
|
dtype = dtype or environ.dftype()
|
1340
1540
|
concentration = lax.convert_element_type(concentration, dtype)
|
1341
1541
|
concentration = jnp.broadcast_to(concentration, shape)
|
@@ -1357,42 +1557,50 @@ def _von_mises_centered(key, concentration, shape, dtype=None):
|
|
1357
1557
|
|
1358
1558
|
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
|
1359
1559
|
|
1360
|
-
def cond_fn(
|
1560
|
+
def cond_fn(
|
1561
|
+
*args
|
1562
|
+
):
|
1361
1563
|
"""check if all are done or reached max number of iterations"""
|
1362
1564
|
i, _, done, _, _ = args[0]
|
1363
1565
|
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
|
1364
1566
|
|
1365
|
-
def body_fn(
|
1567
|
+
def body_fn(
|
1568
|
+
*args
|
1569
|
+
):
|
1366
1570
|
i, key, done, _, w = args[0]
|
1367
1571
|
uni_ukey, uni_vkey, key = jr.split(key, 3)
|
1368
|
-
|
1572
|
+
u_ = jr.uniform(
|
1369
1573
|
key=uni_ukey,
|
1370
1574
|
shape=shape,
|
1371
1575
|
dtype=concentration.dtype,
|
1372
1576
|
minval=-1.0,
|
1373
1577
|
maxval=1.0,
|
1374
1578
|
)
|
1375
|
-
z = jnp.cos(jnp.pi *
|
1579
|
+
z = jnp.cos(jnp.pi * u_)
|
1376
1580
|
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
|
1377
1581
|
y = concentration * (s - w)
|
1378
1582
|
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
|
1379
1583
|
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
|
1380
|
-
return i + 1, key, accept | done,
|
1584
|
+
return i + 1, key, accept | done, u_, w
|
1381
1585
|
|
1382
1586
|
init_done = jnp.zeros(shape, dtype=bool)
|
1383
1587
|
init_u = jnp.zeros(shape)
|
1384
1588
|
init_w = jnp.zeros(shape)
|
1385
1589
|
|
1386
|
-
_, _, done,
|
1590
|
+
_, _, done, uu, w = lax.while_loop(
|
1387
1591
|
cond_fun=cond_fn,
|
1388
1592
|
body_fun=body_fn,
|
1389
1593
|
init_val=(jnp.array(0), key, init_done, init_u, init_w),
|
1390
1594
|
)
|
1391
1595
|
|
1392
|
-
return jnp.sign(
|
1596
|
+
return jnp.sign(uu) * jnp.arccos(w)
|
1393
1597
|
|
1394
1598
|
|
1395
|
-
def _loc_scale(
|
1599
|
+
def _loc_scale(
|
1600
|
+
loc,
|
1601
|
+
scale,
|
1602
|
+
value
|
1603
|
+
):
|
1396
1604
|
if loc is None:
|
1397
1605
|
if scale is None:
|
1398
1606
|
return value
|
@@ -1406,4 +1614,4 @@ def _loc_scale(loc, scale, value):
|
|
1406
1614
|
|
1407
1615
|
|
1408
1616
|
def _check_py_seq(seq):
|
1409
|
-
return
|
1617
|
+
return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq
|