brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,423 +0,0 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # -*- coding: utf-8 -*-
16
-
17
-
18
- from typing import Callable, Union, Tuple
19
-
20
- import brainunit as u
21
-
22
- from brainstate import init
23
- from brainstate._state import ParamState
24
- from brainstate.typing import ArrayLike
25
- from ._dynamics import Dynamics, Projection
26
-
27
- __all__ = [
28
- 'SymmetryGapJunction',
29
- 'AsymmetryGapJunction',
30
- ]
31
-
32
-
33
- class align_pre_ltp(Projection):
34
- pass
35
-
36
-
37
- class align_post_ltp(Projection):
38
- pass
39
-
40
-
41
- def get_gap_junction_post_key(i: int):
42
- return f'gap_junction_post_{i}'
43
-
44
-
45
- def get_gap_junction_pre_key(i: int):
46
- return f'gap_junction_pre_{i}'
47
-
48
-
49
- class SymmetryGapJunction(Projection):
50
- """
51
- Implements a symmetric electrical coupling (gap junction) between neuron populations.
52
-
53
- This class represents electrical synapses where the conductance is identical in both
54
- directions. Gap junctions allow bidirectional flow of electrical current directly between
55
- neurons, with the current magnitude proportional to the voltage difference between
56
- connected neurons.
57
-
58
- Parameters
59
- ----------
60
- couples : Union[Tuple[Dynamics, Dynamics], Dynamics]
61
- Either a single Dynamics object (when pre and post populations are the same)
62
- or a tuple of two Dynamics objects (pre, post) representing the coupled neuron populations.
63
- states : Union[str, Tuple[str, str]]
64
- Either a single string (when pre and post states are the same)
65
- or a tuple of two strings (pre_state, post_state) representing the state variables
66
- to use for calculating voltage differences (typically membrane potentials).
67
- conn : Callable
68
- Connection function that returns pre_ids and post_ids arrays defining connections.
69
- weight : Union[Callable, ArrayLike]
70
- Conductance weights for the gap junctions. The same weight applies in both directions
71
- of the connection.
72
- param_type : type, optional
73
- The parameter state type to use for weights, defaults to ParamState.
74
-
75
- Notes
76
- -----
77
- The symmetric gap junction applies identical conductance in both directions between
78
- connected neurons, ensuring balanced electrical coupling in the network.
79
-
80
- See Also
81
- --------
82
- AsymmetryGapJunction : For gap junctions with different conductances in each direction.
83
- """
84
-
85
- def __init__(
86
- self,
87
- couples: Union[Tuple[Dynamics, Dynamics], Dynamics],
88
- states: Union[str, Tuple[str, str]],
89
- conn: Callable,
90
- weight: Union[Callable, ArrayLike],
91
- param_type: type = ParamState
92
- ):
93
- super().__init__()
94
-
95
- if isinstance(states, str):
96
- pre_state = states
97
- post_state = states
98
- else:
99
- pre_state, post_state = states
100
- if isinstance(couples, Dynamics):
101
- pre = couples
102
- post = couples
103
- else:
104
- pre, post = couples
105
- assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
106
- assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
107
- assert isinstance(pre, Dynamics), "pre must be a Dynamics object"
108
- assert isinstance(post, Dynamics), "post must be a Dynamics object"
109
- self.pre_state = pre_state
110
- self.post_state = post_state
111
- self.pre = pre
112
- self.post = post
113
- self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
114
- self.weight = param_type(init.param(weight, (len(self.pre_ids),)))
115
-
116
- def update(self, *args, **kwargs):
117
- if not hasattr(self.pre, self.pre_state):
118
- raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
119
- if not hasattr(self.post, self.post_state):
120
- raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
121
- pre = getattr(self.pre, self.pre_state).value
122
- post = getattr(self.post, self.post_state).value
123
-
124
- return symmetry_gap_junction_projection(
125
- pre=self.pre,
126
- pre_value=pre,
127
- post=self.post,
128
- post_value=post,
129
- pre_ids=self.pre_ids,
130
- post_ids=self.post_ids,
131
- weight=self.weight.value,
132
- )
133
-
134
-
135
- def symmetry_gap_junction_projection(
136
- pre: Dynamics,
137
- pre_value: ArrayLike,
138
- post: Dynamics,
139
- post_value: ArrayLike,
140
- pre_ids: ArrayLike,
141
- post_ids: ArrayLike,
142
- weight: ArrayLike,
143
- ):
144
- """
145
- Calculate symmetrical electrical coupling through gap junctions between neurons.
146
-
147
- This function implements bidirectional gap junction coupling where the same weight is
148
- applied in both directions. It computes the electrical current flowing between
149
- connected neurons based on their potential differences and updates both pre-synaptic
150
- and post-synaptic neuron groups with appropriate input currents.
151
-
152
- Parameters
153
- ----------
154
- pre : Dynamics
155
- The pre-synaptic neuron group dynamics object.
156
- pre_value : ArrayLike
157
- State values (typically membrane potentials) of the pre-synaptic neurons.
158
- post : Dynamics
159
- The post-synaptic neuron group dynamics object.
160
- post_value : ArrayLike
161
- State values (typically membrane potentials) of the post-synaptic neurons.
162
- pre_ids : ArrayLike
163
- Indices of pre-synaptic neurons that form gap junctions.
164
- post_ids : ArrayLike
165
- Indices of post-synaptic neurons that form gap junctions,
166
- where each pre_ids[i] is connected to post_ids[i].
167
- weight : ArrayLike
168
- Conductance weights for the gap junctions. Can be a scalar (same weight for all connections)
169
- or an array with length matching pre_ids.
170
-
171
- Returns
172
- -------
173
- ArrayLike
174
- The input currents that were added to the pre-synaptic neuron group.
175
-
176
- Notes
177
- -----
178
- The electrical coupling is implemented as I = g(V_pre - V_post), where:
179
- - I is the current flowing from pre to post neuron
180
- - g is the gap junction conductance (weight)
181
- - V_pre and V_post are the membrane potentials of connected neurons
182
-
183
- Equal but opposite currents are applied to both connected neurons, ensuring
184
- conservation of current in the network.
185
-
186
- Raises
187
- ------
188
- AssertionError
189
- If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
190
- """
191
- assert u.math.ndim(weight) == 0 or weight.shape[0] == len(pre_ids), \
192
- "weight must be a scalar or have the same length as pre_ids"
193
- assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
194
- # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
195
- # and multiply by the connection weights
196
- diff = (pre_value[pre_ids] - post_value[post_ids]) * weight
197
-
198
- # add to post-synaptic neuron group
199
- # Initialize the input currents for the post-synaptic neuron group
200
- inputs = u.math.zeros(post.out_size, unit=u.get_unit(diff))
201
- # Add the calculated current to the corresponding post-synaptic neurons
202
- inputs = inputs.at[post_ids].add(diff)
203
- # Generate a unique key for the post-synaptic input currents
204
- key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
205
- # Add the input currents to the post-synaptic neuron group
206
- post.add_current_input(key, inputs)
207
-
208
- # add to pre-synaptic neuron group
209
- # Initialize the input currents for the pre-synaptic neuron group
210
- inputs = u.math.zeros(pre.out_size, unit=u.get_unit(diff))
211
- # Add the calculated current to the corresponding pre-synaptic neurons
212
- inputs = inputs.at[pre_ids].add(diff)
213
- # Generate a unique key for the pre-synaptic input currents
214
- key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
215
- # Add the input currents to the pre-synaptic neuron group with opposite polarity
216
- pre.add_current_input(key, -inputs)
217
- return inputs
218
-
219
-
220
- class AsymmetryGapJunction(Projection):
221
- """
222
- Implements an asymmetric electrical coupling (gap junction) between neuron populations.
223
-
224
- This class represents electrical synapses where the conductance in one direction can differ
225
- from the conductance in the opposite direction. Unlike chemical synapses, gap junctions
226
- allow bidirectional flow of electrical current directly between neurons, with the current
227
- magnitude proportional to the voltage difference between connected neurons.
228
-
229
- Parameters
230
- ----------
231
- pre : Dynamics
232
- The pre-synaptic neuron group dynamics object.
233
- pre_state : str
234
- Name of the state variable in pre-synaptic neurons (typically membrane potential).
235
- post : Dynamics
236
- The post-synaptic neuron group dynamics object.
237
- post_state : str
238
- Name of the state variable in post-synaptic neurons (typically membrane potential).
239
- conn : Callable
240
- Connection function that returns pre_ids and post_ids arrays defining connections.
241
- weight : Union[Callable, ArrayLike]
242
- Conductance weights for the gap junctions. Must have shape [..., 2] where the last
243
- dimension contains [pre_weight, post_weight] for each connection, defining
244
- potentially different conductances in each direction.
245
- param_type : type, optional
246
- The parameter state type to use for weights, defaults to ParamState.
247
-
248
- Examples
249
- --------
250
- >>> import brainstate
251
- >>> import brainunit as u
252
- >>> import numpy as np
253
- >>>
254
- >>> # Create two neuron populations
255
- >>> n_neurons = 100
256
- >>> pre_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
- >>> post_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
258
- >>> pre_pop.init_state()
259
- >>> post_pop.init_state()
260
- >>>
261
- >>> # Create asymmetric gap junction with different weights in each direction
262
- >>> weights = np.ones((n_neurons, 2)) * u.nS
263
- >>> weights[:, 0] *= 2.0 # Double weight in pre->post direction
264
- >>>
265
- >>> gap_junction = brainstate.nn.AsymmetryGapJunction(
266
- ... pre=pre_pop,
267
- ... pre_state='V',
268
- ... post=post_pop,
269
- ... post_state='V',
270
- ... conn=one_to_one,
271
- ... weight=weights
272
- ... )
273
-
274
- Notes
275
- -----
276
- The asymmetric gap junction allows for different conductances in each direction between
277
- the same pair of neurons. This can model rectifying electrical synapses that preferentially
278
- allow current to flow in one direction.
279
-
280
- See Also
281
- --------
282
- SymmetryGapJunction : For gap junctions with identical conductance in both directions.
283
- """
284
-
285
- def __init__(
286
- self,
287
- pre: Dynamics,
288
- pre_state: str,
289
- post: Dynamics,
290
- post_state: str,
291
- conn: Callable,
292
- weight: Union[Callable, ArrayLike],
293
- param_type: type = ParamState
294
- ):
295
- super().__init__()
296
-
297
- assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
298
- assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
299
- self.pre_state = pre_state
300
- self.post_state = post_state
301
- self.pre = pre
302
- self.post = post
303
- self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
304
- self.weight = param_type(init.param(weight, (len(self.pre_ids), 2)))
305
-
306
- def update(self, *args, **kwargs):
307
- if not hasattr(self.pre, self.pre_state):
308
- raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
309
- if not hasattr(self.post, self.post_state):
310
- raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
311
- pre = getattr(self.pre, self.pre_state).value
312
- post = getattr(self.post, self.post_state).value
313
-
314
- return asymmetry_gap_junction_projection(
315
- pre=self.pre,
316
- pre_value=pre,
317
- post=self.post,
318
- post_value=post,
319
- pre_ids=self.pre_ids,
320
- post_ids=self.post_ids,
321
- weight=self.weight.value,
322
- )
323
-
324
-
325
- def asymmetry_gap_junction_projection(
326
- pre: Dynamics,
327
- pre_value: ArrayLike,
328
- post: Dynamics,
329
- post_value: ArrayLike,
330
- pre_ids: ArrayLike,
331
- post_ids: ArrayLike,
332
- weight: ArrayLike,
333
- ):
334
- """
335
- Calculate asymmetrical electrical coupling through gap junctions between neurons.
336
-
337
- This function implements bidirectional gap junction coupling where different weights
338
- can be applied in each direction. It computes the electrical current flowing between
339
- connected neurons based on their potential differences and updates both pre-synaptic
340
- and post-synaptic neuron groups with appropriate input currents.
341
-
342
- Parameters
343
- ----------
344
- pre : Dynamics
345
- The pre-synaptic neuron group dynamics object.
346
- pre_value : ArrayLike
347
- State values (typically membrane potentials) of the pre-synaptic neurons.
348
- post : Dynamics
349
- The post-synaptic neuron group dynamics object.
350
- post_value : ArrayLike
351
- State values (typically membrane potentials) of the post-synaptic neurons.
352
- pre_ids : ArrayLike
353
- Indices of pre-synaptic neurons that form gap junctions.
354
- post_ids : ArrayLike
355
- Indices of post-synaptic neurons that form gap junctions,
356
- where each pre_ids[i] is connected to post_ids[i].
357
- weight : ArrayLike
358
- Conductance weights for the gap junctions. Must have shape [..., 2], where
359
- the last dimension contains [pre_weight, post_weight] for each connection.
360
- Can be a 1D array [pre_weight, post_weight] (same weights for all connections)
361
- or a 2D array with shape [len(pre_ids), 2] for connection-specific weights.
362
-
363
- Returns
364
- -------
365
- ArrayLike
366
- The input currents that were added to the pre-synaptic neuron group.
367
-
368
- Notes
369
- -----
370
- The electrical coupling is implemented with direction-specific conductances:
371
- - I_pre2post = g_pre * (V_pre - V_post) flowing from pre to post neuron
372
- - I_post2pre = g_post * (V_pre - V_post) flowing from post to pre neuron
373
- where g_pre and g_post can be different, allowing for asymmetrical coupling.
374
-
375
- Raises
376
- ------
377
- AssertionError
378
- If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
379
- ValueError
380
- If weight shape is incompatible with asymmetrical gap junction requirements.
381
- """
382
- assert weight.shape[-1] == 2, 'weight must be a 2-element array for asymmetry gap junctions'
383
- assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
384
- if u.math.ndim(weight) == 1:
385
- # If weight is a 1D array, it should have two elements for pre and post weights
386
- assert weight.shape[0] == 2, "weight must be a 2-element array for asymmetry gap junctions"
387
- pre_weight = weight[0]
388
- post_weight = weight[1]
389
- elif u.math.ndim(weight) == 2:
390
- # If weight is a 2D array, it should have two rows for pre and post weights
391
- pre_weight = weight[:, 0]
392
- post_weight = weight[:, 1]
393
- assert pre_weight.shape[0] == len(pre_ids), "pre_weight must have the same length as pre_ids"
394
- assert post_weight.shape[0] == len(post_ids), "post_weight must have the same length as post_ids"
395
- else:
396
- raise ValueError("weight must be a 1D or 2D array for asymmetry gap junctions")
397
-
398
- # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
399
- # and multiply by the connection weights
400
- diff = pre_value[pre_ids] - post_value[post_ids]
401
- pre2post_current = diff * pre_weight
402
- post2pre_current = diff * post_weight
403
-
404
- # add to post-synaptic neuron group
405
- # Initialize the input currents for the post-synaptic neuron group
406
- inputs = u.math.zeros(post.out_size, unit=u.get_unit(pre2post_current))
407
- # Add the calculated current to the corresponding post-synaptic neurons
408
- inputs = inputs.at[post_ids].add(pre2post_current)
409
- # Generate a unique key for the post-synaptic input currents
410
- key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
411
- # Add the input currents to the post-synaptic neuron group
412
- post.add_current_input(key, inputs)
413
-
414
- # add to pre-synaptic neuron group
415
- # Initialize the input currents for the pre-synaptic neuron group
416
- inputs = u.math.zeros(pre.out_size, unit=u.get_unit(post2pre_current))
417
- # Add the calculated current to the corresponding pre-synaptic neurons
418
- inputs = inputs.at[pre_ids].add(post2pre_current)
419
- # Generate a unique key for the pre-synaptic input currents
420
- key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
421
- # Add the input currents to the pre-synaptic neuron group with opposite polarity
422
- pre.add_current_input(key, -inputs)
423
- return inputs
brainstate/nn/_synouts.py DELETED
@@ -1,162 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import brainunit as u
19
- import jax.numpy as jnp
20
-
21
- from brainstate.mixin import BindCondData
22
- from brainstate.typing import ArrayLike
23
- from ._module import Module
24
-
25
- __all__ = [
26
- 'SynOut', 'COBA', 'CUBA', 'MgBlock',
27
- ]
28
-
29
-
30
- class SynOut(Module, BindCondData):
31
- """
32
- Base class for synaptic outputs.
33
-
34
- :py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
35
- """
36
-
37
- __module__ = 'brainstate.nn'
38
-
39
- def __init__(self, ):
40
- super().__init__()
41
- self._conductance = None
42
-
43
- def __call__(self, *args, **kwargs):
44
- if self._conductance is None:
45
- raise ValueError(f'Please first pack conductance data at the current step using '
46
- f'".{BindCondData.bind_cond.__name__}(data)". {self}')
47
- ret = self.update(self._conductance, *args, **kwargs)
48
- return ret
49
-
50
- def update(self, conductance, potential):
51
- raise NotImplementedError
52
-
53
-
54
- class COBA(SynOut):
55
- r"""
56
- Conductance-based synaptic output.
57
-
58
- Given the synaptic conductance, the model output the post-synaptic current with
59
-
60
- .. math::
61
-
62
- I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
63
-
64
- Parameters
65
- ----------
66
- E: ArrayLike
67
- The reversal potential.
68
-
69
- See Also
70
- --------
71
- CUBA
72
- """
73
- __module__ = 'brainstate.nn'
74
-
75
- def __init__(self, E: ArrayLike):
76
- super().__init__()
77
-
78
- self.E = E
79
-
80
- def update(self, conductance, potential):
81
- return conductance * (self.E - potential)
82
-
83
-
84
- class CUBA(SynOut):
85
- r"""Current-based synaptic output.
86
-
87
- Given the conductance, this model outputs the post-synaptic current with a identity function:
88
-
89
- .. math::
90
-
91
- I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
92
-
93
- Parameters
94
- ----------
95
- scale: ArrayLike
96
- The scaling factor for the conductance. Default 1. [mV]
97
-
98
- See Also
99
- --------
100
- COBA
101
- """
102
- __module__ = 'brainstate.nn'
103
-
104
- def __init__(self, scale: ArrayLike = u.volt):
105
- super().__init__()
106
- self.scale = scale
107
-
108
- def update(self, conductance, potential=None):
109
- return conductance * self.scale
110
-
111
-
112
- class MgBlock(SynOut):
113
- r"""Synaptic output based on Magnesium blocking.
114
-
115
- Given the synaptic conductance, the model output the post-synaptic current with
116
-
117
- .. math::
118
-
119
- I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
120
-
121
- where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
122
-
123
- .. math::
124
-
125
- g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
126
-
127
- Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
128
-
129
- Parameters
130
- ----------
131
- E: ArrayLike
132
- The reversal potential for the synaptic current. [mV]
133
- alpha: ArrayLike
134
- Binding constant. Default 0.062
135
- beta: ArrayLike
136
- Unbinding constant. Default 3.57
137
- cc_Mg: ArrayLike
138
- Concentration of Magnesium ion. Default 1.2 [mM].
139
- V_offset: ArrayLike
140
- The offset potential. Default 0. [mV]
141
- """
142
- __module__ = 'brainstate.nn'
143
-
144
- def __init__(
145
- self,
146
- E: ArrayLike = 0.,
147
- cc_Mg: ArrayLike = 1.2,
148
- alpha: ArrayLike = 0.062,
149
- beta: ArrayLike = 3.57,
150
- V_offset: ArrayLike = 0.,
151
- ):
152
- super().__init__()
153
-
154
- self.E = E
155
- self.V_offset = V_offset
156
- self.cc_Mg = cc_Mg
157
- self.alpha = alpha
158
- self.beta = beta
159
-
160
- def update(self, conductance, potential):
161
- norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
162
- return conductance * (self.E - potential) / norm
@@ -1,57 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import brainunit as u
20
- import jax.numpy as jnp
21
- import numpy as np
22
-
23
- import brainstate
24
-
25
-
26
- class TestSynOutModels(unittest.TestCase):
27
- def setUp(self):
28
- self.conductance = jnp.array([0.5, 1.0, 1.5])
29
- self.potential = jnp.array([-70.0, -65.0, -60.0])
30
- self.E = jnp.array([-70.0])
31
- self.alpha = jnp.array([0.062])
32
- self.beta = jnp.array([3.57])
33
- self.cc_Mg = jnp.array([1.2])
34
- self.V_offset = jnp.array([0.0])
35
-
36
- def test_COBA(self):
37
- model = brainstate.nn.COBA(E=self.E)
38
- output = model.update(self.conductance, self.potential)
39
- expected_output = self.conductance * (self.E - self.potential)
40
- np.testing.assert_array_almost_equal(output, expected_output)
41
-
42
- def test_CUBA(self):
43
- model = brainstate.nn.CUBA()
44
- output = model.update(self.conductance)
45
- expected_output = self.conductance * model.scale
46
- self.assertTrue(u.math.allclose(output, expected_output))
47
-
48
- def test_MgBlock(self):
49
- model = brainstate.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
50
- output = model.update(self.conductance, self.potential)
51
- norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
52
- expected_output = self.conductance * (self.E - self.potential) / norm
53
- np.testing.assert_array_almost_equal(output, expected_output)
54
-
55
-
56
- if __name__ == '__main__':
57
- unittest.main()