brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +875 -93
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +183 -35
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +128 -10
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Union, Optional, Sequence, Callable
|
18
|
-
|
19
17
|
import brainunit as u
|
20
18
|
import jax
|
21
19
|
import numpy as np
|
20
|
+
from typing import Union, Optional, Sequence, Callable
|
22
21
|
|
23
22
|
from brainstate import environ, init, random
|
24
23
|
from brainstate._state import ShortTermState
|
@@ -17,9 +17,8 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import Callable, Union
|
21
|
-
|
22
20
|
import jax.numpy as jnp
|
21
|
+
from typing import Callable, Union
|
23
22
|
|
24
23
|
from brainstate import random, init, functional
|
25
24
|
from brainstate._state import HiddenState, ParamState
|
@@ -17,11 +17,10 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import numbers
|
21
|
-
from typing import Callable
|
22
|
-
|
23
20
|
import brainunit as u
|
24
21
|
import jax
|
22
|
+
import numbers
|
23
|
+
from typing import Callable
|
25
24
|
|
26
25
|
from brainstate import environ, init, surrogate
|
27
26
|
from brainstate._state import HiddenState, ParamState
|
@@ -34,10 +34,9 @@ For handling the delays:
|
|
34
34
|
"""
|
35
35
|
from __future__ import annotations
|
36
36
|
|
37
|
-
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
38
|
-
|
39
37
|
import brainunit as u
|
40
38
|
import numpy as np
|
39
|
+
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
41
40
|
|
42
41
|
from brainstate import environ
|
43
42
|
from brainstate._state import State
|
@@ -420,7 +419,7 @@ class Dynamics(Module):
|
|
420
419
|
else:
|
421
420
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
422
421
|
|
423
|
-
def
|
422
|
+
def __pretty_repr_item__(self, name, value):
|
424
423
|
if name in ['_in_size', '_out_size', '_name', '_mode',
|
425
424
|
'_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
|
426
425
|
return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
@@ -16,14 +16,14 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import math
|
19
|
-
import numbers
|
20
|
-
from functools import partial
|
21
|
-
from typing import Optional, Dict, Callable, Union, Sequence
|
22
19
|
|
23
20
|
import brainunit as u
|
24
21
|
import jax
|
25
22
|
import jax.numpy as jnp
|
23
|
+
import numbers
|
26
24
|
import numpy as np
|
25
|
+
from functools import partial
|
26
|
+
from typing import Optional, Dict, Callable, Union, Sequence
|
27
27
|
|
28
28
|
from brainstate import environ
|
29
29
|
from brainstate._state import ShortTermState, State
|
@@ -16,11 +16,10 @@
|
|
16
16
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from functools import partial
|
20
|
-
from typing import Optional, Sequence
|
21
|
-
|
22
19
|
import brainunit as u
|
23
20
|
import jax.numpy as jnp
|
21
|
+
from functools import partial
|
22
|
+
from typing import Optional, Sequence
|
24
23
|
|
25
24
|
from brainstate import random, environ, init
|
26
25
|
from brainstate._state import ShortTermState
|
@@ -72,11 +71,11 @@ class Dropout(ElementWiseBlock):
|
|
72
71
|
for dim in self.broadcast_dims:
|
73
72
|
broadcast_shape[dim] = 1
|
74
73
|
keep_mask = random.bernoulli(self.prob, broadcast_shape)
|
75
|
-
keep_mask =
|
76
|
-
return
|
74
|
+
keep_mask = u.math.broadcast_to(keep_mask, x.shape)
|
75
|
+
return u.math.where(
|
77
76
|
keep_mask,
|
78
|
-
|
79
|
-
|
77
|
+
u.math.asarray(x / self.prob, dtype=dtype),
|
78
|
+
u.math.asarray(0., dtype=dtype)
|
80
79
|
)
|
81
80
|
else:
|
82
81
|
return x
|
@@ -17,11 +17,10 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import Optional
|
21
|
-
|
22
20
|
import brainunit as u
|
23
21
|
import jax.numpy as jnp
|
24
22
|
import jax.typing
|
23
|
+
from typing import Optional
|
25
24
|
|
26
25
|
from brainstate import random, functional as F
|
27
26
|
from brainstate._state import ParamState
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -16,10 +16,9 @@
|
|
16
16
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from typing import Callable
|
20
|
-
|
21
19
|
import brainunit as u
|
22
20
|
import jax.numpy as jnp
|
21
|
+
from typing import Callable
|
23
22
|
|
24
23
|
from brainstate import environ, random
|
25
24
|
from brainstate.augment import vector_grad
|
brainstate/nn/_exp_euler_test.py
CHANGED
@@ -18,10 +18,9 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import collections.abc
|
21
|
-
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
22
|
-
|
23
21
|
import jax
|
24
22
|
import jax.numpy as jnp
|
23
|
+
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
25
24
|
|
26
25
|
from brainstate import init, functional
|
27
26
|
from brainstate._state import ParamState
|
@@ -17,10 +17,9 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import Callable, Union, Optional
|
21
|
-
|
22
20
|
import brainunit as u
|
23
21
|
import jax.numpy as jnp
|
22
|
+
from typing import Callable, Union, Optional
|
24
23
|
|
25
24
|
from brainstate import init, functional
|
26
25
|
from brainstate._state import ParamState
|
@@ -17,10 +17,9 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import Callable, Union, Sequence, Optional, Any, Dict
|
21
|
-
|
22
20
|
import jax
|
23
21
|
import jax.numpy as jnp
|
22
|
+
from typing import Callable, Union, Sequence, Optional, Any
|
24
23
|
|
25
24
|
from brainstate import environ, init
|
26
25
|
from brainstate._state import ParamState, BatchState
|
@@ -17,14 +17,13 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import functools
|
21
|
-
from typing import Sequence, Optional
|
22
|
-
from typing import Union, Tuple, Callable, List
|
23
|
-
|
24
20
|
import brainunit as u
|
21
|
+
import functools
|
25
22
|
import jax
|
26
23
|
import jax.numpy as jnp
|
27
24
|
import numpy as np
|
25
|
+
from typing import Sequence, Optional
|
26
|
+
from typing import Union, Tuple, Callable, List
|
28
27
|
|
29
28
|
from brainstate import environ
|
30
29
|
from brainstate.nn._module import Module
|
brainstate/nn/_module.py
CHANGED
@@ -27,16 +27,15 @@ The basic classes include:
|
|
27
27
|
"""
|
28
28
|
from __future__ import annotations
|
29
29
|
|
30
|
-
import warnings
|
31
|
-
from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING
|
32
|
-
|
33
30
|
import numpy as np
|
31
|
+
import warnings
|
32
|
+
from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
|
34
33
|
|
35
34
|
from brainstate._state import State
|
36
35
|
from brainstate.graph import Node, states, nodes, flatten
|
37
36
|
from brainstate.mixin import ParamDescriber, ParamDesc
|
38
37
|
from brainstate.typing import PathParts
|
39
|
-
from brainstate.util import FlattedDict, NestedDict
|
38
|
+
from brainstate.util import FlattedDict, NestedDict, BrainStateError
|
40
39
|
|
41
40
|
# maximum integer
|
42
41
|
max_int = np.iinfo(np.int32).max
|
@@ -226,7 +225,7 @@ class Module(Node, ParamDesc):
|
|
226
225
|
"""
|
227
226
|
pass
|
228
227
|
|
229
|
-
def
|
228
|
+
def __pretty_repr_item__(self, name, value):
|
230
229
|
if name in ['_in_size', '_out_size', '_name']:
|
231
230
|
return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
232
231
|
return name, value
|
@@ -288,7 +287,7 @@ class Sequential(Module):
|
|
288
287
|
in_size = first.out_size
|
289
288
|
self.layers.append(first)
|
290
289
|
for module in layers:
|
291
|
-
module, in_size = _format_module(module, in_size)
|
290
|
+
module, in_size = self._format_module(module, in_size)
|
292
291
|
self.layers.append(module)
|
293
292
|
|
294
293
|
# the input and output shape
|
@@ -301,7 +300,14 @@ class Sequential(Module):
|
|
301
300
|
"""Update function of a sequential model.
|
302
301
|
"""
|
303
302
|
for m in self.layers:
|
304
|
-
|
303
|
+
try:
|
304
|
+
x = m(x)
|
305
|
+
except Exception as e:
|
306
|
+
raise BrainStateError(
|
307
|
+
f'The module \n'
|
308
|
+
f'{m}\n'
|
309
|
+
f'failed to update with input {x}\n'
|
310
|
+
) from e
|
305
311
|
return x
|
306
312
|
|
307
313
|
def __getitem__(self, key: Union[int, slice]):
|
@@ -314,16 +320,54 @@ class Sequential(Module):
|
|
314
320
|
else:
|
315
321
|
raise KeyError(f'Unknown type of key: {type(key)}')
|
316
322
|
|
323
|
+
def append(self, layer: Callable):
|
324
|
+
"""
|
325
|
+
Append a layer to the sequential model.
|
326
|
+
|
327
|
+
This method adds a new layer to the end of the sequential model. The layer can be
|
328
|
+
either a Module instance, an ElementWiseBlock instance, or a callable function. If the
|
329
|
+
layer is a callable function, it will be wrapped in an ElementWiseBlock instance.
|
330
|
+
|
331
|
+
Parameters:
|
332
|
+
----------
|
333
|
+
layer : Callable
|
334
|
+
The layer to be appended to the sequential model. It can be a Module instance,
|
335
|
+
an ElementWiseBlock instance, or a callable function.
|
336
|
+
|
337
|
+
Raises:
|
338
|
+
-------
|
339
|
+
ValueError
|
340
|
+
If the sequential model is empty and the first layer is a callable function.
|
317
341
|
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
342
|
+
Returns:
|
343
|
+
--------
|
344
|
+
None
|
345
|
+
The method does not return any value. It modifies the sequential model by adding
|
346
|
+
the new layer to the end.
|
347
|
+
"""
|
348
|
+
if len(self.layers) == 0:
|
349
|
+
raise ValueError('The first layer should be a module, not a function.')
|
350
|
+
module, in_size = self._format_module(layer, self.out_size)
|
351
|
+
self.layers.append(module)
|
352
|
+
self.out_size = in_size
|
353
|
+
|
354
|
+
def _format_module(self, module, in_size):
|
355
|
+
if isinstance(module, ParamDescriber):
|
356
|
+
if in_size is None:
|
357
|
+
raise ValueError(
|
358
|
+
'The input size should be specified. '
|
359
|
+
f'Please set the in_size attribute of the previous module: \n'
|
360
|
+
f'{self.layers[-1]}'
|
361
|
+
)
|
362
|
+
module = module(in_size=in_size)
|
363
|
+
assert isinstance(module, Module), 'The module should be an instance of Module.'
|
364
|
+
out_size = module.out_size
|
365
|
+
elif isinstance(module, ElementWiseBlock):
|
366
|
+
out_size = in_size
|
367
|
+
elif isinstance(module, Module):
|
368
|
+
out_size = module.out_size
|
369
|
+
elif callable(module):
|
370
|
+
out_size = in_size
|
371
|
+
else:
|
372
|
+
raise TypeError(f"Unsupported type {type(module)}. ")
|
373
|
+
return module, out_size
|
brainstate/nn/_module_test.py
CHANGED
brainstate/nn/metrics.py
CHANGED
@@ -16,13 +16,12 @@
|
|
16
16
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import typing as tp
|
20
|
-
from dataclasses import dataclass
|
21
|
-
from functools import partial
|
22
|
-
|
23
19
|
import jax
|
24
20
|
import jax.numpy as jnp
|
25
21
|
import numpy as np
|
22
|
+
import typing as tp
|
23
|
+
from dataclasses import dataclass
|
24
|
+
from functools import partial
|
26
25
|
|
27
26
|
from brainstate._state import State
|
28
27
|
|
@@ -16,11 +16,10 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from typing import Sequence, Union
|
20
|
-
|
21
19
|
import jax
|
22
20
|
import jax.numpy as jnp
|
23
21
|
import numpy as np
|
22
|
+
from typing import Sequence, Union
|
24
23
|
|
25
24
|
from brainstate import environ
|
26
25
|
from brainstate._state import State, LongTermState
|
@@ -15,9 +15,8 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import unittest
|
19
|
-
|
20
18
|
import jax.numpy as jnp
|
19
|
+
import unittest
|
21
20
|
|
22
21
|
import brainstate as bst
|
23
22
|
|
@@ -37,7 +36,7 @@ class TestMultiStepLR(unittest.TestCase):
|
|
37
36
|
self.assertTrue(jnp.allclose(r, 0.0001))
|
38
37
|
|
39
38
|
def test2(self):
|
40
|
-
lr = bst.
|
39
|
+
lr = bst.compile.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
|
41
40
|
for i in range(40):
|
42
41
|
r = lr(i)
|
43
42
|
if i < 10:
|
@@ -16,12 +16,11 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import functools
|
20
|
-
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
21
|
-
|
22
19
|
import brainunit as u
|
20
|
+
import functools
|
23
21
|
import jax
|
24
22
|
import jax.numpy as jnp
|
23
|
+
from typing import Union, Dict, Optional, Tuple, Any, TypeVar
|
25
24
|
|
26
25
|
from brainstate import environ
|
27
26
|
from brainstate._state import State, LongTermState, StateDictManager
|
brainstate/random/_rand_funs.py
CHANGED
@@ -17,9 +17,8 @@
|
|
17
17
|
# -*- coding: utf-8 -*-
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import Optional
|
21
|
-
|
22
20
|
import numpy as np
|
21
|
+
from typing import Optional
|
23
22
|
|
24
23
|
from brainstate.typing import DTypeLike, Size, SeedOrKey
|
25
24
|
from ._rand_state import RandomState, DEFAULT
|
@@ -15,13 +15,12 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import platform
|
19
|
-
import unittest
|
20
|
-
|
21
18
|
import jax.numpy as jnp
|
22
19
|
import jax.random as jr
|
23
20
|
import numpy as np
|
21
|
+
import platform
|
24
22
|
import pytest
|
23
|
+
import unittest
|
25
24
|
|
26
25
|
import brainstate as bst
|
27
26
|
|
brainstate/random/_rand_seed.py
CHANGED
@@ -14,11 +14,10 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from contextlib import contextmanager
|
18
|
-
from typing import Optional
|
19
|
-
|
20
17
|
import jax
|
21
18
|
import numpy as np
|
19
|
+
from contextlib import contextmanager
|
20
|
+
from typing import Optional
|
22
21
|
|
23
22
|
from brainstate.typing import SeedOrKey
|
24
23
|
from ._rand_state import RandomState, DEFAULT, use_prng_key
|
brainstate/random/_rand_state.py
CHANGED
@@ -16,17 +16,16 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from functools import partial
|
20
|
-
from operator import index
|
21
|
-
from typing import Optional
|
22
|
-
|
23
19
|
import brainunit as u
|
24
20
|
import jax
|
25
21
|
import jax.numpy as jnp
|
26
22
|
import jax.random as jr
|
27
23
|
import numpy as np
|
24
|
+
from functools import partial
|
28
25
|
from jax import jit, vmap
|
29
26
|
from jax import lax, core, dtypes
|
27
|
+
from operator import index
|
28
|
+
from typing import Optional
|
30
29
|
|
31
30
|
from brainstate import environ
|
32
31
|
from brainstate._state import State
|