brainstate 0.1.0.post20250218__py2.py3-none-any.whl → 0.1.0.post20250315__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,11 +29,12 @@ The wrapped gradient transformations here are made possible by using the followi
29
29
 
30
30
  from __future__ import annotations
31
31
 
32
- import brainunit as u
33
- import jax
34
32
  from functools import wraps, partial
35
33
  from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
36
34
 
35
+ import brainunit as u
36
+ import jax
37
+
37
38
  from brainstate._state import State
38
39
  from brainstate._utils import set_module_as
39
40
  from brainstate.compile._make_jaxpr import StatefulFunction
@@ -195,6 +196,7 @@ class GradientTransform(PrettyRepr):
195
196
  grad_states = {k: v for k, v in grad_states.items()}
196
197
  self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
197
198
  self._grad_state_ids = [id(v) for v in self._grad_states]
199
+ self._grad_id_to_state = {id(v): v for v in self._grad_states}
198
200
  if any(not isinstance(v, State) for v in self._grad_states):
199
201
  raise TypeError("All grad_states must be State instances.")
200
202
 
@@ -250,12 +252,20 @@ class GradientTransform(PrettyRepr):
250
252
  """
251
253
  grad_vals = dict()
252
254
  other_vals = dict()
255
+ all_ids = set(self._grad_state_ids)
253
256
  for st in state_trace.states:
254
257
  id_ = id(st)
255
- if id_ in self._grad_state_ids:
258
+ if id_ in all_ids:
256
259
  grad_vals[id_] = st.value
260
+ all_ids.remove(id_)
257
261
  else:
258
262
  other_vals[id_] = st.value
263
+ if len(all_ids):
264
+ err = f"Some states are not found in the state trace when performing gradient transformations.\n "
265
+ for i, id_ in enumerate(all_ids):
266
+ st = self._grad_id_to_state[id_]
267
+ st.raise_error_with_source_info(ValueError(err + str(st)))
268
+
259
269
  return grad_vals, other_vals
260
270
 
261
271
  def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
@@ -16,8 +16,6 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import functools
19
- import jax
20
- from jax.interpreters.batching import BatchTracer
21
19
  from typing import (
22
20
  Any,
23
21
  TypeVar,
@@ -32,6 +30,9 @@ from typing import (
32
30
  List
33
31
  )
34
32
 
33
+ import jax
34
+ from jax.interpreters.batching import BatchTracer
35
+
35
36
  from brainstate._state import State, catch_new_states
36
37
  from brainstate.compile import scan, StatefulFunction
37
38
  from brainstate.random import RandomState, DEFAULT
@@ -378,8 +379,15 @@ def _vmap_transform(
378
379
  # call the function
379
380
  return f(*args)
380
381
 
382
+ def _set_axis_env(batch_size):
383
+ axis_env = None if axis_name is None else [(axis_name, batch_size)]
384
+ stateful_fn.axis_env = axis_env
385
+
381
386
  # stateful function
382
- stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
387
+ stateful_fn = StatefulFunction(
388
+ _vmap_fn_for_compilation,
389
+ name='vmap',
390
+ )
383
391
 
384
392
  @functools.wraps(f)
385
393
  def new_fn_for_vmap(
@@ -506,6 +514,10 @@ def _vmap_transform(
506
514
  st_in_axes = 0
507
515
 
508
516
  # compile stateful function
517
+ batch_size = None
518
+ if axis_name is not None:
519
+ batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
520
+ _set_axis_env(batch_size)
509
521
  cache_key = _compile_stateful_function(
510
522
  stateful_fn,
511
523
  (st_in_axes, in_axes),
@@ -518,7 +530,8 @@ def _vmap_transform(
518
530
  rng_sets = set(rngs)
519
531
  if len(rngs):
520
532
  # batch size
521
- batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
533
+ if batch_size is None:
534
+ batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
522
535
  rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
523
536
  rng_backup = tuple(rng.split_key() for rng in rngs)
524
537
  else:
@@ -905,11 +918,11 @@ def map(
905
918
  g = lambda _, x: ((), vmap(f)(*x))
906
919
  _, scan_ys = scan(g, (), scan_xs)
907
920
  if remainder_xs is None:
908
- ys = jax.tree.map(lambda x: flatten_(x), scan_ys)
921
+ ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
909
922
  else:
910
923
  remainder_ys = vmap(f)(*remainder_xs)
911
924
  ys = jax.tree.map(
912
- lambda x, y: jax.lax.concatenate([flatten_(x), y], dimension=0),
925
+ lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
913
926
  scan_ys,
914
927
  remainder_ys,
915
928
  )
@@ -919,7 +932,7 @@ def map(
919
932
  return ys
920
933
 
921
934
 
922
- def flatten_(x):
935
+ def _flatten(x):
923
936
  return x.reshape(-1, *x.shape[2:])
924
937
 
925
938
 
@@ -935,17 +948,19 @@ def _vmap_new_states_transform(
935
948
  # -- brainstate specific arguments -- #
936
949
  state_tag: str | None = None,
937
950
  state_to_exclude: Filter | None = None,
951
+ in_states: Dict[int, Dict] | Any | None = None,
952
+ out_states: Dict[int, Dict] | Any | None = None,
938
953
  ):
939
-
940
954
  # TODO: How about nested call ``vmap_new_states``?
941
955
 
942
-
943
956
  @vmap(
944
957
  in_axes=in_axes,
945
958
  out_axes=out_axes,
946
959
  axis_name=axis_name,
947
960
  axis_size=axis_size,
948
961
  spmd_axis_name=spmd_axis_name,
962
+ in_states=in_states,
963
+ out_states=out_states,
949
964
  )
950
965
  def new_fun(args):
951
966
  # call the function
@@ -988,6 +1003,8 @@ def vmap_new_states(
988
1003
  # -- brainstate specific arguments -- #
989
1004
  state_tag: str | None = None,
990
1005
  state_to_exclude: Filter = None,
1006
+ in_states: Dict[int, Dict] | Any | None = None,
1007
+ out_states: Dict[int, Dict] | Any | None = None,
991
1008
  ):
992
1009
  """
