brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +0 -15
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +30 -14
  9. brainstate/nn/__init__.py +84 -17
  10. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  11. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
  12. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
  13. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
  14. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  15. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  16. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
  17. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  18. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  19. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
  20. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  23. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  24. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  25. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
  26. brainstate/nn/_projection.py +486 -0
  27. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  28. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  29. brainstate/nn/_stp.py +236 -0
  30. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
  31. brainstate/nn/_synaptic_projection.py +423 -0
  32. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  33. brainstate/surrogate.py +1 -1
  34. brainstate/typing.py +1 -1
  35. brainstate/util/__init__.py +14 -14
  36. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  37. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  38. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
  39. brainstate/nn/_dyn_impl/__init__.py +0 -42
  40. brainstate/nn/_dynamics/__init__.py +0 -37
  41. brainstate/nn/_dynamics/_projection_base.py +0 -362
  42. brainstate/nn/_elementwise/__init__.py +0 -22
  43. brainstate/nn/_interaction/__init__.py +0 -41
  44. /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
  45. /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
  46. /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
  47. /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
  48. /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  49. /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
  50. /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
  51. /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
  52. /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
  53. /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
  54. /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
  55. /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
  56. /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
  57. /brainstate/util/{_caller.py → caller.py} +0 -0
  58. /brainstate/util/{_error.py → error.py} +0 -0
  59. /brainstate/util/{_others.py → others.py} +0 -0
  60. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  61. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  62. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  63. /brainstate/util/{_struct.py → struct.py} +0 -0
  64. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  65. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  66. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # -*- coding: utf-8 -*-
