brainstate 0.1.0.post20250218__py2.py3-none-any.whl → 0.1.0.post20250315__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/augment/_autograd.py +13 -3
- brainstate/augment/_mapping.py +30 -9
- brainstate/augment/_mapping_test.py +52 -1
- brainstate/compile/_make_jaxpr.py +2 -2
- brainstate/functional/_spikes.py +1 -1
- brainstate/nn/_collective_ops.py +14 -16
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +16 -6
- brainstate/nn/_exp_euler.py +5 -5
- brainstate/optim/_sgd_optimizer.py +52 -21
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/METADATA +4 -3
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/RECORD +14 -14
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/top_level.txt +0 -0
brainstate/augment/_autograd.py
CHANGED
@@ -29,11 +29,12 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
29
29
|
|
30
30
|
from __future__ import annotations
|
31
31
|
|
32
|
-
import brainunit as u
|
33
|
-
import jax
|
34
32
|
from functools import wraps, partial
|
35
33
|
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
36
34
|
|
35
|
+
import brainunit as u
|
36
|
+
import jax
|
37
|
+
|
37
38
|
from brainstate._state import State
|
38
39
|
from brainstate._utils import set_module_as
|
39
40
|
from brainstate.compile._make_jaxpr import StatefulFunction
|
@@ -195,6 +196,7 @@ class GradientTransform(PrettyRepr):
|
|
195
196
|
grad_states = {k: v for k, v in grad_states.items()}
|
196
197
|
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
|
197
198
|
self._grad_state_ids = [id(v) for v in self._grad_states]
|
199
|
+
self._grad_id_to_state = {id(v): v for v in self._grad_states}
|
198
200
|
if any(not isinstance(v, State) for v in self._grad_states):
|
199
201
|
raise TypeError("All grad_states must be State instances.")
|
200
202
|
|
@@ -250,12 +252,20 @@ class GradientTransform(PrettyRepr):
|
|
250
252
|
"""
|
251
253
|
grad_vals = dict()
|
252
254
|
other_vals = dict()
|
255
|
+
all_ids = set(self._grad_state_ids)
|
253
256
|
for st in state_trace.states:
|
254
257
|
id_ = id(st)
|
255
|
-
if id_ in
|
258
|
+
if id_ in all_ids:
|
256
259
|
grad_vals[id_] = st.value
|
260
|
+
all_ids.remove(id_)
|
257
261
|
else:
|
258
262
|
other_vals[id_] = st.value
|
263
|
+
if len(all_ids):
|
264
|
+
err = f"Some states are not found in the state trace when performing gradient transformations.\n "
|
265
|
+
for i, id_ in enumerate(all_ids):
|
266
|
+
st = self._grad_id_to_state[id_]
|
267
|
+
st.raise_error_with_source_info(ValueError(err + str(st)))
|
268
|
+
|
259
269
|
return grad_vals, other_vals
|
260
270
|
|
261
271
|
def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
|
brainstate/augment/_mapping.py
CHANGED
@@ -16,8 +16,6 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
|
-
from jax.interpreters.batching import BatchTracer
|
21
19
|
from typing import (
|
22
20
|
Any,
|
23
21
|
TypeVar,
|
@@ -32,6 +30,9 @@ from typing import (
|
|
32
30
|
List
|
33
31
|
)
|
34
32
|
|
33
|
+
import jax
|
34
|
+
from jax.interpreters.batching import BatchTracer
|
35
|
+
|
35
36
|
from brainstate._state import State, catch_new_states
|
36
37
|
from brainstate.compile import scan, StatefulFunction
|
37
38
|
from brainstate.random import RandomState, DEFAULT
|
@@ -378,8 +379,15 @@ def _vmap_transform(
|
|
378
379
|
# call the function
|
379
380
|
return f(*args)
|
380
381
|
|
382
|
+
def _set_axis_env(batch_size):
|
383
|
+
axis_env = None if axis_name is None else [(axis_name, batch_size)]
|
384
|
+
stateful_fn.axis_env = axis_env
|
385
|
+
|
381
386
|
# stateful function
|
382
|
-
stateful_fn = StatefulFunction(
|
387
|
+
stateful_fn = StatefulFunction(
|
388
|
+
_vmap_fn_for_compilation,
|
389
|
+
name='vmap',
|
390
|
+
)
|
383
391
|
|
384
392
|
@functools.wraps(f)
|
385
393
|
def new_fn_for_vmap(
|
@@ -506,6 +514,10 @@ def _vmap_transform(
|
|
506
514
|
st_in_axes = 0
|
507
515
|
|
508
516
|
# compile stateful function
|
517
|
+
batch_size = None
|
518
|
+
if axis_name is not None:
|
519
|
+
batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
|
520
|
+
_set_axis_env(batch_size)
|
509
521
|
cache_key = _compile_stateful_function(
|
510
522
|
stateful_fn,
|
511
523
|
(st_in_axes, in_axes),
|
@@ -518,7 +530,8 @@ def _vmap_transform(
|
|
518
530
|
rng_sets = set(rngs)
|
519
531
|
if len(rngs):
|
520
532
|
# batch size
|
521
|
-
batch_size
|
533
|
+
if batch_size is None:
|
534
|
+
batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
|
522
535
|
rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
|
523
536
|
rng_backup = tuple(rng.split_key() for rng in rngs)
|
524
537
|
else:
|
@@ -905,11 +918,11 @@ def map(
|
|
905
918
|
g = lambda _, x: ((), vmap(f)(*x))
|
906
919
|
_, scan_ys = scan(g, (), scan_xs)
|
907
920
|
if remainder_xs is None:
|
908
|
-
ys = jax.tree.map(lambda x:
|
921
|
+
ys = jax.tree.map(lambda x: _flatten(x), scan_ys)
|
909
922
|
else:
|
910
923
|
remainder_ys = vmap(f)(*remainder_xs)
|
911
924
|
ys = jax.tree.map(
|
912
|
-
lambda x, y: jax.lax.concatenate([
|
925
|
+
lambda x, y: jax.lax.concatenate([_flatten(x), y], dimension=0),
|
913
926
|
scan_ys,
|
914
927
|
remainder_ys,
|
915
928
|
)
|
@@ -919,7 +932,7 @@ def map(
|
|
919
932
|
return ys
|
920
933
|
|
921
934
|
|
922
|
-
def
|
935
|
+
def _flatten(x):
|
923
936
|
return x.reshape(-1, *x.shape[2:])
|
924
937
|
|
925
938
|
|
@@ -935,17 +948,19 @@ def _vmap_new_states_transform(
|
|
935
948
|
# -- brainstate specific arguments -- #
|
936
949
|
state_tag: str | None = None,
|
937
950
|
state_to_exclude: Filter | None = None,
|
951
|
+
in_states: Dict[int, Dict] | Any | None = None,
|
952
|
+
out_states: Dict[int, Dict] | Any | None = None,
|
938
953
|
):
|
939
|
-
|
940
954
|
# TODO: How about nested call ``vmap_new_states``?
|
941
955
|
|
942
|
-
|
943
956
|
@vmap(
|
944
957
|
in_axes=in_axes,
|
945
958
|
out_axes=out_axes,
|
946
959
|
axis_name=axis_name,
|
947
960
|
axis_size=axis_size,
|
948
961
|
spmd_axis_name=spmd_axis_name,
|
962
|
+
in_states=in_states,
|
963
|
+
out_states=out_states,
|
949
964
|
)
|
950
965
|
def new_fun(args):
|
951
966
|
# call the function
|
@@ -988,6 +1003,8 @@ def vmap_new_states(
|
|
988
1003
|
# -- brainstate specific arguments -- #
|
989
1004
|
state_tag: str | None = None,
|
990
1005
|
state_to_exclude: Filter = None,
|
1006
|
+
in_states: Dict[int, Dict] | Any | None = None,
|
1007
|
+
out_states: Dict[int, Dict] | Any | None = None,
|
991
1008
|
):
|
992
1009
|
"""
|
993
1010
|
Vectorize a function over new states created within it.
|
@@ -1019,6 +1036,8 @@ def vmap_new_states(
|
|
1019
1036
|
spmd_axis_name=spmd_axis_name,
|
1020
1037
|
state_tag=state_tag,
|
1021
1038
|
state_to_exclude=state_to_exclude,
|
1039
|
+
in_states=in_states,
|
1040
|
+
out_states=out_states,
|
1022
1041
|
)
|
1023
1042
|
else:
|
1024
1043
|
return _vmap_new_states_transform(
|
@@ -1030,4 +1049,6 @@ def vmap_new_states(
|
|
1030
1049
|
spmd_axis_name=spmd_axis_name,
|
1031
1050
|
state_tag=state_tag,
|
1032
1051
|
state_to_exclude=state_to_exclude,
|
1052
|
+
in_states=in_states,
|
1053
|
+
out_states=out_states,
|
1033
1054
|
)
|
@@ -15,12 +15,14 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import unittest
|
19
|
+
|
18
20
|
import jax
|
19
21
|
import jax.numpy as jnp
|
20
22
|
import numpy as np
|
21
|
-
import unittest
|
22
23
|
|
23
24
|
import brainstate as bst
|
25
|
+
import brainstate.augment
|
24
26
|
from brainstate.augment._mapping import BatchAxisError
|
25
27
|
from brainstate.augment._mapping import _remove_axis
|
26
28
|
|
@@ -264,6 +266,55 @@ class TestVmap(unittest.TestCase):
|
|
264
266
|
res2 = jax.vmap(f, axis_size=10)()
|
265
267
|
self.assertTrue(jnp.all((res2[0] == res2[1:])))
|
266
268
|
|
269
|
+
def test_axis(self):
|
270
|
+
def f(x):
|
271
|
+
return x - jax.lax.pmean(x, 'i')
|
272
|
+
|
273
|
+
r = jax.vmap(f, axis_name='i')(jnp.arange(10))
|
274
|
+
print(r)
|
275
|
+
|
276
|
+
r2 = brainstate.augment.vmap(f, axis_name='i')(jnp.arange(10))
|
277
|
+
print(r2)
|
278
|
+
self.assertTrue(jnp.allclose(r, r2))
|
279
|
+
|
280
|
+
def test_vmap_init(self):
|
281
|
+
class Foo(bst.nn.Module):
|
282
|
+
def __init__(self):
|
283
|
+
super().__init__()
|
284
|
+
self.a = bst.ParamState(jnp.arange(4))
|
285
|
+
self.b = bst.ShortTermState(jnp.arange(4))
|
286
|
+
|
287
|
+
def init_state_v1(self, *args, **kwargs):
|
288
|
+
self.c = bst.State(jnp.arange(4))
|
289
|
+
|
290
|
+
def init_state_v2(self):
|
291
|
+
self.d = bst.State(self.c.value * 2.)
|
292
|
+
|
293
|
+
foo = Foo()
|
294
|
+
|
295
|
+
@brainstate.augment.vmap_new_states(state_tag='new1', axis_size=5)
|
296
|
+
def init1():
|
297
|
+
foo.init_state_v1()
|
298
|
+
|
299
|
+
init1()
|
300
|
+
print(foo.c.value)
|
301
|
+
|
302
|
+
@brainstate.augment.vmap_new_states(state_tag='new2', axis_size=5, in_states=foo.states('new1'))
|
303
|
+
def init2():
|
304
|
+
foo.init_state_v2()
|
305
|
+
|
306
|
+
init2()
|
307
|
+
print(foo.c.value)
|
308
|
+
print(foo.d.value)
|
309
|
+
|
310
|
+
self.assertTrue(
|
311
|
+
jnp.allclose(
|
312
|
+
foo.d.value,
|
313
|
+
foo.c.value * 2.
|
314
|
+
)
|
315
|
+
)
|
316
|
+
|
317
|
+
|
267
318
|
|
268
319
|
class TestMap(unittest.TestCase):
|
269
320
|
def test_map(self):
|
@@ -753,8 +753,8 @@ def _make_jaxpr(
|
|
753
753
|
if jax.__version_info__ < (0, 5, 0):
|
754
754
|
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
755
755
|
with ExitStack() as stack:
|
756
|
-
|
757
|
-
stack.enter_context(jax.core.
|
756
|
+
if axis_env is not None:
|
757
|
+
stack.enter_context(jax.core.extend_axis_env_nd(axis_env))
|
758
758
|
if jax.__version_info__ < (0, 5, 0):
|
759
759
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
760
760
|
else:
|
brainstate/functional/_spikes.py
CHANGED
brainstate/nn/_collective_ops.py
CHANGED
@@ -18,9 +18,7 @@ from __future__ import annotations
|
|
18
18
|
from collections import namedtuple
|
19
19
|
|
20
20
|
import jax
|
21
|
-
from typing import
|
22
|
-
Callable, TypeVar, Tuple, Any, Dict
|
23
|
-
)
|
21
|
+
from typing import Callable, TypeVar, Tuple, Any, Dict
|
24
22
|
|
25
23
|
from brainstate._state import catch_new_states
|
26
24
|
from brainstate._utils import set_module_as
|
@@ -103,7 +101,7 @@ def call_all_functions(
|
|
103
101
|
on each node. It respects the call order of functions if defined, and provides options for
|
104
102
|
handling cases where the specified function does not exist on a node.
|
105
103
|
|
106
|
-
Parameters
|
104
|
+
Parameters
|
107
105
|
-----------
|
108
106
|
target : T
|
109
107
|
The target module on which to call functions.
|
@@ -121,12 +119,12 @@ def call_all_functions(
|
|
121
119
|
- 'raise': Raise an exception (default)
|
122
120
|
- 'pass' or 'none': Skip the node and continue
|
123
121
|
|
124
|
-
Returns
|
122
|
+
Returns
|
125
123
|
--------
|
126
124
|
T
|
127
125
|
The target module after calling the specified function on all applicable nodes.
|
128
126
|
|
129
|
-
Raises
|
127
|
+
Raises
|
130
128
|
-------
|
131
129
|
AssertionError
|
132
130
|
If fun_name is not a string or kwargs is not a dictionary.
|
@@ -186,7 +184,7 @@ def vmap_call_all_functions(
|
|
186
184
|
This function vectorizes the process of calling a specified function across multiple instances
|
187
185
|
of the target module, effectively batching the operation.
|
188
186
|
|
189
|
-
Parameters
|
187
|
+
Parameters
|
190
188
|
-----------
|
191
189
|
target : T
|
192
190
|
The target module on which to call functions.
|
@@ -208,12 +206,12 @@ def vmap_call_all_functions(
|
|
208
206
|
- 'raise': Raise an exception (default)
|
209
207
|
- 'pass' or 'none': Skip the node and continue
|
210
208
|
|
211
|
-
Returns
|
209
|
+
Returns
|
212
210
|
--------
|
213
211
|
T
|
214
212
|
The target module after applying the vectorized function call on all applicable nodes.
|
215
213
|
|
216
|
-
Raises
|
214
|
+
Raises
|
217
215
|
-------
|
218
216
|
AssertionError
|
219
217
|
If axis_size is not specified or is not a positive integer.
|
@@ -304,7 +302,7 @@ def vmap_init_all_states(
|
|
304
302
|
This function applies vectorized mapping (vmap) to initialize states across multiple
|
305
303
|
instances of the target module, effectively batching the initialization process.
|
306
304
|
|
307
|
-
Parameters
|
305
|
+
Parameters
|
308
306
|
-----------
|
309
307
|
target : T
|
310
308
|
The target module whose states are to be initialized.
|
@@ -319,12 +317,12 @@ def vmap_init_all_states(
|
|
319
317
|
state_tag : str | None, optional
|
320
318
|
A tag to be used for catching new states.
|
321
319
|
|
322
|
-
Returns
|
320
|
+
Returns
|
323
321
|
--------
|
324
322
|
T
|
325
323
|
The target module with initialized states.
|
326
324
|
|
327
|
-
Raises
|
325
|
+
Raises
|
328
326
|
-------
|
329
327
|
AssertionError
|
330
328
|
If axis_size is not specified or is not greater than 0.
|
@@ -413,7 +411,7 @@ def vmap_reset_all_states(
|
|
413
411
|
This function applies vectorized mapping (vmap) to reset states across multiple
|
414
412
|
instances of the target module, effectively batching the reset process.
|
415
413
|
|
416
|
-
Parameters
|
414
|
+
Parameters
|
417
415
|
-----------
|
418
416
|
target : T
|
419
417
|
The target module whose states are to be reset.
|
@@ -428,12 +426,12 @@ def vmap_reset_all_states(
|
|
428
426
|
tag : str | None, optional
|
429
427
|
A tag to be used for catching new states.
|
430
428
|
|
431
|
-
Returns
|
429
|
+
Returns
|
432
430
|
--------
|
433
431
|
T
|
434
432
|
The target module with reset states.
|
435
433
|
|
436
|
-
Raises
|
434
|
+
Raises
|
437
435
|
-------
|
438
436
|
AssertionError
|
439
437
|
If axis_size is not specified or is not greater than 0.
|
@@ -486,7 +484,7 @@ def save_all_states(target: Module, **kwargs) -> Dict:
|
|
486
484
|
Args:
|
487
485
|
target: Module. The node to save its states.
|
488
486
|
|
489
|
-
Returns
|
487
|
+
Returns
|
490
488
|
Dict. The state dict for serialization.
|
491
489
|
"""
|
492
490
|
return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
|
@@ -17,8 +17,9 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from typing import Optional, Callable
|
21
|
+
|
20
22
|
import brainunit as u
|
21
|
-
from typing import Optional
|
22
23
|
|
23
24
|
from brainstate import init, environ
|
24
25
|
from brainstate._state import ShortTermState, HiddenState
|
@@ -54,7 +55,7 @@ class Expon(Synapse, AlignPost):
|
|
54
55
|
in_size: Size,
|
55
56
|
name: Optional[str] = None,
|
56
57
|
tau: ArrayLike = 8.0 * u.ms,
|
57
|
-
g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
|
58
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
58
59
|
):
|
59
60
|
super().__init__(name=name, in_size=in_size)
|
60
61
|
|
@@ -85,7 +86,7 @@ class DualExpon(Synapse, AlignPost):
|
|
85
86
|
tau_decay: ArrayLike = 10.0 * u.ms,
|
86
87
|
tau_rise: ArrayLike = 1.0 * u.ms,
|
87
88
|
A: Optional[ArrayLike] = None,
|
88
|
-
g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
|
89
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
89
90
|
):
|
90
91
|
super().__init__(name=name, in_size=in_size)
|
91
92
|
|
@@ -133,7 +134,7 @@ class Alpha(Synapse):
|
|
133
134
|
in_size: Size,
|
134
135
|
name: Optional[str] = None,
|
135
136
|
tau: ArrayLike = 8.0 * u.ms,
|
136
|
-
g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
|
137
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
137
138
|
):
|
138
139
|
super().__init__(name=name, in_size=in_size)
|
139
140
|
|
@@ -321,7 +322,7 @@ class AMPA(Synapse):
|
|
321
322
|
beta: ArrayLike = 0.18 / u.ms,
|
322
323
|
T: ArrayLike = 0.5 * u.mM,
|
323
324
|
T_dur: ArrayLike = 0.5 * u.ms,
|
324
|
-
g_initializer: ArrayLike = init.ZeroInit(),
|
325
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(),
|
325
326
|
):
|
326
327
|
super().__init__(name=name, in_size=in_size)
|
327
328
|
|
@@ -394,5 +395,14 @@ class GABAa(AMPA):
|
|
394
395
|
beta: ArrayLike = 0.18 / u.ms,
|
395
396
|
T: ArrayLike = 1.0 * u.mM,
|
396
397
|
T_dur: ArrayLike = 1.0 * u.ms,
|
398
|
+
g_initializer: ArrayLike | Callable = init.ZeroInit(),
|
397
399
|
):
|
398
|
-
super().__init__(
|
400
|
+
super().__init__(
|
401
|
+
alpha=alpha,
|
402
|
+
beta=beta,
|
403
|
+
T=T,
|
404
|
+
T_dur=T_dur,
|
405
|
+
name=name,
|
406
|
+
in_size=in_size,
|
407
|
+
g_initializer=g_initializer
|
408
|
+
)
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -49,13 +49,13 @@ def exp_euler_step(
|
|
49
49
|
should have units of ( [X]/\sqrt{[T]} ).
|
50
50
|
|
51
51
|
Args:
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
52
|
+
fun: Callable. The function to be solved.
|
53
|
+
diffusion: Callable. The diffusion function.
|
54
|
+
*args: The input arguments.
|
55
|
+
drift: Callable. The drift function.
|
56
56
|
|
57
57
|
Returns:
|
58
|
-
|
58
|
+
The one-step solution of the ODE.
|
59
59
|
"""
|
60
60
|
assert callable(fn), 'The input function should be callable.'
|
61
61
|
assert len(args) > 0, 'The input arguments should not be empty.'
|
@@ -201,9 +201,11 @@ class SGD(_WeightDecayOptimizer):
|
|
201
201
|
def update(self, grads: dict):
|
202
202
|
lr = self.lr()
|
203
203
|
weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
|
204
|
-
updates = jax.tree.map(
|
205
|
-
|
206
|
-
|
204
|
+
updates = jax.tree.map(
|
205
|
+
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
206
|
+
weight_values,
|
207
|
+
grad_values
|
208
|
+
)
|
207
209
|
self.param_states.assign_values(updates)
|
208
210
|
self.lr.step_call()
|
209
211
|
|
@@ -324,12 +326,16 @@ class MomentumNesterov(_WeightDecayOptimizer):
|
|
324
326
|
def update(self, grads: dict):
|
325
327
|
lr = self.lr()
|
326
328
|
states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
|
327
|
-
momentum_values = jax.tree.map(
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
329
|
+
momentum_values = jax.tree.map(
|
330
|
+
lambda mv, gv: self.momentum * mv - lr * gv,
|
331
|
+
momentum_values,
|
332
|
+
grad_values
|
333
|
+
)
|
334
|
+
weight_values = jax.tree.map(
|
335
|
+
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
336
|
+
states_values,
|
337
|
+
momentum_values
|
338
|
+
)
|
333
339
|
self.param_states.assign_values(weight_values)
|
334
340
|
self.momentum_states.assign_values(momentum_values)
|
335
341
|
self.lr.step_call()
|
@@ -388,11 +394,21 @@ class Adagrad(_WeightDecayOptimizer):
|
|
388
394
|
def update(self, grads: dict):
|
389
395
|
lr = self.lr()
|
390
396
|
cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
|
391
|
-
cache_values = jax.tree.map(
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
397
|
+
cache_values = jax.tree.map(
|
398
|
+
lambda cv, gv: cv + gv ** 2,
|
399
|
+
cache_values,
|
400
|
+
grad_values
|
401
|
+
)
|
402
|
+
updates = jax.tree.map(
|
403
|
+
lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon),
|
404
|
+
cache_values,
|
405
|
+
grad_values
|
406
|
+
)
|
407
|
+
weight_values = jax.tree.map(
|
408
|
+
functools.partial(_sgd, weight_decay=self.weight_decay),
|
409
|
+
weight_values,
|
410
|
+
updates
|
411
|
+
)
|
396
412
|
self.cache_states.assign_values(cache_values)
|
397
413
|
self.param_states.assign_values(weight_values)
|
398
414
|
self.lr.step_call()
|
@@ -605,13 +621,28 @@ class Adam(_WeightDecayOptimizer):
|
|
605
621
|
lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
|
606
622
|
lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
|
607
623
|
weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
|
608
|
-
self.param_states, grads, self.m1_states, self.m2_states
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
624
|
+
self.param_states, grads, self.m1_states, self.m2_states
|
625
|
+
)
|
626
|
+
m1_values = jax.tree.map(
|
627
|
+
lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv,
|
628
|
+
m1_values,
|
629
|
+
grad_values
|
630
|
+
)
|
631
|
+
m2_values = jax.tree.map(
|
632
|
+
lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2,
|
633
|
+
m2_values,
|
634
|
+
grad_values
|
635
|
+
)
|
636
|
+
update = jax.tree.map(
|
637
|
+
lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps),
|
638
|
+
m1_values,
|
639
|
+
m2_values
|
640
|
+
)
|
641
|
+
weight_values = jax.tree.map(
|
642
|
+
functools.partial(_sgd, weight_decay=self.weight_decay),
|
643
|
+
weight_values,
|
644
|
+
update
|
645
|
+
)
|
615
646
|
self.param_states.assign_values(weight_values)
|
616
647
|
self.m1_states.assign_values(m1_values)
|
617
648
|
self.m2_states.assign_values(m2_values)
|
{brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250315
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -62,6 +62,7 @@ Requires-Dist: jaxlib[tpu] ; extra == 'tpu'
|
|
62
62
|
<a href="https://badge.fury.io/py/brainstate"><img alt="PyPI version" src="https://badge.fury.io/py/brainstate.svg"></a>
|
63
63
|
<a href="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml/badge.svg"></a>
|
64
64
|
<a href="https://pepy.tech/projects/brainstate"><img src="https://static.pepy.tech/badge/brainstate" alt="PyPI Downloads"></a>
|
65
|
+
<a href="https://doi.org/10.5281/zenodo.14970015"><img src="https://zenodo.org/badge/811300394.svg" alt="DOI"></a>
|
65
66
|
</p>
|
66
67
|
|
67
68
|
|
@@ -81,8 +82,8 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
|
|
81
82
|
|
82
83
|
|
83
84
|
|
84
|
-
## See also the
|
85
|
+
## See also the brain modeling ecosystem
|
85
86
|
|
86
|
-
We are building the
|
87
|
+
We are building the brain modeling ecosystem: https://brainmodeling.readthedocs.io/
|
87
88
|
|
88
89
|
|
@@ -10,12 +10,12 @@ brainstate/surrogate.py,sha256=wWYw-TxaFxHVneXuHjWD1UtTcOTk3XRSnhRtUkt_Hb8,53580
|
|
10
10
|
brainstate/transform.py,sha256=vZWzO4F7qTsXL_SiVQPlTz0l9b_hRo9D-igETfgCTy0,758
|
11
11
|
brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
|
12
12
|
brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
|
13
|
-
brainstate/augment/_autograd.py,sha256=
|
13
|
+
brainstate/augment/_autograd.py,sha256=hfDoa2HbkRn-InOS0yOcb6gEZ2DLNqtWA133P8-hvIo,30138
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=2wCC8aUcDp2IHgF7wr1GK5HwWfELXni5PpA-082azuU,44058
|
15
15
|
brainstate/augment/_eval_shape.py,sha256=jgsS197Nizehr9A2nGaQPE7NuNujhFhmR3J96hTicX8,3890
|
16
16
|
brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
|
17
|
-
brainstate/augment/_mapping.py,sha256=
|
18
|
-
brainstate/augment/_mapping_test.py,sha256
|
17
|
+
brainstate/augment/_mapping.py,sha256=BPwpD7jX4xRNl4BdAsKGoF45MKbmEF9Lyyp11pJucIg,43356
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=-4HJXmJw_6SD9dQnHTBjgYVuq6VTVjz0xpc9v2CJVNw,13414
|
19
19
|
brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors,5386
|
20
20
|
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
21
21
|
brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
|
@@ -30,7 +30,7 @@ brainstate/compile/_loop_collect_return.py,sha256=-LsP7fkHmAyGnDOKa3BxxYOEWe8M2J
|
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=D9RQ5RyQHkqBr4nmSK-yM_uge3EC6uVm_Dzy42g3vtg,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=2OEVtv5XztOx-e0focZ1UnWkXmFzmDskjHJXuVXmuhA,7587
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=8iV8XyvkMH3n3wbEWZAgZtbrUxryljwQJD6o5DMW9Lc,33189
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=fZe3K4RHFLmMAeXZoFZ5RyxgXvncTcuMQdjmOROJtKU,4365
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=3Z3OVcc5sl9FK9Fkt813l20MNzEfa6UZ9lJrvSgXTCU,7522
|
36
36
|
brainstate/compile/_unvmap.py,sha256=uCvQjvb8J7kT0kalX576mrAPvQuCh_W76EPdgZ53kTM,4230
|
@@ -40,7 +40,7 @@ brainstate/functional/_activations.py,sha256=VmCU9HOKWbysxuJFBN-JsShS4loNMG_E6IX
|
|
40
40
|
brainstate/functional/_activations_test.py,sha256=-bCijTvo4Wo_P283RYKYMPcTLsjhu5i2X9ySdf1ayEY,13034
|
41
41
|
brainstate/functional/_normalization.py,sha256=L3S4DIF1EztrlE4_KHX7j_m6Mw0mpAwnx5UTAX6YYBU,2603
|
42
42
|
brainstate/functional/_others.py,sha256=eBV43WqQsDvHkkwX0xbqCRoIoJlngMFLSUKgleH2dt0,1735
|
43
|
-
brainstate/functional/_spikes.py,sha256=
|
43
|
+
brainstate/functional/_spikes.py,sha256=7FTfCfEN1mjlY-EULzCisk7_NOmxZPj-mp-ODncW7R0,7087
|
44
44
|
brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
|
45
45
|
brainstate/graph/_graph_node.py,sha256=JE1Tc0mK3nJFWUFzXE53MWWiEYiXJO5VdqZEYKbXlw0,6872
|
46
46
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
@@ -54,10 +54,10 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
|
|
54
54
|
brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
|
55
55
|
brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
|
56
56
|
brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
|
57
|
-
brainstate/nn/_collective_ops.py,sha256=
|
57
|
+
brainstate/nn/_collective_ops.py,sha256=NI9BT-908TbIlXLMjbWsPyI5YLZD_cCkSKGeOY-qO60,17512
|
58
58
|
brainstate/nn/_collective_ops_test.py,sha256=yW7NNYsGFglFRFkqVlpGSY6WLnU-h8GlK6wCmG5jtRc,1189
|
59
59
|
brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
|
60
|
-
brainstate/nn/_exp_euler.py,sha256=
|
60
|
+
brainstate/nn/_exp_euler.py,sha256=s-Z_cT_oYvCvE-OaXuUidIxQs3KOy1pzkx1lwtfPo00,3529
|
61
61
|
brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
|
62
62
|
brainstate/nn/_module.py,sha256=vrukVI0ylbymzilh9BZtb-d9dnsBsykqanUNTx9Eb6Y,12844
|
63
63
|
brainstate/nn/_module_test.py,sha256=UrVA85fo0KVFN9ApPkxkRcvtXEskWOXPzZIBa4JSFo0,8891
|
@@ -66,7 +66,7 @@ brainstate/nn/metrics.py,sha256=p7eVwd5y8r0N5rMws-zOS_KaZCLOMdrXyQvLnoJeq1w,1473
|
|
66
66
|
brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
|
67
67
|
brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=mcDxVZlk56NAEkR6xcE74hOZ9up8Rua4SvKEeAhJKU4,10925
|
68
68
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=_wPp6UvWVZI9EYba-DWL_JZXyMxm0-SHDkZHI8lKp8w,6315
|
69
|
-
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=
|
69
|
+
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=xXMNlDWX0tQ9N0zORfT4DFKoEXtrsRbeetqOq-bYovs,15518
|
70
70
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=VUDMlHNcyeqHrBd1eAXg_VD0HCSg5C-eqMmcJVzYcJA,4979
|
71
71
|
brainstate/nn/_dyn_impl/_inputs.py,sha256=ubM5Z2q0gXpJ2M5Das3A5MJpFOorVomfW6-079mqJ9k,12935
|
72
72
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
@@ -102,7 +102,7 @@ brainstate/optim/_lr_scheduler.py,sha256=Ua_H3VWUt9QZ0pHtGpnq2wrqsTOUidXJ3VDz9s-
|
|
102
102
|
brainstate/optim/_lr_scheduler_test.py,sha256=W0F1eHb9S4seE468c26owcJIWTtgNhZYrOi2GrysVNI,1774
|
103
103
|
brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSMLy8hLcXE,5312
|
104
104
|
brainstate/optim/_optax_optimizer_test.py,sha256=J7zvmeSaWmTlfbpjx1ILb9cSC5qlj1wn4H2QMw9jUY0,1760
|
105
|
-
brainstate/optim/_sgd_optimizer.py,sha256=
|
105
|
+
brainstate/optim/_sgd_optimizer.py,sha256=RDQLrWsJFeEpTc93toDbf4hXnHmfY0hpTs38z1AZYPY,46144
|
106
106
|
brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
|
107
107
|
brainstate/random/_rand_funs.py,sha256=fmyaBTl_P7M63RRp5V2dOR_ttXQ0vB2qU2C8RqfkjY0,137680
|
108
108
|
brainstate/random/_rand_funs_test.py,sha256=Nhy5gXuJ2ld9u8CTCqU1V94FPm0TvYQ-oMy2bP_CZ7I,19436
|
@@ -121,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
|
|
121
121
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
122
122
|
brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
|
123
123
|
brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
|
124
|
-
brainstate-0.1.0.
|
125
|
-
brainstate-0.1.0.
|
126
|
-
brainstate-0.1.0.
|
127
|
-
brainstate-0.1.0.
|
128
|
-
brainstate-0.1.0.
|
124
|
+
brainstate-0.1.0.post20250315.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
125
|
+
brainstate-0.1.0.post20250315.dist-info/METADATA,sha256=8gOEdv6PiXBLr_gAvx70Yik7G7XxidMVPPLOLx3ndPc,3689
|
126
|
+
brainstate-0.1.0.post20250315.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
127
|
+
brainstate-0.1.0.post20250315.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
128
|
+
brainstate-0.1.0.post20250315.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250315.dist-info}/top_level.txt
RENAMED
File without changes
|