brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,8 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
17
|
-
from . import
|
16
|
+
from . import init
|
17
|
+
from ._activations import *
|
18
|
+
from ._activations import __all__ as activation_all
|
18
19
|
from ._collective_ops import *
|
19
20
|
from ._collective_ops import __all__ as collective_ops_all
|
20
21
|
from ._common import *
|
@@ -31,105 +32,106 @@ from ._elementwise import *
|
|
31
32
|
from ._elementwise import __all__ as elementwise_all
|
32
33
|
from ._embedding import *
|
33
34
|
from ._embedding import __all__ as embed_all
|
35
|
+
from ._event_fixedprob import *
|
36
|
+
from ._event_fixedprob import __all__ as fixedprob_all
|
37
|
+
from ._event_linear import *
|
38
|
+
from ._event_linear import __all__ as linear_mv_all
|
34
39
|
from ._exp_euler import *
|
35
40
|
from ._exp_euler import __all__ as exp_euler_all
|
36
|
-
from ._fixedprob import *
|
37
|
-
from ._fixedprob import __all__ as fixedprob_all
|
38
|
-
from ._inputs import *
|
39
|
-
from ._inputs import __all__ as inputs_all
|
40
41
|
from ._linear import *
|
41
42
|
from ._linear import __all__ as linear_all
|
42
|
-
from .
|
43
|
-
from .
|
44
|
-
from ._ltp import *
|
45
|
-
from ._ltp import __all__ as ltp_all
|
43
|
+
from ._metrics import *
|
44
|
+
from ._metrics import __all__ as metrics_all
|
46
45
|
from ._module import *
|
47
46
|
from ._module import __all__ as module_all
|
48
|
-
from ._neuron import *
|
49
|
-
from ._neuron import __all__ as dyn_neuron_all
|
50
47
|
from ._normalizations import *
|
51
48
|
from ._normalizations import __all__ as normalizations_all
|
52
|
-
from .
|
53
|
-
from .
|
49
|
+
from ._paddings import *
|
50
|
+
from ._paddings import __all__ as paddings_all
|
54
51
|
from ._poolings import *
|
55
52
|
from ._poolings import __all__ as poolings_all
|
56
|
-
from .
|
57
|
-
from .
|
58
|
-
from ._rate_rnns import *
|
59
|
-
from ._rate_rnns import __all__ as rate_rnns
|
60
|
-
from ._readout import *
|
61
|
-
from ._readout import __all__ as readout_all
|
62
|
-
from ._stp import *
|
63
|
-
from ._stp import __all__ as stp_all
|
64
|
-
from ._synapse import *
|
65
|
-
from ._synapse import __all__ as dyn_synapse_all
|
66
|
-
from ._synaptic_projection import *
|
67
|
-
from ._synaptic_projection import __all__ as _syn_proj_all
|
68
|
-
from ._synouts import *
|
69
|
-
from ._synouts import __all__ as synouts_all
|
53
|
+
from ._rnns import *
|
54
|
+
from ._rnns import __all__ as rate_rnns
|
70
55
|
from ._utils import *
|
71
56
|
from ._utils import __all__ as utils_all
|
72
57
|
|
73
|
-
__all__ =
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
+ elementwise_all
|
80
|
-
+ module_all
|
81
|
-
+ exp_euler_all
|
82
|
-
+ utils_all
|
83
|
-
+ dyn_all
|
84
|
-
+ projection_all
|
85
|
-
+ state_delay_all
|
86
|
-
+ synouts_all
|
87
|
-
+ conv_all
|
88
|
-
+ linear_all
|
89
|
-
+ normalizations_all
|
90
|
-
+ poolings_all
|
91
|
-
+ fixedprob_all
|
92
|
-
+ linear_mv_all
|
93
|
-
+ embed_all
|
94
|
-
+ dropout_all
|
95
|
-
+ elementwise_all
|
96
|
-
+ dyn_neuron_all
|
97
|
-
+ dyn_synapse_all
|
98
|
-
+ inputs_all
|
99
|
-
+ rate_rnns
|
100
|
-
+ readout_all
|
101
|
-
+ stp_all
|
102
|
-
+ ltp_all
|
103
|
-
+ _syn_proj_all
|
104
|
-
+ _others_all
|
105
|
-
)
|
58
|
+
__all__ = ['init'] + activation_all + metrics_all
|
59
|
+
__all__ = __all__ + collective_ops_all + common_all + elementwise_all + module_all + exp_euler_all
|
60
|
+
__all__ = __all__ + utils_all + dyn_all + state_delay_all + conv_all
|
61
|
+
__all__ = __all__ + linear_all + normalizations_all + paddings_all + poolings_all + fixedprob_all + linear_mv_all
|
62
|
+
__all__ = __all__ + embed_all + dropout_all + elementwise_all
|
63
|
+
__all__ = __all__ + rate_rnns
|
106
64
|
|
107
65
|
del (
|
66
|
+
metrics_all,
|
67
|
+
activation_all,
|
108
68
|
collective_ops_all,
|
109
69
|
common_all,
|
110
70
|
module_all,
|
111
71
|
exp_euler_all,
|
112
72
|
utils_all,
|
113
73
|
dyn_all,
|
114
|
-
projection_all,
|
115
74
|
state_delay_all,
|
116
|
-
synouts_all,
|
117
75
|
conv_all,
|
118
76
|
linear_all,
|
119
77
|
normalizations_all,
|
78
|
+
paddings_all,
|
120
79
|
poolings_all,
|
121
80
|
embed_all,
|
122
81
|
fixedprob_all,
|
123
82
|
linear_mv_all,
|
124
83
|
dropout_all,
|
125
84
|
elementwise_all,
|
126
|
-
dyn_neuron_all,
|
127
|
-
dyn_synapse_all,
|
128
|
-
inputs_all,
|
129
|
-
readout_all,
|
130
85
|
rate_rnns,
|
131
|
-
stp_all,
|
132
|
-
ltp_all,
|
133
|
-
_syn_proj_all,
|
134
|
-
_others_all,
|
135
86
|
)
|
87
|
+
|
88
|
+
# Deprecated names that redirect to brainpy
|
89
|
+
_DEPRECATED_NAMES = {
|
90
|
+
'SpikeTime': 'brainpy.SpikeTime',
|
91
|
+
'PoissonSpike': 'brainpy.PoissonSpike',
|
92
|
+
'PoissonEncoder': 'brainpy.PoissonEncoder',
|
93
|
+
'PoissonInput': 'brainpy.PoissonInput',
|
94
|
+
'poisson_input': 'brainpy.poisson_input',
|
95
|
+
'Neuron': 'brainpy.Neuron',
|
96
|
+
'IF': 'brainpy.IF',
|
97
|
+
'LIF': 'brainpy.LIF',
|
98
|
+
'LIFRef': 'brainpy.LIFRef',
|
99
|
+
'ALIF': 'brainpy.ALIF',
|
100
|
+
'LeakyRateReadout': 'brainpy.LeakyRateReadout',
|
101
|
+
'LeakySpikeReadout': 'brainpy.LeakySpikeReadout',
|
102
|
+
'STP': 'brainpy.STP',
|
103
|
+
'STD': 'brainpy.STD',
|
104
|
+
'Synapse': 'brainpy.Synapse',
|
105
|
+
'Expon': 'brainpy.Expon',
|
106
|
+
'DualExpon': 'brainpy.DualExpon',
|
107
|
+
'Alpha': 'brainpy.Alpha',
|
108
|
+
'AMPA': 'brainpy.AMPA',
|
109
|
+
'GABAa': 'brainpy.GABAa',
|
110
|
+
'COBA': 'brainpy.COBA',
|
111
|
+
'CUBA': 'brainpy.CUBA',
|
112
|
+
'MgBlock': 'brainpy.MgBlock',
|
113
|
+
'SynOut': 'brainpy.SynOut',
|
114
|
+
'AlignPostProj': 'brainpy.AlignPostProj',
|
115
|
+
'DeltaProj': 'brainpy.DeltaProj',
|
116
|
+
'CurrentProj': 'brainpy.CurrentProj',
|
117
|
+
'align_pre_projection': 'brainpy.align_pre_projection',
|
118
|
+
'Projection': 'brainpy.Projection',
|
119
|
+
'SymmetryGapJunction': 'brainpy.SymmetryGapJunction',
|
120
|
+
'AsymmetryGapJunction': 'brainpy.AsymmetryGapJunction',
|
121
|
+
}
|
122
|
+
|
123
|
+
|
124
|
+
def __getattr__(name: str):
|
125
|
+
if name in _DEPRECATED_NAMES:
|
126
|
+
import warnings
|
127
|
+
new_name = _DEPRECATED_NAMES[name]
|
128
|
+
warnings.warn(
|
129
|
+
f"'brainstate.nn.{name}' is deprecated and will be removed in a future version. "
|
130
|
+
f"Please use '{new_name}' instead.",
|
131
|
+
DeprecationWarning,
|
132
|
+
stacklevel=2
|
133
|
+
)
|
134
|
+
# Import and return the actual brainpy object
|
135
|
+
import brainpy
|
136
|
+
return getattr(brainpy, name)
|
137
|
+
raise AttributeError(f"module 'brainstate.nn' has no attribute '{name}'")
|