brainstate 0.1.1__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +12 -9
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +31 -33
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +9 -9
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +14 -15
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/_elementwise_test.py +169 -0
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -1
- brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +15 -17
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module_test.py +34 -37
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +18 -19
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +10 -12
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +20 -22
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +6 -7
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +9 -10
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +9 -10
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +4 -5
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- brainstate-0.1.3.dist-info/RECORD +131 -0
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_elementwise/_elementwise_test.py +0 -171
- brainstate/nn/_interaction/__init__.py +0 -41
- brainstate-0.1.1.dist-info/RECORD +0 -133
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -19,7 +19,7 @@
|
|
19
19
|
import importlib.util
|
20
20
|
from contextlib import contextmanager
|
21
21
|
from functools import partial
|
22
|
-
from typing import Iterable, Hashable, TypeVar, Callable
|
22
|
+
from typing import Iterable, Hashable, TypeVar, Callable, TYPE_CHECKING
|
23
23
|
|
24
24
|
import jax
|
25
25
|
|
@@ -37,6 +37,7 @@ __all__ = [
|
|
37
37
|
'unzip2',
|
38
38
|
'wraps',
|
39
39
|
'Device',
|
40
|
+
'wrap_init',
|
40
41
|
]
|
41
42
|
|
42
43
|
T = TypeVar("T")
|
@@ -44,6 +45,8 @@ T1 = TypeVar("T1")
|
|
44
45
|
T2 = TypeVar("T2")
|
45
46
|
T3 = TypeVar("T3")
|
46
47
|
|
48
|
+
from saiunit._compatible_import import wrap_init
|
49
|
+
|
47
50
|
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
48
51
|
|
49
52
|
from jax.core import get_aval, Tracer
|
@@ -148,13 +151,13 @@ def to_concrete_aval(aval):
|
|
148
151
|
return aval
|
149
152
|
|
150
153
|
|
151
|
-
if brainevent_installed:
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
def __getattr__(self, item):
|
157
|
-
raise ImportError('brainevent is not installed, please install brainevent first.')
|
154
|
+
if not brainevent_installed:
|
155
|
+
if not TYPE_CHECKING:
|
156
|
+
class BrainEvent:
|
157
|
+
def __getattr__(self, item):
|
158
|
+
raise ImportError('brainevent is not installed, please install brainevent first.')
|
158
159
|
|
160
|
+
brainevent = BrainEvent()
|
159
161
|
|
160
|
-
|
162
|
+
else:
|
163
|
+
import brainevent
|
brainstate/_state.py
CHANGED
@@ -1049,7 +1049,7 @@ class StateTraceStack(Generic[A]):
|
|
1049
1049
|
"""
|
1050
1050
|
if self._jax_trace_new_arg is not None:
|
1051
1051
|
# internal use
|
1052
|
-
state._value = jax.tree.map(
|
1052
|
+
state._value = jax.tree.map(self._jax_trace_new_arg, state._value)
|
1053
1053
|
|
1054
1054
|
def __enter__(self) -> 'StateTraceStack':
|
1055
1055
|
TRACE_CONTEXT.state_stack.append(self)
|