brainstate 0.1.0.post20250218__py2.py3-none-any.whl → 0.1.0.post20250222__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/augment/_autograd.py +13 -3
- brainstate/augment/_mapping.py +17 -6
- brainstate/augment/_mapping_test.py +12 -0
- brainstate/compile/_make_jaxpr.py +2 -2
- brainstate/functional/_spikes.py +1 -1
- brainstate/optim/_sgd_optimizer.py +52 -21
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/RECORD +11 -11
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/top_level.txt +0 -0
brainstate/augment/_autograd.py
CHANGED
@@ -29,11 +29,12 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
29
29
|
|
30
30
|
from __future__ import annotations
|
31
31
|
|
32
|
-
import brainunit as u
|
33
|
-
import jax
|
34
32
|
from functools import wraps, partial
|
35
33
|
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
36
34
|
|
35
|
+
import brainunit as u
|
36
|
+
import jax
|
37
|
+
|
37
38
|
from brainstate._state import State
|
38
39
|
from brainstate._utils import set_module_as
|
39
40
|
from brainstate.compile._make_jaxpr import StatefulFunction
|
@@ -195,6 +196,7 @@ class GradientTransform(PrettyRepr):
|
|
195
196
|
grad_states = {k: v for k, v in grad_states.items()}
|
196
197
|
self._grad_states, self._grad_tree = jax.tree.flatten(grad_states)
|
197
198
|
self._grad_state_ids = [id(v) for v in self._grad_states]
|
199
|
+
self._grad_id_to_state = {id(v): v for v in self._grad_states}
|
198
200
|
if any(not isinstance(v, State) for v in self._grad_states):
|
199
201
|
raise TypeError("All grad_states must be State instances.")
|
200
202
|
|
@@ -250,12 +252,20 @@ class GradientTransform(PrettyRepr):
|
|
250
252
|
"""
|
251
253
|
grad_vals = dict()
|
252
254
|
other_vals = dict()
|
255
|
+
all_ids = set(self._grad_state_ids)
|
253
256
|
for st in state_trace.states:
|
254
257
|
id_ = id(st)
|
255
|
-
if id_ in
|
258
|
+
if id_ in all_ids:
|
256
259
|
grad_vals[id_] = st.value
|
260
|
+
all_ids.remove(id_)
|
257
261
|
else:
|
258
262
|
other_vals[id_] = st.value
|
263
|
+
if len(all_ids):
|
264
|
+
err = f"Some states are not found in the state trace when performing gradient transformations.\n "
|
265
|
+
for i, id_ in enumerate(all_ids):
|
266
|
+
st = self._grad_id_to_state[id_]
|
267
|
+
st.raise_error_with_source_info(ValueError(err + str(st)))
|
268
|
+
|
259
269
|
return grad_vals, other_vals
|
260
270
|
|
261
271
|
def _merge_state_vals(self, grad_vals: Dict, other_vals: Dict, state_trace):
|
brainstate/augment/_mapping.py
CHANGED
@@ -16,8 +16,6 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import functools
|
19
|
-
import jax
|
20
|
-
from jax.interpreters.batching import BatchTracer
|
21
19
|
from typing import (
|
22
20
|
Any,
|
23
21
|
TypeVar,
|
@@ -32,6 +30,9 @@ from typing import (
|
|
32
30
|
List
|
33
31
|
)
|
34
32
|
|
33
|
+
import jax
|
34
|
+
from jax.interpreters.batching import BatchTracer
|
35
|
+
|
35
36
|
from brainstate._state import State, catch_new_states
|
36
37
|
from brainstate.compile import scan, StatefulFunction
|
37
38
|
from brainstate.random import RandomState, DEFAULT
|
@@ -378,8 +379,15 @@ def _vmap_transform(
|
|
378
379
|
# call the function
|
379
380
|
return f(*args)
|
380
381
|
|
382
|
+
def _set_axis_env(batch_size):
|
383
|
+
axis_env = None if axis_name is None else [(axis_name, batch_size)]
|
384
|
+
stateful_fn.axis_env = axis_env
|
385
|
+
|
381
386
|
# stateful function
|
382
|
-
stateful_fn = StatefulFunction(
|
387
|
+
stateful_fn = StatefulFunction(
|
388
|
+
_vmap_fn_for_compilation,
|
389
|
+
name='vmap',
|
390
|
+
)
|
383
391
|
|
384
392
|
@functools.wraps(f)
|
385
393
|
def new_fn_for_vmap(
|
@@ -506,6 +514,10 @@ def _vmap_transform(
|
|
506
514
|
st_in_axes = 0
|
507
515
|
|
508
516
|
# compile stateful function
|
517
|
+
batch_size = None
|
518
|
+
if axis_name is not None:
|
519
|
+
batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
|
520
|
+
_set_axis_env(batch_size)
|
509
521
|
cache_key = _compile_stateful_function(
|
510
522
|
stateful_fn,
|
511
523
|
(st_in_axes, in_axes),
|
@@ -518,7 +530,8 @@ def _vmap_transform(
|
|
518
530
|
rng_sets = set(rngs)
|
519
531
|
if len(rngs):
|
520
532
|
# batch size
|
521
|
-
batch_size
|
533
|
+
if batch_size is None:
|
534
|
+
batch_size = _get_batch_size(args, in_axes, axis_to_in_states, axis_size)
|
522
535
|
rng_keys = tuple(rng.split_key(batch_size) for rng in rngs)
|
523
536
|
rng_backup = tuple(rng.split_key() for rng in rngs)
|
524
537
|
else:
|
@@ -936,10 +949,8 @@ def _vmap_new_states_transform(
|
|
936
949
|
state_tag: str | None = None,
|
937
950
|
state_to_exclude: Filter | None = None,
|
938
951
|
):
|
939
|
-
|
940
952
|
# TODO: How about nested call ``vmap_new_states``?
|
941
953
|
|
942
|
-
|
943
954
|
@vmap(
|
944
955
|
in_axes=in_axes,
|
945
956
|
out_axes=out_axes,
|
@@ -21,6 +21,7 @@ import numpy as np
|
|
21
21
|
import unittest
|
22
22
|
|
23
23
|
import brainstate as bst
|
24
|
+
import brainstate.augment
|
24
25
|
from brainstate.augment._mapping import BatchAxisError
|
25
26
|
from brainstate.augment._mapping import _remove_axis
|
26
27
|
|
@@ -264,6 +265,17 @@ class TestVmap(unittest.TestCase):
|
|
264
265
|
res2 = jax.vmap(f, axis_size=10)()
|
265
266
|
self.assertTrue(jnp.all((res2[0] == res2[1:])))
|
266
267
|
|
268
|
+
def test_axis(self):
|
269
|
+
def f(x):
|
270
|
+
return x - jax.lax.pmean(x, 'i')
|
271
|
+
r = jax.vmap(f, axis_name='i')(jnp.arange(10))
|
272
|
+
print(r)
|
273
|
+
|
274
|
+
r2 = brainstate.augment.vmap(f, axis_name='i')(jnp.arange(10))
|
275
|
+
print(r2)
|
276
|
+
self.assertTrue(jnp.allclose(r, r2))
|
277
|
+
|
278
|
+
|
267
279
|
|
268
280
|
class TestMap(unittest.TestCase):
|
269
281
|
def test_map(self):
|
@@ -753,8 +753,8 @@ def _make_jaxpr(
|
|
753
753
|
if jax.__version_info__ < (0, 5, 0):
|
754
754
|
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
755
755
|
with ExitStack() as stack:
|
756
|
-
|
757
|
-
stack.enter_context(jax.core.
|
756
|
+
if axis_env is not None:
|
757
|
+
stack.enter_context(jax.core.extend_axis_env_nd(axis_env))
|
758
758
|
if jax.__version_info__ < (0, 5, 0):
|
759
759
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
760
760
|
else:
|
brainstate/functional/_spikes.py
CHANGED
@@ -201,9 +201,11 @@ class SGD(_WeightDecayOptimizer):
|
|
201
201
|
def update(self, grads: dict):
|
202
202
|
lr = self.lr()
|
203
203
|
weight_values, grad_values = to_same_dict_tree(self.param_states, grads)
|
204
|
-
updates = jax.tree.map(
|
205
|
-
|
206
|
-
|
204
|
+
updates = jax.tree.map(
|
205
|
+
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
206
|
+
weight_values,
|
207
|
+
grad_values
|
208
|
+
)
|
207
209
|
self.param_states.assign_values(updates)
|
208
210
|
self.lr.step_call()
|
209
211
|
|
@@ -324,12 +326,16 @@ class MomentumNesterov(_WeightDecayOptimizer):
|
|
324
326
|
def update(self, grads: dict):
|
325
327
|
lr = self.lr()
|
326
328
|
states_values, grad_values, momentum_values = to_same_dict_tree(self.param_states, grads, self.momentum_states)
|
327
|
-
momentum_values = jax.tree.map(
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
329
|
+
momentum_values = jax.tree.map(
|
330
|
+
lambda mv, gv: self.momentum * mv - lr * gv,
|
331
|
+
momentum_values,
|
332
|
+
grad_values
|
333
|
+
)
|
334
|
+
weight_values = jax.tree.map(
|
335
|
+
functools.partial(_sgd, lr=lr, weight_decay=self.weight_decay),
|
336
|
+
states_values,
|
337
|
+
momentum_values
|
338
|
+
)
|
333
339
|
self.param_states.assign_values(weight_values)
|
334
340
|
self.momentum_states.assign_values(momentum_values)
|
335
341
|
self.lr.step_call()
|
@@ -388,11 +394,21 @@ class Adagrad(_WeightDecayOptimizer):
|
|
388
394
|
def update(self, grads: dict):
|
389
395
|
lr = self.lr()
|
390
396
|
cache_values, grad_values, weight_values = to_same_dict_tree(self.cache_states, grads, self.param_states)
|
391
|
-
cache_values = jax.tree.map(
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
397
|
+
cache_values = jax.tree.map(
|
398
|
+
lambda cv, gv: cv + gv ** 2,
|
399
|
+
cache_values,
|
400
|
+
grad_values
|
401
|
+
)
|
402
|
+
updates = jax.tree.map(
|
403
|
+
lambda cv, gv: lr * gv / jnp.sqrt(cv + self.epsilon),
|
404
|
+
cache_values,
|
405
|
+
grad_values
|
406
|
+
)
|
407
|
+
weight_values = jax.tree.map(
|
408
|
+
functools.partial(_sgd, weight_decay=self.weight_decay),
|
409
|
+
weight_values,
|
410
|
+
updates
|
411
|
+
)
|
396
412
|
self.cache_states.assign_values(cache_values)
|
397
413
|
self.param_states.assign_values(weight_values)
|
398
414
|
self.lr.step_call()
|
@@ -605,13 +621,28 @@ class Adam(_WeightDecayOptimizer):
|
|
605
621
|
lr = lr / (1 - self.beta1 ** (self.lr.last_epoch.value + 2))
|
606
622
|
lr = lr * jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2))
|
607
623
|
weight_values, grad_values, m1_values, m2_values = to_same_dict_tree(
|
608
|
-
self.param_states, grads, self.m1_states, self.m2_states
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
624
|
+
self.param_states, grads, self.m1_states, self.m2_states
|
625
|
+
)
|
626
|
+
m1_values = jax.tree.map(
|
627
|
+
lambda m1, gv: self.beta1 * m1 + (1 - self.beta1) * gv,
|
628
|
+
m1_values,
|
629
|
+
grad_values
|
630
|
+
)
|
631
|
+
m2_values = jax.tree.map(
|
632
|
+
lambda m2, gv: self.beta2 * m2 + (1 - self.beta2) * gv ** 2,
|
633
|
+
m2_values,
|
634
|
+
grad_values
|
635
|
+
)
|
636
|
+
update = jax.tree.map(
|
637
|
+
lambda m1, m2: lr * m1 / (jnp.sqrt(m2) + self.eps),
|
638
|
+
m1_values,
|
639
|
+
m2_values
|
640
|
+
)
|
641
|
+
weight_values = jax.tree.map(
|
642
|
+
functools.partial(_sgd, weight_decay=self.weight_decay),
|
643
|
+
weight_values,
|
644
|
+
update
|
645
|
+
)
|
615
646
|
self.param_states.assign_values(weight_values)
|
616
647
|
self.m1_states.assign_values(m1_values)
|
617
648
|
self.m2_states.assign_values(m2_values)
|
{brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250222
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -10,12 +10,12 @@ brainstate/surrogate.py,sha256=wWYw-TxaFxHVneXuHjWD1UtTcOTk3XRSnhRtUkt_Hb8,53580
|
|
10
10
|
brainstate/transform.py,sha256=vZWzO4F7qTsXL_SiVQPlTz0l9b_hRo9D-igETfgCTy0,758
|
11
11
|
brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
|
12
12
|
brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
|
13
|
-
brainstate/augment/_autograd.py,sha256=
|
13
|
+
brainstate/augment/_autograd.py,sha256=hfDoa2HbkRn-InOS0yOcb6gEZ2DLNqtWA133P8-hvIo,30138
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=2wCC8aUcDp2IHgF7wr1GK5HwWfELXni5PpA-082azuU,44058
|
15
15
|
brainstate/augment/_eval_shape.py,sha256=jgsS197Nizehr9A2nGaQPE7NuNujhFhmR3J96hTicX8,3890
|
16
16
|
brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_gLqMZsJTw,1393
|
17
|
-
brainstate/augment/_mapping.py,sha256=
|
18
|
-
brainstate/augment/_mapping_test.py,sha256=
|
17
|
+
brainstate/augment/_mapping.py,sha256=X6C7IcYYXqDkYlWRXE_pdZIcCXvnltyNdESBoEm2fms,42940
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=M7pcXnRwBe2LyJh7VsWROGLR3N2bRfMUU1wj0ivbjNQ,12368
|
19
19
|
brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors,5386
|
20
20
|
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
21
21
|
brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
|
@@ -30,7 +30,7 @@ brainstate/compile/_loop_collect_return.py,sha256=-LsP7fkHmAyGnDOKa3BxxYOEWe8M2J
|
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=D9RQ5RyQHkqBr4nmSK-yM_uge3EC6uVm_Dzy42g3vtg,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=2OEVtv5XztOx-e0focZ1UnWkXmFzmDskjHJXuVXmuhA,7587
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=8iV8XyvkMH3n3wbEWZAgZtbrUxryljwQJD6o5DMW9Lc,33189
|
34
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=fZe3K4RHFLmMAeXZoFZ5RyxgXvncTcuMQdjmOROJtKU,4365
|
35
35
|
brainstate/compile/_progress_bar.py,sha256=3Z3OVcc5sl9FK9Fkt813l20MNzEfa6UZ9lJrvSgXTCU,7522
|
36
36
|
brainstate/compile/_unvmap.py,sha256=uCvQjvb8J7kT0kalX576mrAPvQuCh_W76EPdgZ53kTM,4230
|
@@ -40,7 +40,7 @@ brainstate/functional/_activations.py,sha256=VmCU9HOKWbysxuJFBN-JsShS4loNMG_E6IX
|
|
40
40
|
brainstate/functional/_activations_test.py,sha256=-bCijTvo4Wo_P283RYKYMPcTLsjhu5i2X9ySdf1ayEY,13034
|
41
41
|
brainstate/functional/_normalization.py,sha256=L3S4DIF1EztrlE4_KHX7j_m6Mw0mpAwnx5UTAX6YYBU,2603
|
42
42
|
brainstate/functional/_others.py,sha256=eBV43WqQsDvHkkwX0xbqCRoIoJlngMFLSUKgleH2dt0,1735
|
43
|
-
brainstate/functional/_spikes.py,sha256=
|
43
|
+
brainstate/functional/_spikes.py,sha256=7FTfCfEN1mjlY-EULzCisk7_NOmxZPj-mp-ODncW7R0,7087
|
44
44
|
brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
|
45
45
|
brainstate/graph/_graph_node.py,sha256=JE1Tc0mK3nJFWUFzXE53MWWiEYiXJO5VdqZEYKbXlw0,6872
|
46
46
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
@@ -102,7 +102,7 @@ brainstate/optim/_lr_scheduler.py,sha256=Ua_H3VWUt9QZ0pHtGpnq2wrqsTOUidXJ3VDz9s-
|
|
102
102
|
brainstate/optim/_lr_scheduler_test.py,sha256=W0F1eHb9S4seE468c26owcJIWTtgNhZYrOi2GrysVNI,1774
|
103
103
|
brainstate/optim/_optax_optimizer.py,sha256=SuXV_xUBfhOw1_C2J5TIpy3dXDtI9VJFaSMLy8hLcXE,5312
|
104
104
|
brainstate/optim/_optax_optimizer_test.py,sha256=J7zvmeSaWmTlfbpjx1ILb9cSC5qlj1wn4H2QMw9jUY0,1760
|
105
|
-
brainstate/optim/_sgd_optimizer.py,sha256=
|
105
|
+
brainstate/optim/_sgd_optimizer.py,sha256=RDQLrWsJFeEpTc93toDbf4hXnHmfY0hpTs38z1AZYPY,46144
|
106
106
|
brainstate/random/__init__.py,sha256=c5q-RC3grRIjx-HBb2IhKZpi_xzbFmUUxzRAzqfREic,1045
|
107
107
|
brainstate/random/_rand_funs.py,sha256=fmyaBTl_P7M63RRp5V2dOR_ttXQ0vB2qU2C8RqfkjY0,137680
|
108
108
|
brainstate/random/_rand_funs_test.py,sha256=Nhy5gXuJ2ld9u8CTCqU1V94FPm0TvYQ-oMy2bP_CZ7I,19436
|
@@ -121,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
|
|
121
121
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
122
122
|
brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
|
123
123
|
brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
|
124
|
-
brainstate-0.1.0.
|
125
|
-
brainstate-0.1.0.
|
126
|
-
brainstate-0.1.0.
|
127
|
-
brainstate-0.1.0.
|
128
|
-
brainstate-0.1.0.
|
124
|
+
brainstate-0.1.0.post20250222.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
125
|
+
brainstate-0.1.0.post20250222.dist-info/METADATA,sha256=Rzs6V20AVpOxQhtd8EJLidBkF7UJriVF4g2zkr0S92M,3585
|
126
|
+
brainstate-0.1.0.post20250222.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
127
|
+
brainstate-0.1.0.post20250222.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
128
|
+
brainstate-0.1.0.post20250222.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250218.dist-info → brainstate-0.1.0.post20250222.dist-info}/top_level.txt
RENAMED
File without changes
|