brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_error_if_test.py +1 -0
  9. brainstate/compile/_jit.py +37 -28
  10. brainstate/compile/_loop_collect_return.py +8 -5
  11. brainstate/compile/_loop_no_collection.py +2 -0
  12. brainstate/compile/_make_jaxpr.py +7 -3
  13. brainstate/compile/_make_jaxpr_test.py +2 -1
  14. brainstate/compile/_progress_bar.py +68 -40
  15. brainstate/compile/_unvmap.py +6 -2
  16. brainstate/environ.py +28 -18
  17. brainstate/environ_test.py +4 -0
  18. brainstate/event/__init__.py +0 -2
  19. brainstate/event/_csr.py +266 -23
  20. brainstate/event/_csr_test.py +187 -0
  21. brainstate/event/_fixedprob_mv.py +4 -2
  22. brainstate/event/_fixedprob_mv_test.py +2 -1
  23. brainstate/event/_xla_custom_op.py +16 -5
  24. brainstate/graph/__init__.py +8 -12
  25. brainstate/graph/_graph_node.py +1 -23
  26. brainstate/graph/_graph_operation.py +1 -1
  27. brainstate/graph/_graph_operation_test.py +0 -159
  28. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  29. brainstate/nn/_interaction/_conv.py +4 -2
  30. brainstate/nn/_interaction/_linear.py +84 -10
  31. brainstate/random/_rand_funs.py +9 -2
  32. brainstate/random/_rand_seed.py +12 -2
  33. brainstate/random/_rand_state.py +50 -179
  34. brainstate/surrogate.py +5 -1
  35. brainstate/util/__init__.py +0 -4
  36. brainstate/util/_caller.py +1 -1
  37. brainstate/util/_dict.py +4 -1
  38. brainstate/util/_filter.py +1 -1
  39. brainstate/util/_pretty_repr.py +1 -1
  40. brainstate/util/_struct.py +1 -1
  41. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  42. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
  43. brainstate/event/_csr_mv_test.py +0 -118
  44. brainstate/graph/_graph_context.py +0 -443
  45. brainstate/graph/_graph_context_test.py +0 -65
  46. brainstate/graph/_graph_convert.py +0 -246
  47. brainstate/util/_tracers.py +0 -68
  48. brainstate/util/_visualization.py +0 -47
  49. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  50. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  51. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  52. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@ __all__ = [
34
34
  'SparseLinear',
35
35
  'AllToAll',
36
36
  'OneToOne',
37
+ 'LoRA',
37
38
  ]
38
39
 
39
40
 
@@ -51,6 +52,7 @@ class Linear(Module):
51
52
  b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
52
53
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
53
54
  name: Optional[str] = None,
55
+ param_type: type = ParamState,
54
56
  ):
55
57
  super().__init__(name=name)
56
58
 
@@ -67,7 +69,7 @@ class Linear(Module):
67
69
  params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
68
70
  if b_init is not None:
69
71
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
70
- self.weight = ParamState(params)
72
+ self.weight = param_type(params)
71
73
 
72
74
  def update(self, x):
73
75
  params = self.weight.value
@@ -93,7 +95,7 @@ class SignedWLinear(Module):
93
95
  w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
94
96
  w_sign: Optional[ArrayLike] = None,
95
97
  name: Optional[str] = None,
96
-
98
+ param_type: type = ParamState,
97
99
  ):
98
100
  super().__init__(name=name)
99
101
 
@@ -108,7 +110,7 @@ class SignedWLinear(Module):
108
110
 
109
111
  # weights
110
112
  weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
111
- self.weight = ParamState(weight)
113
+ self.weight = param_type(weight)
112
114
 
113
115
  def update(self, x):
114
116
  w = self.weight.value
@@ -156,6 +158,7 @@ class ScaledWSLinear(Module):
156
158
  ws_gain: bool = True,
157
159
  eps: float = 1e-4,
158
160
  name: str = None,
161
+ param_type: type = ParamState,
159
162
  ):
160
163
  super().__init__(name=name)
161
164
 
@@ -179,7 +182,7 @@ class ScaledWSLinear(Module):
179
182
  if ws_gain:
180
183
  s = params['weight'].shape
181
184
  params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
182
- self.weight = ParamState(params)
185
+ self.weight = param_type(params)
183
186
 
