brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,423 +1,423 @@
1
- # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # -*- coding: utf-8 -*-
16
-
17
-
18
- from typing import Callable, Union, Tuple
19
-
20
- import brainunit as u
21
-
22
- from brainstate import init
23
- from brainstate._state import ParamState
24
- from brainstate.typing import ArrayLike
25
- from ._dynamics import Dynamics, Projection
26
-
27
- __all__ = [
28
- 'SymmetryGapJunction',
29
- 'AsymmetryGapJunction',
30
- ]
31
-
32
-
33
- class align_pre_ltp(Projection):
34
- pass
35
-
36
-
37
- class align_post_ltp(Projection):
38
- pass
39
-
40
-
41
- def get_gap_junction_post_key(i: int):
42
- return f'gap_junction_post_{i}'
43
-
44
-
45
- def get_gap_junction_pre_key(i: int):
46
- return f'gap_junction_pre_{i}'
47
-
48
-
49
- class SymmetryGapJunction(Projection):
50
- """
51
- Implements a symmetric electrical coupling (gap junction) between neuron populations.
52
-
53
- This class represents electrical synapses where the conductance is identical in both
54
- directions. Gap junctions allow bidirectional flow of electrical current directly between
55
- neurons, with the current magnitude proportional to the voltage difference between
56
- connected neurons.
57
-
58
- Parameters
59
- ----------
60
- couples : Union[Tuple[Dynamics, Dynamics], Dynamics]
61
- Either a single Dynamics object (when pre and post populations are the same)
62
- or a tuple of two Dynamics objects (pre, post) representing the coupled neuron populations.
63
- states : Union[str, Tuple[str, str]]
64
- Either a single string (when pre and post states are the same)
65
- or a tuple of two strings (pre_state, post_state) representing the state variables
66
- to use for calculating voltage differences (typically membrane potentials).
67
- conn : Callable
68
- Connection function that returns pre_ids and post_ids arrays defining connections.
69
- weight : Union[Callable, ArrayLike]
70
- Conductance weights for the gap junctions. The same weight applies in both directions
71
- of the connection.
72
- param_type : type, optional
73
- The parameter state type to use for weights, defaults to ParamState.
74
-
75
- Notes
76
- -----
77
- The symmetric gap junction applies identical conductance in both directions between
78
- connected neurons, ensuring balanced electrical coupling in the network.
79
-
80
- See Also
81
- --------
82
- AsymmetryGapJunction : For gap junctions with different conductances in each direction.
83
- """
84
-
85
- def __init__(
86
- self,
87
- couples: Union[Tuple[Dynamics, Dynamics], Dynamics],
88
- states: Union[str, Tuple[str, str]],
89
- conn: Callable,
90
- weight: Union[Callable, ArrayLike],
91
- param_type: type = ParamState
92
- ):
93
- super().__init__()
94
-
95
- if isinstance(states, str):
96
- pre_state = states
97
- post_state = states
98
- else:
99
- pre_state, post_state = states
100
- if isinstance(couples, Dynamics):
101
- pre = couples
102
- post = couples
103
- else:
104
- pre, post = couples
105
- assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
106
- assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
107
- assert isinstance(pre, Dynamics), "pre must be a Dynamics object"
108
- assert isinstance(post, Dynamics), "post must be a Dynamics object"
109
- self.pre_state = pre_state
110
- self.post_state = post_state
111
- self.pre = pre
112
- self.post = post
113
- self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
114
- self.weight = param_type(init.param(weight, (len(self.pre_ids),)))
115
-
116
- def update(self, *args, **kwargs):
117
- if not hasattr(self.pre, self.pre_state):
118
- raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
119
- if not hasattr(self.post, self.post_state):
120
- raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
121
- pre = getattr(self.pre, self.pre_state).value
122
- post = getattr(self.post, self.post_state).value
123
-
124
- return symmetry_gap_junction_projection(
125
- pre=self.pre,
126
- pre_value=pre,
127
- post=self.post,
128
- post_value=post,
129
- pre_ids=self.pre_ids,
130
- post_ids=self.post_ids,
131
- weight=self.weight.value,
132
- )
133
-
134
-
135
- def symmetry_gap_junction_projection(
136
- pre: Dynamics,
137
- pre_value: ArrayLike,
138
- post: Dynamics,
139
- post_value: ArrayLike,
140
- pre_ids: ArrayLike,
141
- post_ids: ArrayLike,
142
- weight: ArrayLike,
143
- ):
144
- """
145
- Calculate symmetrical electrical coupling through gap junctions between neurons.
146
-
147
- This function implements bidirectional gap junction coupling where the same weight is
148
- applied in both directions. It computes the electrical current flowing between
149
- connected neurons based on their potential differences and updates both pre-synaptic
150
- and post-synaptic neuron groups with appropriate input currents.
151
-
152
- Parameters
153
- ----------
154
- pre : Dynamics
155
- The pre-synaptic neuron group dynamics object.
156
- pre_value : ArrayLike
157
- State values (typically membrane potentials) of the pre-synaptic neurons.
158
- post : Dynamics
159
- The post-synaptic neuron group dynamics object.
160
- post_value : ArrayLike
161
- State values (typically membrane potentials) of the post-synaptic neurons.
162
- pre_ids : ArrayLike
163
- Indices of pre-synaptic neurons that form gap junctions.
164
- post_ids : ArrayLike
165
- Indices of post-synaptic neurons that form gap junctions,
166
- where each pre_ids[i] is connected to post_ids[i].
167
- weight : ArrayLike
168
- Conductance weights for the gap junctions. Can be a scalar (same weight for all connections)
169
- or an array with length matching pre_ids.
170
-
171
- Returns
172
- -------
173
- ArrayLike
174
- The input currents that were added to the pre-synaptic neuron group.
175
-
176
- Notes
177
- -----
178
- The electrical coupling is implemented as I = g(V_pre - V_post), where:
179
- - I is the current flowing from pre to post neuron
180
- - g is the gap junction conductance (weight)
181
- - V_pre and V_post are the membrane potentials of connected neurons
182
-
183
- Equal but opposite currents are applied to both connected neurons, ensuring
184
- conservation of current in the network.
185
-
186
- Raises
187
- ------
188
- AssertionError
189
- If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
190
- """
191
- assert u.math.ndim(weight) == 0 or weight.shape[0] == len(pre_ids), \
192
- "weight must be a scalar or have the same length as pre_ids"
193
- assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
194
- # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
195
- # and multiply by the connection weights
196
- diff = (pre_value[pre_ids] - post_value[post_ids]) * weight
197
-
198
- # add to post-synaptic neuron group
199
- # Initialize the input currents for the post-synaptic neuron group
200
- inputs = u.math.zeros(post.out_size, unit=u.get_unit(diff))
201
- # Add the calculated current to the corresponding post-synaptic neurons
202
- inputs = inputs.at[post_ids].add(diff)
203
- # Generate a unique key for the post-synaptic input currents
204
- key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
205
- # Add the input currents to the post-synaptic neuron group
206
- post.add_current_input(key, inputs)
207
-
208
- # add to pre-synaptic neuron group
209
- # Initialize the input currents for the pre-synaptic neuron group
210
- inputs = u.math.zeros(pre.out_size, unit=u.get_unit(diff))
211
- # Add the calculated current to the corresponding pre-synaptic neurons
212
- inputs = inputs.at[pre_ids].add(diff)
213
- # Generate a unique key for the pre-synaptic input currents
214
- key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
215
- # Add the input currents to the pre-synaptic neuron group with opposite polarity
216
- pre.add_current_input(key, -inputs)
217
- return inputs
218
-
219
-
220
- class AsymmetryGapJunction(Projection):
221
- """
222
- Implements an asymmetric electrical coupling (gap junction) between neuron populations.
223
-
224
- This class represents electrical synapses where the conductance in one direction can differ
225
- from the conductance in the opposite direction. Unlike chemical synapses, gap junctions
226
- allow bidirectional flow of electrical current directly between neurons, with the current
227
- magnitude proportional to the voltage difference between connected neurons.
228
-
229
- Parameters
230
- ----------
231
- pre : Dynamics
232
- The pre-synaptic neuron group dynamics object.
233
- pre_state : str
234
- Name of the state variable in pre-synaptic neurons (typically membrane potential).
235
- post : Dynamics
236
- The post-synaptic neuron group dynamics object.
237
- post_state : str
238
- Name of the state variable in post-synaptic neurons (typically membrane potential).
239
- conn : Callable
240
- Connection function that returns pre_ids and post_ids arrays defining connections.
241
- weight : Union[Callable, ArrayLike]
242
- Conductance weights for the gap junctions. Must have shape [..., 2] where the last
243
- dimension contains [pre_weight, post_weight] for each connection, defining
244
- potentially different conductances in each direction.
245
- param_type : type, optional
246
- The parameter state type to use for weights, defaults to ParamState.
247
-
248
- Examples
249
- --------
250
- >>> import brainstate
251
- >>> import brainunit as u
252
- >>> import numpy as np
253
- >>>
254
- >>> # Create two neuron populations
255
- >>> n_neurons = 100
256
- >>> pre_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
- >>> post_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
258
- >>> pre_pop.init_state()
259
- >>> post_pop.init_state()
260
- >>>
261
- >>> # Create asymmetric gap junction with different weights in each direction
262
- >>> weights = np.ones((n_neurons, 2)) * u.nS
263
- >>> weights[:, 0] *= 2.0 # Double weight in pre->post direction
264
- >>>
265
- >>> gap_junction = brainstate.nn.AsymmetryGapJunction(
266
- ... pre=pre_pop,
267
- ... pre_state='V',
268
- ... post=post_pop,
269
- ... post_state='V',
270
- ... conn=one_to_one,
271
- ... weight=weights
272
- ... )
273
-
274
- Notes
275
- -----
276
- The asymmetric gap junction allows for different conductances in each direction between
277
- the same pair of neurons. This can model rectifying electrical synapses that preferentially
278
- allow current to flow in one direction.
279
-
280
- See Also
281
- --------
282
- SymmetryGapJunction : For gap junctions with identical conductance in both directions.
283
- """
284
-
285
- def __init__(
286
- self,
287
- pre: Dynamics,
288
- pre_state: str,
289
- post: Dynamics,
290
- post_state: str,
291
- conn: Callable,
292
- weight: Union[Callable, ArrayLike],
293
- param_type: type = ParamState
294
- ):
295
- super().__init__()
296
-
297
- assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
298
- assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
299
- self.pre_state = pre_state
300
- self.post_state = post_state
301
- self.pre = pre
302
- self.post = post
303
- self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
304
- self.weight = param_type(init.param(weight, (len(self.pre_ids), 2)))
305
-
306
- def update(self, *args, **kwargs):
307
- if not hasattr(self.pre, self.pre_state):
308
- raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
309
- if not hasattr(self.post, self.post_state):
310
- raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
311
- pre = getattr(self.pre, self.pre_state).value
312
- post = getattr(self.post, self.post_state).value
313
-
314
- return asymmetry_gap_junction_projection(
315
- pre=self.pre,
316
- pre_value=pre,
317
- post=self.post,
318
- post_value=post,
319
- pre_ids=self.pre_ids,
320
- post_ids=self.post_ids,
321
- weight=self.weight.value,
322
- )
323
-
324
-
325
- def asymmetry_gap_junction_projection(
326
- pre: Dynamics,
327
- pre_value: ArrayLike,
328
- post: Dynamics,
329
- post_value: ArrayLike,
330
- pre_ids: ArrayLike,
331
- post_ids: ArrayLike,
332
- weight: ArrayLike,
333
- ):
334
- """
335
- Calculate asymmetrical electrical coupling through gap junctions between neurons.
336
-
337
- This function implements bidirectional gap junction coupling where different weights
338
- can be applied in each direction. It computes the electrical current flowing between
339
- connected neurons based on their potential differences and updates both pre-synaptic
340
- and post-synaptic neuron groups with appropriate input currents.
341
-
342
- Parameters
343
- ----------
344
- pre : Dynamics
345
- The pre-synaptic neuron group dynamics object.
346
- pre_value : ArrayLike
347
- State values (typically membrane potentials) of the pre-synaptic neurons.
348
- post : Dynamics
349
- The post-synaptic neuron group dynamics object.
350
- post_value : ArrayLike
351
- State values (typically membrane potentials) of the post-synaptic neurons.
352
- pre_ids : ArrayLike
353
- Indices of pre-synaptic neurons that form gap junctions.
354
- post_ids : ArrayLike
355
- Indices of post-synaptic neurons that form gap junctions,
356
- where each pre_ids[i] is connected to post_ids[i].
357
- weight : ArrayLike
358
- Conductance weights for the gap junctions. Must have shape [..., 2], where
359
- the last dimension contains [pre_weight, post_weight] for each connection.
360
- Can be a 1D array [pre_weight, post_weight] (same weights for all connections)
361
- or a 2D array with shape [len(pre_ids), 2] for connection-specific weights.
362
-
363
- Returns
364
- -------
365
- ArrayLike
366
- The input currents that were added to the pre-synaptic neuron group.
367
-
368
- Notes
369
- -----
370
- The electrical coupling is implemented with direction-specific conductances:
371
- - I_pre2post = g_pre * (V_pre - V_post) flowing from pre to post neuron
372
- - I_post2pre = g_post * (V_pre - V_post) flowing from post to pre neuron
373
- where g_pre and g_post can be different, allowing for asymmetrical coupling.
374
-
375
- Raises
376
- ------
377
- AssertionError
378
- If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
379
- ValueError
380
- If weight shape is incompatible with asymmetrical gap junction requirements.
381
- """
382
- assert weight.shape[-1] == 2, 'weight must be a 2-element array for asymmetry gap junctions'
383
- assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
384
- if u.math.ndim(weight) == 1:
385
- # If weight is a 1D array, it should have two elements for pre and post weights
386
- assert weight.shape[0] == 2, "weight must be a 2-element array for asymmetry gap junctions"
387
- pre_weight = weight[0]
388
- post_weight = weight[1]
389
- elif u.math.ndim(weight) == 2:
390
- # If weight is a 2D array, it should have two rows for pre and post weights
391
- pre_weight = weight[:, 0]
392
- post_weight = weight[:, 1]
393
- assert pre_weight.shape[0] == len(pre_ids), "pre_weight must have the same length as pre_ids"
394
- assert post_weight.shape[0] == len(post_ids), "post_weight must have the same length as post_ids"
395
- else:
396
- raise ValueError("weight must be a 1D or 2D array for asymmetry gap junctions")
397
-
398
- # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
399
- # and multiply by the connection weights
400
- diff = pre_value[pre_ids] - post_value[post_ids]
401
- pre2post_current = diff * pre_weight
402
- post2pre_current = diff * post_weight
403
-
404
- # add to post-synaptic neuron group
405
- # Initialize the input currents for the post-synaptic neuron group
406
- inputs = u.math.zeros(post.out_size, unit=u.get_unit(pre2post_current))
407
- # Add the calculated current to the corresponding post-synaptic neurons
408
- inputs = inputs.at[post_ids].add(pre2post_current)
409
- # Generate a unique key for the post-synaptic input currents
410
- key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
411
- # Add the input currents to the post-synaptic neuron group
412
- post.add_current_input(key, inputs)
413
-
414
- # add to pre-synaptic neuron group
415
- # Initialize the input currents for the pre-synaptic neuron group
416
- inputs = u.math.zeros(pre.out_size, unit=u.get_unit(post2pre_current))
417
- # Add the calculated current to the corresponding pre-synaptic neurons
418
- inputs = inputs.at[pre_ids].add(post2pre_current)
419
- # Generate a unique key for the pre-synaptic input currents
420
- key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
421
- # Add the input currents to the pre-synaptic neuron group with opposite polarity
422
- pre.add_current_input(key, -inputs)
423
- return inputs
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # -*- coding: utf-8 -*-
16
+
17
+
18
+ from typing import Callable, Union, Tuple
19
+
20
+ import brainunit as u
21
+
22
+ from brainstate import init
23
+ from brainstate._state import ParamState
24
+ from brainstate.typing import ArrayLike
25
+ from ._dynamics import Dynamics, Projection
26
+
27
+ __all__ = [
28
+ 'SymmetryGapJunction',
29
+ 'AsymmetryGapJunction',
30
+ ]
31
+
32
+
33
+ class align_pre_ltp(Projection):
34
+ pass
35
+
36
+
37
+ class align_post_ltp(Projection):
38
+ pass
39
+
40
+
41
+ def get_gap_junction_post_key(i: int):
42
+ return f'gap_junction_post_{i}'
43
+
44
+
45
+ def get_gap_junction_pre_key(i: int):
46
+ return f'gap_junction_pre_{i}'
47
+
48
+
49
+ class SymmetryGapJunction(Projection):
50
+ """
51
+ Implements a symmetric electrical coupling (gap junction) between neuron populations.
52
+
53
+ This class represents electrical synapses where the conductance is identical in both
54
+ directions. Gap junctions allow bidirectional flow of electrical current directly between
55
+ neurons, with the current magnitude proportional to the voltage difference between
56
+ connected neurons.
57
+
58
+ Parameters
59
+ ----------
60
+ couples : Union[Tuple[Dynamics, Dynamics], Dynamics]
61
+ Either a single Dynamics object (when pre and post populations are the same)
62
+ or a tuple of two Dynamics objects (pre, post) representing the coupled neuron populations.
63
+ states : Union[str, Tuple[str, str]]
64
+ Either a single string (when pre and post states are the same)
65
+ or a tuple of two strings (pre_state, post_state) representing the state variables
66
+ to use for calculating voltage differences (typically membrane potentials).
67
+ conn : Callable
68
+ Connection function that returns pre_ids and post_ids arrays defining connections.
69
+ weight : Union[Callable, ArrayLike]
70
+ Conductance weights for the gap junctions. The same weight applies in both directions
71
+ of the connection.
72
+ param_type : type, optional
73
+ The parameter state type to use for weights, defaults to ParamState.
74
+
75
+ Notes
76
+ -----
77
+ The symmetric gap junction applies identical conductance in both directions between
78
+ connected neurons, ensuring balanced electrical coupling in the network.
79
+
80
+ See Also
81
+ --------
82
+ AsymmetryGapJunction : For gap junctions with different conductances in each direction.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ couples: Union[Tuple[Dynamics, Dynamics], Dynamics],
88
+ states: Union[str, Tuple[str, str]],
89
+ conn: Callable,
90
+ weight: Union[Callable, ArrayLike],
91
+ param_type: type = ParamState
92
+ ):
93
+ super().__init__()
94
+
95
+ if isinstance(states, str):
96
+ pre_state = states
97
+ post_state = states
98
+ else:
99
+ pre_state, post_state = states
100
+ if isinstance(couples, Dynamics):
101
+ pre = couples
102
+ post = couples
103
+ else:
104
+ pre, post = couples
105
+ assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
106
+ assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
107
+ assert isinstance(pre, Dynamics), "pre must be a Dynamics object"
108
+ assert isinstance(post, Dynamics), "post must be a Dynamics object"
109
+ self.pre_state = pre_state
110
+ self.post_state = post_state
111
+ self.pre = pre
112
+ self.post = post
113
+ self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
114
+ self.weight = param_type(init.param(weight, (len(self.pre_ids),)))
115
+
116
+ def update(self, *args, **kwargs):
117
+ if not hasattr(self.pre, self.pre_state):
118
+ raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
119
+ if not hasattr(self.post, self.post_state):
120
+ raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
121
+ pre = getattr(self.pre, self.pre_state).value
122
+ post = getattr(self.post, self.post_state).value
123
+
124
+ return symmetry_gap_junction_projection(
125
+ pre=self.pre,
126
+ pre_value=pre,
127
+ post=self.post,
128
+ post_value=post,
129
+ pre_ids=self.pre_ids,
130
+ post_ids=self.post_ids,
131
+ weight=self.weight.value,
132
+ )
133
+
134
+
135
+ def symmetry_gap_junction_projection(
136
+ pre: Dynamics,
137
+ pre_value: ArrayLike,
138
+ post: Dynamics,
139
+ post_value: ArrayLike,
140
+ pre_ids: ArrayLike,
141
+ post_ids: ArrayLike,
142
+ weight: ArrayLike,
143
+ ):
144
+ """
145
+ Calculate symmetrical electrical coupling through gap junctions between neurons.
146
+
147
+ This function implements bidirectional gap junction coupling where the same weight is
148
+ applied in both directions. It computes the electrical current flowing between
149
+ connected neurons based on their potential differences and updates both pre-synaptic
150
+ and post-synaptic neuron groups with appropriate input currents.
151
+
152
+ Parameters
153
+ ----------
154
+ pre : Dynamics
155
+ The pre-synaptic neuron group dynamics object.
156
+ pre_value : ArrayLike
157
+ State values (typically membrane potentials) of the pre-synaptic neurons.
158
+ post : Dynamics
159
+ The post-synaptic neuron group dynamics object.
160
+ post_value : ArrayLike
161
+ State values (typically membrane potentials) of the post-synaptic neurons.
162
+ pre_ids : ArrayLike
163
+ Indices of pre-synaptic neurons that form gap junctions.
164
+ post_ids : ArrayLike
165
+ Indices of post-synaptic neurons that form gap junctions,
166
+ where each pre_ids[i] is connected to post_ids[i].
167
+ weight : ArrayLike
168
+ Conductance weights for the gap junctions. Can be a scalar (same weight for all connections)
169
+ or an array with length matching pre_ids.
170
+
171
+ Returns
172
+ -------
173
+ ArrayLike
174
+ The input currents that were added to the pre-synaptic neuron group.
175
+
176
+ Notes
177
+ -----
178
+ The electrical coupling is implemented as I = g(V_pre - V_post), where:
179
+ - I is the current flowing from pre to post neuron
180
+ - g is the gap junction conductance (weight)
181
+ - V_pre and V_post are the membrane potentials of connected neurons
182
+
183
+ Equal but opposite currents are applied to both connected neurons, ensuring
184
+ conservation of current in the network.
185
+
186
+ Raises
187
+ ------
188
+ AssertionError
189
+ If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
190
+ """
191
+ assert u.math.ndim(weight) == 0 or weight.shape[0] == len(pre_ids), \
192
+ "weight must be a scalar or have the same length as pre_ids"
193
+ assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
194
+ # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
195
+ # and multiply by the connection weights
196
+ diff = (pre_value[pre_ids] - post_value[post_ids]) * weight
197
+
198
+ # add to post-synaptic neuron group
199
+ # Initialize the input currents for the post-synaptic neuron group
200
+ inputs = u.math.zeros(post.out_size, unit=u.get_unit(diff))
201
+ # Add the calculated current to the corresponding post-synaptic neurons
202
+ inputs = inputs.at[post_ids].add(diff)
203
+ # Generate a unique key for the post-synaptic input currents
204
+ key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
205
+ # Add the input currents to the post-synaptic neuron group
206
+ post.add_current_input(key, inputs)
207
+
208
+ # add to pre-synaptic neuron group
209
+ # Initialize the input currents for the pre-synaptic neuron group
210
+ inputs = u.math.zeros(pre.out_size, unit=u.get_unit(diff))
211
+ # Add the calculated current to the corresponding pre-synaptic neurons
212
+ inputs = inputs.at[pre_ids].add(diff)
213
+ # Generate a unique key for the pre-synaptic input currents
214
+ key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
215
+ # Add the input currents to the pre-synaptic neuron group with opposite polarity
216
+ pre.add_current_input(key, -inputs)
217
+ return inputs
218
+
219
+
220
+ class AsymmetryGapJunction(Projection):
221
+ """
222
+ Implements an asymmetric electrical coupling (gap junction) between neuron populations.
223
+
224
+ This class represents electrical synapses where the conductance in one direction can differ
225
+ from the conductance in the opposite direction. Unlike chemical synapses, gap junctions
226
+ allow bidirectional flow of electrical current directly between neurons, with the current
227
+ magnitude proportional to the voltage difference between connected neurons.
228
+
229
+ Parameters
230
+ ----------
231
+ pre : Dynamics
232
+ The pre-synaptic neuron group dynamics object.
233
+ pre_state : str
234
+ Name of the state variable in pre-synaptic neurons (typically membrane potential).
235
+ post : Dynamics
236
+ The post-synaptic neuron group dynamics object.
237
+ post_state : str
238
+ Name of the state variable in post-synaptic neurons (typically membrane potential).
239
+ conn : Callable
240
+ Connection function that returns pre_ids and post_ids arrays defining connections.
241
+ weight : Union[Callable, ArrayLike]
242
+ Conductance weights for the gap junctions. Must have shape [..., 2] where the last
243
+ dimension contains [pre_weight, post_weight] for each connection, defining
244
+ potentially different conductances in each direction.
245
+ param_type : type, optional
246
+ The parameter state type to use for weights, defaults to ParamState.
247
+
248
+ Examples
249
+ --------
250
+ >>> import brainstate
251
+ >>> import brainunit as u
252
+ >>> import numpy as np
253
+ >>>
254
+ >>> # Create two neuron populations
255
+ >>> n_neurons = 100
256
+ >>> pre_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
+ >>> post_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
258
+ >>> pre_pop.init_state()
259
+ >>> post_pop.init_state()
260
+ >>>
261
+ >>> # Create asymmetric gap junction with different weights in each direction
262
+ >>> weights = np.ones((n_neurons, 2)) * u.nS
263
+ >>> weights[:, 0] *= 2.0 # Double weight in pre->post direction
264
+ >>>
265
+ >>> gap_junction = brainstate.nn.AsymmetryGapJunction(
266
+ ... pre=pre_pop,
267
+ ... pre_state='V',
268
+ ... post=post_pop,
269
+ ... post_state='V',
270
+ ... conn=one_to_one,
271
+ ... weight=weights
272
+ ... )
273
+
274
+ Notes
275
+ -----
276
+ The asymmetric gap junction allows for different conductances in each direction between
277
+ the same pair of neurons. This can model rectifying electrical synapses that preferentially
278
+ allow current to flow in one direction.
279
+
280
+ See Also
281
+ --------
282
+ SymmetryGapJunction : For gap junctions with identical conductance in both directions.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ pre: Dynamics,
288
+ pre_state: str,
289
+ post: Dynamics,
290
+ post_state: str,
291
+ conn: Callable,
292
+ weight: Union[Callable, ArrayLike],
293
+ param_type: type = ParamState
294
+ ):
295
+ super().__init__()
296
+
297
+ assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
298
+ assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
299
+ self.pre_state = pre_state
300
+ self.post_state = post_state
301
+ self.pre = pre
302
+ self.post = post
303
+ self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
304
+ self.weight = param_type(init.param(weight, (len(self.pre_ids), 2)))
305
+
306
+ def update(self, *args, **kwargs):
307
+ if not hasattr(self.pre, self.pre_state):
308
+ raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
309
+ if not hasattr(self.post, self.post_state):
310
+ raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
311
+ pre = getattr(self.pre, self.pre_state).value
312
+ post = getattr(self.post, self.post_state).value
313
+
314
+ return asymmetry_gap_junction_projection(
315
+ pre=self.pre,
316
+ pre_value=pre,
317
+ post=self.post,
318
+ post_value=post,
319
+ pre_ids=self.pre_ids,
320
+ post_ids=self.post_ids,
321
+ weight=self.weight.value,
322
+ )
323
+
324
+
325
+ def asymmetry_gap_junction_projection(
326
+ pre: Dynamics,
327
+ pre_value: ArrayLike,
328
+ post: Dynamics,
329
+ post_value: ArrayLike,
330
+ pre_ids: ArrayLike,
331
+ post_ids: ArrayLike,
332
+ weight: ArrayLike,
333
+ ):
334
+ """
335
+ Calculate asymmetrical electrical coupling through gap junctions between neurons.
336
+
337
+ This function implements bidirectional gap junction coupling where different weights
338
+ can be applied in each direction. It computes the electrical current flowing between
339
+ connected neurons based on their potential differences and updates both pre-synaptic
340
+ and post-synaptic neuron groups with appropriate input currents.
341
+
342
+ Parameters
343
+ ----------
344
+ pre : Dynamics
345
+ The pre-synaptic neuron group dynamics object.
346
+ pre_value : ArrayLike
347
+ State values (typically membrane potentials) of the pre-synaptic neurons.
348
+ post : Dynamics
349
+ The post-synaptic neuron group dynamics object.
350
+ post_value : ArrayLike
351
+ State values (typically membrane potentials) of the post-synaptic neurons.
352
+ pre_ids : ArrayLike
353
+ Indices of pre-synaptic neurons that form gap junctions.
354
+ post_ids : ArrayLike
355
+ Indices of post-synaptic neurons that form gap junctions,
356
+ where each pre_ids[i] is connected to post_ids[i].
357
+ weight : ArrayLike
358
+ Conductance weights for the gap junctions. Must have shape [..., 2], where
359
+ the last dimension contains [pre_weight, post_weight] for each connection.
360
+ Can be a 1D array [pre_weight, post_weight] (same weights for all connections)
361
+ or a 2D array with shape [len(pre_ids), 2] for connection-specific weights.
362
+
363
+ Returns
364
+ -------
365
+ ArrayLike
366
+ The input currents that were added to the pre-synaptic neuron group.
367
+
368
+ Notes
369
+ -----
370
+ The electrical coupling is implemented with direction-specific conductances:
371
+ - I_pre2post = g_pre * (V_pre - V_post) flowing from pre to post neuron
372
+ - I_post2pre = g_post * (V_pre - V_post) flowing from post to pre neuron
373
+ where g_pre and g_post can be different, allowing for asymmetrical coupling.
374
+
375
+ Raises
376
+ ------
377
+ AssertionError
378
+ If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
379
+ ValueError
380
+ If weight shape is incompatible with asymmetrical gap junction requirements.
381
+ """
382
+ assert weight.shape[-1] == 2, 'weight must be a 2-element array for asymmetry gap junctions'
383
+ assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
384
+ if u.math.ndim(weight) == 1:
385
+ # If weight is a 1D array, it should have two elements for pre and post weights
386
+ assert weight.shape[0] == 2, "weight must be a 2-element array for asymmetry gap junctions"
387
+ pre_weight = weight[0]
388
+ post_weight = weight[1]
389
+ elif u.math.ndim(weight) == 2:
390
+ # If weight is a 2D array, it should have two rows for pre and post weights
391
+ pre_weight = weight[:, 0]
392
+ post_weight = weight[:, 1]
393
+ assert pre_weight.shape[0] == len(pre_ids), "pre_weight must have the same length as pre_ids"
394
+ assert post_weight.shape[0] == len(post_ids), "post_weight must have the same length as post_ids"
395
+ else:
396
+ raise ValueError("weight must be a 1D or 2D array for asymmetry gap junctions")
397
+
398
+ # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
399
+ # and multiply by the connection weights
400
+ diff = pre_value[pre_ids] - post_value[post_ids]
401
+ pre2post_current = diff * pre_weight
402
+ post2pre_current = diff * post_weight
403
+
404
+ # add to post-synaptic neuron group
405
+ # Initialize the input currents for the post-synaptic neuron group
406
+ inputs = u.math.zeros(post.out_size, unit=u.get_unit(pre2post_current))
407
+ # Add the calculated current to the corresponding post-synaptic neurons
408
+ inputs = inputs.at[post_ids].add(pre2post_current)
409
+ # Generate a unique key for the post-synaptic input currents
410
+ key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
411
+ # Add the input currents to the post-synaptic neuron group
412
+ post.add_current_input(key, inputs)
413
+
414
+ # add to pre-synaptic neuron group
415
+ # Initialize the input currents for the pre-synaptic neuron group
416
+ inputs = u.math.zeros(pre.out_size, unit=u.get_unit(post2pre_current))
417
+ # Add the calculated current to the corresponding pre-synaptic neurons
418
+ inputs = inputs.at[pre_ids].add(post2pre_current)
419
+ # Generate a unique key for the pre-synaptic input currents
420
+ key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
421
+ # Add the input currents to the pre-synaptic neuron group with opposite polarity
422
+ pre.add_current_input(key, -inputs)
423
+ return inputs