brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@ __all__ = [
|
|
34
34
|
'SparseLinear',
|
35
35
|
'AllToAll',
|
36
36
|
'OneToOne',
|
37
|
+
'LoRA',
|
37
38
|
]
|
38
39
|
|
39
40
|
|
@@ -51,6 +52,7 @@ class Linear(Module):
|
|
51
52
|
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
52
53
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
53
54
|
name: Optional[str] = None,
|
55
|
+
param_type: type = ParamState,
|
54
56
|
):
|
55
57
|
super().__init__(name=name)
|
56
58
|
|
@@ -67,7 +69,7 @@ class Linear(Module):
|
|
67
69
|
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
68
70
|
if b_init is not None:
|
69
71
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
70
|
-
self.weight =
|
72
|
+
self.weight = param_type(params)
|
71
73
|
|
72
74
|
def update(self, x):
|
73
75
|
params = self.weight.value
|
@@ -93,7 +95,7 @@ class SignedWLinear(Module):
|
|
93
95
|
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
94
96
|
w_sign: Optional[ArrayLike] = None,
|
95
97
|
name: Optional[str] = None,
|
96
|
-
|
98
|
+
param_type: type = ParamState,
|
97
99
|
):
|
98
100
|
super().__init__(name=name)
|
99
101
|
|
@@ -108,7 +110,7 @@ class SignedWLinear(Module):
|
|
108
110
|
|
109
111
|
# weights
|
110
112
|
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
111
|
-
self.weight =
|
113
|
+
self.weight = param_type(weight)
|
112
114
|
|
113
115
|
def update(self, x):
|
114
116
|
w = self.weight.value
|
@@ -156,6 +158,7 @@ class ScaledWSLinear(Module):
|
|
156
158
|
ws_gain: bool = True,
|
157
159
|
eps: float = 1e-4,
|
158
160
|
name: str = None,
|
161
|
+
param_type: type = ParamState,
|
159
162
|
):
|
160
163
|
super().__init__(name=name)
|
161
164
|
|
@@ -179,7 +182,7 @@ class ScaledWSLinear(Module):
|
|
179
182
|
if ws_gain:
|
180
183
|
s = params['weight'].shape
|
181
184
|
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
182
|
-
self.weight =
|
185
|
+
self.weight = param_type(params)
|
183
186
|
|
184
187
|
def update(self, x):
|
185
188
|
params = self.weight.value
|
@@ -211,6 +214,7 @@ class SparseLinear(Module):
|
|
211
214
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
212
215
|
in_size: Size = None,
|
213
216
|
name: Optional[str] = None,
|
217
|
+
param_type: type = ParamState,
|
214
218
|
):
|
215
219
|
super().__init__(name=name)
|
216
220
|
|
@@ -230,7 +234,7 @@ class SparseLinear(Module):
|
|
230
234
|
params = dict(weight=spar_mat.data)
|
231
235
|
if b_init is not None:
|
232
236
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
233
|
-
self.weight =
|
237
|
+
self.weight = param_type(params)
|
234
238
|
|
235
239
|
def update(self, x):
|
236
240
|
data = self.weight.value['weight']
|
@@ -260,6 +264,7 @@ class AllToAll(Module):
|
|
260
264
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
261
265
|
include_self: bool = True,
|
262
266
|
name: Optional[str] = None,
|
267
|
+
param_type: type = ParamState,
|
263
268
|
):
|
264
269
|
super().__init__(name=name)
|
265
270
|
|
@@ -277,7 +282,7 @@ class AllToAll(Module):
|
|
277
282
|
params = dict(weight=weight)
|
278
283
|
if b_init is not None:
|
279
284
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
280
|
-
self.weight =
|
285
|
+
self.weight = param_type(params)
|
281
286
|
|
282
287
|
def update(self, pre_val):
|
283
288
|
params = self.weight.value
|
@@ -332,6 +337,7 @@ class OneToOne(Module):
|
|
332
337
|
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
333
338
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
334
339
|
name: Optional[str] = None,
|
340
|
+
param_type: type = ParamState,
|
335
341
|
):
|
336
342
|
super().__init__(name=name)
|
337
343
|
|
@@ -343,13 +349,81 @@ class OneToOne(Module):
|
|
343
349
|
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
344
350
|
if b_init is not None:
|
345
351
|
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
346
|
-
self.weight = param
|
352
|
+
self.weight = param_type(param)
|
347
353
|
|
348
354
|
def update(self, pre_val):
|
349
355
|
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
350
|
-
w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
|
356
|
+
w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
|
351
357
|
post_val = pre_val * w_val
|
352
358
|
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
353
|
-
if 'bias' in self.weight:
|
354
|
-
post_val = post_val + self.weight['bias']
|
359
|
+
if 'bias' in self.weight.value:
|
360
|
+
post_val = post_val + self.weight.value['bias']
|
355
361
|
return post_val
|
362
|
+
|
363
|
+
|
364
|
+
class LoRA(Module):
|
365
|
+
"""A standalone LoRA layer.
|
366
|
+
|
367
|
+
Example usage::
|
368
|
+
|
369
|
+
>>> import brainstate as bst
|
370
|
+
>>> import jax, jax.numpy as jnp
|
371
|
+
>>> layer = bst.nn.LoRA(3, 2, 4)
|
372
|
+
>>> layer.weight.value
|
373
|
+
{'lora_a': Array([[ 0.25141352, -0.09826107],
|
374
|
+
[ 0.2328382 , 0.38869813],
|
375
|
+
[ 0.27069277, 0.7678282 ]], dtype=float32),
|
376
|
+
'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
|
377
|
+
[ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
|
378
|
+
>>> # Wrap around existing layer
|
379
|
+
>>> linear = bst.nn.Linear(3, 4)
|
380
|
+
>>> wrapper = bst.nn.LoRA(3, 2, 4, base_module=linear)
|
381
|
+
>>> assert wrapper.base_module == linear
|
382
|
+
>>> y = layer(jnp.ones((16, 3)))
|
383
|
+
>>> y.shape
|
384
|
+
(16, 4)
|
385
|
+
|
386
|
+
Args:
|
387
|
+
in_features: the number of input features.
|
388
|
+
lora_rank: the rank of the LoRA dimension.
|
389
|
+
out_features: the number of output features.
|
390
|
+
base_module: a base module to call and substitute, if possible.
|
391
|
+
kernel_init: initializer function for the weight matrices.
|
392
|
+
param_type: the type of the LoRA params.
|
393
|
+
"""
|
394
|
+
|
395
|
+
def __init__(
|
396
|
+
self,
|
397
|
+
in_features: int,
|
398
|
+
lora_rank: int,
|
399
|
+
out_features: int,
|
400
|
+
*,
|
401
|
+
base_module: Optional[Module] = None,
|
402
|
+
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
403
|
+
param_type: type = ParamState,
|
404
|
+
):
|
405
|
+
super().__init__()
|
406
|
+
|
407
|
+
# input and output shape
|
408
|
+
self.in_size = in_features
|
409
|
+
self.out_size = out_features
|
410
|
+
self.in_features = in_features
|
411
|
+
self.out_features = out_features
|
412
|
+
|
413
|
+
# others
|
414
|
+
self.base_module = base_module
|
415
|
+
|
416
|
+
# weights
|
417
|
+
param = dict(
|
418
|
+
lora_a=kernel_init((in_features, lora_rank)),
|
419
|
+
lora_b=kernel_init((lora_rank, out_features))
|
420
|
+
)
|
421
|
+
self.weight = param_type(param)
|
422
|
+
|
423
|
+
def __call__(self, x: ArrayLike):
|
424
|
+
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
425
|
+
if self.base_module is not None:
|
426
|
+
if not callable(self.base_module):
|
427
|
+
raise ValueError('`self.base_module` must be callable.')
|
428
|
+
out += self.base_module(x)
|
429
|
+
return out
|
brainstate/random/_rand_funs.py
CHANGED
@@ -1848,7 +1848,14 @@ def lognormal(mean=None, sigma=None, size: Optional[Size] = None,
|
|
1848
1848
|
return DEFAULT.lognormal(mean, sigma, size, key=key, dtype=dtype)
|
1849
1849
|
|
1850
1850
|
|
1851
|
-
def binomial(
|
1851
|
+
def binomial(
|
1852
|
+
n,
|
1853
|
+
p,
|
1854
|
+
size: Optional[Size] = None,
|
1855
|
+
key: Optional[SeedOrKey] = None,
|
1856
|
+
dtype: DTypeLike = None,
|
1857
|
+
check_valid: bool = True,
|
1858
|
+
):
|
1852
1859
|
r"""
|
1853
1860
|
Draw samples from a binomial distribution.
|
1854
1861
|
|
@@ -1933,7 +1940,7 @@ def binomial(n, p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
|
|
1933
1940
|
>>> sum(brainstate.random.binomial(9, 0.1, 20000) == 0)/20000.
|
1934
1941
|
# answer = 0.38885, or 38%.
|
1935
1942
|
"""
|
1936
|
-
return DEFAULT.binomial(n, p, size, key=key, dtype=dtype)
|
1943
|
+
return DEFAULT.binomial(n, p, size, key=key, dtype=dtype, check_valid=check_valid)
|
1937
1944
|
|
1938
1945
|
|
1939
1946
|
def chisquare(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
brainstate/random/_rand_seed.py
CHANGED
@@ -21,7 +21,7 @@ import jax
|
|
21
21
|
import numpy as np
|
22
22
|
|
23
23
|
from brainstate.typing import SeedOrKey
|
24
|
-
from ._rand_state import RandomState, DEFAULT
|
24
|
+
from ._rand_state import RandomState, DEFAULT, use_prng_key
|
25
25
|
|
26
26
|
__all__ = [
|
27
27
|
'seed', 'set_key', 'get_key', 'default_rng', 'split_key', 'split_keys', 'seed_context', 'restore_key',
|
@@ -123,7 +123,17 @@ def set_key(seed_or_key: SeedOrKey):
|
|
123
123
|
seed_or_key: int
|
124
124
|
The random key.
|
125
125
|
"""
|
126
|
-
|
126
|
+
if isinstance(seed_or_key, int):
|
127
|
+
# key = jax.random.key(seed_or_key)
|
128
|
+
key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jrjax.random.key(seed_or_key)
|
129
|
+
elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
|
130
|
+
if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
131
|
+
key = seed_or_key
|
132
|
+
elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
|
133
|
+
key = seed_or_key
|
134
|
+
else:
|
135
|
+
raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
|
136
|
+
DEFAULT.set_key(key)
|
127
137
|
|
128
138
|
|
129
139
|
def get_key():
|
brainstate/random/_rand_state.py
CHANGED
@@ -16,7 +16,6 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from collections import namedtuple
|
20
19
|
from functools import partial
|
21
20
|
from operator import index
|
22
21
|
from typing import Optional
|
@@ -37,6 +36,8 @@ from ._random_for_unit import uniform_for_unit, permutation_for_unit
|
|
37
36
|
|
38
37
|
__all__ = ['RandomState', 'DEFAULT', ]
|
39
38
|
|
39
|
+
use_prng_key = True
|
40
|
+
|
40
41
|
|
41
42
|
class RandomState(State):
|
42
43
|
"""RandomState that track the random generator state. """
|
@@ -56,12 +57,15 @@ class RandomState(State):
|
|
56
57
|
if seed_or_key is None:
|
57
58
|
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
58
59
|
if isinstance(seed_or_key, int):
|
59
|
-
key = jr.PRNGKey(seed_or_key)
|
60
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
60
61
|
else:
|
61
|
-
if
|
62
|
-
|
63
|
-
|
64
|
-
|
62
|
+
if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
63
|
+
key = seed_or_key
|
64
|
+
else:
|
65
|
+
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
|
66
|
+
raise ValueError('key must be an array with dtype uint32. '
|
67
|
+
f'But we got {seed_or_key}')
|
68
|
+
key = seed_or_key
|
65
69
|
super().__init__(key)
|
66
70
|
|
67
71
|
self._backup = None
|
@@ -70,6 +74,9 @@ class RandomState(State):
|
|
70
74
|
return f'{self.__class__.__name__}({self.value})'
|
71
75
|
|
72
76
|
def check_if_deleted(self):
|
77
|
+
if not use_prng_key and isinstance(self._value, np.ndarray):
|
78
|
+
self._value = jr.key(np.random.randint(0, 10000))
|
79
|
+
|
73
80
|
if (
|
74
81
|
isinstance(self._value, jax.Array) and
|
75
82
|
not isinstance(self._value, jax.core.Tracer) and
|
@@ -111,12 +118,19 @@ class RandomState(State):
|
|
111
118
|
if seed_or_key is None:
|
112
119
|
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
113
120
|
if np.size(seed_or_key) == 1:
|
114
|
-
|
121
|
+
if isinstance(seed_or_key, int):
|
122
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
123
|
+
elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
124
|
+
key = seed_or_key
|
125
|
+
elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
|
126
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
127
|
+
else:
|
128
|
+
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
115
129
|
else:
|
116
|
-
if len(seed_or_key)
|
117
|
-
|
118
|
-
|
119
|
-
|
130
|
+
if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
|
131
|
+
key = seed_or_key
|
132
|
+
else:
|
133
|
+
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
120
134
|
self.value = key
|
121
135
|
|
122
136
|
def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
|
@@ -560,15 +574,15 @@ class RandomState(State):
|
|
560
574
|
)
|
561
575
|
return out if unit.is_unitless else u.Quantity(out, unit=unit)
|
562
576
|
|
563
|
-
def _check_p(self,
|
564
|
-
raise ValueError(
|
577
|
+
def _check_p(self, *args, **kwargs):
|
578
|
+
raise ValueError('Parameter p should be within [0, 1], but we got {p}')
|
565
579
|
|
566
580
|
def bernoulli(self,
|
567
581
|
p,
|
568
582
|
size: Optional[Size] = None,
|
569
583
|
key: Optional[SeedOrKey] = None):
|
570
584
|
p = _check_py_seq(p)
|
571
|
-
jit_error_if(jnp.any(jnp.
|
585
|
+
jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
|
572
586
|
if size is None:
|
573
587
|
size = jnp.shape(p)
|
574
588
|
key = self.split_key() if key is None else _formalize_key(key)
|
@@ -603,19 +617,27 @@ class RandomState(State):
|
|
603
617
|
samples = jnp.exp(samples)
|
604
618
|
return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
|
605
619
|
|
606
|
-
def binomial(
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
620
|
+
def binomial(
|
621
|
+
self,
|
622
|
+
n,
|
623
|
+
p,
|
624
|
+
size: Optional[Size] = None,
|
625
|
+
key: Optional[SeedOrKey] = None,
|
626
|
+
dtype: DTypeLike = None,
|
627
|
+
check_valid: bool = True,
|
628
|
+
):
|
612
629
|
n = _check_py_seq(n)
|
613
630
|
p = _check_py_seq(p)
|
614
|
-
|
631
|
+
if check_valid:
|
632
|
+
jit_error_if(
|
633
|
+
jnp.any(jnp.logical_or(p < 0, p > 1)),
|
634
|
+
'Parameter p should be within [0, 1], but we got {p}',
|
635
|
+
p=p
|
636
|
+
)
|
615
637
|
if size is None:
|
616
638
|
size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
|
617
639
|
key = self.split_key() if key is None else _formalize_key(key)
|
618
|
-
r =
|
640
|
+
r = jr.binomial(key, n, p, shape=_size2shape(size))
|
619
641
|
dtype = dtype or environ.ditype()
|
620
642
|
return jnp.asarray(r, dtype=dtype)
|
621
643
|
|
@@ -1142,8 +1164,13 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
|
|
1142
1164
|
|
1143
1165
|
def _formalize_key(key):
|
1144
1166
|
if isinstance(key, int):
|
1145
|
-
return jr.PRNGKey(key)
|
1167
|
+
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1146
1168
|
elif isinstance(key, (jax.Array, np.ndarray)):
|
1169
|
+
if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
|
1170
|
+
return key
|
1171
|
+
if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
|
1172
|
+
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1173
|
+
|
1147
1174
|
if key.dtype != jnp.uint32:
|
1148
1175
|
raise TypeError('key must be a int or an array with two uint32.')
|
1149
1176
|
if key.size != 2:
|
@@ -1216,162 +1243,6 @@ def _const(example, val):
|
|
1216
1243
|
return np.array(val, dtype)
|
1217
1244
|
|
1218
1245
|
|
1219
|
-
_tr_params = namedtuple(
|
1220
|
-
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
|
1221
|
-
)
|
1222
|
-
|
1223
|
-
|
1224
|
-
def _get_tr_params(n, p):
|
1225
|
-
# See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
|
1226
|
-
# constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
|
1227
|
-
mu = n * p
|
1228
|
-
spq = jnp.sqrt(mu * (1 - p))
|
1229
|
-
c = mu + 0.5
|
1230
|
-
b = 1.15 + 2.53 * spq
|
1231
|
-
a = -0.0873 + 0.0248 * b + 0.01 * p
|
1232
|
-
alpha = (2.83 + 5.1 / b) * spq
|
1233
|
-
u_r = 0.43
|
1234
|
-
v_r = 0.92 - 4.2 / b
|
1235
|
-
m = jnp.floor((n + 1) * p).astype(n.dtype)
|
1236
|
-
log_p = jnp.log(p)
|
1237
|
-
log1_p = jnp.log1p(-p)
|
1238
|
-
log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
|
1239
|
-
_stirling_approx_tail(m) + _stirling_approx_tail(n - m))
|
1240
|
-
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
|
1241
|
-
|
1242
|
-
|
1243
|
-
def _stirling_approx_tail(k):
|
1244
|
-
precomputed = jnp.array([0.08106146679532726,
|
1245
|
-
0.04134069595540929,
|
1246
|
-
0.02767792568499834,
|
1247
|
-
0.02079067210376509,
|
1248
|
-
0.01664469118982119,
|
1249
|
-
0.01387612882307075,
|
1250
|
-
0.01189670994589177,
|
1251
|
-
0.01041126526197209,
|
1252
|
-
0.009255462182712733,
|
1253
|
-
0.008330563433362871],
|
1254
|
-
dtype=environ.dftype())
|
1255
|
-
kp1 = k + 1
|
1256
|
-
kp1sq = (k + 1) ** 2
|
1257
|
-
return jnp.where(k < 10,
|
1258
|
-
precomputed[k],
|
1259
|
-
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
|
1260
|
-
|
1261
|
-
|
1262
|
-
def _binomial_btrs(key, p, n):
|
1263
|
-
"""
|
1264
|
-
Based on the transformed rejection sampling algorithm (BTRS) from the
|
1265
|
-
following reference:
|
1266
|
-
|
1267
|
-
Hormann, "The Generation of Binonmial Random Variates"
|
1268
|
-
(https://core.ac.uk/download/pdf/11007254.pdf)
|
1269
|
-
"""
|
1270
|
-
|
1271
|
-
def _btrs_body_fn(val):
|
1272
|
-
_, key, _, _ = val
|
1273
|
-
key, key_u, key_v = jr.split(key, 3)
|
1274
|
-
u = jr.uniform(key_u)
|
1275
|
-
v = jr.uniform(key_v)
|
1276
|
-
u = u - 0.5
|
1277
|
-
k = jnp.floor(
|
1278
|
-
(2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
|
1279
|
-
).astype(n.dtype)
|
1280
|
-
return k, key, u, v
|
1281
|
-
|
1282
|
-
def _btrs_cond_fn(val):
|
1283
|
-
def accept_fn(k, u, v):
|
1284
|
-
# See acceptance condition in Step 3. (Page 3) of TRS algorithm
|
1285
|
-
# v <= f(k) * g_grad(u) / alpha
|
1286
|
-
|
1287
|
-
m = tr_params.m
|
1288
|
-
log_p = tr_params.log_p
|
1289
|
-
log1_p = tr_params.log1_p
|
1290
|
-
# See: formula for log(f(k)) at bottom of Page 5.
|
1291
|
-
log_f = (
|
1292
|
-
(n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
|
1293
|
-
+ (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
|
1294
|
-
+ (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
|
1295
|
-
+ tr_params.log_h
|
1296
|
-
)
|
1297
|
-
g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
|
1298
|
-
return jnp.log((v * tr_params.alpha) / g) <= log_f
|
1299
|
-
|
1300
|
-
k, key, u, v = val
|
1301
|
-
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
|
1302
|
-
early_reject = (k < 0) | (k > n)
|
1303
|
-
return lax.cond(
|
1304
|
-
early_accept | early_reject,
|
1305
|
-
(),
|
1306
|
-
lambda _: ~early_accept,
|
1307
|
-
(k, u, v),
|
1308
|
-
lambda x: ~accept_fn(*x),
|
1309
|
-
)
|
1310
|
-
|
1311
|
-
tr_params = _get_tr_params(n, p)
|
1312
|
-
ret = lax.while_loop(
|
1313
|
-
_btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
|
1314
|
-
) # use k=-1 initially so that cond_fn returns True
|
1315
|
-
return ret[0]
|
1316
|
-
|
1317
|
-
|
1318
|
-
def _binomial_inversion(key, p, n):
|
1319
|
-
def _binom_inv_body_fn(val):
|
1320
|
-
i, key, geom_acc = val
|
1321
|
-
key, key_u = jr.split(key)
|
1322
|
-
u = jr.uniform(key_u)
|
1323
|
-
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
|
1324
|
-
geom_acc = geom_acc + geom
|
1325
|
-
return i + 1, key, geom_acc
|
1326
|
-
|
1327
|
-
def _binom_inv_cond_fn(val):
|
1328
|
-
i, _, geom_acc = val
|
1329
|
-
return geom_acc <= n
|
1330
|
-
|
1331
|
-
log1_p = jnp.log1p(-p)
|
1332
|
-
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
|
1333
|
-
return ret[0]
|
1334
|
-
|
1335
|
-
|
1336
|
-
def _binomial_dispatch(key, p, n):
|
1337
|
-
def dispatch(key, p, n):
|
1338
|
-
is_le_mid = p <= 0.5
|
1339
|
-
pq = jnp.where(is_le_mid, p, 1 - p)
|
1340
|
-
mu = n * pq
|
1341
|
-
k = lax.cond(
|
1342
|
-
mu < 10,
|
1343
|
-
(key, pq, n),
|
1344
|
-
lambda x: _binomial_inversion(*x),
|
1345
|
-
(key, pq, n),
|
1346
|
-
lambda x: _binomial_btrs(*x),
|
1347
|
-
)
|
1348
|
-
return jnp.where(is_le_mid, k, n - k)
|
1349
|
-
|
1350
|
-
# Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
|
1351
|
-
cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
|
1352
|
-
return lax.cond(
|
1353
|
-
cond0 & (p < 1),
|
1354
|
-
(key, p, n),
|
1355
|
-
lambda x: dispatch(*x),
|
1356
|
-
(),
|
1357
|
-
lambda _: jnp.where(cond0, n, 0),
|
1358
|
-
)
|
1359
|
-
|
1360
|
-
|
1361
|
-
@partial(jit, static_argnums=(3,))
|
1362
|
-
def _binomial(key, p, n, shape):
|
1363
|
-
shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
|
1364
|
-
# reshape to map over axis 0
|
1365
|
-
p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
|
1366
|
-
n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
|
1367
|
-
key = jr.split(key, jnp.size(p))
|
1368
|
-
if jax.default_backend() == "cpu":
|
1369
|
-
ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
|
1370
|
-
else:
|
1371
|
-
ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
|
1372
|
-
return jnp.reshape(ret, shape)
|
1373
|
-
|
1374
|
-
|
1375
1246
|
@partial(jit, static_argnums=(2,))
|
1376
1247
|
def _categorical(key, p, shape):
|
1377
1248
|
# this implementation is fast when event shape is small, and slow otherwise
|
brainstate/surrogate.py
CHANGED
@@ -19,9 +19,13 @@ from __future__ import annotations
|
|
19
19
|
import jax
|
20
20
|
import jax.numpy as jnp
|
21
21
|
import jax.scipy as sci
|
22
|
-
from jax.core import Primitive
|
23
22
|
from jax.interpreters import batching, ad, mlir
|
24
23
|
|
24
|
+
if jax.__version_info__ < (0, 4, 38):
|
25
|
+
from jax.core import Primitive
|
26
|
+
else:
|
27
|
+
from jax.extend.core import Primitive
|
28
|
+
|
25
29
|
__all__ = [
|
26
30
|
'Surrogate',
|
27
31
|
'Sigmoid',
|
brainstate/util/__init__.py
CHANGED
@@ -27,8 +27,6 @@ from ._scaling import *
|
|
27
27
|
from ._scaling import __all__ as _mem_scale_all
|
28
28
|
from ._struct import *
|
29
29
|
from ._struct import __all__ as _struct_all
|
30
|
-
from ._visualization import *
|
31
|
-
from ._visualization import __all__ as _visualization_all
|
32
30
|
|
33
31
|
__all__ = (
|
34
32
|
_others_all
|
@@ -38,7 +36,6 @@ __all__ = (
|
|
38
36
|
+ _struct_all
|
39
37
|
+ _error_all
|
40
38
|
+ _mapping_all
|
41
|
-
+ _visualization_all
|
42
39
|
)
|
43
40
|
del (
|
44
41
|
_others_all,
|
@@ -48,5 +45,4 @@ del (
|
|
48
45
|
_struct_all,
|
49
46
|
_error_all,
|
50
47
|
_mapping_all,
|
51
|
-
_visualization_all,
|
52
48
|
)
|
brainstate/util/_caller.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_dict.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
|
-
#
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
2
5
|
#
|
3
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_filter.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_pretty_repr.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
brainstate/util/_struct.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
{brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250120
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.10
|
20
20
|
Classifier: Programming Language :: Python :: 3.11
|
21
21
|
Classifier: Programming Language :: Python :: 3.12
|
22
|
+
Classifier: Programming Language :: Python :: 3.13
|
22
23
|
Classifier: License :: OSI Approved :: Apache Software License
|
23
24
|
Classifier: Programming Language :: Python
|
24
25
|
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|