184
187
  def update(self, x):
185
188
  params = self.weight.value
@@ -211,6 +214,7 @@ class SparseLinear(Module):
211
214
  b_init: Optional[Union[Callable, ArrayLike]] = None,
212
215
  in_size: Size = None,
213
216
  name: Optional[str] = None,
217
+ param_type: type = ParamState,
214
218
  ):
215
219
  super().__init__(name=name)
216
220
 
@@ -230,7 +234,7 @@ class SparseLinear(Module):
230
234
  params = dict(weight=spar_mat.data)
231
235
  if b_init is not None:
232
236
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
233
- self.weight = ParamState(params)
237
+ self.weight = param_type(params)
234
238
 
235
239
  def update(self, x):
236
240
  data = self.weight.value['weight']
@@ -260,6 +264,7 @@ class AllToAll(Module):
260
264
  b_init: Optional[Union[Callable, ArrayLike]] = None,
261
265
  include_self: bool = True,
262
266
  name: Optional[str] = None,
267
+ param_type: type = ParamState,
263
268
  ):
264
269
  super().__init__(name=name)
265
270
 
@@ -277,7 +282,7 @@ class AllToAll(Module):
277
282
  params = dict(weight=weight)
278
283
  if b_init is not None:
279
284
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
280
- self.weight = ParamState(params)
285
+ self.weight = param_type(params)
281
286
 
282
287
  def update(self, pre_val):
283
288
  params = self.weight.value
@@ -332,6 +337,7 @@ class OneToOne(Module):
332
337
  w_init: Union[Callable, ArrayLike] = init.Normal(),
333
338
  b_init: Optional[Union[Callable, ArrayLike]] = None,
334
339
  name: Optional[str] = None,
340
+ param_type: type = ParamState,
335
341
  ):
336
342
  super().__init__(name=name)
337
343
 
@@ -343,13 +349,81 @@ class OneToOne(Module):
343
349
  param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
344
350
  if b_init is not None:
345
351
  param['bias'] = init.param(b_init, self.out_size, allow_none=False)
346
- self.weight = param
352
+ self.weight = param_type(param)
347
353
 
348
354
  def update(self, pre_val):
349
355
  pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
350
- w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
356
+ w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
351
357
  post_val = pre_val * w_val
352
358
  post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
353
- if 'bias' in self.weight:
354
- post_val = post_val + self.weight['bias']
359
+ if 'bias' in self.weight.value:
360
+ post_val = post_val + self.weight.value['bias']
355
361
  return post_val
362
+
363
+
364
+ class LoRA(Module):
365
+ """A standalone LoRA layer.
366
+
367
+ Example usage::
368
+
369
+ >>> import brainstate as bst
370
+ >>> import jax, jax.numpy as jnp
371
+ >>> layer = bst.nn.LoRA(3, 2, 4)
372
+ >>> layer.weight.value
373
+ {'lora_a': Array([[ 0.25141352, -0.09826107],
374
+ [ 0.2328382 , 0.38869813],
375
+ [ 0.27069277, 0.7678282 ]], dtype=float32),
376
+ 'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
377
+ [ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
378
+ >>> # Wrap around existing layer
379
+ >>> linear = bst.nn.Linear(3, 4)
380
+ >>> wrapper = bst.nn.LoRA(3, 2, 4, base_module=linear)
381
+ >>> assert wrapper.base_module == linear
382
+ >>> y = layer(jnp.ones((16, 3)))
383
+ >>> y.shape
384
+ (16, 4)
385
+
386
+ Args:
387
+ in_features: the number of input features.
388
+ lora_rank: the rank of the LoRA dimension.
389
+ out_features: the number of output features.
390
+ base_module: a base module to call and substitute, if possible.
391
+ kernel_init: initializer function for the weight matrices.
392
+ param_type: the type of the LoRA params.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ in_features: int,
398
+ lora_rank: int,
399
+ out_features: int,
400
+ *,
401
+ base_module: Optional[Module] = None,
402
+ kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
403
+ param_type: type = ParamState,
404
+ ):
405
+ super().__init__()
406
+
407
+ # input and output shape
408
+ self.in_size = in_features
409
+ self.out_size = out_features
410
+ self.in_features = in_features
411
+ self.out_features = out_features
412
+
413
+ # others
414
+ self.base_module = base_module
415
+
416
+ # weights
417
+ param = dict(
418
+ lora_a=kernel_init((in_features, lora_rank)),
419
+ lora_b=kernel_init((lora_rank, out_features))
420
+ )
421
+ self.weight = param_type(param)
422
+
423
+ def __call__(self, x: ArrayLike):
424
+ out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
425
+ if self.base_module is not None:
426
+ if not callable(self.base_module):
427
+ raise ValueError('`self.base_module` must be callable.')
428
+ out += self.base_module(x)
429
+ return out
@@ -1848,7 +1848,14 @@ def lognormal(mean=None, sigma=None, size: Optional[Size] = None,
1848
1848
  return DEFAULT.lognormal(mean, sigma, size, key=key, dtype=dtype)
1849
1849
 
1850
1850
 
1851
- def binomial(n, p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1851
+ def binomial(
1852
+ n,
1853
+ p,
1854
+ size: Optional[Size] = None,
1855
+ key: Optional[SeedOrKey] = None,
1856
+ dtype: DTypeLike = None,
1857
+ check_valid: bool = True,
1858
+ ):
1852
1859
  r"""
