brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +6 -3
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -3
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_xla_custom_op.py +7 -3
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,8 @@ import numpy as np
|
|
22
22
|
|
23
23
|
from brainstate import environ, init, random
|
24
24
|
from brainstate._state import ShortTermState
|
25
|
-
from brainstate._state import State
|
26
|
-
from brainstate.compile import while_loop
|
25
|
+
from brainstate._state import State, maybe_state
|
26
|
+
from brainstate.compile import while_loop
|
27
27
|
from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
|
28
28
|
from brainstate.nn._module import Module
|
29
29
|
from brainstate.typing import ArrayLike, Size, DTypeLike
|
@@ -198,55 +198,97 @@ class PoissonInput(Module):
|
|
198
198
|
self.weight = weight
|
199
199
|
|
200
200
|
def update(self):
|
201
|
-
p = self.freq * environ.get_dt()
|
202
|
-
a = self.num_input * p
|
203
|
-
b = self.num_input * (1 - p)
|
204
|
-
|
205
|
-
target = self.target()
|
206
201
|
target_state = getattr(self.target.module, self.target.item)
|
207
202
|
|
208
203
|
# generate Poisson input
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
204
|
+
poisson_input(
|
205
|
+
self.freq,
|
206
|
+
self.num_input,
|
207
|
+
self.weight,
|
208
|
+
target_state,
|
209
|
+
self.indices,
|
213
210
|
)
|
214
211
|
|
215
|
-
# update target variable
|
216
|
-
target_state.value = target.at[self.indices].add(inp * self.weight)
|
217
|
-
|
218
212
|
|
219
213
|
def poisson_input(
|
220
|
-
freq:
|
214
|
+
freq: u.Quantity[u.Hz],
|
221
215
|
num_input: int,
|
222
|
-
weight:
|
216
|
+
weight: u.Quantity,
|
223
217
|
target: State,
|
224
218
|
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
225
219
|
):
|
226
220
|
"""
|
227
221
|
Poisson Input to the given :py:class:`brainstate.State`.
|
228
222
|
"""
|
223
|
+
freq = maybe_state(freq)
|
224
|
+
weight = maybe_state(weight)
|
225
|
+
|
229
226
|
assert isinstance(target, State), 'The target must be a State.'
|
230
|
-
p = freq * environ.get_dt()
|
227
|
+
p = (freq * environ.get_dt()).to_decimal()
|
231
228
|
a = num_input * p
|
232
229
|
b = num_input * (1 - p)
|
233
230
|
tar_val = target.value
|
231
|
+
cond = u.math.logical_and(a > 5, b > 5)
|
232
|
+
|
234
233
|
if indices is None:
|
235
234
|
# generate Poisson input
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
235
|
+
branch1 = jax.tree.map(
|
236
|
+
lambda tar: random.normal(
|
237
|
+
a,
|
238
|
+
b * p,
|
239
|
+
tar.shape,
|
240
|
+
dtype=tar.dtype
|
242
241
|
),
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
242
|
+
tar_val,
|
243
|
+
is_leaf=u.math.is_quantity
|
244
|
+
)
|
245
|
+
branch2 = jax.tree.map(
|
246
|
+
lambda tar: random.binomial(
|
247
|
+
num_input,
|
248
|
+
p,
|
249
|
+
tar.shape,
|
250
|
+
check_valid=False,
|
251
|
+
dtype=tar.dtype
|
252
|
+
),
|
253
|
+
tar_val,
|
254
|
+
is_leaf=u.math.is_quantity,
|
255
|
+
)
|
256
|
+
|
257
|
+
inp = jax.tree.map(
|
258
|
+
lambda b1, b2: u.math.where(cond, b1, b2),
|
259
|
+
branch1,
|
260
|
+
branch2,
|
261
|
+
is_leaf=u.math.is_quantity,
|
248
262
|
)
|
249
263
|
|
264
|
+
# inp = jax.lax.cond(
|
265
|
+
# cond,
|
266
|
+
# lambda rand_key: jax.tree.map(
|
267
|
+
# lambda tar: random.normal(
|
268
|
+
# a,
|
269
|
+
# b * p,
|
270
|
+
# tar.shape,
|
271
|
+
# key=rand_key,
|
272
|
+
# dtype=tar.dtype
|
273
|
+
# ),
|
274
|
+
# tar_val,
|
275
|
+
# is_leaf=u.math.is_quantity
|
276
|
+
# ),
|
277
|
+
# lambda rand_key: jax.tree.map(
|
278
|
+
# lambda tar: random.binomial(
|
279
|
+
# num_input,
|
280
|
+
# p,
|
281
|
+
# tar.shape,
|
282
|
+
# key=rand_key,
|
283
|
+
# check_valid=False,
|
284
|
+
# dtype=tar.dtype
|
285
|
+
# ),
|
286
|
+
# tar_val,
|
287
|
+
# is_leaf=u.math.is_quantity,
|
288
|
+
# ),
|
289
|
+
# random.split_key()
|
290
|
+
# )
|
291
|
+
|
250
292
|
# update target variable
|
251
293
|
target.value = jax.tree.map(
|
252
294
|
lambda x: x * weight,
|
@@ -256,19 +298,62 @@ def poisson_input(
|
|
256
298
|
|
257
299
|
else:
|
258
300
|
# generate Poisson input
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
301
|
+
branch1 = jax.tree.map(
|
302
|
+
lambda tar: random.normal(
|
303
|
+
a,
|
304
|
+
b * p,
|
305
|
+
tar[indices].shape,
|
306
|
+
dtype=tar.dtype
|
265
307
|
),
|
266
|
-
|
267
|
-
|
268
|
-
tar_val,
|
269
|
-
is_leaf=u.math.is_quantity
|
270
|
-
)
|
308
|
+
tar_val,
|
309
|
+
is_leaf=u.math.is_quantity
|
271
310
|
)
|
311
|
+
branch2 = jax.tree.map(
|
312
|
+
lambda tar: random.binomial(
|
313
|
+
num_input,
|
314
|
+
p,
|
315
|
+
tar[indices].shape,
|
316
|
+
# check_valid=False,
|
317
|
+
dtype=tar.dtype
|
318
|
+
),
|
319
|
+
tar_val,
|
320
|
+
is_leaf=u.math.is_quantity
|
321
|
+
)
|
322
|
+
|
323
|
+
inp = jax.tree.map(
|
324
|
+
lambda b1, b2: u.math.where(cond, b1, b2),
|
325
|
+
branch1,
|
326
|
+
branch2,
|
327
|
+
is_leaf=u.math.is_quantity,
|
328
|
+
)
|
329
|
+
|
330
|
+
# inp = jax.lax.cond(
|
331
|
+
# cond,
|
332
|
+
# lambda rand_key: jax.tree.map(
|
333
|
+
# lambda tar: random.normal(
|
334
|
+
# a,
|
335
|
+
# b * p,
|
336
|
+
# tar[indices].shape,
|
337
|
+
# key=rand_key,
|
338
|
+
# dtype=tar.dtype
|
339
|
+
# ),
|
340
|
+
# tar_val,
|
341
|
+
# is_leaf=u.math.is_quantity
|
342
|
+
# ),
|
343
|
+
# lambda rand_key: jax.tree.map(
|
344
|
+
# lambda tar: random.binomial(
|
345
|
+
# num_input,
|
346
|
+
# p,
|
347
|
+
# tar[indices].shape,
|
348
|
+
# key=rand_key,
|
349
|
+
# check_valid=False,
|
350
|
+
# dtype=tar.dtype
|
351
|
+
# ),
|
352
|
+
# tar_val,
|
353
|
+
# is_leaf=u.math.is_quantity
|
354
|
+
# ),
|
355
|
+
# random.split_key()
|
356
|
+
# )
|
272
357
|
|
273
358
|
# update target variable
|
274
359
|
target.value = jax.tree.map(
|
@@ -191,6 +191,7 @@ class _Conv(_BaseConv):
|
|
191
191
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
192
192
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
193
193
|
name: str = None,
|
194
|
+
param_type: type = ParamState,
|
194
195
|
):
|
195
196
|
super().__init__(in_size=in_size,
|
196
197
|
out_channels=out_channels,
|
@@ -215,7 +216,7 @@ class _Conv(_BaseConv):
|
|
215
216
|
params['bias'] = bias
|
216
217
|
|
217
218
|
# The weight operation
|
218
|
-
self.weight =
|
219
|
+
self.weight = param_type(params)
|
219
220
|
|
220
221
|
# Evaluate the output shape
|
221
222
|
abstract_y = jax.eval_shape(
|
@@ -346,6 +347,7 @@ class _ScaledWSConv(_BaseConv):
|
|
346
347
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
347
348
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
348
349
|
name: str = None,
|
350
|
+
param_type: type = ParamState,
|
349
351
|
):
|
350
352
|
super().__init__(in_size=in_size,
|
351
353
|
out_channels=out_channels,
|
@@ -379,7 +381,7 @@ class _ScaledWSConv(_BaseConv):
|
|
379
381
|
self.eps = eps
|
380
382
|
|
381
383
|
# The weight operation
|
382
|
-
self.weight =
|
384
|
+
self.weight = param_type(params)
|
383
385
|
|
384
386
|
# Evaluate the output shape
|
385
387
|
abstract_y = jax.eval_shape(
|
@@ -34,6 +34,7 @@ __all__ = [
|
|
34
34
|
'SparseLinear',
|
35
35
|
'AllToAll',
|
36
36
|
'OneToOne',
|
37
|
+
'LoRA',
|
37
38
|
]
|
38
39
|
|
39
40
|
|
@@ -51,6 +52,7 @@ class Linear(Module):
|
|
51
52
|
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
52
53
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
53
54
|
name: Optional[str] = None,
|
55
|
+
param_type: type = ParamState,
|
54
56
|
):
|
55
57
|
super().__init__(name=name)
|
56
58
|
|
@@ -67,7 +69,7 @@ class Linear(Module):
|
|
67
69
|
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
68
70
|
if b_init is not None:
|
69
71
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
70
|
-
self.weight =
|
72
|
+
self.weight = param_type(params)
|
71
73
|
|
72
74
|
def update(self, x):
|
73
75
|
params = self.weight.value
|
@@ -93,7 +95,7 @@ class SignedWLinear(Module):
|
|
93
95
|
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
94
96
|
w_sign: Optional[ArrayLike] = None,
|
95
97
|
name: Optional[str] = None,
|
96
|
-
|
98
|
+
param_type: type = ParamState,
|
97
99
|
):
|
98
100
|
super().__init__(name=name)
|
99
101
|
|
@@ -108,7 +110,7 @@ class SignedWLinear(Module):
|
|
108
110
|
|
109
111
|
# weights
|
110
112
|
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
111
|
-
self.weight =
|
113
|
+
self.weight = param_type(weight)
|
112
114
|
|
113
115
|
def update(self, x):
|
114
116
|
w = self.weight.value
|
@@ -156,6 +158,7 @@ class ScaledWSLinear(Module):
|
|
156
158
|
ws_gain: bool = True,
|
157
159
|
eps: float = 1e-4,
|
158
160
|
name: str = None,
|
161
|
+
param_type: type = ParamState,
|
159
162
|
):
|
160
163
|
super().__init__(name=name)
|
161
164
|
|
@@ -179,7 +182,7 @@ class ScaledWSLinear(Module):
|
|
179
182
|
if ws_gain:
|
180
183
|
s = params['weight'].shape
|
181
184
|
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
182
|
-
self.weight =
|
185
|
+
self.weight = param_type(params)
|
183
186
|
|
184
187
|
def update(self, x):
|
185
188
|
params = self.weight.value
|
@@ -211,6 +214,7 @@ class SparseLinear(Module):
|
|
211
214
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
212
215
|
in_size: Size = None,
|
213
216
|
name: Optional[str] = None,
|
217
|
+
param_type: type = ParamState,
|
214
218
|
):
|
215
219
|
super().__init__(name=name)
|
216
220
|
|
@@ -230,7 +234,7 @@ class SparseLinear(Module):
|
|
230
234
|
params = dict(weight=spar_mat.data)
|
231
235
|
if b_init is not None:
|
232
236
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
233
|
-
self.weight =
|
237
|
+
self.weight = param_type(params)
|
234
238
|
|
235
239
|
def update(self, x):
|
236
240
|
data = self.weight.value['weight']
|
@@ -260,6 +264,7 @@ class AllToAll(Module):
|
|
260
264
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
261
265
|
include_self: bool = True,
|
262
266
|
name: Optional[str] = None,
|
267
|
+
param_type: type = ParamState,
|
263
268
|
):
|
264
269
|
super().__init__(name=name)
|
265
270
|
|
@@ -277,7 +282,7 @@ class AllToAll(Module):
|
|
277
282
|
params = dict(weight=weight)
|
278
283
|
if b_init is not None:
|
279
284
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
280
|
-
self.weight =
|
285
|
+
self.weight = param_type(params)
|
281
286
|
|
282
287
|
def update(self, pre_val):
|
283
288
|
params = self.weight.value
|
@@ -332,6 +337,7 @@ class OneToOne(Module):
|
|
332
337
|
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
333
338
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
334
339
|
name: Optional[str] = None,
|
340
|
+
param_type: type = ParamState,
|
335
341
|
):
|
336
342
|
super().__init__(name=name)
|
337
343
|
|
@@ -343,13 +349,81 @@ class OneToOne(Module):
|
|
343
349
|
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
344
350
|
if b_init is not None:
|
345
351
|
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
346
|
-
self.weight = param
|
352
|
+
self.weight = param_type(param)
|
347
353
|
|
348
354
|
def update(self, pre_val):
|
349
355
|
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
350
|
-
w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
|
356
|
+
w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
|
351
357
|
post_val = pre_val * w_val
|
352
358
|
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
353
|
-
if 'bias' in self.weight:
|
354
|
-
post_val = post_val + self.weight['bias']
|
359
|
+
if 'bias' in self.weight.value:
|
360
|
+
post_val = post_val + self.weight.value['bias']
|
355
361
|
return post_val
|
362
|
+
|
363
|
+
|
364
|
+
class LoRA(Module):
|
365
|
+
"""A standalone LoRA layer.
|
366
|
+
|
367
|
+
Example usage::
|
368
|
+
|
369
|
+
>>> import brainstate as bst
|
370
|
+
>>> import jax, jax.numpy as jnp
|
371
|
+
>>> layer = bst.nn.LoRA(3, 2, 4)
|
372
|
+
>>> layer.weight.value
|
373
|
+
{'lora_a': Array([[ 0.25141352, -0.09826107],
|
374
|
+
[ 0.2328382 , 0.38869813],
|
375
|
+
[ 0.27069277, 0.7678282 ]], dtype=float32),
|
376
|
+
'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
|
377
|
+
[ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
|
378
|
+
>>> # Wrap around existing layer
|
379
|
+
>>> linear = bst.nn.Linear(3, 4)
|
380
|
+
>>> wrapper = bst.nn.LoRA(3, 2, 4, base_module=linear)
|
381
|
+
>>> assert wrapper.base_module == linear
|
382
|
+
>>> y = layer(jnp.ones((16, 3)))
|
383
|
+
>>> y.shape
|
384
|
+
(16, 4)
|
385
|
+
|
386
|
+
Args:
|
387
|
+
in_features: the number of input features.
|
388
|
+
lora_rank: the rank of the LoRA dimension.
|
389
|
+
out_features: the number of output features.
|
390
|
+
base_module: a base module to call and substitute, if possible.
|
391
|
+
kernel_init: initializer function for the weight matrices.
|
392
|
+
param_type: the type of the LoRA params.
|
393
|
+
"""
|
394
|
+
|
395
|
+
def __init__(
|
396
|
+
self,
|
397
|
+
in_features: int,
|
398
|
+
lora_rank: int,
|
399
|
+
out_features: int,
|
400
|
+
*,
|
401
|
+
base_module: Optional[Module] = None,
|
402
|
+
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
403
|
+
param_type: type = ParamState,
|
404
|
+
):
|
405
|
+
super().__init__()
|
406
|
+
|
407
|
+
# input and output shape
|
408
|
+
self.in_size = in_features
|
409
|
+
self.out_size = out_features
|
410
|
+
self.in_features = in_features
|
411
|
+
self.out_features = out_features
|
412
|
+
|
413
|
+
# others
|
414
|
+
self.base_module = base_module
|
415
|
+
|
416
|
+
# weights
|
417
|
+
param = dict(
|
418
|
+
lora_a=kernel_init((in_features, lora_rank)),
|
419
|
+
lora_b=kernel_init((lora_rank, out_features))
|
420
|
+
)
|
421
|
+
self.weight = param_type(param)
|
422
|
+
|
423
|
+
def __call__(self, x: ArrayLike):
|
424
|
+
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
425
|
+
if self.base_module is not None:
|
426
|
+
if not callable(self.base_module):
|
427
|
+
raise ValueError('`self.base_module` must be callable.')
|
428
|
+
out += self.base_module(x)
|
429
|
+
return out
|
brainstate/random/_rand_funs.py
CHANGED
@@ -1848,7 +1848,14 @@ def lognormal(mean=None, sigma=None, size: Optional[Size] = None,
|
|
1848
1848
|
return DEFAULT.lognormal(mean, sigma, size, key=key, dtype=dtype)
|
1849
1849
|
|
1850
1850
|
|
1851
|
-
def binomial(
|
1851
|
+
def binomial(
|
1852
|
+
n,
|
1853
|
+
p,
|
1854
|
+
size: Optional[Size] = None,
|
1855
|
+
key: Optional[SeedOrKey] = None,
|
1856
|
+
dtype: DTypeLike = None,
|
1857
|
+
check_valid: bool = True,
|
1858
|
+
):
|
1852
1859
|
r"""
|
1853
1860
|
Draw samples from a binomial distribution.
|
1854
1861
|
|
@@ -1933,7 +1940,7 @@ def binomial(n, p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
|
|
1933
1940
|
>>> sum(brainstate.random.binomial(9, 0.1, 20000) == 0)/20000.
|
1934
1941
|
# answer = 0.38885, or 38%.
|
1935
1942
|
"""
|
1936
|
-
return DEFAULT.binomial(n, p, size, key=key, dtype=dtype)
|
1943
|
+
return DEFAULT.binomial(n, p, size, key=key, dtype=dtype, check_valid=check_valid)
|
1937
1944
|
|
1938
1945
|
|
1939
1946
|
def chisquare(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
brainstate/random/_rand_seed.py
CHANGED
@@ -21,7 +21,7 @@ import jax
|
|
21
21
|
import numpy as np
|
22
22
|
|
23
23
|
from brainstate.typing import SeedOrKey
|
24
|
-
from ._rand_state import RandomState, DEFAULT
|
24
|
+
from ._rand_state import RandomState, DEFAULT, use_prng_key
|
25
25
|
|
26
26
|
__all__ = [
|
27
27
|
'seed', 'set_key', 'get_key', 'default_rng', 'split_key', 'split_keys', 'seed_context', 'restore_key',
|
@@ -123,7 +123,17 @@ def set_key(seed_or_key: SeedOrKey):
|
|
123
123
|
seed_or_key: int
|
124
124
|
The random key.
|
125
125
|
"""
|
126
|
-
|
126
|
+
if isinstance(seed_or_key, int):
|
127
|
+
# key = jax.random.key(seed_or_key)
|
128
|
+
key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jrjax.random.key(seed_or_key)
|
129
|
+
elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
|
130
|
+
if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
131
|
+
key = seed_or_key
|
132
|
+
elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
|
133
|
+
key = seed_or_key
|
134
|
+
else:
|
135
|
+
raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
|
136
|
+
DEFAULT.set_key(key)
|
127
137
|
|
128
138
|
|
129
139
|
def get_key():
|