993
1010
  Vectorize a function over new states created within it.
@@ -1019,6 +1036,8 @@ def vmap_new_states(
1019
1036
  spmd_axis_name=spmd_axis_name,
1020
1037
  state_tag=state_tag,
1021
1038
  state_to_exclude=state_to_exclude,
1039
+ in_states=in_states,
1040
+ out_states=out_states,
1022
1041
  )
1023
1042
  else:
1024
1043
  return _vmap_new_states_transform(
@@ -1030,4 +1049,6 @@ def vmap_new_states(
1030
1049
  spmd_axis_name=spmd_axis_name,
1031
1050
  state_tag=state_tag,
1032
1051
  state_to_exclude=state_to_exclude,
1052
+ in_states=in_states,
1053
+ out_states=out_states,
1033
1054
  )
@@ -15,12 +15,14 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import jax
19
21
  import jax.numpy as jnp
20
22
  import numpy as np
21
- import unittest
22
23
 
23
24
  import brainstate as bst
25
+ import brainstate.augment
24
26
  from brainstate.augment._mapping import BatchAxisError
25
27
  from brainstate.augment._mapping import _remove_axis
26
28
 
@@ -264,6 +266,55 @@ class TestVmap(unittest.TestCase):
264
266
  res2 = jax.vmap(f, axis_size=10)()
265
267
  self.assertTrue(jnp.all((res2[0] == res2[1:])))
266
268
 
269
+ def test_axis(self):
270
+ def f(x):
271
+ return x - jax.lax.pmean(x, 'i')
272
+
273
+ r = jax.vmap(f, axis_name='i')(jnp.arange(10))
274
+ print(r)
275
+
276
+ r2 = brainstate.augment.vmap(f, axis_name='i')(jnp.arange(10))
277
+ print(r2)
278
+ self.assertTrue(jnp.allclose(r, r2))
279
+
280
+ def test_vmap_init(self):
281
+ class Foo(bst.nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+ self.a = bst.ParamState(jnp.arange(4))
285
+ self.b = bst.ShortTermState(jnp.arange(4))
286
+
287
+ def init_state_v1(self, *args, **kwargs):
288
+ self.c = bst.State(jnp.arange(4))
289
+
290
+ def init_state_v2(self):
291
+ self.d = bst.State(self.c.value * 2.)
292
+
293
+ foo = Foo()
294
+
295
+ @brainstate.augment.vmap_new_states(state_tag='new1', axis_size=5)
296
+ def init1():
297
+ foo.init_state_v1()
298
+
299
+ init1()
300
+ print(foo.c.value)
301
+
302
+ @brainstate.augment.vmap_new_states(state_tag='new2', axis_size=5, in_states=foo.states('new1'))
303
+ def init2():
304
+ foo.init_state_v2()
305
+
306
+ init2()
307
+ print(foo.c.value)
308
+ print(foo.d.value)
309
+
310
+ self.assertTrue(
311
+ jnp.allclose(
312
+ foo.d.value,
313
+ foo.c.value * 2.
314
+ )
315
+ )
316
+
317
+
267
318
 
268
319
  class TestMap(unittest.TestCase):
269
320
  def test_map(self):
@@ -753,8 +753,8 @@ def _make_jaxpr(
753
753
  if jax.__version_info__ < (0, 5, 0):
754
754
  debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
755
755
  with ExitStack() as stack:
756
- for axis_name, size in axis_env or []:
757
- stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
756
+ if axis_env is not None:
757
+ stack.enter_context(jax.core.extend_axis_env_nd(axis_env))
758
758
  if jax.__version_info__ < (0, 5, 0):
759
759
  jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
760
760
  else:
@@ -159,7 +159,7 @@ def spike_bitwise_ixor(x, y):
159
159
 
160
160
 
161
161
  def spike_bitwise(x, y, op: str):
162
- """
162
+ r"""
163
163
  Perform bitwise operations on spike tensors.
164
164
 
165
165
  This function applies various bitwise operations on spike tensors based on the specified operation.
@@ -18,9 +18,7 @@ from __future__ import annotations
18
18
  from collections import namedtuple
19
19
 
20
20
  import jax
21
- from typing import (
22
- Callable, TypeVar, Tuple, Any, Dict
23
- )
21
+ from typing import Callable, TypeVar, Tuple, Any, Dict
24
22
 
25
23
  from brainstate._state import catch_new_states
26
24
  from brainstate._utils import set_module_as
@@ -103,7 +101,7 @@ def call_all_functions(
103
101
  on each node. It respects the call order of functions if defined, and provides options for
104
102
  handling cases where the specified function does not exist on a node.
105
103
 
106
- Parameters:
104
+ Parameters
107
105
  -----------
108
106
  target : T
109
107
  The target module on which to call functions.
@@ -121,12 +119,12 @@ def call_all_functions(
121
119
  - 'raise': Raise an exception (default)
122
120
  - 'pass' or 'none': Skip the node and continue
123
121
 
124
- Returns:
122
+ Returns
125
123
  --------
126
124
  T
127
125
  The target module after calling the specified function on all applicable nodes.
128
126
 
129
- Raises:
127
+ Raises
130
128
  -------
131
129
  AssertionError
132
130
  If fun_name is not a string or kwargs is not a dictionary.
@@ -186,7 +184,7 @@ def vmap_call_all_functions(
186
184
  This function vectorizes the process of calling a specified function across multiple instances
187
185
  of the target module, effectively batching the operation.
188
186
 
189
- Parameters:
187
+ Parameters
190
188
  -----------
191
189
  target : T
192
190
  The target module on which to call functions.
@@ -208,12 +206,12 @@ def vmap_call_all_functions(
208
206
  - 'raise': Raise an exception (default)
209
207
  - 'pass' or 'none': Skip the node and continue
210
208
 
211
- Returns:
209
+ Returns
212
210
  --------
213
211
  T
214
212
  The target module after applying the vectorized function call on all applicable nodes.
215
213
 
216
- Raises:
214
+ Raises
217
215
  -------
218
216
  AssertionError
219
217
  If axis_size is not specified or is not a positive integer.
@@ -304,7 +302,7 @@ def vmap_init_all_states(
304
302
  This function applies vectorized mapping (vmap) to initialize states across multiple
305
303
  instances of the target module, effectively batching the initialization process.
306
304
 
307
- Parameters:
305
+ Parameters
308
306
  -----------
309
307
  target : T
310
308
  The target module whose states are to be initialized.
@@ -319,12 +317,12 @@ def vmap_init_all_states(
319
317
  state_tag : str | None, optional
320
318
  A tag to be used for catching new states.
321
319
 
322
- Returns:
320
+ Returns
323
321
  --------
324
322
  T
325
323
  The target module with initialized states.
326
324
 
327
- Raises:
325
+ Raises
328
326
  -------
329
327
  AssertionError
330
328
  If axis_size is not specified or is not greater than 0.
@@ -413,7 +411,7 @@ def vmap_reset_all_states(
413
411
  This function applies vectorized mapping (vmap) to reset states across multiple
414
412
  instances of the target module, effectively batching the reset process.
415
413
 
416
- Parameters:
414
+ Parameters
417
415
  -----------
418
416
  target : T
419
417
  The target module whose states are to be reset.
@@ -428,12 +426,12 @@ def vmap_reset_all_states(
428
426
  tag : str | None, optional
429
427
  A tag to be used for catching new states.
430
428
 
431
- Returns:
429
+ Returns
432
430
  --------
433
431
  T
434
432
  The target module with reset states.
435
433
 
436
- Raises:
434
+ Raises
437
435
  -------
438
436
  AssertionError
439
437
  If axis_size is not specified or is not greater than 0.
@@ -486,7 +484,7 @@ def save_all_states(target: Module, **kwargs) -> Dict:
486
484
  Args:
487
485
  target: Module. The node to save its states.
488
486
 
489
- Returns:
487
+ Returns
490
488
  Dict. The state dict for serialization.
491
489
  """
492
490
  return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
@@ -17,8 +17,9 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from typing import Optional, Callable
21
+
20
22
  import brainunit as u
21
- from typing import Optional
22
23
 
23
24
  from brainstate import init, environ
24
25
  from brainstate._state import ShortTermState, HiddenState
@@ -54,7 +55,7 @@ class Expon(Synapse, AlignPost):
54
55
  in_size: Size,
55
56
  name: Optional[str] = None,
56
57
  tau: ArrayLike = 8.0 * u.ms,
57
- g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
58
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
58
59
  ):
59
60
  super().__init__(name=name, in_size=in_size)
60
61
 
@@ -85,7 +86,7 @@ class DualExpon(Synapse, AlignPost):
85
86
  tau_decay: ArrayLike = 10.0 * u.ms,
86
87
  tau_rise: ArrayLike = 1.0 * u.ms,
87
88
  A: Optional[ArrayLike] = None,
88
- g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
89
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
89
90
  ):
90
91
  super().__init__(name=name, in_size=in_size)
91
92
 
@@ -133,7 +134,7 @@ class Alpha(Synapse):
133
134
  in_size: Size,
134
135
  name: Optional[str] = None,
135
136
  tau: ArrayLike = 8.0 * u.ms,
136
- g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
137
+ g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
137
138
  ):
138
139
  super().__init__(name=name, in_size=in_size)
139
140
 
@@ -321,7 +322,7 @@ class AMPA(Synapse):
321
322
  beta: ArrayLike = 0.18 / u.ms,
322
323
  T: ArrayLike = 0.5 * u.mM,
323
324
  T_dur: ArrayLike = 0.5 * u.ms,
324
- g_initializer: ArrayLike = init.ZeroInit(),
325
+ g_initializer: ArrayLike | Callable = init.ZeroInit(),
325
326
  ):
326
327
  super().__init__(name=name, in_size=in_size)
327
328
 
@@ -394,5 +395,14 @@ class GABAa(AMPA):
394
395
  beta: ArrayLike = 0.18 / u.ms,
395
396
  T: ArrayLike = 1.0 * u.mM,
396
397
  T_dur: ArrayLike = 1.0 * u.ms,
398
+ g_initializer: ArrayLike | Callable = init.ZeroInit(),
397
399
  ):
398
- super().__init__(alpha=alpha, beta=beta, T=T, T_dur=T_dur, name=name, in_size=in_size)
400
+ super().__init__(
401
+ alpha=alpha,
402
+ beta=beta,
403
+ T=T,
404
+ T_dur=T_dur,
405
+ name=name,
406
+ in_size=in_size,
407
+ g_initializer=g_initializer
408
+ )
@@ -49,13 +49,13 @@ def exp_euler_step(
49
49
  should have units of ( [X]/\sqrt{[T]} ).
50
50
 
51
51
  Args:
52
- fun: Callable. The function to be solved.
53
- diffusion: Callable. The diffusion function.
54
- *args: The input arguments.
55
- drift: Callable. The drift function.
52
+ fun: Callable. The function to be solved.
53
+ diffusion: Callable. The diffusion function.
54
+ *args: The input arguments.
55
+ drift: Callable. The drift function.
56
56
 
57
57
  Returns:
58
- The one-step solution of the ODE.
58
+ The one-step solution of the ODE.
59
59
  """
60
60
  assert callable(fn), 'The input function should be callable.'
61
61
  assert len(args) > 0, 'The input arguments should not be empty.'
@@ -201,9 +201,11 @@ class SGD(_WeightDecayOptimizer):
201
201
  def update(self, grads: dict):
202
202
  lr = self.lr()
203
203
  weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
204
- updates = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
205
- weight_values,
206
- grad_values)
204
+ updates = jax.tree.map(
205
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
206
+ weight_values,
207
+ grad_values
208
+ )
207
209
  self.param_states.assign_values(updates)
208
210
  self.lr.step_call()
209
211
 
@@ -324,12 +326,16 @@ class MomentumNesterov(_WeightDecayOptimizer):
324
326
  def update(self, grads: dict):
325
327
  lr = self.lr()
326
328
  states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
327
- momentum_values = jax.tree.map(lambda mv, gv: self.momentum * mv - lr * gv,
328
- momentum_values,
329
- grad_values)
330
- weight_values = jax.tree.map(functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
331
- states_values,
332
- momentum_values)
329
+ momentum_values = jax.tree.map(
330
+ lambda mv, gv: self.momentum * mv - lr * gv,
331
+ momentum_values,
332
+ grad_values
333
+ )
334
+ weight_values = jax.tree.map(
335
+ functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
336
+ states_values,
337
+ momentum_values
338
+ )
333
339
  self.param_states.assign_values(weight_values)
334
340
  self.momentum_states.assign_values(momentum_values)
335
341
  self.lr.step_call()
@@ -388,11 +394,21 @@ class Adagrad(_WeightDecayOptimizer):
388
394
  def update(self, grads: dict):
389
395
  lr = self.lr()
390
396
  cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
391
- cache_values = jax.tree.map(lambda cv, gv: cv + gv ** 2, cache_values, grad_values)
392
- updates = jax.tree.map(lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon), cache_values, grad_values)
393
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
394
- weight_values,
395
- updates)
397
+ cache_values = jax.tree.map(
398
+ lambda cv, gv: cv + gv ** 2,
399
+ cache_values,
400
+ grad_values
401
+ )
402
+ updates = jax.tree.map(
403
+ lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon),
404
+ cache_values,
405
+ grad_values
406
+ )
407
+ weight_values = jax.tree.map(
408
+ functools.partial(_sgd, weight_decay=self.weight_decay),
409
+ weight_values,
410
+ updates
411
+ )
396
412
  self.cache_states.assign_values(cache_values)
397
413
  self.param_states.assign_values(weight_values)
398
414
  self.lr.step_call()
@@ -605,13 +621,28 @@ class Adam(_WeightDecayOptimizer):
605
621
  lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
606
622
  lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
607
623
  weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
608
- self.param_states, grads, self.m1_states, self.m2_states)
609
- m1_values = jax.tree.map(lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv, m1_values, grad_values)
610
- m2_values = jax.tree.map(lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2, m2_values, grad_values)
611
- update = jax.tree.map(lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps), m1_values, m2_values)
612
- weight_values = jax.tree.map(functools.partial(_sgd, weight_decay=self.weight_decay),
613
- weight_values,
614
- update)
624
+ self.param_states, grads, self.m1_states, self.m2_states
625
+ )
626
+ m1_values = jax.tree.map(
627
+ lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv,
628
+ m1_values,
629
+ grad_values
630
+ )
631
+ m2_values = jax.tree.map(
632
+ lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2,
633
+ m2_values,
634
+ grad_values
635
+ )
636
+ update = jax.tree.map(
637
+ lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps),
638
+ m1_values,
639
+ m2_values
640
+ )
641
+ weight_values = jax.tree.map(
642
+ functools.partial(_sgd, weight_decay=self.weight_decay),
643
+ weight_values,
644
+ update
645
+ )
615
646
  self.param_states.assign_values(weight_values)
616
647
  self.m1_states.assign_values(m1_values)
617
648
  self.m2_states.assign_values(m2_values)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250218
3
+ Version: 0.1.0.post20250315
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -62,6 +62,7 @@ Requires-Dist: jaxlib[tpu] ; extra == 'tpu'
62
62
  <a href="https://badge.fury.io/py/brainstate"><img alt="PyPI version" src="https://badge.fury.io/py/brainstate.svg"></a>
63
63
  <a href="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml/badge.svg"></a>
64
64
  <a href="https://pepy.tech/projects/brainstate"><img src="https://static.pepy.tech/badge/brainstate" alt="PyPI Downloads"></a>
65
+ <a href="https://doi.org/10.5281/zenodo.14970015"><img src="https://zenodo.org/badge/811300394.svg" alt="DOI"></a>
65
66
  </p>
66
67
 
67
68
 
@@ -81,8 +82,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
81
82
 
82
83
 
83
84
 
84
- ## See also the BDP ecosystem
85
+ ## See also the brain modeling ecosystem
85
86
 
86
- We are building the Brain Dynamics Programming ecosystem: https://ecosystem-for-brain-dynamics.readthedocs.io/
87
+ We are building the brain modeling ecosystem: https://brainmodeling.readthedocs.io/
87
88
 
88
89
 
@@ -10,12 +10,12 @@ brainstate/surrogate.py,sha256=wWYw-TxaFxHVneXuHjWD1UtTcOTk3XRSnhRtUkt_Hb8,53580
10
10
  brainstate/transform.py,sha256=vZWzO4F7qTsXL_SiVQPlTz0l9b_hRo9D-igETfgCTy0,758
11
11
  brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
12
12
  brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
13
- brainstate/augment/_autograd.py,sha256=eAXdFPOEei2yTHtrhMXQBpC-36fmc6YZu3YWw4q_Wmk,29682
13
+ brainstate/augment/_autograd.py,sha256=hfDoa2HbkRn-InOS0yOcb6gEZ2DLNqtWA133P8-hvIo,30138
14
14
  brainstate/augment/_autograd_test.py,sha256=2wCC8aUcDp2IHgF7wr1GK5HwWfELXni5PpA-082azuU,44058
15
15
  brainstate/augment/_eval_shape.py,sha256=jgsS197Nizehr9A2nGaQPE7NuNujhFhmR3J96hTicX8,3890
16
16
  brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
17
- brainstate/augment/_mapping.py,sha256=XqKRM88yAvvhWzSUqwq67sX49xV8oYtUdcmQ0KO9hVs,42532
18
- brainstate/augment/_mapping_test.py,sha256=V9NOpiZydDDUUYnLLCJVU9KJ-2IRNuVbL-ziuW-pU1I,12033
17
+ brainstate/augment/_mapping.py,sha256=BPwpD7jX4xRNl4BdAsKGoF45MKbmEF9Lyyp11pJucIg,43356
18
+ brainstate/augment/_mapping_test.py,sha256=-4HJXmJw_6SD9dQnHTBjgYVuq6VTVjz0xpc9v2CJVNw,13414
19
19
  brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors,5386
20
20
  brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
21
21
  brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
@@ -30,7 +30,7 @@ brainstate/compile/_loop_collect_return.py,sha256=-LsP7fkHmAyGnDOKa3BxxYOEWe8M2J
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=D9RQ5RyQHkqBr4nmSK-yM_uge3EC6uVm_Dzy42g3vtg,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=2OEVtv5XztOx-e0focZ1UnWkXmFzmDskjHJXuVXmuhA,7587
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=QGbVU_Y6pEfFgr61v_gmsRS3HXcHp7ILV_JZk_e3J4o,33213
33
+ brainstate/compile/_make_jaxpr.py,sha256=8iV8XyvkMH3n3wbEWZAgZtbrUxryljwQJD6o5DMW9Lc,33189
34
34
  brainstate/compile/_make_jaxpr_test.py,sha256=fZe3K4RHFLmMAeXZoFZ5RyxgXvncTcuMQdjmOROJtKU,4365
35
35
  brainstate/compile/_progress_bar.py,sha256=3Z3OVcc5sl9FK9Fkt813l20MNzEfa6UZ9lJrvSgXTCU,7522
36
36
  brainstate/compile/_unvmap.py,sha256=uCvQjvb8J7kT0kalX576mrAPvQuCh_W76EPdgZ53kTM,4230
@@ -40,7 +40,7 @@ brainstate/functional/_activations.py,sha256=VmCU9HOKWbysxuJFBN-JsShS4loNMG_E6IX
40
40
  brainstate/functional/_activations_test.py,sha256=-bCijTvo4Wo_P283RYKYMPcTLsjhu5i2X9ySdf1ayEY,13034
41
41
  brainstate/functional/_normalization.py,sha256=L3S4DIF1EztrlE4_KHX7j_m6Mw0mpAwnx5UTAX6YYBU,2603
42
42
  brainstate/functional/_others.py,sha256=eBV43WqQsDvHkkwX0xbqCRoIoJlngMFLSUKgleH2dt0,1735
43
- brainstate/functional/_spikes.py,sha256=EZqwJctEElHhxaC9tsm8WKkXTQJJzo901Db0I3QS3iM,7086
43
+ brainstate/functional/_spikes.py,sha256=7FTfCfEN1mjlY-EULzCisk7_NOmxZPj-mp-ODncW7R0,7087
44
44
  brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
45
45
  brainstate/graph/_graph_node.py,sha256=JE1Tc0mK3nJFWUFzXE53MWWiEYiXJO5VdqZEYKbXlw0,6872
46
46
  brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
@@ -54,10 +54,10 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
54
54
  brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
55
55
  brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
56
56
  brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
57
- brainstate/nn/_collective_ops.py,sha256=yQNBnh-XVEFnTg-Ga14mHOCGtGxiTkL9MYKdNjJF1BI,17535
57
+ brainstate/nn/_collective_ops.py,sha256=NI9BT-908TbIlXLMjbWsPyI5YLZD_cCkSKGeOY-qO60,17512
58
58
  brainstate/nn/_collective_ops_test.py,sha256=yW7NNYsGFglFRFkqVlpGSY6WLnU-h8GlK6wCmG5jtRc,1189
59
59
  brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
60
- brainstate/nn/_exp_euler.py,sha256=cRgPNcjMs2C9x_8JabtYz5hm_FwqbiJ_U1VfRHYIlrE,3519
60
+ brainstate/nn/_exp_euler.py,sha256=s-Z_cT_oYvCvE-OaXuUidIxQs3KOy1pzkx1lwtfPo00,3529
61
61
  brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
62
62
  brainstate/nn/_module.py,sha256=vrukVI0ylbymzilh9BZtb-d9dnsBsykqanUNTx9Eb6Y,12844
63
63
  brainstate/nn/_module_test.py,sha256=UrVA85fo0KVFN9ApPkxkRcvtXEskWOXPzZIBa4JSFo0,8891
@@ -66,7 +66,7 @@ brainstate/nn/metrics.py,sha256=p7eVwd5y8r0N5rMws-zOS_KaZCLOMdrXyQvLnoJeq1w,1473
66
66
  brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
67
67
  brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=mcDxVZlk56NAEkR6xcE74hOZ9up8Rua4SvKEeAhJKU4,10925
68
68
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=_wPp6UvWVZI9EYba-DWL_JZXyMxm0-SHDkZHI8lKp8w,6315
69
- brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=Nk_c62nCJjjjhihQV_xqYeT-x_34xzc6OhdScO19ffw,15267
69
+ brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=xXMNlDWX0tQ9N0zORfT4DFKoEXtrsRbeetqOq-bYovs,15518
70
70
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=VUDMlHNcyeqHrBd1eAXg_VD0HCSg5C-eqMmcJVzYcJA,4979
71
71
  brainstate/nn/_dyn_impl/_inputs.py,sha256=ubM5Z2q0gXpJ2M5Das3A5MJpFOorVomfW6-079mqJ9k,12935
72
72
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
@@ -102,7 +102,7 @@ brainstate/optim/_lr_scheduler.py,sha256=Ua_H3VWUt9QZ0pHtGpnq2wrqsTOUidXJ3VDz9s-
102
102
  brainstate/optim/_lr_scheduler_test.py,sha256=W0F1eHb9S4seE468c26owcJIWTtgNhZYrOi2GrysVNI,1774
103
103
  brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSMLy8hLcXE,5312
104
104
  brainstate/optim/_optax_optimizer_test.py,sha256=J7zvmeSaWmTlfbpjx1ILb9cSC5qlj1wn4H2QMw9jUY0,1760
105
- brainstate/optim/_sgd_optimizer.py,sha256=oOFUEqCFX-WfhMnB614AjScbgFYP8y-zIwb94FZ_olA,46006
105
+ brainstate/optim/_sgd_optimizer.py,sha256=RDQLrWsJFeEpTc93toDbf4hXnHmfY0hpTs38z1AZYPY,46144
106
106
  brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
107
107
  brainstate/random/_rand_funs.py,sha256=fmyaBTl_P7M63RRp5V2dOR_ttXQ0vB2qU2C8RqfkjY0,137680
108
108
  brainstate/random/_rand_funs_test.py,sha256=Nhy5gXuJ2ld9u8CTCqU1V94FPm0TvYQ-oMy2bP_CZ7I,19436
@@ -121,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
121
121
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
122
122
  brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
123
123
  brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
124
- brainstate-0.1.0.post20250218.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
125
- brainstate-0.1.0.post20250218.dist-info/METADATA,sha256=INsdNphwnxGh07Urn2gjiOYzPEEQ8Lcr19bugHUYrU4,3585
126
- brainstate-0.1.0.post20250218.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
127
- brainstate-0.1.0.post20250218.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
128
- brainstate-0.1.0.post20250218.dist-info/RECORD,,
124
+ brainstate-0.1.0.post20250315.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
125
+ brainstate-0.1.0.post20250315.dist-info/METADATA,sha256=8gOEdv6PiXBLr_gAvx70Yik7G7XxidMVPPLOLx3ndPc,3689
126
+ brainstate-0.1.0.post20250315.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
127
+ brainstate-0.1.0.post20250315.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
128
+ brainstate-0.1.0.post20250315.dist-info/RECORD,,