1853
1860
  Draw samples from a binomial distribution.
1854
1861
 
@@ -1933,7 +1940,7 @@ def binomial(n, p, size: Optional[Size] = None, key: Optional[SeedOrKey] = None,
1933
1940
  >>> sum(brainstate.random.binomial(9, 0.1, 20000) == 0)/20000.
1934
1941
  # answer = 0.38885, or 38%.
1935
1942
  """
1936
- return DEFAULT.binomial(n, p, size, key=key, dtype=dtype)
1943
+ return DEFAULT.binomial(n, p, size, key=key, dtype=dtype, check_valid=check_valid)
1937
1944
 
1938
1945
 
1939
1946
  def chisquare(df, size: Optional[Size] = None, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
@@ -21,7 +21,7 @@ import jax
21
21
  import numpy as np
22
22
 
23
23
  from brainstate.typing import SeedOrKey
24
- from ._rand_state import RandomState, DEFAULT
24
+ from ._rand_state import RandomState, DEFAULT, use_prng_key
25
25
 
26
26
  __all__ = [
27
27
  'seed', 'set_key', 'get_key', 'default_rng', 'split_key', 'split_keys', 'seed_context', 'restore_key',
@@ -123,7 +123,17 @@ def set_key(seed_or_key: SeedOrKey):
123
123
  seed_or_key: int
124
124
  The random key.
125
125
  """
126
- DEFAULT.set_key(jax.random.PRNGKey(seed_or_key) if jax.numpy.shape(seed_or_key) == () else seed_or_key)
126
+ if isinstance(seed_or_key, int):
127
+ # key = jax.random.key(seed_or_key)
128
+ key = jax.random.PRNGKey(seed_or_key) if use_prng_key else jrjax.random.key(seed_or_key)
129
+ elif isinstance(seed_or_key, (jax.numpy.ndarray, np.ndarray)):
130
+ if jax.numpy.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
131
+ key = seed_or_key
132
+ elif seed_or_key.size == 2 and seed_or_key.dtype == jax.numpy.uint32:
133
+ key = seed_or_key
134
+ else:
135
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
136
+ DEFAULT.set_key(key)
127
137
 
128
138
 
129
139
  def get_key():
@@ -16,7 +16,6 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- from collections import namedtuple
20
19
  from functools import partial
21
20
  from operator import index
22
21
  from typing import Optional
@@ -37,6 +36,8 @@ from ._random_for_unit import uniform_for_unit, permutation_for_unit
37
36
 
38
37
  __all__ = ['RandomState', 'DEFAULT', ]
39
38
 
39
+ use_prng_key = True
40
+
40
41
 
41
42
  class RandomState(State):
42
43
  """RandomState that track the random generator state. """
@@ -56,12 +57,15 @@ class RandomState(State):
56
57
  if seed_or_key is None:
57
58
  seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
58
59
  if isinstance(seed_or_key, int):
59
- key = jr.PRNGKey(seed_or_key)
60
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
60
61
  else:
61
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
62
- raise ValueError('key must be an array with dtype uint32. '
63
- f'But we got {seed_or_key}')
64
- key = seed_or_key
62
+ if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
63
+ key = seed_or_key
64
+ else:
65
+ if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
66
+ raise ValueError('key must be an array with dtype uint32. '
67
+ f'But we got {seed_or_key}')
68
+ key = seed_or_key
65
69
  super().__init__(key)
66
70
 
67
71
  self._backup = None
@@ -70,6 +74,9 @@ class RandomState(State):
70
74
  return f'{self.__class__.__name__}({self.value})'
71
75
 
72
76
  def check_if_deleted(self):
77
+ if not use_prng_key and isinstance(self._value, np.ndarray):
78
+ self._value = jr.key(np.random.randint(0, 10000))
79
+
73
80
  if (
74
81
  isinstance(self._value, jax.Array) and
75
82
  not isinstance(self._value, jax.core.Tracer) and
@@ -111,12 +118,19 @@ class RandomState(State):
111
118
  if seed_or_key is None:
112
119
  seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
113
120
  if np.size(seed_or_key) == 1:
114
- key = jr.PRNGKey(seed_or_key)
121
+ if isinstance(seed_or_key, int):
122
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
123
+ elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
124
+ key = seed_or_key
125
+ elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
126
+ key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
127
+ else:
128
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
115
129
  else:
116
- if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
117
- raise ValueError('key must be an array with dtype uint32. '
118
- f'But we got {seed_or_key}')
119
- key = seed_or_key
130
+ if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
131
+ key = seed_or_key
132
+ else:
133
+ raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
120
134
  self.value = key
121
135
 
122
136
  def split_key(self, n: Optional[int] = None, backup: bool = False) -> SeedOrKey:
@@ -560,15 +574,15 @@ class RandomState(State):
560
574
  )
561
575
  return out if unit.is_unitless else u.Quantity(out, unit=unit)
562
576
 
563
- def _check_p(self, p):
564
- raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
577
+ def _check_p(self, *args, **kwargs):
578
+ raise ValueError('Parameter p should be within [0, 1], but we got {p}')
565
579
 
566
580
  def bernoulli(self,
567
581
  p,
568
582
  size: Optional[Size] = None,
569
583
  key: Optional[SeedOrKey] = None):
570
584
  p = _check_py_seq(p)
571
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
585
+ jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
572
586
  if size is None:
573
587
  size = jnp.shape(p)
574
588
  key = self.split_key() if key is None else _formalize_key(key)
@@ -603,19 +617,27 @@ class RandomState(State):
603
617
  samples = jnp.exp(samples)
604
618
  return samples if unit.is_unitless else u.Quantity(samples, unit=unit)
605
619
 
606
- def binomial(self,
607
- n,
608
- p,
609
- size: Optional[Size] = None,
610
- key: Optional[SeedOrKey] = None,
611
- dtype: DTypeLike = None):
620
+ def binomial(
621
+ self,
622
+ n,
623
+ p,
624
+ size: Optional[Size] = None,
625
+ key: Optional[SeedOrKey] = None,
626
+ dtype: DTypeLike = None,
627
+ check_valid: bool = True,
628
+ ):
612
629
  n = _check_py_seq(n)
613
630
  p = _check_py_seq(p)
614
- jit_error_if(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p)
631
+ if check_valid:
632
+ jit_error_if(
633
+ jnp.any(jnp.logical_or(p < 0, p > 1)),
634
+ 'Parameter p should be within [0, 1], but we got {p}',
635
+ p=p
636
+ )
615
637
  if size is None:
616
638
  size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p))
