brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,423 +0,0 @@
|
|
1
|
-
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# -*- coding: utf-8 -*-
|
16
|
-
|
17
|
-
|
18
|
-
from typing import Callable, Union, Tuple
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
|
22
|
-
from brainstate import init
|
23
|
-
from brainstate._state import ParamState
|
24
|
-
from brainstate.typing import ArrayLike
|
25
|
-
from ._dynamics import Dynamics, Projection
|
26
|
-
|
27
|
-
__all__ = [
|
28
|
-
'SymmetryGapJunction',
|
29
|
-
'AsymmetryGapJunction',
|
30
|
-
]
|
31
|
-
|
32
|
-
|
33
|
-
class align_pre_ltp(Projection):
|
34
|
-
pass
|
35
|
-
|
36
|
-
|
37
|
-
class align_post_ltp(Projection):
|
38
|
-
pass
|
39
|
-
|
40
|
-
|
41
|
-
def get_gap_junction_post_key(i: int):
|
42
|
-
return f'gap_junction_post_{i}'
|
43
|
-
|
44
|
-
|
45
|
-
def get_gap_junction_pre_key(i: int):
|
46
|
-
return f'gap_junction_pre_{i}'
|
47
|
-
|
48
|
-
|
49
|
-
class SymmetryGapJunction(Projection):
|
50
|
-
"""
|
51
|
-
Implements a symmetric electrical coupling (gap junction) between neuron populations.
|
52
|
-
|
53
|
-
This class represents electrical synapses where the conductance is identical in both
|
54
|
-
directions. Gap junctions allow bidirectional flow of electrical current directly between
|
55
|
-
neurons, with the current magnitude proportional to the voltage difference between
|
56
|
-
connected neurons.
|
57
|
-
|
58
|
-
Parameters
|
59
|
-
----------
|
60
|
-
couples : Union[Tuple[Dynamics, Dynamics], Dynamics]
|
61
|
-
Either a single Dynamics object (when pre and post populations are the same)
|
62
|
-
or a tuple of two Dynamics objects (pre, post) representing the coupled neuron populations.
|
63
|
-
states : Union[str, Tuple[str, str]]
|
64
|
-
Either a single string (when pre and post states are the same)
|
65
|
-
or a tuple of two strings (pre_state, post_state) representing the state variables
|
66
|
-
to use for calculating voltage differences (typically membrane potentials).
|
67
|
-
conn : Callable
|
68
|
-
Connection function that returns pre_ids and post_ids arrays defining connections.
|
69
|
-
weight : Union[Callable, ArrayLike]
|
70
|
-
Conductance weights for the gap junctions. The same weight applies in both directions
|
71
|
-
of the connection.
|
72
|
-
param_type : type, optional
|
73
|
-
The parameter state type to use for weights, defaults to ParamState.
|
74
|
-
|
75
|
-
Notes
|
76
|
-
-----
|
77
|
-
The symmetric gap junction applies identical conductance in both directions between
|
78
|
-
connected neurons, ensuring balanced electrical coupling in the network.
|
79
|
-
|
80
|
-
See Also
|
81
|
-
--------
|
82
|
-
AsymmetryGapJunction : For gap junctions with different conductances in each direction.
|
83
|
-
"""
|
84
|
-
|
85
|
-
def __init__(
|
86
|
-
self,
|
87
|
-
couples: Union[Tuple[Dynamics, Dynamics], Dynamics],
|
88
|
-
states: Union[str, Tuple[str, str]],
|
89
|
-
conn: Callable,
|
90
|
-
weight: Union[Callable, ArrayLike],
|
91
|
-
param_type: type = ParamState
|
92
|
-
):
|
93
|
-
super().__init__()
|
94
|
-
|
95
|
-
if isinstance(states, str):
|
96
|
-
pre_state = states
|
97
|
-
post_state = states
|
98
|
-
else:
|
99
|
-
pre_state, post_state = states
|
100
|
-
if isinstance(couples, Dynamics):
|
101
|
-
pre = couples
|
102
|
-
post = couples
|
103
|
-
else:
|
104
|
-
pre, post = couples
|
105
|
-
assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
|
106
|
-
assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
|
107
|
-
assert isinstance(pre, Dynamics), "pre must be a Dynamics object"
|
108
|
-
assert isinstance(post, Dynamics), "post must be a Dynamics object"
|
109
|
-
self.pre_state = pre_state
|
110
|
-
self.post_state = post_state
|
111
|
-
self.pre = pre
|
112
|
-
self.post = post
|
113
|
-
self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
|
114
|
-
self.weight = param_type(init.param(weight, (len(self.pre_ids),)))
|
115
|
-
|
116
|
-
def update(self, *args, **kwargs):
|
117
|
-
if not hasattr(self.pre, self.pre_state):
|
118
|
-
raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
|
119
|
-
if not hasattr(self.post, self.post_state):
|
120
|
-
raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
|
121
|
-
pre = getattr(self.pre, self.pre_state).value
|
122
|
-
post = getattr(self.post, self.post_state).value
|
123
|
-
|
124
|
-
return symmetry_gap_junction_projection(
|
125
|
-
pre=self.pre,
|
126
|
-
pre_value=pre,
|
127
|
-
post=self.post,
|
128
|
-
post_value=post,
|
129
|
-
pre_ids=self.pre_ids,
|
130
|
-
post_ids=self.post_ids,
|
131
|
-
weight=self.weight.value,
|
132
|
-
)
|
133
|
-
|
134
|
-
|
135
|
-
def symmetry_gap_junction_projection(
|
136
|
-
pre: Dynamics,
|
137
|
-
pre_value: ArrayLike,
|
138
|
-
post: Dynamics,
|
139
|
-
post_value: ArrayLike,
|
140
|
-
pre_ids: ArrayLike,
|
141
|
-
post_ids: ArrayLike,
|
142
|
-
weight: ArrayLike,
|
143
|
-
):
|
144
|
-
"""
|
145
|
-
Calculate symmetrical electrical coupling through gap junctions between neurons.
|
146
|
-
|
147
|
-
This function implements bidirectional gap junction coupling where the same weight is
|
148
|
-
applied in both directions. It computes the electrical current flowing between
|
149
|
-
connected neurons based on their potential differences and updates both pre-synaptic
|
150
|
-
and post-synaptic neuron groups with appropriate input currents.
|
151
|
-
|
152
|
-
Parameters
|
153
|
-
----------
|
154
|
-
pre : Dynamics
|
155
|
-
The pre-synaptic neuron group dynamics object.
|
156
|
-
pre_value : ArrayLike
|
157
|
-
State values (typically membrane potentials) of the pre-synaptic neurons.
|
158
|
-
post : Dynamics
|
159
|
-
The post-synaptic neuron group dynamics object.
|
160
|
-
post_value : ArrayLike
|
161
|
-
State values (typically membrane potentials) of the post-synaptic neurons.
|
162
|
-
pre_ids : ArrayLike
|
163
|
-
Indices of pre-synaptic neurons that form gap junctions.
|
164
|
-
post_ids : ArrayLike
|
165
|
-
Indices of post-synaptic neurons that form gap junctions,
|
166
|
-
where each pre_ids[i] is connected to post_ids[i].
|
167
|
-
weight : ArrayLike
|
168
|
-
Conductance weights for the gap junctions. Can be a scalar (same weight for all connections)
|
169
|
-
or an array with length matching pre_ids.
|
170
|
-
|
171
|
-
Returns
|
172
|
-
-------
|
173
|
-
ArrayLike
|
174
|
-
The input currents that were added to the pre-synaptic neuron group.
|
175
|
-
|
176
|
-
Notes
|
177
|
-
-----
|
178
|
-
The electrical coupling is implemented as I = g(V_pre - V_post), where:
|
179
|
-
- I is the current flowing from pre to post neuron
|
180
|
-
- g is the gap junction conductance (weight)
|
181
|
-
- V_pre and V_post are the membrane potentials of connected neurons
|
182
|
-
|
183
|
-
Equal but opposite currents are applied to both connected neurons, ensuring
|
184
|
-
conservation of current in the network.
|
185
|
-
|
186
|
-
Raises
|
187
|
-
------
|
188
|
-
AssertionError
|
189
|
-
If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
|
190
|
-
"""
|
191
|
-
assert u.math.ndim(weight) == 0 or weight.shape[0] == len(pre_ids), \
|
192
|
-
"weight must be a scalar or have the same length as pre_ids"
|
193
|
-
assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
|
194
|
-
# Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
|
195
|
-
# and multiply by the connection weights
|
196
|
-
diff = (pre_value[pre_ids] - post_value[post_ids]) * weight
|
197
|
-
|
198
|
-
# add to post-synaptic neuron group
|
199
|
-
# Initialize the input currents for the post-synaptic neuron group
|
200
|
-
inputs = u.math.zeros(post.out_size, unit=u.get_unit(diff))
|
201
|
-
# Add the calculated current to the corresponding post-synaptic neurons
|
202
|
-
inputs = inputs.at[post_ids].add(diff)
|
203
|
-
# Generate a unique key for the post-synaptic input currents
|
204
|
-
key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
|
205
|
-
# Add the input currents to the post-synaptic neuron group
|
206
|
-
post.add_current_input(key, inputs)
|
207
|
-
|
208
|
-
# add to pre-synaptic neuron group
|
209
|
-
# Initialize the input currents for the pre-synaptic neuron group
|
210
|
-
inputs = u.math.zeros(pre.out_size, unit=u.get_unit(diff))
|
211
|
-
# Add the calculated current to the corresponding pre-synaptic neurons
|
212
|
-
inputs = inputs.at[pre_ids].add(diff)
|
213
|
-
# Generate a unique key for the pre-synaptic input currents
|
214
|
-
key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
|
215
|
-
# Add the input currents to the pre-synaptic neuron group with opposite polarity
|
216
|
-
pre.add_current_input(key, -inputs)
|
217
|
-
return inputs
|
218
|
-
|
219
|
-
|
220
|
-
class AsymmetryGapJunction(Projection):
|
221
|
-
"""
|
222
|
-
Implements an asymmetric electrical coupling (gap junction) between neuron populations.
|
223
|
-
|
224
|
-
This class represents electrical synapses where the conductance in one direction can differ
|
225
|
-
from the conductance in the opposite direction. Unlike chemical synapses, gap junctions
|
226
|
-
allow bidirectional flow of electrical current directly between neurons, with the current
|
227
|
-
magnitude proportional to the voltage difference between connected neurons.
|
228
|
-
|
229
|
-
Parameters
|
230
|
-
----------
|
231
|
-
pre : Dynamics
|
232
|
-
The pre-synaptic neuron group dynamics object.
|
233
|
-
pre_state : str
|
234
|
-
Name of the state variable in pre-synaptic neurons (typically membrane potential).
|
235
|
-
post : Dynamics
|
236
|
-
The post-synaptic neuron group dynamics object.
|
237
|
-
post_state : str
|
238
|
-
Name of the state variable in post-synaptic neurons (typically membrane potential).
|
239
|
-
conn : Callable
|
240
|
-
Connection function that returns pre_ids and post_ids arrays defining connections.
|
241
|
-
weight : Union[Callable, ArrayLike]
|
242
|
-
Conductance weights for the gap junctions. Must have shape [..., 2] where the last
|
243
|
-
dimension contains [pre_weight, post_weight] for each connection, defining
|
244
|
-
potentially different conductances in each direction.
|
245
|
-
param_type : type, optional
|
246
|
-
The parameter state type to use for weights, defaults to ParamState.
|
247
|
-
|
248
|
-
Examples
|
249
|
-
--------
|
250
|
-
>>> import brainstate
|
251
|
-
>>> import brainunit as u
|
252
|
-
>>> import numpy as np
|
253
|
-
>>>
|
254
|
-
>>> # Create two neuron populations
|
255
|
-
>>> n_neurons = 100
|
256
|
-
>>> pre_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
257
|
-
>>> post_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
258
|
-
>>> pre_pop.init_state()
|
259
|
-
>>> post_pop.init_state()
|
260
|
-
>>>
|
261
|
-
>>> # Create asymmetric gap junction with different weights in each direction
|
262
|
-
>>> weights = np.ones((n_neurons, 2)) * u.nS
|
263
|
-
>>> weights[:, 0] *= 2.0 # Double weight in pre->post direction
|
264
|
-
>>>
|
265
|
-
>>> gap_junction = brainstate.nn.AsymmetryGapJunction(
|
266
|
-
... pre=pre_pop,
|
267
|
-
... pre_state='V',
|
268
|
-
... post=post_pop,
|
269
|
-
... post_state='V',
|
270
|
-
... conn=one_to_one,
|
271
|
-
... weight=weights
|
272
|
-
... )
|
273
|
-
|
274
|
-
Notes
|
275
|
-
-----
|
276
|
-
The asymmetric gap junction allows for different conductances in each direction between
|
277
|
-
the same pair of neurons. This can model rectifying electrical synapses that preferentially
|
278
|
-
allow current to flow in one direction.
|
279
|
-
|
280
|
-
See Also
|
281
|
-
--------
|
282
|
-
SymmetryGapJunction : For gap junctions with identical conductance in both directions.
|
283
|
-
"""
|
284
|
-
|
285
|
-
def __init__(
|
286
|
-
self,
|
287
|
-
pre: Dynamics,
|
288
|
-
pre_state: str,
|
289
|
-
post: Dynamics,
|
290
|
-
post_state: str,
|
291
|
-
conn: Callable,
|
292
|
-
weight: Union[Callable, ArrayLike],
|
293
|
-
param_type: type = ParamState
|
294
|
-
):
|
295
|
-
super().__init__()
|
296
|
-
|
297
|
-
assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
|
298
|
-
assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
|
299
|
-
self.pre_state = pre_state
|
300
|
-
self.post_state = post_state
|
301
|
-
self.pre = pre
|
302
|
-
self.post = post
|
303
|
-
self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
|
304
|
-
self.weight = param_type(init.param(weight, (len(self.pre_ids), 2)))
|
305
|
-
|
306
|
-
def update(self, *args, **kwargs):
|
307
|
-
if not hasattr(self.pre, self.pre_state):
|
308
|
-
raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
|
309
|
-
if not hasattr(self.post, self.post_state):
|
310
|
-
raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
|
311
|
-
pre = getattr(self.pre, self.pre_state).value
|
312
|
-
post = getattr(self.post, self.post_state).value
|
313
|
-
|
314
|
-
return asymmetry_gap_junction_projection(
|
315
|
-
pre=self.pre,
|
316
|
-
pre_value=pre,
|
317
|
-
post=self.post,
|
318
|
-
post_value=post,
|
319
|
-
pre_ids=self.pre_ids,
|
320
|
-
post_ids=self.post_ids,
|
321
|
-
weight=self.weight.value,
|
322
|
-
)
|
323
|
-
|
324
|
-
|
325
|
-
def asymmetry_gap_junction_projection(
|
326
|
-
pre: Dynamics,
|
327
|
-
pre_value: ArrayLike,
|
328
|
-
post: Dynamics,
|
329
|
-
post_value: ArrayLike,
|
330
|
-
pre_ids: ArrayLike,
|
331
|
-
post_ids: ArrayLike,
|
332
|
-
weight: ArrayLike,
|
333
|
-
):
|
334
|
-
"""
|
335
|
-
Calculate asymmetrical electrical coupling through gap junctions between neurons.
|
336
|
-
|
337
|
-
This function implements bidirectional gap junction coupling where different weights
|
338
|
-
can be applied in each direction. It computes the electrical current flowing between
|
339
|
-
connected neurons based on their potential differences and updates both pre-synaptic
|
340
|
-
and post-synaptic neuron groups with appropriate input currents.
|
341
|
-
|
342
|
-
Parameters
|
343
|
-
----------
|
344
|
-
pre : Dynamics
|
345
|
-
The pre-synaptic neuron group dynamics object.
|
346
|
-
pre_value : ArrayLike
|
347
|
-
State values (typically membrane potentials) of the pre-synaptic neurons.
|
348
|
-
post : Dynamics
|
349
|
-
The post-synaptic neuron group dynamics object.
|
350
|
-
post_value : ArrayLike
|
351
|
-
State values (typically membrane potentials) of the post-synaptic neurons.
|
352
|
-
pre_ids : ArrayLike
|
353
|
-
Indices of pre-synaptic neurons that form gap junctions.
|
354
|
-
post_ids : ArrayLike
|
355
|
-
Indices of post-synaptic neurons that form gap junctions,
|
356
|
-
where each pre_ids[i] is connected to post_ids[i].
|
357
|
-
weight : ArrayLike
|
358
|
-
Conductance weights for the gap junctions. Must have shape [..., 2], where
|
359
|
-
the last dimension contains [pre_weight, post_weight] for each connection.
|
360
|
-
Can be a 1D array [pre_weight, post_weight] (same weights for all connections)
|
361
|
-
or a 2D array with shape [len(pre_ids), 2] for connection-specific weights.
|
362
|
-
|
363
|
-
Returns
|
364
|
-
-------
|
365
|
-
ArrayLike
|
366
|
-
The input currents that were added to the pre-synaptic neuron group.
|
367
|
-
|
368
|
-
Notes
|
369
|
-
-----
|
370
|
-
The electrical coupling is implemented with direction-specific conductances:
|
371
|
-
- I_pre2post = g_pre * (V_pre - V_post) flowing from pre to post neuron
|
372
|
-
- I_post2pre = g_post * (V_pre - V_post) flowing from post to pre neuron
|
373
|
-
where g_pre and g_post can be different, allowing for asymmetrical coupling.
|
374
|
-
|
375
|
-
Raises
|
376
|
-
------
|
377
|
-
AssertionError
|
378
|
-
If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
|
379
|
-
ValueError
|
380
|
-
If weight shape is incompatible with asymmetrical gap junction requirements.
|
381
|
-
"""
|
382
|
-
assert weight.shape[-1] == 2, 'weight must be a 2-element array for asymmetry gap junctions'
|
383
|
-
assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
|
384
|
-
if u.math.ndim(weight) == 1:
|
385
|
-
# If weight is a 1D array, it should have two elements for pre and post weights
|
386
|
-
assert weight.shape[0] == 2, "weight must be a 2-element array for asymmetry gap junctions"
|
387
|
-
pre_weight = weight[0]
|
388
|
-
post_weight = weight[1]
|
389
|
-
elif u.math.ndim(weight) == 2:
|
390
|
-
# If weight is a 2D array, it should have two rows for pre and post weights
|
391
|
-
pre_weight = weight[:, 0]
|
392
|
-
post_weight = weight[:, 1]
|
393
|
-
assert pre_weight.shape[0] == len(pre_ids), "pre_weight must have the same length as pre_ids"
|
394
|
-
assert post_weight.shape[0] == len(post_ids), "post_weight must have the same length as post_ids"
|
395
|
-
else:
|
396
|
-
raise ValueError("weight must be a 1D or 2D array for asymmetry gap junctions")
|
397
|
-
|
398
|
-
# Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
|
399
|
-
# and multiply by the connection weights
|
400
|
-
diff = pre_value[pre_ids] - post_value[post_ids]
|
401
|
-
pre2post_current = diff * pre_weight
|
402
|
-
post2pre_current = diff * post_weight
|
403
|
-
|
404
|
-
# add to post-synaptic neuron group
|
405
|
-
# Initialize the input currents for the post-synaptic neuron group
|
406
|
-
inputs = u.math.zeros(post.out_size, unit=u.get_unit(pre2post_current))
|
407
|
-
# Add the calculated current to the corresponding post-synaptic neurons
|
408
|
-
inputs = inputs.at[post_ids].add(pre2post_current)
|
409
|
-
# Generate a unique key for the post-synaptic input currents
|
410
|
-
key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
|
411
|
-
# Add the input currents to the post-synaptic neuron group
|
412
|
-
post.add_current_input(key, inputs)
|
413
|
-
|
414
|
-
# add to pre-synaptic neuron group
|
415
|
-
# Initialize the input currents for the pre-synaptic neuron group
|
416
|
-
inputs = u.math.zeros(pre.out_size, unit=u.get_unit(post2pre_current))
|
417
|
-
# Add the calculated current to the corresponding pre-synaptic neurons
|
418
|
-
inputs = inputs.at[pre_ids].add(post2pre_current)
|
419
|
-
# Generate a unique key for the pre-synaptic input currents
|
420
|
-
key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
|
421
|
-
# Add the input currents to the pre-synaptic neuron group with opposite polarity
|
422
|
-
pre.add_current_input(key, -inputs)
|
423
|
-
return inputs
|
brainstate/nn/_synouts.py
DELETED
@@ -1,162 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import brainunit as u
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
from brainstate.mixin import BindCondData
|
22
|
-
from brainstate.typing import ArrayLike
|
23
|
-
from ._module import Module
|
24
|
-
|
25
|
-
__all__ = [
|
26
|
-
'SynOut', 'COBA', 'CUBA', 'MgBlock',
|
27
|
-
]
|
28
|
-
|
29
|
-
|
30
|
-
class SynOut(Module, BindCondData):
|
31
|
-
"""
|
32
|
-
Base class for synaptic outputs.
|
33
|
-
|
34
|
-
:py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
|
35
|
-
"""
|
36
|
-
|
37
|
-
__module__ = 'brainstate.nn'
|
38
|
-
|
39
|
-
def __init__(self, ):
|
40
|
-
super().__init__()
|
41
|
-
self._conductance = None
|
42
|
-
|
43
|
-
def __call__(self, *args, **kwargs):
|
44
|
-
if self._conductance is None:
|
45
|
-
raise ValueError(f'Please first pack conductance data at the current step using '
|
46
|
-
f'".{BindCondData.bind_cond.__name__}(data)". {self}')
|
47
|
-
ret = self.update(self._conductance, *args, **kwargs)
|
48
|
-
return ret
|
49
|
-
|
50
|
-
def update(self, conductance, potential):
|
51
|
-
raise NotImplementedError
|
52
|
-
|
53
|
-
|
54
|
-
class COBA(SynOut):
|
55
|
-
r"""
|
56
|
-
Conductance-based synaptic output.
|
57
|
-
|
58
|
-
Given the synaptic conductance, the model output the post-synaptic current with
|
59
|
-
|
60
|
-
.. math::
|
61
|
-
|
62
|
-
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
|
63
|
-
|
64
|
-
Parameters
|
65
|
-
----------
|
66
|
-
E: ArrayLike
|
67
|
-
The reversal potential.
|
68
|
-
|
69
|
-
See Also
|
70
|
-
--------
|
71
|
-
CUBA
|
72
|
-
"""
|
73
|
-
__module__ = 'brainstate.nn'
|
74
|
-
|
75
|
-
def __init__(self, E: ArrayLike):
|
76
|
-
super().__init__()
|
77
|
-
|
78
|
-
self.E = E
|
79
|
-
|
80
|
-
def update(self, conductance, potential):
|
81
|
-
return conductance * (self.E - potential)
|
82
|
-
|
83
|
-
|
84
|
-
class CUBA(SynOut):
|
85
|
-
r"""Current-based synaptic output.
|
86
|
-
|
87
|
-
Given the conductance, this model outputs the post-synaptic current with a identity function:
|
88
|
-
|
89
|
-
.. math::
|
90
|
-
|
91
|
-
I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
|
92
|
-
|
93
|
-
Parameters
|
94
|
-
----------
|
95
|
-
scale: ArrayLike
|
96
|
-
The scaling factor for the conductance. Default 1. [mV]
|
97
|
-
|
98
|
-
See Also
|
99
|
-
--------
|
100
|
-
COBA
|
101
|
-
"""
|
102
|
-
__module__ = 'brainstate.nn'
|
103
|
-
|
104
|
-
def __init__(self, scale: ArrayLike = u.volt):
|
105
|
-
super().__init__()
|
106
|
-
self.scale = scale
|
107
|
-
|
108
|
-
def update(self, conductance, potential=None):
|
109
|
-
return conductance * self.scale
|
110
|
-
|
111
|
-
|
112
|
-
class MgBlock(SynOut):
|
113
|
-
r"""Synaptic output based on Magnesium blocking.
|
114
|
-
|
115
|
-
Given the synaptic conductance, the model output the post-synaptic current with
|
116
|
-
|
117
|
-
.. math::
|
118
|
-
|
119
|
-
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
|
120
|
-
|
121
|
-
where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
|
122
|
-
|
123
|
-
.. math::
|
124
|
-
|
125
|
-
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
|
126
|
-
|
127
|
-
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
|
128
|
-
|
129
|
-
Parameters
|
130
|
-
----------
|
131
|
-
E: ArrayLike
|
132
|
-
The reversal potential for the synaptic current. [mV]
|
133
|
-
alpha: ArrayLike
|
134
|
-
Binding constant. Default 0.062
|
135
|
-
beta: ArrayLike
|
136
|
-
Unbinding constant. Default 3.57
|
137
|
-
cc_Mg: ArrayLike
|
138
|
-
Concentration of Magnesium ion. Default 1.2 [mM].
|
139
|
-
V_offset: ArrayLike
|
140
|
-
The offset potential. Default 0. [mV]
|
141
|
-
"""
|
142
|
-
__module__ = 'brainstate.nn'
|
143
|
-
|
144
|
-
def __init__(
|
145
|
-
self,
|
146
|
-
E: ArrayLike = 0.,
|
147
|
-
cc_Mg: ArrayLike = 1.2,
|
148
|
-
alpha: ArrayLike = 0.062,
|
149
|
-
beta: ArrayLike = 3.57,
|
150
|
-
V_offset: ArrayLike = 0.,
|
151
|
-
):
|
152
|
-
super().__init__()
|
153
|
-
|
154
|
-
self.E = E
|
155
|
-
self.V_offset = V_offset
|
156
|
-
self.cc_Mg = cc_Mg
|
157
|
-
self.alpha = alpha
|
158
|
-
self.beta = beta
|
159
|
-
|
160
|
-
def update(self, conductance, potential):
|
161
|
-
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
|
162
|
-
return conductance * (self.E - potential) / norm
|
brainstate/nn/_synouts_test.py
DELETED
@@ -1,57 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import brainunit as u
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import numpy as np
|
22
|
-
|
23
|
-
import brainstate
|
24
|
-
|
25
|
-
|
26
|
-
class TestSynOutModels(unittest.TestCase):
|
27
|
-
def setUp(self):
|
28
|
-
self.conductance = jnp.array([0.5, 1.0, 1.5])
|
29
|
-
self.potential = jnp.array([-70.0, -65.0, -60.0])
|
30
|
-
self.E = jnp.array([-70.0])
|
31
|
-
self.alpha = jnp.array([0.062])
|
32
|
-
self.beta = jnp.array([3.57])
|
33
|
-
self.cc_Mg = jnp.array([1.2])
|
34
|
-
self.V_offset = jnp.array([0.0])
|
35
|
-
|
36
|
-
def test_COBA(self):
|
37
|
-
model = brainstate.nn.COBA(E=self.E)
|
38
|
-
output = model.update(self.conductance, self.potential)
|
39
|
-
expected_output = self.conductance * (self.E - self.potential)
|
40
|
-
np.testing.assert_array_almost_equal(output, expected_output)
|
41
|
-
|
42
|
-
def test_CUBA(self):
|
43
|
-
model = brainstate.nn.CUBA()
|
44
|
-
output = model.update(self.conductance)
|
45
|
-
expected_output = self.conductance * model.scale
|
46
|
-
self.assertTrue(u.math.allclose(output, expected_output))
|
47
|
-
|
48
|
-
def test_MgBlock(self):
|
49
|
-
model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
|
50
|
-
output = model.update(self.conductance, self.potential)
|
51
|
-
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
|
52
|
-
expected_output = self.conductance * (self.E - self.potential) / norm
|
53
|
-
np.testing.assert_array_almost_equal(output, expected_output)
|
54
|
-
|
55
|
-
|
56
|
-
if __name__ == '__main__':
|
57
|
-
unittest.main()
|