brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,453 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import numbers
20
+ from functools import partial
21
+ from typing import Optional, Dict, Callable, Union, Sequence
22
+
23
+ import brainunit as u
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ from brainstate import environ
29
+ from brainstate._state import ShortTermState, State
30
+ from brainstate.compile import jit_error_if
31
+ from brainstate.graph import Node
32
+ from brainstate.nn._collective_ops import call_order
33
+ from brainstate.nn._module import Module
34
+ from brainstate.typing import ArrayLike, PyTree
35
+
36
+ __all__ = [
37
+ 'Delay', 'DelayAccess', 'StateWithDelay',
38
+ ]
39
+
40
+ _DELAY_ROTATE = 'rotation'
41
+ _DELAY_CONCAT = 'concat'
42
+ _INTERP_LINEAR = 'linear_interp'
43
+ _INTERP_ROUND = 'round'
44
+
45
+
46
+ def _get_delay(delay_time, delay_step):
47
+ if delay_time is None:
48
+ if delay_step is None:
49
+ return 0., 0
50
+ else:
51
+ assert isinstance(delay_step, int), '"delay_step" should be an integer.'
52
+ if delay_step == 0:
53
+ return 0., 0
54
+ delay_time = delay_step * environ.get_dt()
55
+ else:
56
+ assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
57
+ # assert isinstance(delay_time, (int, float))
58
+ delay_step = math.ceil(delay_time / environ.get_dt())
59
+ return delay_time, delay_step
60
+
61
+
62
+ class DelayAccess(Node):
63
+ """
64
+ The delay access class.
65
+
66
+ Args:
67
+ delay: The delay instance.
68
+ time: The delay time.
69
+ indices: The indices of the delay data.
70
+ delay_entry: The delay entry.
71
+ """
72
+
73
+ __module__ = 'brainstate.nn'
74
+
75
+ def __init__(
76
+ self,
77
+ delay: 'Delay',
78
+ time: Union[None, int, float],
79
+ delay_entry: str,
80
+ *indices,
81
+ ):
82
+ super().__init__()
83
+ self.refs = {'delay': delay}
84
+ assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
85
+ self._delay_entry = delay_entry
86
+ delay.register_entry(self._delay_entry, time)
87
+ self.indices = indices
88
+
89
+ def update(self):
90
+ return self.refs['delay'].at(self._delay_entry, *self.indices)
91
+
92
+
93
+ class Delay(Module):
94
+ """
95
+ Generate Delays for the given :py:class:`~.State` instance.
96
+
97
+ The data in this delay variable is arranged as::
98
+
99
+ delay = 0 [ data
100
+ delay = 1 data
101
+ delay = 2 data
102
+ ... ....
103
+ ... ....
104
+ delay = length-1 data
105
+ delay = length data ]
106
+
107
+ Args:
108
+ time: int, float. The delay time.
109
+ init: Any. The delay data. It can be a Python number, like float, int, boolean values.
110
+ It can also be arrays. Or a callable function or instance of ``Connector``.
111
+ Note that ``initial_delay_data`` should be arranged as the following way::
112
+
113
+ delay = 1 [ data
114
+ delay = 2 data
115
+ ... ....
116
+ ... ....
117
+ delay = length-1 data
118
+ delay = length data ]
119
+ entries: optional, dict. The delay access entries.
120
+ delay_method: str. The method used for updating delay. Default None.
121
+ """
122
+
123
+ __module__ = 'brainstate.nn'
124
+
125
+ max_time: float #
126
+ max_length: int
127
+ history: Optional[ShortTermState]
128
+
129
+ def __init__(
130
+ self,
131
+ target_info: PyTree,
132
+ time: Optional[Union[int, float, u.Quantity]] = None, # delay time
133
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
134
+ entries: Optional[Dict] = None, # delay access entry
135
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
136
+ interp_method: str = _INTERP_LINEAR, # interpolation method
137
+ ):
138
+ # target information
139
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
140
+
141
+ # delay method
142
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
143
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
144
+ self.delay_method = delay_method
145
+
146
+ # interp method
147
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
148
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
149
+ self.interp_method = interp_method
150
+
151
+ # delay length and time
152
+ self.max_time, delay_length = _get_delay(time, None)
153
+ self.max_length = delay_length + 1
154
+
155
+ super().__init__()
156
+
157
+ # delay data
158
+ if init is not None:
159
+ if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
160
+ raise TypeError(f'init should be Array, Callable, or None. But got {init}')
161
+ self._init = init
162
+ self._history = None
163
+
164
+ # other info
165
+ self._registered_entries = dict()
166
+
167
+ # other info
168
+ if entries is not None:
169
+ for entry, delay_time in entries.items():
170
+ self.register_entry(entry, delay_time)
171
+
172
+ @property
173
+ def history(self):
174
+ return self._history
175
+
176
+ @history.setter
177
+ def history(self, value):
178
+ self._history = value
179
+
180
+ def _f_to_init(self, a, batch_size, length):
181
+ shape = list(a.shape)
182
+ if batch_size is not None:
183
+ shape.insert(0, batch_size)
184
+ shape.insert(0, length)
185
+ if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
186
+ data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
187
+ elif callable(self._init):
188
+ data = self._init(shape, dtype=a.dtype)
189
+ else:
190
+ assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
191
+ data = jnp.zeros(shape, dtype=a.dtype)
192
+ return data
193
+
194
+ @call_order(3)
195
+ def init_state(self, batch_size: int = None, **kwargs):
196
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
197
+ self.history = ShortTermState(jax.tree.map(fun, self.target_info))
198
+
199
+ def reset_state(self, batch_size: int = None, **kwargs):
200
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
201
+ self.history.value = jax.tree.map(fun, self.target_info)
202
+
203
+ def register_delay(
204
+ self,
205
+ delay_time: Optional[Union[int, float]] = None,
206
+ delay_step: Optional[int] = None,
207
+ ):
208
+ if isinstance(delay_time, (np.ndarray, jax.Array)):
209
+ assert delay_time.size == 1 and delay_time.ndim == 0
210
+ delay_time = delay_time.item()
211
+
212
+ _, delay_step = _get_delay(delay_time, delay_step)
213
+
214
+ # delay variable
215
+ if self.max_length <= delay_step + 1:
216
+ self.max_length = delay_step + 1
217
+ self.max_time = delay_time
218
+ return self
219
+
220
+ def register_entry(
221
+ self,
222
+ entry: str,
223
+ delay_time: Optional[Union[int, float]] = None,
224
+ delay_step: Optional[int] = None,
225
+ ) -> 'Delay':
226
+ """
227
+ Register an entry to access the delay data.
228
+
229
+ Args:
230
+ entry: str. The entry to access the delay data.
231
+ delay_time: The delay time of the entry (can be a float).
232
+ delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
233
+
234
+ Returns:
235
+ Return the self.
236
+ """
237
+ if entry in self._registered_entries:
238
+ raise KeyError(f'Entry {entry} has been registered. '
239
+ f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
240
+ f'The new delay for the key {entry} is {delay_time}. '
241
+ f'You can use another key. ')
242
+
243
+ if isinstance(delay_time, (np.ndarray, jax.Array)):
244
+ assert delay_time.size == 1 and delay_time.ndim == 0
245
+ delay_time = delay_time.item()
246
+
247
+ _, delay_step = _get_delay(delay_time, delay_step)
248
+
249
+ # delay variable
250
+ if self.max_length <= delay_step + 1:
251
+ self.max_length = delay_step + 1
252
+ self.max_time = delay_time
253
+ self._registered_entries[entry] = delay_step
254
+ return self
255
+
256
+ def access(
257
+ self,
258
+ entry: str = None,
259
+ time: Sequence = None,
260
+ ) -> DelayAccess:
261
+ return DelayAccess(self, time, delay_entry=entry)
262
+
263
+ def at(self, entry: str, *indices) -> ArrayLike:
264
+ """
265
+ Get the data at the given entry.
266
+
267
+ Args:
268
+ entry: str. The entry to access the data.
269
+ *indices: The slicing indices. Not include the slice at the batch dimension.
270
+
271
+ Returns:
272
+ The data.
273
+ """
274
+ assert isinstance(entry, str), (f'entry should be a string for describing the '
275
+ f'entry of the delay data. But we got {entry}.')
276
+ if entry not in self._registered_entries:
277
+ raise KeyError(f'Does not find delay entry "{entry}".')
278
+ delay_step = self._registered_entries[entry]
279
+ if delay_step is None:
280
+ delay_step = 0
281
+ return self.retrieve_at_step(delay_step, *indices)
282
+
283
+ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
284
+ """
285
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
286
+
287
+ Parameters
288
+ ----------
289
+ delay_step: int_like
290
+ Retrieve the data at the given time step.
291
+ indices: tuple
292
+ The indices to slice the data.
293
+
294
+ Returns
295
+ -------
296
+ delay_data: The delay data at the given delay step.
297
+
298
+ """
299
+ assert self.history is not None, 'The delay history is not initialized.'
300
+ assert delay_step is not None, 'The delay step should be given.'
301
+
302
+ if environ.get(environ.JIT_ERROR_CHECK, False):
303
+ def _check_delay(delay_len):
304
+ raise ValueError(f'The request delay length should be less than the '
305
+ f'maximum delay {self.max_length - 1}. But we got {delay_len}')
306
+
307
+ jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
308
+
309
+ # rotation method
310
+ if self.delay_method == _DELAY_ROTATE:
311
+ i = environ.get(environ.I, desc='The time step index.')
312
+ di = i - delay_step
313
+ delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
314
+ delay_idx = jax.lax.stop_gradient(delay_idx)
315
+
316
+ elif self.delay_method == _DELAY_CONCAT:
317
+ delay_idx = delay_step
318
+
319
+ else:
320
+ raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
321
+
322
+ # the delay index
323
+ if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
324
+ raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
325
+ indices = (delay_idx,) + indices
326
+
327
+ # the delay data
328
+ return jax.tree.map(lambda a: a[indices], self.history.value)
329
+
330
+ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
331
+ """
332
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
333
+
334
+ Parameters
335
+ ----------
336
+ delay_time: float
337
+ Retrieve the data at the given time.
338
+ indices: tuple
339
+ The indices to slice the data.
340
+
341
+ Returns
342
+ -------
343
+ delay_data: The delay data at the given delay step.
344
+
345
+ """
346
+ assert self.history is not None, 'The delay history is not initialized.'
347
+ assert delay_time is not None, 'The delay time should be given.'
348
+
349
+ current_time = environ.get(environ.T, desc='The current time.')
350
+ dt = environ.get_dt()
351
+
352
+ if environ.get(environ.JIT_ERROR_CHECK, False):
353
+ def _check_delay(t_now, t_delay):
354
+ raise ValueError(f'The request delay time should be within '
355
+ f'[{t_now - self.max_time - dt}, {t_now}], '
356
+ f'but we got {t_delay}')
357
+
358
+ jit_error_if(
359
+ jnp.logical_or(delay_time > current_time,
360
+ delay_time < current_time - self.max_time - dt),
361
+ _check_delay,
362
+ current_time,
363
+ delay_time
364
+ )
365
+
366
+ diff = current_time - delay_time
367
+ float_time_step = diff / dt
368
+
369
+ if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
370
+ data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
371
+ data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
372
+ t_diff = float_time_step - jnp.floor(float_time_step)
373
+ return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
374
+
375
+ elif self.interp_method == _INTERP_ROUND: # "round" interpolation
376
+ return self.retrieve_at_step(
377
+ jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
378
+ *indices
379
+ )
380
+
381
+ else: # raise error
382
+ raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
383
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
384
+
385
+ def update(self, current: PyTree) -> None:
386
+ """
387
+ Update delay variable with the new data.
388
+ """
389
+ assert self.history is not None, 'The delay history is not initialized.'
390
+
391
+ # update the delay data at the rotation index
392
+ if self.delay_method == _DELAY_ROTATE:
393
+ i = environ.get(environ.I)
394
+ idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
395
+ idx = jax.lax.stop_gradient(idx)
396
+ self.history.value = jax.tree.map(
397
+ lambda hist, cur: hist.at[idx].set(cur),
398
+ self.history.value,
399
+ current
400
+ )
401
+ # update the delay data at the first position
402
+ elif self.delay_method == _DELAY_CONCAT:
403
+ current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
404
+ if self.max_length > 1:
405
+ self.history.value = jax.tree.map(
406
+ lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
407
+ self.history.value,
408
+ current
409
+ )
410
+ else:
411
+ self.history.value = current
412
+
413
+ else:
414
+ raise ValueError(f'Unknown updating method "{self.delay_method}"')
415
+
416
+
417
+ class StateWithDelay(Delay):
418
+ """
419
+ A ``State`` type that defines the state in a differential equation.
420
+ """
421
+
422
+ __module__ = 'brainstate.nn'
423
+
424
+ state: State # state
425
+
426
+ def __init__(self, target: Node, item: str):
427
+ super().__init__(None)
428
+
429
+ self._target = target
430
+ self._target_term = item
431
+
432
+ @property
433
+ def state(self) -> State:
434
+ r = getattr(self._target, self._target_term)
435
+ if not isinstance(r, State):
436
+ raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
437
+ return r
438
+
439
+ @call_order(3)
440
+ def init_state(self, *args, **kwargs):
441
+ """
442
+ State initialization function.
443
+ """
444
+ state = self.state
445
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
446
+ super().init_state(*args, **kwargs)
447
+
448
+ def update(self, *args) -> None:
449
+ """
450
+ Update the delay variable with the new data.
451
+ """
452
+ value = self.state.value
453
+ return super().update(value)
@@ -0,0 +1,161 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate.mixin import BindCondData
24
+ from brainstate.nn._module import Module
25
+ from brainstate.typing import ArrayLike
26
+
27
+ __all__ = [
28
+ 'SynOut', 'COBA', 'CUBA', 'MgBlock',
29
+ ]
30
+
31
+
32
+ class SynOut(Module, BindCondData):
33
+ """
34
+ Base class for synaptic outputs.
35
+
36
+ :py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
37
+ """
38
+
39
+ __module__ = 'brainstate.nn'
40
+
41
+ def __init__(self, ):
42
+ super().__init__()
43
+ self._conductance = None
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ if self._conductance is None:
47
+ raise ValueError(f'Please first pack conductance data at the current step using '
48
+ f'".{BindCondData.bind_cond.__name__}(data)". {self}')
49
+ ret = self.update(self._conductance, *args, **kwargs)
50
+ return ret
51
+
52
+
53
+ class COBA(SynOut):
54
+ r"""
55
+ Conductance-based synaptic output.
56
+
57
+ Given the synaptic conductance, the model output the post-synaptic current with
58
+
59
+ .. math::
60
+
61
+ I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
62
+
63
+ Parameters
64
+ ----------
65
+ E: ArrayLike
66
+ The reversal potential.
67
+
68
+ See Also
69
+ --------
70
+ CUBA
71
+ """
72
+ __module__ = 'brainstate.nn'
73
+
74
+ def __init__(self, E: ArrayLike):
75
+ super().__init__()
76
+
77
+ self.E = E
78
+
79
+ def update(self, conductance, potential):
80
+ return conductance * (self.E - potential)
81
+
82
+
83
+ class CUBA(SynOut):
84
+ r"""Current-based synaptic output.
85
+
86
+ Given the conductance, this model outputs the post-synaptic current with a identity function:
87
+
88
+ .. math::
89
+
90
+ I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
91
+
92
+ Parameters
93
+ ----------
94
+ scale: ArrayLike
95
+ The scaling factor for the conductance. Default 1. [mV]
96
+
97
+ See Also
98
+ --------
99
+ COBA
100
+ """
101
+ __module__ = 'brainstate.nn'
102
+
103
+ def __init__(self, scale: ArrayLike = u.volt):
104
+ super().__init__()
105
+ self.scale = scale
106
+
107
+ def update(self, conductance, potential=None):
108
+ return conductance * self.scale
109
+
110
+
111
+ class MgBlock(SynOut):
112
+ r"""Synaptic output based on Magnesium blocking.
113
+
114
+ Given the synaptic conductance, the model output the post-synaptic current with
115
+
116
+ .. math::
117
+
118
+ I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
119
+
120
+ where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
121
+
122
+ .. math::
123
+
124
+ g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
125
+
126
+ Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
127
+
128
+ Parameters
129
+ ----------
130
+ E: ArrayLike
131
+ The reversal potential for the synaptic current. [mV]
132
+ alpha: ArrayLike
133
+ Binding constant. Default 0.062
134
+ beta: ArrayLike
135
+ Unbinding constant. Default 3.57
136
+ cc_Mg: ArrayLike
137
+ Concentration of Magnesium ion. Default 1.2 [mM].
138
+ V_offset: ArrayLike
139
+ The offset potential. Default 0. [mV]
140
+ """
141
+ __module__ = 'brainstate.nn'
142
+
143
+ def __init__(
144
+ self,
145
+ E: ArrayLike = 0.,
146
+ cc_Mg: ArrayLike = 1.2,
147
+ alpha: ArrayLike = 0.062,
148
+ beta: ArrayLike = 3.57,
149
+ V_offset: ArrayLike = 0.,
150
+ ):
151
+ super().__init__()
152
+
153
+ self.E = E
154
+ self.V_offset = V_offset
155
+ self.cc_Mg = cc_Mg
156
+ self.alpha = alpha
157
+ self.beta = beta
158
+
159
+ def update(self, conductance, potential):
160
+ norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
161
+ return conductance * (self.E - potential) / norm
@@ -0,0 +1,58 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ import brainstate as bst
25
+
26
+
27
+ class TestSynOutModels(unittest.TestCase):
28
+ def setUp(self):
29
+ self.conductance = jnp.array([0.5, 1.0, 1.5])
30
+ self.potential = jnp.array([-70.0, -65.0, -60.0])
31
+ self.E = jnp.array([-70.0])
32
+ self.alpha = jnp.array([0.062])
33
+ self.beta = jnp.array([3.57])
34
+ self.cc_Mg = jnp.array([1.2])
35
+ self.V_offset = jnp.array([0.0])
36
+
37
+ def test_COBA(self):
38
+ model = bst.nn.COBA(E=self.E)
39
+ output = model.update(self.conductance, self.potential)
40
+ expected_output = self.conductance * (self.E - self.potential)
41
+ np.testing.assert_array_almost_equal(output, expected_output)
42
+
43
+ def test_CUBA(self):
44
+ model = bst.nn.CUBA()
45
+ output = model.update(self.conductance)
46
+ expected_output = self.conductance * model.scale
47
+ self.assertTrue(u.math.allclose(output, expected_output))
48
+
49
+ def test_MgBlock(self):
50
+ model = bst.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
51
+ output = model.update(self.conductance, self.potential)
52
+ norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
53
+ expected_output = self.conductance * (self.E - self.potential) / norm
54
+ np.testing.assert_array_almost_equal(output, expected_output)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ unittest.main()
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ._dropout import *
17
+ from ._dropout import __all__ as dropout_all
18
+ from ._elementwise import *
19
+ from ._elementwise import __all__ as elementwise_all
20
+
21
+ __all__ = dropout_all + elementwise_all
22
+ del dropout_all, elementwise_all