16
+
17
+
18
+ from typing import Callable, Union, Tuple
19
+
20
+ import brainunit as u
21
+
22
+ from brainstate import init
23
+ from brainstate._state import ParamState
24
+ from brainstate.typing import ArrayLike
25
+ from ._dynamics import Dynamics, Projection
26
+
27
+ __all__ = [
28
+ 'SymmetryGapJunction',
29
+ 'AsymmetryGapJunction',
30
+ ]
31
+
32
+
33
+ class align_pre_ltp(Projection):
34
+ pass
35
+
36
+
37
+ class align_post_ltp(Projection):
38
+ pass
39
+
40
+
41
+ def get_gap_junction_post_key(i: int):
42
+ return f'gap_junction_post_{i}'
43
+
44
+
45
+ def get_gap_junction_pre_key(i: int):
46
+ return f'gap_junction_pre_{i}'
47
+
48
+
49
+ class SymmetryGapJunction(Projection):
50
+ """
51
+ Implements a symmetric electrical coupling (gap junction) between neuron populations.
52
+
53
+ This class represents electrical synapses where the conductance is identical in both
54
+ directions. Gap junctions allow bidirectional flow of electrical current directly between
55
+ neurons, with the current magnitude proportional to the voltage difference between
56
+ connected neurons.
57
+
58
+ Parameters
59
+ ----------
60
+ couples : Union[Tuple[Dynamics, Dynamics], Dynamics]
61
+ Either a single Dynamics object (when pre and post populations are the same)
62
+ or a tuple of two Dynamics objects (pre, post) representing the coupled neuron populations.
63
+ states : Union[str, Tuple[str, str]]
64
+ Either a single string (when pre and post states are the same)
65
+ or a tuple of two strings (pre_state, post_state) representing the state variables
66
+ to use for calculating voltage differences (typically membrane potentials).
67
+ conn : Callable
68
+ Connection function that returns pre_ids and post_ids arrays defining connections.
69
+ weight : Union[Callable, ArrayLike]
70
+ Conductance weights for the gap junctions. The same weight applies in both directions
71
+ of the connection.
72
+ param_type : type, optional
73
+ The parameter state type to use for weights, defaults to ParamState.
74
+
75
+ Notes
76
+ -----
77
+ The symmetric gap junction applies identical conductance in both directions between
78
+ connected neurons, ensuring balanced electrical coupling in the network.
79
+
80
+ See Also
81
+ --------
82
+ AsymmetryGapJunction : For gap junctions with different conductances in each direction.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ couples: Union[Tuple[Dynamics, Dynamics], Dynamics],
88
+ states: Union[str, Tuple[str, str]],
89
+ conn: Callable,
90
+ weight: Union[Callable, ArrayLike],
91
+ param_type: type = ParamState
92
+ ):
93
+ super().__init__()
94
+
95
+ if isinstance(states, str):
96
+ pre_state = states
97
+ post_state = states
98
+ else:
99
+ pre_state, post_state = states
100
+ if isinstance(couples, Dynamics):
101
+ pre = couples
102
+ post = couples
103
+ else:
104
+ pre, post = couples
105
+ assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
106
+ assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
107
+ assert isinstance(pre, Dynamics), "pre must be a Dynamics object"
108
+ assert isinstance(post, Dynamics), "post must be a Dynamics object"
109
+ self.pre_state = pre_state
110
+ self.post_state = post_state
111
+ self.pre = pre
112
+ self.post = post
113
+ self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
114
+ self.weight = param_type(init.param(weight, (len(self.pre_ids),)))
115
+
116
+ def update(self, *args, **kwargs):
117
+ if not hasattr(self.pre, self.pre_state):
118
+ raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
119
+ if not hasattr(self.post, self.post_state):
120
+ raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
121
+ pre = getattr(self.pre, self.pre_state).value
122
+ post = getattr(self.post, self.post_state).value
123
+
124
+ return symmetry_gap_junction_projection(
125
+ pre=self.pre,
126
+ pre_value=pre,
127
+ post=self.post,
128
+ post_value=post,
129
+ pre_ids=self.pre_ids,
130
+ post_ids=self.post_ids,
131
+ weight=self.weight.value,
132
+ )
133
+
134
+
135
+ def symmetry_gap_junction_projection(
136
+ pre: Dynamics,
137
+ pre_value: ArrayLike,
138
+ post: Dynamics,
139
+ post_value: ArrayLike,
140
+ pre_ids: ArrayLike,
141
+ post_ids: ArrayLike,
142
+ weight: ArrayLike,
143
+ ):
144
+ """
145
+ Calculate symmetrical electrical coupling through gap junctions between neurons.
146
+
147
+ This function implements bidirectional gap junction coupling where the same weight is
148
+ applied in both directions. It computes the electrical current flowing between
149
+ connected neurons based on their potential differences and updates both pre-synaptic
150
+ and post-synaptic neuron groups with appropriate input currents.
151
+
152
+ Parameters
153
+ ----------
154
+ pre : Dynamics
155
+ The pre-synaptic neuron group dynamics object.
156
+ pre_value : ArrayLike
157
+ State values (typically membrane potentials) of the pre-synaptic neurons.
158
+ post : Dynamics
159
+ The post-synaptic neuron group dynamics object.
160
+ post_value : ArrayLike
161
+ State values (typically membrane potentials) of the post-synaptic neurons.
162
+ pre_ids : ArrayLike
163
+ Indices of pre-synaptic neurons that form gap junctions.
164
+ post_ids : ArrayLike
165
+ Indices of post-synaptic neurons that form gap junctions,
166
+ where each pre_ids[i] is connected to post_ids[i].
167
+ weight : ArrayLike
168
+ Conductance weights for the gap junctions. Can be a scalar (same weight for all connections)
169
+ or an array with length matching pre_ids.
170
+
171
+ Returns
172
+ -------
173
+ ArrayLike
174
+ The input currents that were added to the pre-synaptic neuron group.
175
+
176
+ Notes
177
+ -----
178
+ The electrical coupling is implemented as I = g(V_pre - V_post), where:
179
+ - I is the current flowing from pre to post neuron
180
+ - g is the gap junction conductance (weight)
181
+ - V_pre and V_post are the membrane potentials of connected neurons
182
+
183
+ Equal but opposite currents are applied to both connected neurons, ensuring
184
+ conservation of current in the network.
185
+
186
+ Raises
187
+ ------
188
+ AssertionError
189
+ If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
190
+ """
191
+ assert u.math.ndim(weight) == 0 or weight.shape[0] == len(pre_ids), \
192
+ "weight must be a scalar or have the same length as pre_ids"
193
+ assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
194
+ # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
195
+ # and multiply by the connection weights
196
+ diff = (pre_value[pre_ids] - post_value[post_ids]) * weight
197
+
198
+ # add to post-synaptic neuron group
199
+ # Initialize the input currents for the post-synaptic neuron group
200
+ inputs = u.math.zeros(post.out_size, unit=u.get_unit(diff))
201
+ # Add the calculated current to the corresponding post-synaptic neurons
202
+ inputs = inputs.at[post_ids].add(diff)
203
+ # Generate a unique key for the post-synaptic input currents
204
+ key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
205
+ # Add the input currents to the post-synaptic neuron group
206
+ post.add_current_input(key, inputs)
207
+
208
+ # add to pre-synaptic neuron group
209
+ # Initialize the input currents for the pre-synaptic neuron group
210
+ inputs = u.math.zeros(pre.out_size, unit=u.get_unit(diff))
211
+ # Add the calculated current to the corresponding pre-synaptic neurons
212
+ inputs = inputs.at[pre_ids].add(diff)
213
+ # Generate a unique key for the pre-synaptic input currents
214
+ key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
215
+ # Add the input currents to the pre-synaptic neuron group with opposite polarity
216
+ pre.add_current_input(key, -inputs)
217
+ return inputs
218
+
219
+
220
+ class AsymmetryGapJunction(Projection):
221
+ """
222
+ Implements an asymmetric electrical coupling (gap junction) between neuron populations.
223
+
224
+ This class represents electrical synapses where the conductance in one direction can differ
225
+ from the conductance in the opposite direction. Unlike chemical synapses, gap junctions
226
+ allow bidirectional flow of electrical current directly between neurons, with the current
227
+ magnitude proportional to the voltage difference between connected neurons.
228
+
229
+ Parameters
230
+ ----------
231
+ pre : Dynamics
232
+ The pre-synaptic neuron group dynamics object.
233
+ pre_state : str
234
+ Name of the state variable in pre-synaptic neurons (typically membrane potential).
235
+ post : Dynamics
236
+ The post-synaptic neuron group dynamics object.
237
+ post_state : str
238
+ Name of the state variable in post-synaptic neurons (typically membrane potential).
239
+ conn : Callable
240
+ Connection function that returns pre_ids and post_ids arrays defining connections.
241
+ weight : Union[Callable, ArrayLike]
242
+ Conductance weights for the gap junctions. Must have shape [..., 2] where the last
243
+ dimension contains [pre_weight, post_weight] for each connection, defining
244
+ potentially different conductances in each direction.
245
+ param_type : type, optional
246
+ The parameter state type to use for weights, defaults to ParamState.
247
+
248
+ Examples
249
+ --------
250
+ >>> import brainstate
251
+ >>> import brainunit as u
252
+ >>> import numpy as np
253
+ >>>
254
+ >>> # Create two neuron populations
255
+ >>> n_neurons = 100
256
+ >>> pre_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
257
+ >>> post_pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
258
+ >>> pre_pop.init_state()
259
+ >>> post_pop.init_state()
260
+ >>>
261
+ >>> # Create asymmetric gap junction with different weights in each direction
262
+ >>> weights = np.ones((n_neurons, 2)) * u.nS
263
+ >>> weights[:, 0] *= 2.0 # Double weight in pre->post direction
264
+ >>>
265
+ >>> gap_junction = brainstate.nn.AsymmetryGapJunction(
266
+ ... pre=pre_pop,
267
+ ... pre_state='V',
268
+ ... post=post_pop,
269
+ ... post_state='V',
270
+ ... conn=one_to_one,
271
+ ... weight=weights
272
+ ... )
273
+
274
+ Notes
275
+ -----
276
+ The asymmetric gap junction allows for different conductances in each direction between
277
+ the same pair of neurons. This can model rectifying electrical synapses that preferentially
278
+ allow current to flow in one direction.
279
+
280
+ See Also
281
+ --------
282
+ SymmetryGapJunction : For gap junctions with identical conductance in both directions.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ pre: Dynamics,
288
+ pre_state: str,
289
+ post: Dynamics,
290
+ post_state: str,
291
+ conn: Callable,
292
+ weight: Union[Callable, ArrayLike],
293
+ param_type: type = ParamState
294
+ ):
295
+ super().__init__()
296
+
297
+ assert isinstance(pre_state, str), "pre_state must be a string representing the pre-synaptic state"
298
+ assert isinstance(post_state, str), "post_state must be a string representing the post-synaptic state"
299
+ self.pre_state = pre_state
300
+ self.post_state = post_state
301
+ self.pre = pre
302
+ self.post = post
303
+ self.pre_ids, self.post_ids = conn(pre.out_size, post.out_size)
304
+ self.weight = param_type(init.param(weight, (len(self.pre_ids), 2)))
305
+
306
+ def update(self, *args, **kwargs):
307
+ if not hasattr(self.pre, self.pre_state):
308
+ raise ValueError(f"pre_state {self.pre_state} not found in pre-synaptic neuron group")
309
+ if not hasattr(self.post, self.post_state):
310
+ raise ValueError(f"post_state {self.post_state} not found in post-synaptic neuron group")
311
+ pre = getattr(self.pre, self.pre_state).value
312
+ post = getattr(self.post, self.post_state).value
313
+
314
+ return asymmetry_gap_junction_projection(
315
+ pre=self.pre,
316
+ pre_value=pre,
317
+ post=self.post,
318
+ post_value=post,
319
+ pre_ids=self.pre_ids,
320
+ post_ids=self.post_ids,
321
+ weight=self.weight.value,
322
+ )
323
+
324
+
325
+ def asymmetry_gap_junction_projection(
326
+ pre: Dynamics,
327
+ pre_value: ArrayLike,
328
+ post: Dynamics,
329
+ post_value: ArrayLike,
330
+ pre_ids: ArrayLike,
331
+ post_ids: ArrayLike,
332
+ weight: ArrayLike,
333
+ ):
334
+ """
335
+ Calculate asymmetrical electrical coupling through gap junctions between neurons.
336
+
337
+ This function implements bidirectional gap junction coupling where different weights
338
+ can be applied in each direction. It computes the electrical current flowing between
339
+ connected neurons based on their potential differences and updates both pre-synaptic
340
+ and post-synaptic neuron groups with appropriate input currents.
341
+
342
+ Parameters
343
+ ----------
344
+ pre : Dynamics
345
+ The pre-synaptic neuron group dynamics object.
346
+ pre_value : ArrayLike
347
+ State values (typically membrane potentials) of the pre-synaptic neurons.
348
+ post : Dynamics
349
+ The post-synaptic neuron group dynamics object.
350
+ post_value : ArrayLike
351
+ State values (typically membrane potentials) of the post-synaptic neurons.
352
+ pre_ids : ArrayLike
353
+ Indices of pre-synaptic neurons that form gap junctions.
354
+ post_ids : ArrayLike
355
+ Indices of post-synaptic neurons that form gap junctions,
356
+ where each pre_ids[i] is connected to post_ids[i].
357
+ weight : ArrayLike
358
+ Conductance weights for the gap junctions. Must have shape [..., 2], where
359
+ the last dimension contains [pre_weight, post_weight] for each connection.
360
+ Can be a 1D array [pre_weight, post_weight] (same weights for all connections)
361
+ or a 2D array with shape [len(pre_ids), 2] for connection-specific weights.
362
+
363
+ Returns
364
+ -------
365
+ ArrayLike
366
+ The input currents that were added to the pre-synaptic neuron group.
367
+
368
+ Notes
369
+ -----
370
+ The electrical coupling is implemented with direction-specific conductances:
371
+ - I_pre2post = g_pre * (V_pre - V_post) flowing from pre to post neuron
372
+ - I_post2pre = g_post * (V_pre - V_post) flowing from post to pre neuron
373
+ where g_pre and g_post can be different, allowing for asymmetrical coupling.
374
+
375
+ Raises
376
+ ------
377
+ AssertionError
378
+ If weight dimensionality is incorrect or pre_ids and post_ids have different lengths.
379
+ ValueError
380
+ If weight shape is incompatible with asymmetrical gap junction requirements.
381
+ """
382
+ assert weight.shape[-1] == 2, 'weight must be a 2-element array for asymmetry gap junctions'
383
+ assert len(pre_ids) == len(post_ids), "pre_ids and post_ids must have the same length"
384
+ if u.math.ndim(weight) == 1:
385
+ # If weight is a 1D array, it should have two elements for pre and post weights
386
+ assert weight.shape[0] == 2, "weight must be a 2-element array for asymmetry gap junctions"
387
+ pre_weight = weight[0]
388
+ post_weight = weight[1]
389
+ elif u.math.ndim(weight) == 2:
390
+ # If weight is a 2D array, it should have two rows for pre and post weights
391
+ pre_weight = weight[:, 0]
392
+ post_weight = weight[:, 1]
393
+ assert pre_weight.shape[0] == len(pre_ids), "pre_weight must have the same length as pre_ids"
394
+ assert post_weight.shape[0] == len(post_ids), "post_weight must have the same length as post_ids"
395
+ else:
396
+ raise ValueError("weight must be a 1D or 2D array for asymmetry gap junctions")
397
+
398
+ # Calculate the voltage difference between connected pre-synaptic and post-synaptic neurons
399
+ # and multiply by the connection weights
400
+ diff = pre_value[pre_ids] - post_value[post_ids]
401
+ pre2post_current = diff * pre_weight
402
+ post2pre_current = diff * post_weight
403
+
404
+ # add to post-synaptic neuron group
405
+ # Initialize the input currents for the post-synaptic neuron group
406
+ inputs = u.math.zeros(post.out_size, unit=u.get_unit(pre2post_current))
407
+ # Add the calculated current to the corresponding post-synaptic neurons
408
+ inputs = inputs.at[post_ids].add(pre2post_current)
409
+ # Generate a unique key for the post-synaptic input currents
410
+ key = get_gap_junction_post_key(0 if post.current_inputs is None else len(post.current_inputs))
411
+ # Add the input currents to the post-synaptic neuron group
412
+ post.add_current_input(key, inputs)
413
+
414
+ # add to pre-synaptic neuron group
415
+ # Initialize the input currents for the pre-synaptic neuron group
416
+ inputs = u.math.zeros(pre.out_size, unit=u.get_unit(post2pre_current))
417
+ # Add the calculated current to the corresponding pre-synaptic neurons
418
+ inputs = inputs.at[pre_ids].add(post2pre_current)
419
+ # Generate a unique key for the pre-synaptic input currents
420
+ key = get_gap_junction_pre_key(0 if pre.current_inputs is None else len(pre.current_inputs))
421
+ # Add the input currents to the pre-synaptic neuron group with opposite polarity
422
+ pre.add_current_input(key, -inputs)
423
+ return inputs
@@ -19,8 +19,8 @@ import brainunit as u
19
19
  import jax.numpy as jnp
