brainstate 0.1.3__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +1 -16
- brainstate/_state.py +1 -0
- brainstate/augment/_mapping.py +9 -9
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +31 -2
- brainstate/nn/__init__.py +8 -5
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_delay.py +13 -1
- brainstate/nn/_dropout.py +5 -4
- brainstate/nn/_dynamics.py +39 -44
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module.py +5 -5
- brainstate/nn/_projection.py +190 -98
- brainstate/nn/_synapse.py +5 -9
- brainstate/nn/_synaptic_projection.py +376 -86
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/RECORD +42 -42
- /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
@@ -15,119 +15,409 @@
|
|
15
15
|
# -*- coding: utf-8 -*-
|
16
16
|
|
17
17
|
|
18
|
-
from typing import Callable, Union
|
18
|
+
from typing import Callable, Union, Tuple
|
19
19
|
|
20
20
|
import brainunit as u
|
21
21
|
|
22
|
-
from brainstate
|
23
|
-
from brainstate.
|
22
|
+
from brainstate import init
|
23
|
+
from brainstate._state import ParamState
|
24
|
+
from brainstate.typing import ArrayLike
|
24
25
|
from ._dynamics import Dynamics, Projection
|
25
|
-
from ._projection import AlignPostProj, RawProj
|
26
|
-
from ._stp import ShortTermPlasticity
|
27
|
-
from ._synapse import Synapse
|
28
|
-
from ._synouts import SynOut
|
29
26
|
|
30
27
|
__all__ = [
|
31
|
-
'
|
32
|
-
'
|
28
|
+
'SymmetryGapJunction',
|
29
|
+
'AsymmetryGapJunction',
|
33
30
|
]
|
34
31
|
|
35
32
|
|
36
|
-
class
|
33
|
+
class align_pre_ltp(Projection):
|
34
|
+
pass
|
35
|
+
|
36
|
+
|
37
|
+
class align_post_ltp(Projection):
|
38
|
+
pass
|
39
|
+
|
40
|
+
|
41
|
+
def get_gap_junction_post_key(i: int):
|
42
|
+
return f'gap_junction_post_{i}'
|
43
|
+
|
44
|
+
|
45
|
+
def get_gap_junction_pre_key(i: int):
|
46
|
+
return f'gap_junction_pre_{i}'
|
47
|
+
|
48
|
+
|
49
|
+
class SymmetryGapJunction(Projection):
|
37
50
|
"""
|
38
|
-
|
39
|
-
|
40
|
-
This class
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
51
|
+
Implements a symmetric electrical coupling (gap junction) between neuron populations.
|
52
|
+
|
53
|
+
This class represents electrical synapses where the conductance is identical in both
|
54
|
+
directions. Gap junctions allow bidirectional flow of electrical current directly between
|
55
|
+
neurons, with the current magnitude proportional to the voltage difference between
|
56
|
+
connected neurons.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
couples : Union[Tuple[Dynamics, Dynamics], Dynamics]
|
61
|
+
Either a single Dynamics object (when pre and post populations are the same)
|
62
|
+
or a tuple of two Dynamics objects (pre, post) representing the coupled neuron populations.
|
63
|
+
states : Union[str, Tuple[str, str]]
|
64
|
+
Either a single string (when pre and post states are the same)
|
65
|
+
or a tuple of two strings (pre_state, post_state) representing the state variables
|
66
|
+
to use for calculating voltage differences (typically membrane potentials).
|
67
|
+
conn : Callable
|
68
|
+
Connection function that returns pre_ids and post_ids arrays defining connections.
|
69
|
+
weight : Union[Callable, ArrayLike]
|
70
|
+
Conductance weights for the gap junctions. The same weight applies in both directions
|
71
|
+
of the connection.
|
72
|
+
param_type : type, optional
|
73
|
+
The parameter state type to use for weights, defaults to ParamState.
|
74
|
+
|
75
|
+
Notes
|
76
|
+
-----
|
77
|
+
The symmetric gap junction applies identical conductance in both directions between
|
78
|
+
connected neurons, ensuring balanced electrical coupling in the network.
|
79
|
+
|
80
|
+
See Also
|
81
|
+
--------
|
82
|
+
AsymmetryGapJunction : For gap junctions with different conductances in each direction.
|
54
83
|
"""
|
55
84
|
|
56
85
|
def __init__(
|
57
86
|
self,
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
post: Dynamics,
|
64
|
-
stp: ShortTermPlasticity = None,
|
87
|
+
couples: Union[Tuple[Dynamics, Dynamics], Dynamics],
|
88
|
+
states: Union[str, Tuple[str, str]],
|
89
|
+
conn: Callable,
|
90
|
+
weight: Union[Callable, ArrayLike],
|
91
|
+
param_type: type = ParamState
|
65
92
|
):
|
66
93
|
super().__init__()
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
94
|
+
|
95
|
+
if isinstance(states, str):
|
96
|
+
pre_state = states
|
97
|
+
post_state = states
|
98
|
+
else:
|
99
|
+
pre_state, post_state = states
|
100
|
+
if isinstance(couples, Dynamics):
|
101
|
+
pre = couples
|
102
|
+
post = couples
|
103
|
+
else:
|
104
|
+
pre, post = couples
|
105
|
+
assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
|
106
|
+
assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
|
107
|
+
assert isinstance(pre, Dynamics), "pre must be a Dynamics object"
|
108
|
+
assert isinstance(post, Dynamics), "post must be a Dynamics object"
|
109
|
+
self.pre_state = pre_state
|
110
|
+
self.post_state = post_state
|
111
|
+
self.pre = pre
|
112
|
+
self.post = post
|
113
|
+
self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
|
114
|
+
self.weight = param_type(init.param(weight, (len(self.pre_ids),)))
|
115
|
+
|
116
|
+
def update(self, *args, **kwargs):
|
117
|
+
if not hasattr(self.pre, self.pre_state):
|
118
|
+
raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
|
119
|
+
if not hasattr(self.post, self.post_state):
|
120
|
+
raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
|
121
|
+
pre = getattr(self.pre, self.pre_state).value
|
122
|
+
post = getattr(self.post, self.post_state).value
|
123
|
+
|
124
|
+
return symmetry_gap_junction_projection(
|
125
|
+
pre=self.pre,
|
126
|
+
pre_value=pre,
|
127
|
+
post=self.post,
|
128
|
+
post_value=post,
|
129
|
+
pre_ids=self.pre_ids,
|
130
|
+
post_ids=self.post_ids,
|
131
|
+
weight=self.weight.value,
|
132
|
+
)
|
133
|
+
|
134
|
+
|
135
|
+
def symmetry_gap_junction_projection(
|
136
|
+
pre: Dynamics,
|
137
|
+
pre_value: ArrayLike,
|
138
|
+
post: Dynamics,
|
139
|
+
post_value: ArrayLike,
|
140
|
+
pre_ids: ArrayLike,
|
141
|
+
post_ids: ArrayLike,
|
142
|
+
weight: ArrayLike,
|
143
|
+
):
|
144
|
+
"""
|
145
|
+
Calculate symmetrical electrical coupling through gap junctions between neurons.
|
146
|
+
|
147
|
+
This function implements bidirectional gap junction coupling where the same weight is
|
148
|
+
applied in both directions. It computes the electrical current flowing between
|
149
|
+
connected neurons based on their potential differences and updates both pre-synaptic
|
150
|
+
and post-synaptic neuron groups with appropriate input currents.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
pre : Dynamics
|
155
|
+
The pre-synaptic neuron group dynamics object.
|
156
|
+
pre_value : ArrayLike
|
157
|
+
State values (typically membrane potentials) of the pre-synaptic neurons.
|
158
|
+
post : Dynamics
|
159
|
+
The post-synaptic neuron group dynamics object.
|
160
|
+
post_value : ArrayLike
|
161
|
+
State values (typically membrane potentials) of the post-synaptic neurons.
|
162
|
+
pre_ids : ArrayLike
|
163
|
+
Indices of pre-synaptic neurons that form gap junctions.
|
164
|
+
post_ids : ArrayLike
|
165
|
+
Indices of post-synaptic neurons that form gap junctions,
|
166
|
+
where each pre_ids[i] is connected to post_ids[i].
|
167
|
+
weight : ArrayLike
|
168
|
+
Conductance weights for the gap junctions. Can be a scalar (same weight for all connections)
|
169
|
+
or an array with length matching pre_ids.
|
170
|
+
|
171
|
+
Returns
|
172
|
+
-------
|
173
|
+
ArrayLike
|
174
|
+
The input currents that were added to the pre-synaptic neuron group.
|
175
|
+
|
176
|
+
Notes
|
177
|
+
-----
|
178
|
+
The electrical coupling is implemented as I = g(V_pre - V_post), where:
|
179
|
+
- I is the current flowing from pre to post neuron
|
180
|
+
- g is the gap junction conductance (weight)
|
181
|
+
- V_pre and V_post are the membrane potentials of connected neurons
|
182
|
+
|
183
|
+
Equal but opposite currents are applied to both connected neurons, ensuring
|
184
|
+
conservation of current in the network.
|
185
|
+
|
186
|
+
Raises
|
187
|
+
------
|
188
|
+
AssertionError
|
189
|
+
If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
|
190
|
+
"""
|
191
|
+
assert u.math.ndim(weight) == 0 or weight.shape[0] == len(pre_ids), \
|
192
|
+
"weight must be a scalar or have the same length as pre_ids"
|
193
|
+
assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
|
194
|
+
# Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
|
195
|
+
# and multiply by the connection weights
|
196
|
+
diff = (pre_value[pre_ids] - post_value[post_ids]) * weight
|
197
|
+
|
198
|
+
# add to post-synaptic neuron group
|
199
|
+
# Initialize the input currents for the post-synaptic neuron group
|
200
|
+
inputs = u.math.zeros(post.out_size, unit=u.get_unit(diff))
|
201
|
+
# Add the calculated current to the corresponding post-synaptic neurons
|
202
|
+
inputs = inputs.at[post_ids].add(diff)
|
203
|
+
# Generate a unique key for the post-synaptic input currents
|
204
|
+
key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
|
205
|
+
# Add the input currents to the post-synaptic neuron group
|
206
|
+
post.add_current_input(key, inputs)
|
207
|
+
|
208
|
+
# add to pre-synaptic neuron group
|
209
|
+
# Initialize the input currents for the pre-synaptic neuron group
|
210
|
+
inputs = u.math.zeros(pre.out_size, unit=u.get_unit(diff))
|
211
|
+
# Add the calculated current to the corresponding pre-synaptic neurons
|
212
|
+
inputs = inputs.at[pre_ids].add(diff)
|
213
|
+
# Generate a unique key for the pre-synaptic input currents
|
214
|
+
key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
|
215
|
+
# Add the input currents to the pre-synaptic neuron group with opposite polarity
|
216
|
+
pre.add_current_input(key, -inputs)
|
217
|
+
return inputs
|
218
|
+
|
219
|
+
|
220
|
+
class AsymmetryGapJunction(Projection):
|
83
221
|
"""
|
84
|
-
|
222
|
+
Implements an asymmetric electrical coupling (gap junction) between neuron populations.
|
223
|
+
|
224
|
+
This class represents electrical synapses where the conductance in one direction can differ
|
225
|
+
from the conductance in the opposite direction. Unlike chemical synapses, gap junctions
|
226
|
+
allow bidirectional flow of electrical current directly between neurons, with the current
|
227
|
+
magnitude proportional to the voltage difference between connected neurons.
|
85
228
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
229
|
+
Parameters
|
230
|
+
----------
|
231
|
+
pre : Dynamics
|
232
|
+
The pre-synaptic neuron group dynamics object.
|
233
|
+
pre_state : str
|
234
|
+
Name of the state variable in pre-synaptic neurons (typically membrane potential).
|
235
|
+
post : Dynamics
|
236
|
+
The post-synaptic neuron group dynamics object.
|
237
|
+
post_state : str
|
238
|
+
Name of the state variable in post-synaptic neurons (typically membrane potential).
|
239
|
+
conn : Callable
|
240
|
+
Connection function that returns pre_ids and post_ids arrays defining connections.
|
241
|
+
weight : Union[Callable, ArrayLike]
|
242
|
+
Conductance weights for the gap junctions. Must have shape [..., 2] where the last
|
243
|
+
dimension contains [pre_weight, post_weight] for each connection, defining
|
244
|
+
potentially different conductances in each direction.
|
245
|
+
param_type : type, optional
|
246
|
+
The parameter state type to use for weights, defaults to ParamState.
|
90
247
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
248
|
+
Examples
|
249
|
+
--------
|
250
|
+
>>> import brainstate
|
251
|
+
>>> import brainunit as u
|
252
|
+
>>> import numpy as np
|
253
|
+
>>>
|
254
|
+
>>> # Create two neuron populations
|
255
|
+
>>> n_neurons = 100
|
256
|
+
>>> pre_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
257
|
+
>>> post_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
258
|
+
>>> pre_pop.init_state()
|
259
|
+
>>> post_pop.init_state()
|
260
|
+
>>>
|
261
|
+
>>> # Create asymmetric gap junction with different weights in each direction
|
262
|
+
>>> weights = np.ones((n_neurons, 2)) * u.nS
|
263
|
+
>>> weights[:, 0] *= 2.0 # Double weight in pre->post direction
|
264
|
+
>>>
|
265
|
+
>>> gap_junction = brainstate.nn.AsymmetryGapJunction(
|
266
|
+
... pre=pre_pop,
|
267
|
+
... pre_state='V',
|
268
|
+
... post=post_pop,
|
269
|
+
... post_state='V',
|
270
|
+
... conn=one_to_one,
|
271
|
+
... weight=weights
|
272
|
+
... )
|
98
273
|
|
274
|
+
Notes
|
275
|
+
-----
|
276
|
+
The asymmetric gap junction allows for different conductances in each direction between
|
277
|
+
the same pair of neurons. This can model rectifying electrical synapses that preferentially
|
278
|
+
allow current to flow in one direction.
|
279
|
+
|
280
|
+
See Also
|
281
|
+
--------
|
282
|
+
SymmetryGapJunction : For gap junctions with identical conductance in both directions.
|
99
283
|
"""
|
284
|
+
|
100
285
|
def __init__(
|
101
286
|
self,
|
102
|
-
|
103
|
-
|
104
|
-
syn: Union[AlignPost, ParamDescriber[AlignPost]],
|
105
|
-
out: Union[SynOut, ParamDescriber[SynOut]],
|
287
|
+
pre: Dynamics,
|
288
|
+
pre_state: str,
|
106
289
|
post: Dynamics,
|
107
|
-
|
290
|
+
post_state: str,
|
291
|
+
conn: Callable,
|
292
|
+
weight: Union[Callable, ArrayLike],
|
293
|
+
param_type: type = ParamState
|
108
294
|
):
|
109
295
|
super().__init__()
|
110
|
-
self.spike_generator = spike_generator
|
111
|
-
self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
|
112
|
-
self.stp = stp
|
113
|
-
|
114
|
-
def update(self, *x):
|
115
|
-
for fun in self.spike_generator:
|
116
|
-
x = fun(*x)
|
117
|
-
if isinstance(x, (tuple, list)):
|
118
|
-
x = tuple(x)
|
119
|
-
else:
|
120
|
-
x = (x,)
|
121
|
-
assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
|
122
|
-
x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
|
123
|
-
if self.stp is not None:
|
124
|
-
x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
|
125
|
-
return self.projection(x)
|
126
296
|
|
297
|
+
assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
|
298
|
+
assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
|
299
|
+
self.pre_state = pre_state
|
300
|
+
self.post_state = post_state
|
301
|
+
self.pre = pre
|
302
|
+
self.post = post
|
303
|
+
self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
|
304
|
+
self.weight = param_type(init.param(weight, (len(self.pre_ids), 2)))
|
127
305
|
|
128
|
-
|
129
|
-
|
306
|
+
def update(self, *args, **kwargs):
|
307
|
+
if not hasattr(self.pre, self.pre_state):
|
308
|
+
raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
|
309
|
+
if not hasattr(self.post, self.post_state):
|
310
|
+
raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
|
311
|
+
pre = getattr(self.pre, self.pre_state).value
|
312
|
+
post = getattr(self.post, self.post_state).value
|
130
313
|
|
314
|
+
return asymmetry_gap_junction_projection(
|
315
|
+
pre=self.pre,
|
316
|
+
pre_value=pre,
|
317
|
+
post=self.post,
|
318
|
+
post_value=post,
|
319
|
+
pre_ids=self.pre_ids,
|
320
|
+
post_ids=self.post_ids,
|
321
|
+
weight=self.weight.value,
|
322
|
+
)
|
131
323
|
|
132
|
-
|
133
|
-
|
324
|
+
|
325
|
+
def asymmetry_gap_junction_projection(
|
326
|
+
pre: Dynamics,
|
327
|
+
pre_value: ArrayLike,
|
328
|
+
post: Dynamics,
|
329
|
+
post_value: ArrayLike,
|
330
|
+
pre_ids: ArrayLike,
|
331
|
+
post_ids: ArrayLike,
|
332
|
+
weight: ArrayLike,
|
333
|
+
):
|
334
|
+
"""
|
335
|
+
Calculate asymmetrical electrical coupling through gap junctions between neurons.
|
336
|
+
|
337
|
+
This function implements bidirectional gap junction coupling where different weights
|
338
|
+
can be applied in each direction. It computes the electrical current flowing between
|
339
|
+
connected neurons based on their potential differences and updates both pre-synaptic
|
340
|
+
and post-synaptic neuron groups with appropriate input currents.
|
341
|
+
|
342
|
+
Parameters
|
343
|
+
----------
|
344
|
+
pre : Dynamics
|
345
|
+
The pre-synaptic neuron group dynamics object.
|
346
|
+
pre_value : ArrayLike
|
347
|
+
State values (typically membrane potentials) of the pre-synaptic neurons.
|
348
|
+
post : Dynamics
|
349
|
+
The post-synaptic neuron group dynamics object.
|
350
|
+
post_value : ArrayLike
|
351
|
+
State values (typically membrane potentials) of the post-synaptic neurons.
|
352
|
+
pre_ids : ArrayLike
|
353
|
+
Indices of pre-synaptic neurons that form gap junctions.
|
354
|
+
post_ids : ArrayLike
|
355
|
+
Indices of post-synaptic neurons that form gap junctions,
|
356
|
+
where each pre_ids[i] is connected to post_ids[i].
|
357
|
+
weight : ArrayLike
|
358
|
+
Conductance weights for the gap junctions. Must have shape [..., 2], where
|
359
|
+
the last dimension contains [pre_weight, post_weight] for each connection.
|
360
|
+
Can be a 1D array [pre_weight, post_weight] (same weights for all connections)
|
361
|
+
or a 2D array with shape [len(pre_ids), 2] for connection-specific weights.
|
362
|
+
|
363
|
+
Returns
|
364
|
+
-------
|
365
|
+
ArrayLike
|
366
|
+
The input currents that were added to the pre-synaptic neuron group.
|
367
|
+
|
368
|
+
Notes
|
369
|
+
-----
|
370
|
+
The electrical coupling is implemented with direction-specific conductances:
|
371
|
+
- I_pre2post = g_pre * (V_pre - V_post) flowing from pre to post neuron
|
372
|
+
- I_post2pre = g_post * (V_pre - V_post) flowing from post to pre neuron
|
373
|
+
where g_pre and g_post can be different, allowing for asymmetrical coupling.
|
374
|
+
|
375
|
+
Raises
|
376
|
+
------
|
377
|
+
AssertionError
|
378
|
+
If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
|
379
|
+
ValueError
|
380
|
+
If weight shape is incompatible with asymmetrical gap junction requirements.
|
381
|
+
"""
|
382
|
+
assert weight.shape[-1] == 2, 'weight must be a 2-element array for asymmetry gap junctions'
|
383
|
+
assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
|
384
|
+
if u.math.ndim(weight) == 1:
|
385
|
+
# If weight is a 1D array, it should have two elements for pre and post weights
|
386
|
+
assert weight.shape[0] == 2, "weight must be a 2-element array for asymmetry gap junctions"
|
387
|
+
pre_weight = weight[0]
|
388
|
+
post_weight = weight[1]
|
389
|
+
elif u.math.ndim(weight) == 2:
|
390
|
+
# If weight is a 2D array, it should have two rows for pre and post weights
|
391
|
+
pre_weight = weight[:, 0]
|
392
|
+
post_weight = weight[:, 1]
|
393
|
+
assert pre_weight.shape[0] == len(pre_ids), "pre_weight must have the same length as pre_ids"
|
394
|
+
assert post_weight.shape[0] == len(post_ids), "post_weight must have the same length as post_ids"
|
395
|
+
else:
|
396
|
+
raise ValueError("weight must be a 1D or 2D array for asymmetry gap junctions")
|
397
|
+
|
398
|
+
# Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
|
399
|
+
# and multiply by the connection weights
|
400
|
+
diff = pre_value[pre_ids] - post_value[post_ids]
|
401
|
+
pre2post_current = diff * pre_weight
|
402
|
+
post2pre_current = diff * post_weight
|
403
|
+
|
404
|
+
# add to post-synaptic neuron group
|
405
|
+
# Initialize the input currents for the post-synaptic neuron group
|
406
|
+
inputs = u.math.zeros(post.out_size, unit=u.get_unit(pre2post_current))
|
407
|
+
# Add the calculated current to the corresponding post-synaptic neurons
|
408
|
+
inputs = inputs.at[post_ids].add(pre2post_current)
|
409
|
+
# Generate a unique key for the post-synaptic input currents
|
410
|
+
key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
|
411
|
+
# Add the input currents to the post-synaptic neuron group
|
412
|
+
post.add_current_input(key, inputs)
|
413
|
+
|
414
|
+
# add to pre-synaptic neuron group
|
415
|
+
# Initialize the input currents for the pre-synaptic neuron group
|
416
|
+
inputs = u.math.zeros(pre.out_size, unit=u.get_unit(post2pre_current))
|
417
|
+
# Add the calculated current to the corresponding pre-synaptic neurons
|
418
|
+
inputs = inputs.at[pre_ids].add(post2pre_current)
|
419
|
+
# Generate a unique key for the pre-synaptic input currents
|
420
|
+
key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
|
421
|
+
# Add the input currents to the pre-synaptic neuron group with opposite polarity
|
422
|
+
pre.add_current_input(key, -inputs)
|
423
|
+
return inputs
|
brainstate/random/_rand_state.py
CHANGED
@@ -384,7 +384,10 @@ class RandomState(State):
|
|
384
384
|
loc = _check_py_seq(loc)
|
385
385
|
scale = _check_py_seq(scale)
|
386
386
|
if size is None:
|
387
|
-
size = lax.broadcast_shapes(
|
387
|
+
size = lax.broadcast_shapes(
|
388
|
+
jnp.shape(loc) if loc is not None else (),
|
389
|
+
jnp.shape(scale) if scale is not None else ()
|
390
|
+
)
|
388
391
|
key = self.split_key() if key is None else _formalize_key(key)
|
389
392
|
dtype = dtype or environ.dftype()
|
390
393
|
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype))
|
@@ -399,7 +402,10 @@ class RandomState(State):
|
|
399
402
|
loc = _check_py_seq(loc)
|
400
403
|
scale = _check_py_seq(scale)
|
401
404
|
if size is None:
|
402
|
-
size = lax.broadcast_shapes(
|
405
|
+
size = lax.broadcast_shapes(
|
406
|
+
jnp.shape(scale) if scale is not None else (),
|
407
|
+
jnp.shape(loc) if loc is not None else ()
|
408
|
+
)
|
403
409
|
key = self.split_key() if key is None else _formalize_key(key)
|
404
410
|
dtype = dtype or environ.dftype()
|
405
411
|
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
@@ -456,7 +462,7 @@ class RandomState(State):
|
|
456
462
|
dtype: DTypeLike = None):
|
457
463
|
shape = _check_py_seq(shape)
|
458
464
|
if size is None:
|
459
|
-
size = jnp.shape(shape)
|
465
|
+
size = jnp.shape(shape) if shape is not None else ()
|
460
466
|
key = self.split_key() if key is None else _formalize_key(key)
|
461
467
|
dtype = dtype or environ.dftype()
|
462
468
|
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
@@ -477,7 +483,7 @@ class RandomState(State):
|
|
477
483
|
dtype: DTypeLike = None):
|
478
484
|
df = _check_py_seq(df)
|
479
485
|
if size is None:
|
480
|
-
size = jnp.shape(size)
|
486
|
+
size = jnp.shape(size) if size is not None else ()
|
481
487
|
key = self.split_key() if key is None else _formalize_key(key)
|
482
488
|
dtype = dtype or environ.dftype()
|
483
489
|
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
@@ -606,8 +612,8 @@ class RandomState(State):
|
|
606
612
|
|
607
613
|
if size is None:
|
608
614
|
size = jnp.broadcast_shapes(
|
609
|
-
jnp.shape(mean),
|
610
|
-
jnp.shape(sigma)
|
615
|
+
jnp.shape(mean) if mean is not None else (),
|
616
|
+
jnp.shape(sigma) if sigma is not None else ()
|
611
617
|
)
|
612
618
|
key = self.split_key() if key is None else _formalize_key(key)
|
613
619
|
dtype = dtype or environ.dftype()
|
@@ -822,7 +828,7 @@ class RandomState(State):
|
|
822
828
|
a = _check_py_seq(a)
|
823
829
|
scale = _check_py_seq(scale)
|
824
830
|
if size is None:
|
825
|
-
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale))
|
831
|
+
size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale) if scale is not None else ())
|
826
832
|
else:
|
827
833
|
if jnp.size(a) > 1:
|
828
834
|
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
brainstate/surrogate.py
CHANGED
@@ -21,7 +21,7 @@ import jax.scipy as sci
|
|
21
21
|
from jax.interpreters import batching, ad, mlir
|
22
22
|
|
23
23
|
from brainstate._compatible_import import Primitive
|
24
|
-
from brainstate.util
|
24
|
+
from brainstate.util import PrettyObject
|
25
25
|
|
26
26
|
__all__ = [
|
27
27
|
'Surrogate',
|
brainstate/typing.py
CHANGED
@@ -257,7 +257,7 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
|
|
257
257
|
cases, all named pieces must already have been seen and their structures bound.
|
258
258
|
""" # noqa: E501
|
259
259
|
|
260
|
-
Size = Union[int, Sequence[int]]
|
260
|
+
Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
|
261
261
|
Axes = Union[int, Sequence[int]]
|
262
262
|
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
263
263
|
Shape = Sequence[int]
|
brainstate/util/__init__.py
CHANGED
@@ -14,20 +14,20 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
from . import filter
|
17
|
-
from .
|
18
|
-
from .
|
19
|
-
from .
|
20
|
-
from .
|
21
|
-
from .
|
22
|
-
from .
|
23
|
-
from .
|
24
|
-
from .
|
25
|
-
from .
|
26
|
-
from .
|
27
|
-
from .
|
28
|
-
from .
|
29
|
-
from .
|
30
|
-
from .
|
17
|
+
from .error import *
|
18
|
+
from .error import __all__ as _error_all
|
19
|
+
from .others import *
|
20
|
+
from .others import __all__ as _others_all
|
21
|
+
from .pretty_pytree import *
|
22
|
+
from .pretty_pytree import __all__ as _mapping_all
|
23
|
+
from .pretty_repr import *
|
24
|
+
from .pretty_repr import __all__ as _pretty_repr_all
|
25
|
+
from .pretty_table import *
|
26
|
+
from .pretty_table import __all__ as _table_all
|
27
|
+
from .scaling import *
|
28
|
+
from .scaling import __all__ as _mem_scale_all
|
29
|
+
from .struct import *
|
30
|
+
from .struct import __all__ as _struct_all
|
31
31
|
|
32
32
|
__all__ = (
|
33
33
|
['filter']
|
@@ -21,8 +21,8 @@ from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dic
|
|
21
21
|
import jax
|
22
22
|
|
23
23
|
from brainstate.typing import Filter, PathParts
|
24
|
-
from .
|
25
|
-
from .
|
24
|
+
from .pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
|
25
|
+
from .struct import dataclass
|
26
26
|
from .filter import to_predicate
|
27
27
|
|
28
28
|
__all__ = [
|