617
639
  key = self.split_key() if key is None else _formalize_key(key)
618
- r = _binomial(key, p, n, shape=_size2shape(size))
640
+ r = jr.binomial(key, n, p, shape=_size2shape(size))
619
641
  dtype = dtype or environ.ditype()
620
642
  return jnp.asarray(r, dtype=dtype)
621
643
 
@@ -1142,8 +1164,13 @@ DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
1142
1164
 
1143
1165
  def _formalize_key(key):
1144
1166
  if isinstance(key, int):
1145
- return jr.PRNGKey(key)
1167
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1146
1168
  elif isinstance(key, (jax.Array, np.ndarray)):
1169
+ if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
1170
+ return key
1171
+ if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
1172
+ return jr.PRNGKey(key) if use_prng_key else jr.key(key)
1173
+
1147
1174
  if key.dtype != jnp.uint32:
1148
1175
  raise TypeError('key must be a int or an array with two uint32.')
1149
1176
  if key.size != 2:
@@ -1216,162 +1243,6 @@ def _const(example, val):
1216
1243
  return np.array(val, dtype)
1217
1244
 
1218
1245
 
1219
- _tr_params = namedtuple(
1220
- "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
1221
- )
1222
-
1223
-
1224
- def _get_tr_params(n, p):
1225
- # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
1226
- # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
1227
- mu = n * p
1228
- spq = jnp.sqrt(mu * (1 - p))
1229
- c = mu + 0.5
1230
- b = 1.15 + 2.53 * spq
1231
- a = -0.0873 + 0.0248 * b + 0.01 * p
1232
- alpha = (2.83 + 5.1 / b) * spq
1233
- u_r = 0.43
1234
- v_r = 0.92 - 4.2 / b
1235
- m = jnp.floor((n + 1) * p).astype(n.dtype)
1236
- log_p = jnp.log(p)
1237
- log1_p = jnp.log1p(-p)
1238
- log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
1239
- _stirling_approx_tail(m) + _stirling_approx_tail(n - m))
1240
- return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
1241
-
1242
-
1243
- def _stirling_approx_tail(k):
1244
- precomputed = jnp.array([0.08106146679532726,
1245
- 0.04134069595540929,
1246
- 0.02767792568499834,
1247
- 0.02079067210376509,
1248
- 0.01664469118982119,
1249
- 0.01387612882307075,
1250
- 0.01189670994589177,
1251
- 0.01041126526197209,
1252
- 0.009255462182712733,
1253
- 0.008330563433362871],
1254
- dtype=environ.dftype())
1255
- kp1 = k + 1
1256
- kp1sq = (k + 1) ** 2
1257
- return jnp.where(k < 10,
1258
- precomputed[k],
1259
- (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)
1260
-
1261
-
1262
- def _binomial_btrs(key, p, n):
1263
- """
1264
- Based on the transformed rejection sampling algorithm (BTRS) from the
1265
- following reference:
1266
-
1267
- Hormann, "The Generation of Binonmial Random Variates"
1268
- (https://core.ac.uk/download/pdf/11007254.pdf)
1269
- """
1270
-
1271
- def _btrs_body_fn(val):
1272
- _, key, _, _ = val
1273
- key, key_u, key_v = jr.split(key, 3)
1274
- u = jr.uniform(key_u)
1275
- v = jr.uniform(key_v)
1276
- u = u - 0.5
1277
- k = jnp.floor(
1278
- (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c
1279
- ).astype(n.dtype)
1280
- return k, key, u, v
1281
-
1282
- def _btrs_cond_fn(val):
1283
- def accept_fn(k, u, v):
1284
- # See acceptance condition in Step 3. (Page 3) of TRS algorithm
1285
- # v <= f(k) * g_grad(u) / alpha
1286
-
1287
- m = tr_params.m
1288
- log_p = tr_params.log_p
1289
- log1_p = tr_params.log1_p
1290
- # See: formula for log(f(k)) at bottom of Page 5.
1291
- log_f = (
1292
- (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0))
1293
- + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p)
1294
- + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k))
1295
- + tr_params.log_h
1296
- )
1297
- g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b
1298
- return jnp.log((v * tr_params.alpha) / g) <= log_f
1299
-
1300
- k, key, u, v = val
1301
- early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
1302
- early_reject = (k < 0) | (k > n)
1303
- return lax.cond(
1304
- early_accept | early_reject,
1305
- (),
1306
- lambda _: ~early_accept,
1307
- (k, u, v),
1308
- lambda x: ~accept_fn(*x),
1309
- )
1310
-
1311
- tr_params = _get_tr_params(n, p)
1312
- ret = lax.while_loop(
1313
- _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0)
1314
- ) # use k=-1 initially so that cond_fn returns True
1315
- return ret[0]
1316
-
1317
-
1318
- def _binomial_inversion(key, p, n):
1319
- def _binom_inv_body_fn(val):
1320
- i, key, geom_acc = val
1321
- key, key_u = jr.split(key)
1322
- u = jr.uniform(key_u)
1323
- geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
1324
- geom_acc = geom_acc + geom
1325
- return i + 1, key, geom_acc
1326
-
1327
- def _binom_inv_cond_fn(val):
1328
- i, _, geom_acc = val
1329
- return geom_acc <= n
1330
-
1331
- log1_p = jnp.log1p(-p)
1332
- ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
1333
- return ret[0]
1334
-
1335
-
1336
- def _binomial_dispatch(key, p, n):
1337
- def dispatch(key, p, n):
1338
- is_le_mid = p <= 0.5
1339
- pq = jnp.where(is_le_mid, p, 1 - p)
1340
- mu = n * pq
1341
- k = lax.cond(
1342
- mu < 10,
1343
- (key, pq, n),
1344
- lambda x: _binomial_inversion(*x),
1345
- (key, pq, n),
1346
- lambda x: _binomial_btrs(*x),
1347
- )
1348
- return jnp.where(is_le_mid, k, n - k)
1349
-
1350
- # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types
1351
- cond0 = jnp.isfinite(p) & (n > 0) & (p > 0)
1352
- return lax.cond(
1353
- cond0 & (p < 1),
1354
- (key, p, n),
1355
- lambda x: dispatch(*x),
1356
- (),
1357
- lambda _: jnp.where(cond0, n, 0),
1358
- )
1359
-
1360
-
1361
- @partial(jit, static_argnums=(3,))
1362
- def _binomial(key, p, n, shape):
1363
- shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n))
1364
- # reshape to map over axis 0
1365
- p = jnp.reshape(jnp.broadcast_to(p, shape), -1)
1366
- n = jnp.reshape(jnp.broadcast_to(n, shape), -1)
1367
- key = jr.split(key, jnp.size(p))
1368
- if jax.default_backend() == "cpu":
1369
- ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n))
1370
- else:
1371
- ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n)
1372
- return jnp.reshape(ret, shape)
1373
-
1374
-
1375
1246
  @partial(jit, static_argnums=(2,))