20
20
 
21
21
  from brainstate.mixin import BindCondData
22
- from brainstate.nn._module import Module
23
22
  from brainstate.typing import ArrayLike
23
+ from ._module import Module
24
24
 
25
25
  __all__ = [
26
26
  'SynOut', 'COBA', 'CUBA', 'MgBlock',
@@ -47,6 +47,9 @@ class SynOut(Module, BindCondData):
47
47
  ret = self.update(self._conductance, *args, **kwargs)
48
48
  return ret
49
49
 
50
+ def update(self, conductance, potential):
51
+ raise NotImplementedError
52
+
50
53
 
51
54
  class COBA(SynOut):
52
55
  r"""
brainstate/surrogate.py CHANGED
@@ -21,7 +21,7 @@ import jax.scipy as sci
21
21
  from jax.interpreters import batching, ad, mlir
22
22
 
23
23
  from brainstate._compatible_import import Primitive
24
- from brainstate.util._pretty_pytree import PrettyObject
24
+ from brainstate.util import PrettyObject
25
25
 
26
26
  __all__ = [
27
27
  'Surrogate',
brainstate/typing.py CHANGED
@@ -257,7 +257,7 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
257
257
  cases, all named pieces must already have been seen and their structures bound.
258
258
  """ # noqa: E501
259
259
 
260
- Size = Union[int, Sequence[int]]
260
+ Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
261
261
  Axes = Union[int, Sequence[int]]
262
262
  SeedOrKey = Union[int, jax.Array, np.ndarray]
263
263
  Shape = Sequence[int]
@@ -14,20 +14,20 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from . import filter
17
- from ._error import *
18
- from ._error import __all__ as _error_all
19
- from ._others import *
20
- from ._others import __all__ as _others_all
21
- from ._pretty_pytree import *
22
- from ._pretty_pytree import __all__ as _mapping_all
23
- from ._pretty_repr import *
24
- from ._pretty_repr import __all__ as _pretty_repr_all
25
- from ._pretty_table import *
26
- from ._pretty_table import __all__ as _table_all
27
- from ._scaling import *
28
- from ._scaling import __all__ as _mem_scale_all
29
- from ._struct import *
30
- from ._struct import __all__ as _struct_all
17
+ from .error import *
18
+ from .error import __all__ as _error_all
19
+ from .others import *
20
+ from .others import __all__ as _others_all
21
+ from .pretty_pytree import *
22
+ from .pretty_pytree import __all__ as _mapping_all
23
+ from .pretty_repr import *
24
+ from .pretty_repr import __all__ as _pretty_repr_all
25
+ from .pretty_table import *
26
+ from .pretty_table import __all__ as _table_all
27
+ from .scaling import *
28
+ from .scaling import __all__ as _mem_scale_all
29
+ from .struct import *
30
+ from .struct import __all__ as _struct_all
31
31
 
32
32
  __all__ = (
33
33
  ['filter']
@@ -21,8 +21,8 @@ from typing import TypeVar, Hashable, Union, Iterable, Any, Optional, Tuple, Dic
21
21
  import jax
22
22
 
23
23
  from brainstate.typing import Filter, PathParts
24
- from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
25
- from ._struct import dataclass
24
+ from .pretty_repr import PrettyRepr, PrettyType, PrettyAttr, yield_unique_pretty_repr_items, pretty_repr_object
25
+ from .struct import dataclass
26
26
  from .filter import to_predicate
27
27
 
28
28
  __all__ = [
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers