brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/nn/_synouts.py
CHANGED
@@ -1,162 +1,162 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import brainunit as u
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
from brainstate.mixin import BindCondData
|
22
|
-
from brainstate.typing import ArrayLike
|
23
|
-
from ._module import Module
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'SynOut', 'COBA', 'CUBA', 'MgBlock',
|
27
|
-
]
|
28
|
-
|
29
|
-
|
30
|
-
class SynOut(Module, BindCondData):
|
31
|
-
"""
|
32
|
-
Base class for synaptic outputs.
|
33
|
-
|
34
|
-
:py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
|
35
|
-
"""
|
36
|
-
|
37
|
-
__module__ = 'brainstate.nn'
|
38
|
-
|
39
|
-
def __init__(self, ):
|
40
|
-
super().__init__()
|
41
|
-
self._conductance = None
|
42
|
-
|
43
|
-
def __call__(self, *args, **kwargs):
|
44
|
-
if self._conductance is None:
|
45
|
-
raise ValueError(f'Please first pack conductance data at the current step using '
|
46
|
-
f'".{BindCondData.bind_cond.__name__}(data)". {self}')
|
47
|
-
ret = self.update(self._conductance, *args, **kwargs)
|
48
|
-
return ret
|
49
|
-
|
50
|
-
def update(self, conductance, potential):
|
51
|
-
raise NotImplementedError
|
52
|
-
|
53
|
-
|
54
|
-
class COBA(SynOut):
|
55
|
-
r"""
|
56
|
-
Conductance-based synaptic output.
|
57
|
-
|
58
|
-
Given the synaptic conductance, the model output the post-synaptic current with
|
59
|
-
|
60
|
-
.. math::
|
61
|
-
|
62
|
-
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
|
63
|
-
|
64
|
-
Parameters
|
65
|
-
----------
|
66
|
-
E: ArrayLike
|
67
|
-
The reversal potential.
|
68
|
-
|
69
|
-
See Also
|
70
|
-
--------
|
71
|
-
CUBA
|
72
|
-
"""
|
73
|
-
__module__ = 'brainstate.nn'
|
74
|
-
|
75
|
-
def __init__(self, E: ArrayLike):
|
76
|
-
super().__init__()
|
77
|
-
|
78
|
-
self.E = E
|
79
|
-
|
80
|
-
def update(self, conductance, potential):
|
81
|
-
return conductance * (self.E - potential)
|
82
|
-
|
83
|
-
|
84
|
-
class CUBA(SynOut):
|
85
|
-
r"""Current-based synaptic output.
|
86
|
-
|
87
|
-
Given the conductance, this model outputs the post-synaptic current with a identity function:
|
88
|
-
|
89
|
-
.. math::
|
90
|
-
|
91
|
-
I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
|
92
|
-
|
93
|
-
Parameters
|
94
|
-
----------
|
95
|
-
scale: ArrayLike
|
96
|
-
The scaling factor for the conductance. Default 1. [mV]
|
97
|
-
|
98
|
-
See Also
|
99
|
-
--------
|
100
|
-
COBA
|
101
|
-
"""
|
102
|
-
__module__ = 'brainstate.nn'
|
103
|
-
|
104
|
-
def __init__(self, scale: ArrayLike = u.volt):
|
105
|
-
super().__init__()
|
106
|
-
self.scale = scale
|
107
|
-
|
108
|
-
def update(self, conductance, potential=None):
|
109
|
-
return conductance * self.scale
|
110
|
-
|
111
|
-
|
112
|
-
class MgBlock(SynOut):
|
113
|
-
r"""Synaptic output based on Magnesium blocking.
|
114
|
-
|
115
|
-
Given the synaptic conductance, the model output the post-synaptic current with
|
116
|
-
|
117
|
-
.. math::
|
118
|
-
|
119
|
-
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
|
120
|
-
|
121
|
-
where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
|
122
|
-
|
123
|
-
.. math::
|
124
|
-
|
125
|
-
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
|
126
|
-
|
127
|
-
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
|
128
|
-
|
129
|
-
Parameters
|
130
|
-
----------
|
131
|
-
E: ArrayLike
|
132
|
-
The reversal potential for the synaptic current. [mV]
|
133
|
-
alpha: ArrayLike
|
134
|
-
Binding constant. Default 0.062
|
135
|
-
beta: ArrayLike
|
136
|
-
Unbinding constant. Default 3.57
|
137
|
-
cc_Mg: ArrayLike
|
138
|
-
Concentration of Magnesium ion. Default 1.2 [mM].
|
139
|
-
V_offset: ArrayLike
|
140
|
-
The offset potential. Default 0. [mV]
|
141
|
-
"""
|
142
|
-
__module__ = 'brainstate.nn'
|
143
|
-
|
144
|
-
def __init__(
|
145
|
-
self,
|
146
|
-
E: ArrayLike = 0.,
|
147
|
-
cc_Mg: ArrayLike = 1.2,
|
148
|
-
alpha: ArrayLike = 0.062,
|
149
|
-
beta: ArrayLike = 3.57,
|
150
|
-
V_offset: ArrayLike = 0.,
|
151
|
-
):
|
152
|
-
super().__init__()
|
153
|
-
|
154
|
-
self.E = E
|
155
|
-
self.V_offset = V_offset
|
156
|
-
self.cc_Mg = cc_Mg
|
157
|
-
self.alpha = alpha
|
158
|
-
self.beta = beta
|
159
|
-
|
160
|
-
def update(self, conductance, potential):
|
161
|
-
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
|
162
|
-
return conductance * (self.E - potential) / norm
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import brainunit as u
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
from brainstate.mixin import BindCondData
|
22
|
+
from brainstate.typing import ArrayLike
|
23
|
+
from ._module import Module
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'SynOut', 'COBA', 'CUBA', 'MgBlock',
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
class SynOut(Module, BindCondData):
|
31
|
+
"""
|
32
|
+
Base class for synaptic outputs.
|
33
|
+
|
34
|
+
:py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
|
35
|
+
"""
|
36
|
+
|
37
|
+
__module__ = 'brainstate.nn'
|
38
|
+
|
39
|
+
def __init__(self, ):
|
40
|
+
super().__init__()
|
41
|
+
self._conductance = None
|
42
|
+
|
43
|
+
def __call__(self, *args, **kwargs):
|
44
|
+
if self._conductance is None:
|
45
|
+
raise ValueError(f'Please first pack conductance data at the current step using '
|
46
|
+
f'".{BindCondData.bind_cond.__name__}(data)". {self}')
|
47
|
+
ret = self.update(self._conductance, *args, **kwargs)
|
48
|
+
return ret
|
49
|
+
|
50
|
+
def update(self, conductance, potential):
|
51
|
+
raise NotImplementedError
|
52
|
+
|
53
|
+
|
54
|
+
class COBA(SynOut):
|
55
|
+
r"""
|
56
|
+
Conductance-based synaptic output.
|
57
|
+
|
58
|
+
Given the synaptic conductance, the model output the post-synaptic current with
|
59
|
+
|
60
|
+
.. math::
|
61
|
+
|
62
|
+
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
|
63
|
+
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
E: ArrayLike
|
67
|
+
The reversal potential.
|
68
|
+
|
69
|
+
See Also
|
70
|
+
--------
|
71
|
+
CUBA
|
72
|
+
"""
|
73
|
+
__module__ = 'brainstate.nn'
|
74
|
+
|
75
|
+
def __init__(self, E: ArrayLike):
|
76
|
+
super().__init__()
|
77
|
+
|
78
|
+
self.E = E
|
79
|
+
|
80
|
+
def update(self, conductance, potential):
|
81
|
+
return conductance * (self.E - potential)
|
82
|
+
|
83
|
+
|
84
|
+
class CUBA(SynOut):
|
85
|
+
r"""Current-based synaptic output.
|
86
|
+
|
87
|
+
Given the conductance, this model outputs the post-synaptic current with a identity function:
|
88
|
+
|
89
|
+
.. math::
|
90
|
+
|
91
|
+
I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
|
92
|
+
|
93
|
+
Parameters
|
94
|
+
----------
|
95
|
+
scale: ArrayLike
|
96
|
+
The scaling factor for the conductance. Default 1. [mV]
|
97
|
+
|
98
|
+
See Also
|
99
|
+
--------
|
100
|
+
COBA
|
101
|
+
"""
|
102
|
+
__module__ = 'brainstate.nn'
|
103
|
+
|
104
|
+
def __init__(self, scale: ArrayLike = u.volt):
|
105
|
+
super().__init__()
|
106
|
+
self.scale = scale
|
107
|
+
|
108
|
+
def update(self, conductance, potential=None):
|
109
|
+
return conductance * self.scale
|
110
|
+
|
111
|
+
|
112
|
+
class MgBlock(SynOut):
|
113
|
+
r"""Synaptic output based on Magnesium blocking.
|
114
|
+
|
115
|
+
Given the synaptic conductance, the model output the post-synaptic current with
|
116
|
+
|
117
|
+
.. math::
|
118
|
+
|
119
|
+
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
|
120
|
+
|
121
|
+
where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
|
122
|
+
|
123
|
+
.. math::
|
124
|
+
|
125
|
+
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
|
126
|
+
|
127
|
+
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
|
128
|
+
|
129
|
+
Parameters
|
130
|
+
----------
|
131
|
+
E: ArrayLike
|
132
|
+
The reversal potential for the synaptic current. [mV]
|
133
|
+
alpha: ArrayLike
|
134
|
+
Binding constant. Default 0.062
|
135
|
+
beta: ArrayLike
|
136
|
+
Unbinding constant. Default 3.57
|
137
|
+
cc_Mg: ArrayLike
|
138
|
+
Concentration of Magnesium ion. Default 1.2 [mM].
|
139
|
+
V_offset: ArrayLike
|
140
|
+
The offset potential. Default 0. [mV]
|
141
|
+
"""
|
142
|
+
__module__ = 'brainstate.nn'
|
143
|
+
|
144
|
+
def __init__(
|
145
|
+
self,
|
146
|
+
E: ArrayLike = 0.,
|
147
|
+
cc_Mg: ArrayLike = 1.2,
|
148
|
+
alpha: ArrayLike = 0.062,
|
149
|
+
beta: ArrayLike = 3.57,
|
150
|
+
V_offset: ArrayLike = 0.,
|
151
|
+
):
|
152
|
+
super().__init__()
|
153
|
+
|
154
|
+
self.E = E
|
155
|
+
self.V_offset = V_offset
|
156
|
+
self.cc_Mg = cc_Mg
|
157
|
+
self.alpha = alpha
|
158
|
+
self.beta = beta
|
159
|
+
|
160
|
+
def update(self, conductance, potential):
|
161
|
+
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
|
162
|
+
return conductance * (self.E - potential) / norm
|
brainstate/nn/_synouts_test.py
CHANGED
@@ -1,57 +1,57 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import brainunit as u
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import numpy as np
|
22
|
-
|
23
|
-
import brainstate
|
24
|
-
|
25
|
-
|
26
|
-
class TestSynOutModels(unittest.TestCase):
|
27
|
-
def setUp(self):
|
28
|
-
self.conductance = jnp.array([0.5, 1.0, 1.5])
|
29
|
-
self.potential = jnp.array([-70.0, -65.0, -60.0])
|
30
|
-
self.E = jnp.array([-70.0])
|
31
|
-
self.alpha = jnp.array([0.062])
|
32
|
-
self.beta = jnp.array([3.57])
|
33
|
-
self.cc_Mg = jnp.array([1.2])
|
34
|
-
self.V_offset = jnp.array([0.0])
|
35
|
-
|
36
|
-
def test_COBA(self):
|
37
|
-
model = brainstate.nn.COBA(E=self.E)
|
38
|
-
output = model.update(self.conductance, self.potential)
|
39
|
-
expected_output = self.conductance * (self.E - self.potential)
|
40
|
-
np.testing.assert_array_almost_equal(output, expected_output)
|
41
|
-
|
42
|
-
def test_CUBA(self):
|
43
|
-
model = brainstate.nn.CUBA()
|
44
|
-
output = model.update(self.conductance)
|
45
|
-
expected_output = self.conductance * model.scale
|
46
|
-
self.assertTrue(u.math.allclose(output, expected_output))
|
47
|
-
|
48
|
-
def test_MgBlock(self):
|
49
|
-
model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
|
50
|
-
output = model.update(self.conductance, self.potential)
|
51
|
-
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
|
52
|
-
expected_output = self.conductance * (self.E - self.potential) / norm
|
53
|
-
np.testing.assert_array_almost_equal(output, expected_output)
|
54
|
-
|
55
|
-
|
56
|
-
if __name__ == '__main__':
|
57
|
-
unittest.main()
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import brainunit as u
|
20
|
+
import jax.numpy as jnp
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
import brainstate
|
24
|
+
|
25
|
+
|
26
|
+
class TestSynOutModels(unittest.TestCase):
|
27
|
+
def setUp(self):
|
28
|
+
self.conductance = jnp.array([0.5, 1.0, 1.5])
|
29
|
+
self.potential = jnp.array([-70.0, -65.0, -60.0])
|
30
|
+
self.E = jnp.array([-70.0])
|
31
|
+
self.alpha = jnp.array([0.062])
|
32
|
+
self.beta = jnp.array([3.57])
|
33
|
+
self.cc_Mg = jnp.array([1.2])
|
34
|
+
self.V_offset = jnp.array([0.0])
|
35
|
+
|
36
|
+
def test_COBA(self):
|
37
|
+
model = brainstate.nn.COBA(E=self.E)
|
38
|
+
output = model.update(self.conductance, self.potential)
|
39
|
+
expected_output = self.conductance * (self.E - self.potential)
|
40
|
+
np.testing.assert_array_almost_equal(output, expected_output)
|
41
|
+
|
42
|
+
def test_CUBA(self):
|
43
|
+
model = brainstate.nn.CUBA()
|
44
|
+
output = model.update(self.conductance)
|
45
|
+
expected_output = self.conductance * model.scale
|
46
|
+
self.assertTrue(u.math.allclose(output, expected_output))
|
47
|
+
|
48
|
+
def test_MgBlock(self):
|
49
|
+
model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
|
50
|
+
output = model.update(self.conductance, self.potential)
|
51
|
+
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
|
52
|
+
expected_output = self.conductance * (self.E - self.potential) / norm
|
53
|
+
np.testing.assert_array_almost_equal(output, expected_output)
|
54
|
+
|
55
|
+
|
56
|
+
if __name__ == '__main__':
|
57
|
+
unittest.main()
|
brainstate/nn/_utils.py
CHANGED
@@ -1,89 +1,89 @@
|
|
1
|
-
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import Union, Tuple
|
19
|
-
|
20
|
-
from brainstate._state import ParamState
|
21
|
-
from brainstate.util import PrettyTable
|
22
|
-
from ._module import Module
|
23
|
-
|
24
|
-
__all__ = [
|
25
|
-
"count_parameters",
|
26
|
-
]
|
27
|
-
|
28
|
-
|
29
|
-
def _format_parameter_count(num_params, precision=2):
|
30
|
-
if num_params < 1000:
|
31
|
-
return str(num_params)
|
32
|
-
|
33
|
-
suffixes = ['', 'K', 'M', 'B', 'T', 'P', 'E']
|
34
|
-
magnitude = 0
|
35
|
-
while abs(num_params) >= 1000:
|
36
|
-
magnitude += 1
|
37
|
-
num_params /= 1000.0
|
38
|
-
|
39
|
-
format_string = '{:.' + str(precision) + 'f}{}'
|
40
|
-
formatted_value = format_string.format(num_params, suffixes[magnitude])
|
41
|
-
|
42
|
-
# 检查是否接近 1000,如果是,尝试使用更大的基数
|
43
|
-
if magnitude < len(suffixes) - 1 and num_params >= 1000 * (1 - 10 ** (-precision)):
|
44
|
-
magnitude += 1
|
45
|
-
num_params /= 1000.0
|
46
|
-
formatted_value = format_string.format(num_params, suffixes[magnitude])
|
47
|
-
|
48
|
-
return formatted_value
|
49
|
-
|
50
|
-
|
51
|
-
def count_parameters(
|
52
|
-
module: Module,
|
53
|
-
precision: int = 2,
|
54
|
-
return_table: bool = False,
|
55
|
-
) -> Union[Tuple[PrettyTable, int], int]:
|
56
|
-
"""
|
57
|
-
Count and display the number of trainable parameters in a neural network model.
|
58
|
-
|
59
|
-
This function iterates through all the parameters of the given model,
|
60
|
-
counts the number of parameters for each module, and displays them in a table.
|
61
|
-
It also calculates and returns the total number of trainable parameters.
|
62
|
-
|
63
|
-
Parameters:
|
64
|
-
-----------
|
65
|
-
model : brainstate.nn.Module
|
66
|
-
The neural network model for which to count parameters.
|
67
|
-
|
68
|
-
Returns:
|
69
|
-
--------
|
70
|
-
int
|
71
|
-
The total number of trainable parameters in the model.
|
72
|
-
|
73
|
-
Prints:
|
74
|
-
-------
|
75
|
-
A pretty-formatted table showing the number of parameters for each module,
|
76
|
-
followed by the total number of trainable parameters.
|
77
|
-
"""
|
78
|
-
assert isinstance(module, Module), "Input must be a neural network module" # noqa: E501
|
79
|
-
table = PrettyTable(["Modules", "Parameters"])
|
80
|
-
total_params = 0
|
81
|
-
for name, parameter in module.states(ParamState).items():
|
82
|
-
param = parameter.numel()
|
83
|
-
table.add_row([name, _format_parameter_count(param, precision=precision)])
|
84
|
-
total_params += param
|
85
|
-
table.add_row(["Total", _format_parameter_count(total_params, precision=precision)])
|
86
|
-
print(table)
|
87
|
-
if return_table:
|
88
|
-
return table, total_params
|
89
|
-
return total_params
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Union, Tuple
|
19
|
+
|
20
|
+
from brainstate._state import ParamState
|
21
|
+
from brainstate.util import PrettyTable
|
22
|
+
from ._module import Module
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"count_parameters",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
def _format_parameter_count(num_params, precision=2):
|
30
|
+
if num_params < 1000:
|
31
|
+
return str(num_params)
|
32
|
+
|
33
|
+
suffixes = ['', 'K', 'M', 'B', 'T', 'P', 'E']
|
34
|
+
magnitude = 0
|
35
|
+
while abs(num_params) >= 1000:
|
36
|
+
magnitude += 1
|
37
|
+
num_params /= 1000.0
|
38
|
+
|
39
|
+
format_string = '{:.' + str(precision) + 'f}{}'
|
40
|
+
formatted_value = format_string.format(num_params, suffixes[magnitude])
|
41
|
+
|
42
|
+
# 检查是否接近 1000,如果是,尝试使用更大的基数
|
43
|
+
if magnitude < len(suffixes) - 1 and num_params >= 1000 * (1 - 10 ** (-precision)):
|
44
|
+
magnitude += 1
|
45
|
+
num_params /= 1000.0
|
46
|
+
formatted_value = format_string.format(num_params, suffixes[magnitude])
|
47
|
+
|
48
|
+
return formatted_value
|
49
|
+
|
50
|
+
|
51
|
+
def count_parameters(
|
52
|
+
module: Module,
|
53
|
+
precision: int = 2,
|
54
|
+
return_table: bool = False,
|
55
|
+
) -> Union[Tuple[PrettyTable, int], int]:
|
56
|
+
"""
|
57
|
+
Count and display the number of trainable parameters in a neural network model.
|
58
|
+
|
59
|
+
This function iterates through all the parameters of the given model,
|
60
|
+
counts the number of parameters for each module, and displays them in a table.
|
61
|
+
It also calculates and returns the total number of trainable parameters.
|
62
|
+
|
63
|
+
Parameters:
|
64
|
+
-----------
|
65
|
+
model : brainstate.nn.Module
|
66
|
+
The neural network model for which to count parameters.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
--------
|
70
|
+
int
|
71
|
+
The total number of trainable parameters in the model.
|
72
|
+
|
73
|
+
Prints:
|
74
|
+
-------
|
75
|
+
A pretty-formatted table showing the number of parameters for each module,
|
76
|
+
followed by the total number of trainable parameters.
|
77
|
+
"""
|
78
|
+
assert isinstance(module, Module), "Input must be a neural network module" # noqa: E501
|
79
|
+
table = PrettyTable(["Modules", "Parameters"])
|
80
|
+
total_params = 0
|
81
|
+
for name, parameter in module.states(ParamState).items():
|
82
|
+
param = parameter.numel()
|
83
|
+
table.add_row([name, _format_parameter_count(param, precision=precision)])
|
84
|
+
total_params += param
|
85
|
+
table.add_row(["Total", _format_parameter_count(total_params, precision=precision)])
|
86
|
+
print(table)
|
87
|
+
if return_table:
|
88
|
+
return table, total_params
|
89
|
+
return total_params
|