1376
1247
  def _categorical(key, p, shape):
1377
1248
  # this implementation is fast when event shape is small, and slow otherwise
brainstate/surrogate.py CHANGED
@@ -19,9 +19,13 @@ from __future__ import annotations
19
19
  import jax
20
20
  import jax.numpy as jnp
21
21
  import jax.scipy as sci
22
- from jax.core import Primitive
23
22
  from jax.interpreters import batching, ad, mlir
24
23
 
24
+ if jax.__version_info__ < (0, 4, 38):
25
+ from jax.core import Primitive
26
+ else:
27
+ from jax.extend.core import Primitive
28
+
25
29
  __all__ = [
26
30
  'Surrogate',
27
31
  'Sigmoid',
@@ -27,8 +27,6 @@ from ._scaling import *
27
27
  from ._scaling import __all__ as _mem_scale_all
28
28
  from ._struct import *
29
29
  from ._struct import __all__ as _struct_all
30
- from ._visualization import *
31
- from ._visualization import __all__ as _visualization_all
32
30
 
33
31
  __all__ = (
34
32
  _others_all
@@ -38,7 +36,6 @@ __all__ = (
38
36
  + _struct_all
39
37
  + _error_all
40
38
  + _mapping_all
41
- + _visualization_all
42
39
  )
43
40
  del (
44
41
  _others_all,
@@ -48,5 +45,4 @@ del (
48
45
  _struct_all,
49
46
  _error_all,
50
47
  _mapping_all,
51
- _visualization_all,
52
48
  )
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
brainstate/util/_dict.py CHANGED
@@ -1,4 +1,7 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
2
5
  #
3
6
  # Licensed under the Apache License, Version 2.0 (the "License");
4
7
  # you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250104
3
+ Version: 0.1.0.post20250120
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -19,6 +19,7 @@ Classifier: Programming Language :: Python :: 3.9
19
19
  Classifier: Programming Language :: Python :: 3.10
20
20
  Classifier: Programming Language :: Python :: 3.11
21
21
  Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Programming Language :: Python :: 3.13
22
23
  Classifier: License :: OSI Approved :: Apache Software License
23
24
  Classifier: Programming Language :: Python
24
25
  Classifier: Topic :: Scientific/Engineering :: Bio-Informatics