brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,672 @@
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from functools import partial
|
17
|
+
|
18
|
+
import brainunit as u
|
19
|
+
import jax
|
20
|
+
import jax.numpy as jnp
|
21
|
+
import jax.random as jr
|
22
|
+
import numpy as np
|
23
|
+
from jax import jit, vmap
|
24
|
+
from jax import lax, dtypes
|
25
|
+
from jax.scipy import special as jsp
|
26
|
+
|
27
|
+
from brainstate import environ
|
28
|
+
|
29
|
+
|
30
|
+
def _categorical(key, p, shape):
|
31
|
+
# this implementation is fast when event shape is small, and slow otherwise
|
32
|
+
# Ref: https://stackoverflow.com/a/34190035
|
33
|
+
shape = shape or p.shape[:-1]
|
34
|
+
s = jnp.cumsum(p, axis=-1)
|
35
|
+
r = jr.uniform(key, shape=shape + (1,))
|
36
|
+
return jnp.sum(s < r, axis=-1)
|
37
|
+
|
38
|
+
|
39
|
+
@partial(jit, static_argnames=('n_max', 'shape'))
|
40
|
+
def multinomial(key, p, n, *, n_max, shape=()):
|
41
|
+
if u.math.shape(n) != u.math.shape(p)[:-1]:
|
42
|
+
broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
|
43
|
+
n = jnp.broadcast_to(n, broadcast_shape)
|
44
|
+
p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
|
45
|
+
shape = shape or p.shape[:-1]
|
46
|
+
if n_max == 0:
|
47
|
+
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
|
48
|
+
# get indices from categorical distribution then gather the result
|
49
|
+
indices = _categorical(key, p, (n_max,) + shape)
|
50
|
+
# mask out values when counts is heterogeneous
|
51
|
+
if jnp.ndim(n) > 0:
|
52
|
+
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
|
53
|
+
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
|
54
|
+
excess = jnp.concatenate(
|
55
|
+
[jnp.expand_dims(n_max - n, -1),
|
56
|
+
jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
|
57
|
+
-1
|
58
|
+
)
|
59
|
+
else:
|
60
|
+
mask = 1
|
61
|
+
excess = 0
|
62
|
+
# NB: we transpose to move batch shape to the front
|
63
|
+
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
|
64
|
+
samples_2D = vmap(_scatter_add_one)(
|
65
|
+
jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
|
66
|
+
jnp.expand_dims(indices_2D, axis=-1),
|
67
|
+
jnp.ones(indices_2D.shape, dtype=indices.dtype)
|
68
|
+
)
|
69
|
+
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
|
70
|
+
|
71
|
+
|
72
|
+
@partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
|
73
|
+
def von_mises_centered(
|
74
|
+
key,
|
75
|
+
concentration,
|
76
|
+
shape,
|
77
|
+
dtype=None
|
78
|
+
):
|
79
|
+
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
out: array_like
|
84
|
+
centered samples from von Mises
|
85
|
+
|
86
|
+
References
|
87
|
+
----------
|
88
|
+
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
|
89
|
+
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
|
90
|
+
|
91
|
+
"""
|
92
|
+
shape = shape or u.math.shape(concentration)
|
93
|
+
dtype = dtype or environ.dftype()
|
94
|
+
concentration = lax.convert_element_type(concentration, dtype)
|
95
|
+
concentration = jnp.broadcast_to(concentration, shape)
|
96
|
+
|
97
|
+
if dtype == jnp.float16:
|
98
|
+
s_cutoff = 1.8e-1
|
99
|
+
elif dtype == jnp.float32:
|
100
|
+
s_cutoff = 2e-2
|
101
|
+
elif dtype == jnp.float64:
|
102
|
+
s_cutoff = 1.2e-4
|
103
|
+
else:
|
104
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
105
|
+
|
106
|
+
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
|
107
|
+
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
|
108
|
+
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
|
109
|
+
|
110
|
+
s_approximate = 1.0 / concentration
|
111
|
+
|
112
|
+
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
|
113
|
+
|
114
|
+
def cond_fn(*args):
|
115
|
+
"""check if all are done or reached max number of iterations"""
|
116
|
+
i, _, done, _, _ = args[0]
|
117
|
+
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
|
118
|
+
|
119
|
+
def body_fn(*args):
|
120
|
+
i, key, done, _, w = args[0]
|
121
|
+
uni_ukey, uni_vkey, key = jr.split(key, 3)
|
122
|
+
u_ = jr.uniform(
|
123
|
+
key=uni_ukey,
|
124
|
+
shape=shape,
|
125
|
+
dtype=concentration.dtype,
|
126
|
+
minval=-1.0,
|
127
|
+
maxval=1.0,
|
128
|
+
)
|
129
|
+
z = jnp.cos(jnp.pi * u_)
|
130
|
+
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
|
131
|
+
y = concentration * (s - w)
|
132
|
+
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
|
133
|
+
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
|
134
|
+
return i + 1, key, accept | done, u_, w
|
135
|
+
|
136
|
+
init_done = jnp.zeros(shape, dtype=bool)
|
137
|
+
init_u = jnp.zeros(shape)
|
138
|
+
init_w = jnp.zeros(shape)
|
139
|
+
|
140
|
+
_, _, done, uu, w = lax.while_loop(
|
141
|
+
cond_fun=cond_fn,
|
142
|
+
body_fun=body_fn,
|
143
|
+
init_val=(jnp.array(0), key, init_done, init_u, init_w),
|
144
|
+
)
|
145
|
+
|
146
|
+
return jnp.sign(uu) * jnp.arccos(w)
|
147
|
+
|
148
|
+
|
149
|
+
def _scatter_add_one(operand, indices, updates):
|
150
|
+
return lax.scatter_add(
|
151
|
+
operand,
|
152
|
+
indices,
|
153
|
+
updates,
|
154
|
+
lax.ScatterDimensionNumbers(
|
155
|
+
update_window_dims=(),
|
156
|
+
inserted_window_dims=(0,),
|
157
|
+
scatter_dims_to_operand_dims=(0,),
|
158
|
+
),
|
159
|
+
)
|
160
|
+
|
161
|
+
|
162
|
+
def _reshape(x, shape):
|
163
|
+
if isinstance(x, (int, float, np.ndarray, np.generic)):
|
164
|
+
return np.reshape(x, shape)
|
165
|
+
else:
|
166
|
+
return jnp.reshape(x, shape)
|
167
|
+
|
168
|
+
|
169
|
+
def _promote_shapes(*args, shape=()):
|
170
|
+
# adapted from lax.lax_numpy
|
171
|
+
if len(args) < 2 and not shape:
|
172
|
+
return args
|
173
|
+
else:
|
174
|
+
shapes = [u.math.shape(arg) for arg in args]
|
175
|
+
num_dims = len(lax.broadcast_shapes(shape, *shapes))
|
176
|
+
return [
|
177
|
+
_reshape(arg, (1,) * (num_dims - len(s)) + s)
|
178
|
+
if len(s) < num_dims else arg
|
179
|
+
for arg, s in zip(args, shapes)
|
180
|
+
]
|
181
|
+
|
182
|
+
|
183
|
+
python_scalar_dtypes = {
|
184
|
+
bool: np.dtype('bool'),
|
185
|
+
int: np.dtype('int64'),
|
186
|
+
float: np.dtype('float64'),
|
187
|
+
complex: np.dtype('complex128'),
|
188
|
+
}
|
189
|
+
|
190
|
+
|
191
|
+
def _dtype(x, *, canonicalize: bool = False):
|
192
|
+
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
|
193
|
+
if x is None:
|
194
|
+
raise ValueError(f"Invalid argument to dtype: {x}.")
|
195
|
+
elif isinstance(x, type) and x in python_scalar_dtypes:
|
196
|
+
dt = python_scalar_dtypes[x]
|
197
|
+
elif type(x) in python_scalar_dtypes:
|
198
|
+
dt = python_scalar_dtypes[type(x)]
|
199
|
+
elif hasattr(x, 'dtype'):
|
200
|
+
dt = x.dtype
|
201
|
+
else:
|
202
|
+
dt = np.result_type(x)
|
203
|
+
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
|
204
|
+
|
205
|
+
|
206
|
+
def _is_python_scalar(x):
|
207
|
+
if hasattr(x, 'aval'):
|
208
|
+
return x.aval.weak_type
|
209
|
+
elif np.ndim(x) == 0:
|
210
|
+
return True
|
211
|
+
elif isinstance(x, (bool, int, float, complex)):
|
212
|
+
return True
|
213
|
+
else:
|
214
|
+
return False
|
215
|
+
|
216
|
+
|
217
|
+
def const(example, val):
|
218
|
+
if _is_python_scalar(example):
|
219
|
+
dtype = dtypes.canonicalize_dtype(type(example))
|
220
|
+
val = dtypes.scalar_type_of(example)(val)
|
221
|
+
return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
|
222
|
+
else:
|
223
|
+
dtype = dtypes.canonicalize_dtype(example.dtype)
|
224
|
+
return np.array(val, dtype)
|
225
|
+
|
226
|
+
|
227
|
+
# ---------------------------------------------------------------------------------------------------------------
|
228
|
+
|
229
|
+
|
230
|
+
def formalize_key(key, use_prng_key=True):
|
231
|
+
if isinstance(key, int):
|
232
|
+
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
233
|
+
elif isinstance(key, (jax.Array, np.ndarray)):
|
234
|
+
if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
|
235
|
+
return key
|
236
|
+
if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
|
237
|
+
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
238
|
+
|
239
|
+
if key.dtype != jnp.uint32:
|
240
|
+
raise TypeError('key must be a int or an array with two uint32.')
|
241
|
+
if key.size != 2:
|
242
|
+
raise TypeError('key must be a int or an array with two uint32.')
|
243
|
+
return u.math.asarray(key, dtype=jnp.uint32)
|
244
|
+
else:
|
245
|
+
raise TypeError('key must be a int or an array with two uint32.')
|
246
|
+
|
247
|
+
|
248
|
+
def _size2shape(size):
|
249
|
+
if size is None:
|
250
|
+
return ()
|
251
|
+
elif isinstance(size, (tuple, list)):
|
252
|
+
return tuple(size)
|
253
|
+
else:
|
254
|
+
return (size,)
|
255
|
+
|
256
|
+
|
257
|
+
def _check_shape(name, shape, *param_shapes):
|
258
|
+
if param_shapes:
|
259
|
+
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
260
|
+
if shape != shape_:
|
261
|
+
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
262
|
+
"argument, and the result of broadcasting the shapes must equal "
|
263
|
+
"the shape argument, but got result {} for shape argument {}.")
|
264
|
+
raise ValueError(msg.format(name, shape_, shape))
|
265
|
+
|
266
|
+
|
267
|
+
def _loc_scale(
|
268
|
+
loc,
|
269
|
+
scale,
|
270
|
+
value
|
271
|
+
):
|
272
|
+
if loc is None:
|
273
|
+
if scale is None:
|
274
|
+
return value
|
275
|
+
else:
|
276
|
+
return value * scale
|
277
|
+
else:
|
278
|
+
if scale is None:
|
279
|
+
return value + loc
|
280
|
+
else:
|
281
|
+
return value * scale + loc
|
282
|
+
|
283
|
+
|
284
|
+
def _check_py_seq(seq):
|
285
|
+
return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq
|
286
|
+
|
287
|
+
|
288
|
+
@partial(jit, static_argnames=['shape', 'dtype'])
|
289
|
+
def f(
|
290
|
+
key,
|
291
|
+
dfnum,
|
292
|
+
dfden,
|
293
|
+
*,
|
294
|
+
shape,
|
295
|
+
dtype=None
|
296
|
+
):
|
297
|
+
"""Draw samples from the central F distribution."""
|
298
|
+
dtype = dtype or environ.dftype()
|
299
|
+
dfnum = lax.convert_element_type(dfnum, dtype)
|
300
|
+
dfden = lax.convert_element_type(dfden, dtype)
|
301
|
+
|
302
|
+
if shape is None:
|
303
|
+
shape = lax.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
|
304
|
+
elif isinstance(shape, int):
|
305
|
+
shape = (shape,)
|
306
|
+
else:
|
307
|
+
shape = tuple(shape)
|
308
|
+
|
309
|
+
dfnum = jnp.broadcast_to(dfnum, shape)
|
310
|
+
dfden = jnp.broadcast_to(dfden, shape)
|
311
|
+
|
312
|
+
size = int(np.prod(shape)) if shape else 1
|
313
|
+
if size == 0:
|
314
|
+
return jnp.empty(shape, dtype=dtype)
|
315
|
+
|
316
|
+
key_num, key_den = jr.split(key)
|
317
|
+
chi2_num = 2.0 * jr.gamma(key_num, 0.5 * dfnum, shape=shape, dtype=dtype)
|
318
|
+
chi2_den = 2.0 * jr.gamma(key_den, 0.5 * dfden, shape=shape, dtype=dtype)
|
319
|
+
|
320
|
+
return (chi2_num / dfnum) / (chi2_den / dfden)
|
321
|
+
|
322
|
+
|
323
|
+
@partial(jit, static_argnames=['shape', 'dtype'])
|
324
|
+
def noncentral_f(
|
325
|
+
key,
|
326
|
+
dfnum,
|
327
|
+
dfden,
|
328
|
+
nonc,
|
329
|
+
*,
|
330
|
+
shape,
|
331
|
+
dtype=None
|
332
|
+
):
|
333
|
+
"""
|
334
|
+
Draw samples from the noncentral F distribution.
|
335
|
+
|
336
|
+
The noncentral F distribution is a generalization of the F distribution.
|
337
|
+
It is parameterized by dfnum (degrees of freedom of the numerator),
|
338
|
+
dfden (degrees of freedom of the denominator), and nonc (noncentrality parameter).
|
339
|
+
|
340
|
+
The implementation uses the relationship:
|
341
|
+
If X ~ noncentral_chisquare(dfnum, nonc) and Y ~ chisquare(dfden), then
|
342
|
+
F = (X / dfnum) / (Y / dfden) ~ noncentral_f(dfnum, dfden, nonc)
|
343
|
+
|
344
|
+
Parameters
|
345
|
+
----------
|
346
|
+
key : jax.random.PRNGKey
|
347
|
+
Random key
|
348
|
+
dfnum : float or array_like
|
349
|
+
Degrees of freedom of the numerator, must be > 0
|
350
|
+
dfden : float or array_like
|
351
|
+
Degrees of freedom of the denominator, must be > 0
|
352
|
+
nonc : float or array_like
|
353
|
+
Noncentrality parameter, must be >= 0
|
354
|
+
shape : tuple
|
355
|
+
Output shape
|
356
|
+
dtype : dtype, optional
|
357
|
+
Data type of the output
|
358
|
+
|
359
|
+
Returns
|
360
|
+
-------
|
361
|
+
out : array_like
|
362
|
+
Samples from the noncentral F distribution
|
363
|
+
"""
|
364
|
+
dtype = dtype or environ.dftype()
|
365
|
+
dfnum = lax.convert_element_type(dfnum, dtype)
|
366
|
+
dfden = lax.convert_element_type(dfden, dtype)
|
367
|
+
nonc = lax.convert_element_type(nonc, dtype)
|
368
|
+
|
369
|
+
# Split key for two random samples
|
370
|
+
key1, key2 = jr.split(key)
|
371
|
+
|
372
|
+
# Generate noncentral chi-square for numerator
|
373
|
+
# noncentral_chisquare(df, nonc) = chi-square(df - 1) + (normal(0,1) + sqrt(nonc))^2
|
374
|
+
# when df > 1, else chi-square(df + 2*poisson(nonc/2))
|
375
|
+
keys_numer = jr.split(key1, 3)
|
376
|
+
i = jr.poisson(keys_numer[0], 0.5 * nonc, shape=shape, dtype=environ.ditype())
|
377
|
+
n = jr.normal(keys_numer[1], shape=shape, dtype=dtype) + jnp.sqrt(nonc)
|
378
|
+
cond = jnp.greater(dfnum, 1.0)
|
379
|
+
df_numerator = jnp.where(cond, dfnum - 1.0, dfnum + 2.0 * i)
|
380
|
+
chi2_numerator = 2.0 * jr.gamma(keys_numer[2], 0.5 * df_numerator, shape=shape, dtype=dtype)
|
381
|
+
numerator = jnp.where(cond, chi2_numerator + n * n, chi2_numerator)
|
382
|
+
|
383
|
+
# Generate central chi-square for denominator
|
384
|
+
# chi-square(df) = 2 * gamma(df/2, 1)
|
385
|
+
chi2_denominator = 2.0 * jr.gamma(key2, 0.5 * dfden, shape=shape, dtype=dtype)
|
386
|
+
|
387
|
+
# Compute F statistic: (numerator / dfnum) / (denominator / dfden)
|
388
|
+
f_stat = (numerator / dfnum) / (chi2_denominator / dfden)
|
389
|
+
|
390
|
+
return f_stat
|
391
|
+
|
392
|
+
|
393
|
+
@partial(jit, static_argnames=['shape', 'dtype'])
|
394
|
+
def logseries(
|
395
|
+
key,
|
396
|
+
p,
|
397
|
+
*,
|
398
|
+
shape,
|
399
|
+
dtype=None
|
400
|
+
):
|
401
|
+
"""Draw samples from the logarithmic series distribution."""
|
402
|
+
dtype = dtype or environ.ditype()
|
403
|
+
float_dtype = dtypes.canonicalize_dtype(environ.dftype())
|
404
|
+
calc_dtype = dtypes.canonicalize_dtype(jnp.promote_types(float_dtype, jnp.float64))
|
405
|
+
|
406
|
+
p = lax.convert_element_type(p, float_dtype)
|
407
|
+
|
408
|
+
if shape is None:
|
409
|
+
shape = u.math.shape(p)
|
410
|
+
elif isinstance(shape, int):
|
411
|
+
shape = (shape,)
|
412
|
+
else:
|
413
|
+
shape = tuple(shape)
|
414
|
+
|
415
|
+
p = jnp.broadcast_to(p, shape)
|
416
|
+
|
417
|
+
size = int(np.prod(shape)) if shape else 1
|
418
|
+
if size == 0:
|
419
|
+
return jnp.empty(shape, dtype=dtype)
|
420
|
+
|
421
|
+
p_flat = jnp.reshape(lax.convert_element_type(p, calc_dtype), (size,))
|
422
|
+
keys = jr.split(key, size)
|
423
|
+
|
424
|
+
tiny = jnp.array(np.finfo(calc_dtype).tiny, dtype=calc_dtype)
|
425
|
+
one_minus_eps = jnp.nextafter(jnp.array(1.0, dtype=calc_dtype), jnp.array(0.0, dtype=calc_dtype))
|
426
|
+
|
427
|
+
def _sample_one(single_key, p_scalar):
|
428
|
+
p_scalar = lax.convert_element_type(p_scalar, calc_dtype)
|
429
|
+
operand = (single_key, p_scalar)
|
430
|
+
|
431
|
+
def _limit_case(_):
|
432
|
+
return jnp.array(1.0, dtype=calc_dtype)
|
433
|
+
|
434
|
+
def _positive_case(args):
|
435
|
+
key_i, p_val = args
|
436
|
+
p_val = jnp.clip(p_val, tiny, one_minus_eps)
|
437
|
+
log_p = jnp.log(p_val)
|
438
|
+
log_norm = jnp.log(-jnp.log1p(-p_val))
|
439
|
+
log_prob = log_p - log_norm
|
440
|
+
log_cdf = log_prob
|
441
|
+
log_u = jnp.log(jr.uniform(key_i, shape=(), dtype=calc_dtype, minval=tiny, maxval=one_minus_eps))
|
442
|
+
|
443
|
+
init_state = (jnp.array(1.0, dtype=calc_dtype), log_prob, log_cdf, log_u)
|
444
|
+
|
445
|
+
def cond_fn(state):
|
446
|
+
_, _, log_cdf_val, log_u_val = state
|
447
|
+
return log_u_val > log_cdf_val
|
448
|
+
|
449
|
+
def body_fn(state):
|
450
|
+
k_val, log_prob_val, log_cdf_val, log_u_val = state
|
451
|
+
k_next = k_val + 1.0
|
452
|
+
log_prob_next = log_prob_val + log_p + jnp.log(k_val) - jnp.log(k_next)
|
453
|
+
log_cdf_next = jnp.logaddexp(log_cdf_val, log_prob_next)
|
454
|
+
return k_next, log_prob_next, log_cdf_next, log_u_val
|
455
|
+
|
456
|
+
k_val, _, _, _ = lax.while_loop(cond_fn, body_fn, init_state)
|
457
|
+
return k_val
|
458
|
+
|
459
|
+
return lax.cond(p_scalar <= 0.0, _limit_case, _positive_case, operand)
|
460
|
+
|
461
|
+
samples = vmap(_sample_one)(keys, p_flat)
|
462
|
+
samples = lax.convert_element_type(samples, dtype)
|
463
|
+
return jnp.reshape(samples, shape)
|
464
|
+
|
465
|
+
|
466
|
+
@partial(jit, static_argnames=['shape', 'dtype'])
|
467
|
+
def zipf(
|
468
|
+
key,
|
469
|
+
a,
|
470
|
+
*,
|
471
|
+
shape,
|
472
|
+
dtype=None
|
473
|
+
):
|
474
|
+
"""Draw samples from the Zipf (zeta) distribution."""
|
475
|
+
dtype = dtype or environ.ditype()
|
476
|
+
float_dtype = dtypes.canonicalize_dtype(environ.dftype())
|
477
|
+
calc_dtype = dtypes.canonicalize_dtype(jnp.promote_types(float_dtype, jnp.float64))
|
478
|
+
|
479
|
+
a = lax.convert_element_type(a, calc_dtype)
|
480
|
+
|
481
|
+
if shape is None:
|
482
|
+
shape = u.math.shape(a)
|
483
|
+
elif isinstance(shape, int):
|
484
|
+
shape = (shape,)
|
485
|
+
else:
|
486
|
+
shape = tuple(shape)
|
487
|
+
|
488
|
+
a = jnp.broadcast_to(a, shape)
|
489
|
+
|
490
|
+
size = int(np.prod(shape)) if shape else 1
|
491
|
+
if size == 0:
|
492
|
+
return jnp.empty(shape, dtype=dtype)
|
493
|
+
|
494
|
+
u_ = jr.uniform(
|
495
|
+
key,
|
496
|
+
shape=shape,
|
497
|
+
dtype=calc_dtype,
|
498
|
+
minval=jnp.finfo(calc_dtype).tiny,
|
499
|
+
maxval=jnp.array(1.0, dtype=calc_dtype)
|
500
|
+
)
|
501
|
+
|
502
|
+
a_flat = jnp.reshape(a, (size,))
|
503
|
+
u_flat = jnp.reshape(u_, (size,))
|
504
|
+
|
505
|
+
max_iters = jnp.array(1000000, dtype=jnp.int32)
|
506
|
+
|
507
|
+
def _sample_one(a_scalar, u_scalar):
|
508
|
+
norm = jsp.zeta(a_scalar, jnp.array(1.0, dtype=calc_dtype))
|
509
|
+
|
510
|
+
def cdf(k_val):
|
511
|
+
return (
|
512
|
+
jnp.array(1.0, dtype=calc_dtype) -
|
513
|
+
jsp.zeta(a_scalar, k_val + jnp.array(1.0, dtype=calc_dtype)) / norm
|
514
|
+
)
|
515
|
+
|
516
|
+
initial = jnp.array(1.0, dtype=calc_dtype)
|
517
|
+
cdf_prev = jnp.array(0.0, dtype=calc_dtype)
|
518
|
+
cdf_curr = cdf(initial)
|
519
|
+
|
520
|
+
state = (
|
521
|
+
initial,
|
522
|
+
cdf_prev,
|
523
|
+
cdf_curr,
|
524
|
+
jnp.array(0, dtype=jnp.int32)
|
525
|
+
)
|
526
|
+
|
527
|
+
def cond_fn(state):
|
528
|
+
_, c_prev, c_curr, it = state
|
529
|
+
not_ok = jnp.logical_or(u_scalar > c_curr, u_scalar <= c_prev)
|
530
|
+
return jnp.logical_and(not_ok, it < max_iters)
|
531
|
+
|
532
|
+
def body_fn(state):
|
533
|
+
k_val, c_prev, c_curr, it = state
|
534
|
+
need_increase = u_scalar > c_curr
|
535
|
+
|
536
|
+
def inc(_):
|
537
|
+
k_next = k_val + jnp.array(1.0, dtype=calc_dtype)
|
538
|
+
c_prev_next = jnp.array(1.0, dtype=calc_dtype) - jsp.zeta(a_scalar, k_next) / norm
|
539
|
+
c_curr_next = cdf(k_next)
|
540
|
+
return k_next, c_prev_next, c_curr_next, it + 1
|
541
|
+
|
542
|
+
def dec(_):
|
543
|
+
k_next = jnp.maximum(jnp.array(1.0, dtype=calc_dtype), k_val - jnp.array(1.0, dtype=calc_dtype))
|
544
|
+
c_prev_next = jnp.array(1.0, dtype=calc_dtype) - jsp.zeta(a_scalar, k_next) / norm
|
545
|
+
c_curr_next = cdf(k_next)
|
546
|
+
return k_next, c_prev_next, c_curr_next, it + 1
|
547
|
+
|
548
|
+
return lax.cond(need_increase, inc, dec, operand=None)
|
549
|
+
|
550
|
+
k_final, _, _, _ = lax.while_loop(cond_fn, body_fn, state)
|
551
|
+
return lax.convert_element_type(k_final, dtype)
|
552
|
+
|
553
|
+
samples_flat = jax.vmap(_sample_one)(a_flat, u_flat)
|
554
|
+
samples = jnp.reshape(samples_flat, shape)
|
555
|
+
return samples
|
556
|
+
|
557
|
+
|
558
|
+
@partial(jit, static_argnames=['shape', 'dtype'])
|
559
|
+
def power(
|
560
|
+
key,
|
561
|
+
a,
|
562
|
+
*,
|
563
|
+
shape,
|
564
|
+
dtype=None
|
565
|
+
):
|
566
|
+
"""Draw samples from the power distribution."""
|
567
|
+
dtype = dtype or environ.dftype()
|
568
|
+
float_dtype = dtypes.canonicalize_dtype(dtype)
|
569
|
+
|
570
|
+
a = lax.convert_element_type(a, float_dtype)
|
571
|
+
|
572
|
+
if shape is None:
|
573
|
+
shape = u.math.shape(a)
|
574
|
+
elif isinstance(shape, int):
|
575
|
+
shape = (shape,)
|
576
|
+
else:
|
577
|
+
shape = tuple(shape)
|
578
|
+
|
579
|
+
a = jnp.broadcast_to(a, shape)
|
580
|
+
|
581
|
+
size = int(np.prod(shape)) if shape else 1
|
582
|
+
if size == 0:
|
583
|
+
return jnp.empty(shape, dtype=float_dtype)
|
584
|
+
|
585
|
+
eps = jnp.array(np.finfo(float_dtype).tiny, dtype=float_dtype)
|
586
|
+
a_safe = jnp.maximum(a, eps)
|
587
|
+
|
588
|
+
u_ = jr.uniform(key, shape=shape, dtype=float_dtype, minval=eps, maxval=1.0)
|
589
|
+
samples = jnp.power(u_, jnp.reciprocal(a_safe))
|
590
|
+
|
591
|
+
return lax.convert_element_type(samples, dtype)
|
592
|
+
|
593
|
+
|
594
|
+
@partial(jit, static_argnames=['shape', 'dtype'])
|
595
|
+
def hypergeometric(
|
596
|
+
key,
|
597
|
+
ngood,
|
598
|
+
nbad,
|
599
|
+
nsample,
|
600
|
+
*,
|
601
|
+
shape,
|
602
|
+
dtype=None
|
603
|
+
):
|
604
|
+
"""Draw samples from the hypergeometric distribution."""
|
605
|
+
dtype = dtype or environ.ditype()
|
606
|
+
out_dtype = dtypes.canonicalize_dtype(dtype)
|
607
|
+
float_dtype = dtypes.canonicalize_dtype(environ.dftype())
|
608
|
+
calc_dtype = dtypes.canonicalize_dtype(jnp.promote_types(float_dtype, jnp.float64))
|
609
|
+
|
610
|
+
ngood = lax.convert_element_type(ngood, out_dtype)
|
611
|
+
nbad = lax.convert_element_type(nbad, out_dtype)
|
612
|
+
nsample = lax.convert_element_type(nsample, out_dtype)
|
613
|
+
|
614
|
+
if shape is None:
|
615
|
+
shape = lax.broadcast_shapes(u.math.shape(ngood), u.math.shape(nbad), u.math.shape(nsample))
|
616
|
+
elif isinstance(shape, int):
|
617
|
+
shape = (shape,)
|
618
|
+
else:
|
619
|
+
shape = tuple(shape)
|
620
|
+
|
621
|
+
ngood = jnp.broadcast_to(ngood, shape)
|
622
|
+
nbad = jnp.broadcast_to(nbad, shape)
|
623
|
+
nsample = jnp.broadcast_to(nsample, shape)
|
624
|
+
|
625
|
+
size = int(np.prod(shape)) if shape else 1
|
626
|
+
if size == 0:
|
627
|
+
return jnp.empty(shape, dtype=out_dtype)
|
628
|
+
|
629
|
+
flat_ngood = jnp.reshape(ngood, (size,))
|
630
|
+
flat_nbad = jnp.reshape(nbad, (size,))
|
631
|
+
flat_nsample = jnp.reshape(nsample, (size,))
|
632
|
+
sample_keys = jr.split(key, size + 1)[1:]
|
633
|
+
|
634
|
+
one = jnp.array(1, dtype=out_dtype)
|
635
|
+
zero = jnp.array(0, dtype=out_dtype)
|
636
|
+
|
637
|
+
def _sample_one(sample_key, good, bad, draws):
|
638
|
+
good = jnp.maximum(good, zero)
|
639
|
+
bad = jnp.maximum(bad, zero)
|
640
|
+
draws = jnp.maximum(draws, zero)
|
641
|
+
total = good + bad
|
642
|
+
draws = jnp.minimum(draws, total)
|
643
|
+
|
644
|
+
init_state = (zero, sample_key, good, bad, zero, draws)
|
645
|
+
|
646
|
+
def cond_fn(state):
|
647
|
+
i, _, good_i, bad_i, _, draws_i = state
|
648
|
+
total_i = good_i + bad_i
|
649
|
+
return jnp.logical_and(i < draws_i, total_i > zero)
|
650
|
+
|
651
|
+
def body_fn(state):
|
652
|
+
i, key_i, good_i, bad_i, succ_i, draws_i = state
|
653
|
+
key_i, subkey = jr.split(key_i)
|
654
|
+
total_i = good_i + bad_i
|
655
|
+
prob = jnp.where(
|
656
|
+
total_i > zero,
|
657
|
+
lax.convert_element_type(good_i, calc_dtype) / lax.convert_element_type(total_i, calc_dtype),
|
658
|
+
jnp.array(0.0, dtype=calc_dtype),
|
659
|
+
)
|
660
|
+
u = jr.uniform(subkey, shape=(), dtype=calc_dtype)
|
661
|
+
success = (u < prob).astype(out_dtype)
|
662
|
+
good_i = good_i - success
|
663
|
+
bad_i = bad_i - jnp.where(total_i > zero, one - success, zero)
|
664
|
+
succ_i = succ_i + success
|
665
|
+
return (i + one, key_i, good_i, bad_i, succ_i, draws_i)
|
666
|
+
|
667
|
+
_, _, _, _, successes, _ = lax.while_loop(cond_fn, body_fn, init_state)
|
668
|
+
return successes
|
669
|
+
|
670
|
+
samples = jax.vmap(_sample_one)(sample_keys, flat_ngood, flat_nbad, flat_nsample)
|
671
|
+
samples = lax.convert_element_type(samples, out_dtype)
|
672
|
+
return jnp.reshape(samples, shape)
|