brainstate 0.1.1__py2.py3-none-any.whl → 0.1.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +3 -0
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd_test.py +132 -133
- brainstate/augment/_eval_shape_test.py +7 -9
- brainstate/augment/_mapping_test.py +75 -76
- brainstate/compile/_ad_checkpoint_test.py +6 -8
- brainstate/compile/_conditions_test.py +35 -36
- brainstate/compile/_error_if_test.py +10 -13
- brainstate/compile/_loop_collect_return_test.py +7 -9
- brainstate/compile/_loop_no_collection_test.py +7 -8
- brainstate/compile/_make_jaxpr.py +29 -14
- brainstate/compile/_make_jaxpr_test.py +20 -20
- brainstate/functional/_activations_test.py +61 -61
- brainstate/graph/_graph_node_test.py +16 -18
- brainstate/graph/_graph_operation_test.py +154 -156
- brainstate/init/_random_inits_test.py +20 -21
- brainstate/init/_regular_inits_test.py +4 -5
- brainstate/nn/_collective_ops_test.py +8 -8
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +18 -19
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +9 -10
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +6 -7
- brainstate/nn/_dyn_impl/_readout_test.py +9 -10
- brainstate/nn/_dynamics/_dynamics_base_test.py +14 -15
- brainstate/nn/_dynamics/_synouts_test.py +4 -5
- brainstate/nn/_elementwise/_dropout_test.py +9 -9
- brainstate/nn/_elementwise/_elementwise_test.py +57 -59
- brainstate/nn/_event/_fixedprob_mv_test.py +0 -1
- brainstate/nn/_event/_linear_mv_test.py +0 -1
- brainstate/nn/_exp_euler_test.py +5 -6
- brainstate/nn/_interaction/_conv_test.py +31 -33
- brainstate/nn/_interaction/_linear_test.py +15 -17
- brainstate/nn/_interaction/_normalizations_test.py +10 -12
- brainstate/nn/_interaction/_poolings_test.py +19 -21
- brainstate/nn/_module_test.py +34 -37
- brainstate/optim/_lr_scheduler_test.py +3 -3
- brainstate/optim/_optax_optimizer_test.py +8 -9
- brainstate/random/_rand_funs_test.py +183 -184
- brainstate/random/_rand_seed_test.py +10 -12
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/METADATA +1 -1
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/RECORD +44 -44
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/LICENSE +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/WHEEL +0 -0
- {brainstate-0.1.1.dist-info → brainstate-0.1.2.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -37,6 +37,7 @@ __all__ = [
|
|
37
37
|
'unzip2',
|
38
38
|
'wraps',
|
39
39
|
'Device',
|
40
|
+
'wrap_init',
|
40
41
|
]
|
41
42
|
|
42
43
|
T = TypeVar("T")
|
@@ -44,6 +45,8 @@ T1 = TypeVar("T1")
|
|
44
45
|
T2 = TypeVar("T2")
|
45
46
|
T3 = TypeVar("T3")
|
46
47
|
|
48
|
+
|
49
|
+
from saiunit._compatible_import import wrap_init
|
47
50
|
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
48
51
|
|
49
52
|
from jax.core import get_aval, Tracer
|
brainstate/_state.py
CHANGED
@@ -1049,7 +1049,7 @@ class StateTraceStack(Generic[A]):
|
|
1049
1049
|
"""
|
1050
1050
|
if self._jax_trace_new_arg is not None:
|
1051
1051
|
# internal use
|
1052
|
-
state._value = jax.tree.map(
|
1052
|
+
state._value = jax.tree.map(self._jax_trace_new_arg, state._value)
|
1053
1053
|
|
1054
1054
|
def __enter__(self) -> 'StateTraceStack':
|
1055
1055
|
TRACE_CONTEXT.state_stack.append(self)
|