brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,453 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import numbers
20
+ from functools import partial
21
+ from typing import Optional, Dict, Callable, Union, Sequence
22
+
23
+ import brainunit as u
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ from brainstate import environ
29
+ from brainstate._state import ShortTermState, State
30
+ from brainstate.compile import jit_error_if
31
+ from brainstate.graph import Node
32
+ from brainstate.nn._collective_ops import call_order
33
+ from brainstate.nn._module import Module
34
+ from brainstate.typing import ArrayLike, PyTree
35
+
36
+ __all__ = [
37
+ 'Delay', 'DelayAccess', 'StateWithDelay',
38
+ ]
39
+
40
+ _DELAY_ROTATE = 'rotation'
41
+ _DELAY_CONCAT = 'concat'
42
+ _INTERP_LINEAR = 'linear_interp'
43
+ _INTERP_ROUND = 'round'
44
+
45
+
46
+ def _get_delay(delay_time, delay_step):
47
+ if delay_time is None:
48
+ if delay_step is None:
49
+ return 0., 0
50
+ else:
51
+ assert isinstance(delay_step, int), '"delay_step" should be an integer.'
52
+ if delay_step == 0:
53
+ return 0., 0
54
+ delay_time = delay_step * environ.get_dt()
55
+ else:
56
+ assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
57
+ # assert isinstance(delay_time, (int, float))
58
+ delay_step = math.ceil(delay_time / environ.get_dt())
59
+ return delay_time, delay_step
60
+
61
+
62
+ class DelayAccess(Node):
63
+ """
64
+ The delay access class.
65
+
66
+ Args:
67
+ delay: The delay instance.
68
+ time: The delay time.
69
+ indices: The indices of the delay data.
70
+ delay_entry: The delay entry.
71
+ """
72
+
73
+ __module__ = 'brainstate.nn'
74
+
75
+ def __init__(
76
+ self,
77
+ delay: 'Delay',
78
+ time: Union[None, int, float],
79
+ delay_entry: str,
80
+ *indices,
81
+ ):
82
+ super().__init__()
83
+ self.refs = {'delay': delay}
84
+ assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
85
+ self._delay_entry = delay_entry
86
+ delay.register_entry(self._delay_entry, time)
87
+ self.indices = indices
88
+
89
+ def update(self):
90
+ return self.refs['delay'].at(self._delay_entry, *self.indices)
91
+
92
+
93
+ class Delay(Module):
94
+ """
95
+ Generate Delays for the given :py:class:`~.State` instance.
96
+
97
+ The data in this delay variable is arranged as::
98
+
99
+ delay = 0 [ data
100
+ delay = 1 data
101
+ delay = 2 data
102
+ ... ....
103
+ ... ....
104
+ delay = length-1 data
105
+ delay = length data ]
106
+
107
+ Args:
108
+ time: int, float. The delay time.
109
+ init: Any. The delay data. It can be a Python number, like float, int, boolean values.
110
+ It can also be arrays. Or a callable function or instance of ``Connector``.
111
+ Note that ``initial_delay_data`` should be arranged as the following way::
112
+
113
+ delay = 1 [ data
114
+ delay = 2 data
115
+ ... ....
116
+ ... ....
117
+ delay = length-1 data
118
+ delay = length data ]
119
+ entries: optional, dict. The delay access entries.
120
+ delay_method: str. The method used for updating delay. Default None.
121
+ """
122
+
123
+ __module__ = 'brainstate.nn'
124
+
125
+ max_time: float #
126
+ max_length: int
127
+ history: Optional[ShortTermState]
128
+
129
+ def __init__(
130
+ self,
131
+ target_info: PyTree,
132
+ time: Optional[Union[int, float, u.Quantity]] = None, # delay time
133
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
134
+ entries: Optional[Dict] = None, # delay access entry
135
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
136
+ interp_method: str = _INTERP_LINEAR, # interpolation method
137
+ ):
138
+ # target information
139
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
140
+
141
+ # delay method
142
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
143
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
144
+ self.delay_method = delay_method
145
+
146
+ # interp method
147
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
148
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
149
+ self.interp_method = interp_method
150
+
151
+ # delay length and time
152
+ self.max_time, delay_length = _get_delay(time, None)
153
+ self.max_length = delay_length + 1
154
+
155
+ super().__init__()
156
+
157
+ # delay data
158
+ if init is not None:
159
+ if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
160
+ raise TypeError(f'init should be Array, Callable, or None. But got {init}')
161
+ self._init = init
162
+ self._history = None
163
+
164
+ # other info
165
+ self._registered_entries = dict()
166
+
167
+ # other info
168
+ if entries is not None:
169
+ for entry, delay_time in entries.items():
170
+ self.register_entry(entry, delay_time)
171
+
172
+ @property
173
+ def history(self):
174
+ return self._history
175
+
176
+ @history.setter
177
+ def history(self, value):
178
+ self._history = value
179
+
180
+ def _f_to_init(self, a, batch_size, length):
181
+ shape = list(a.shape)
182
+ if batch_size is not None:
183
+ shape.insert(0, batch_size)
184
+ shape.insert(0, length)
185
+ if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
186
+ data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
187
+ elif callable(self._init):
188
+ data = self._init(shape, dtype=a.dtype)
189
+ else:
190
+ assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
191
+ data = jnp.zeros(shape, dtype=a.dtype)
192
+ return data
193
+
194
+ @call_order(3)
195
+ def init_state(self, batch_size: int = None, **kwargs):
196
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
197
+ self.history = ShortTermState(jax.tree.map(fun, self.target_info))
198
+
199
+ def reset_state(self, batch_size: int = None, **kwargs):
200
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
201
+ self.history.value = jax.tree.map(fun, self.target_info)
202
+
203
+ def register_delay(
204
+ self,
205
+ delay_time: Optional[Union[int, float]] = None,
206
+ delay_step: Optional[int] = None,
207
+ ):
208
+ if isinstance(delay_time, (np.ndarray, jax.Array)):
209
+ assert delay_time.size == 1 and delay_time.ndim == 0
210
+ delay_time = delay_time.item()
211
+
212
+ _, delay_step = _get_delay(delay_time, delay_step)
213
+
214
+ # delay variable
215
+ if self.max_length <= delay_step + 1:
216
+ self.max_length = delay_step + 1
217
+ self.max_time = delay_time
218
+ return self
219
+
220
+ def register_entry(
221
+ self,
222
+ entry: str,
223
+ delay_time: Optional[Union[int, float]] = None,
224
+ delay_step: Optional[int] = None,
225
+ ) -> 'Delay':
226
+ """
227
+ Register an entry to access the delay data.
228
+
229
+ Args:
230
+ entry: str. The entry to access the delay data.
231
+ delay_time: The delay time of the entry (can be a float).
232
+ delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
233
+
234
+ Returns:
235
+ Return the self.
236
+ """
237
+ if entry in self._registered_entries:
238
+ raise KeyError(f'Entry {entry} has been registered. '
239
+ f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
240
+ f'The new delay for the key {entry} is {delay_time}. '
241
+ f'You can use another key. ')
242
+
243
+ if isinstance(delay_time, (np.ndarray, jax.Array)):
244
+ assert delay_time.size == 1 and delay_time.ndim == 0
245
+ delay_time = delay_time.item()
246
+
247
+ _, delay_step = _get_delay(delay_time, delay_step)
248
+
249
+ # delay variable
250
+ if self.max_length <= delay_step + 1:
251
+ self.max_length = delay_step + 1
252
+ self.max_time = delay_time
253
+ self._registered_entries[entry] = delay_step
254
+ return self
255
+
256
+ def access(
257
+ self,
258
+ entry: str = None,
259
+ time: Sequence = None,
260
+ ) -> DelayAccess:
261
+ return DelayAccess(self, time, delay_entry=entry)
262
+
263
+ def at(self, entry: str, *indices) -> ArrayLike:
264
+ """
265
+ Get the data at the given entry.
266
+
267
+ Args:
268
+ entry: str. The entry to access the data.
269
+ *indices: The slicing indices. Not include the slice at the batch dimension.
270
+
271
+ Returns:
272
+ The data.
273
+ """
274
+ assert isinstance(entry, str), (f'entry should be a string for describing the '
275
+ f'entry of the delay data. But we got {entry}.')
276
+ if entry not in self._registered_entries:
277
+ raise KeyError(f'Does not find delay entry "{entry}".')
278
+ delay_step = self._registered_entries[entry]
279
+ if delay_step is None:
280
+ delay_step = 0
281
+ return self.retrieve_at_step(delay_step, *indices)
282
+
283
+ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
284
+ """
285
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
286
+
287
+ Parameters
288
+ ----------
289
+ delay_step: int_like
290
+ Retrieve the data at the given time step.
291
+ indices: tuple
292
+ The indices to slice the data.
293
+
294
+ Returns
295
+ -------
296
+ delay_data: The delay data at the given delay step.
297
+
298
+ """
299
+ assert self.history is not None, 'The delay history is not initialized.'
300
+ assert delay_step is not None, 'The delay step should be given.'
301
+
302
+ if environ.get(environ.JIT_ERROR_CHECK, False):
303
+ def _check_delay(delay_len):
304
+ raise ValueError(f'The request delay length should be less than the '
305
+ f'maximum delay {self.max_length - 1}. But we got {delay_len}')
306
+
307
+ jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
308
+
309
+ # rotation method
310
+ if self.delay_method == _DELAY_ROTATE:
311
+ i = environ.get(environ.I, desc='The time step index.')
312
+ di = i - delay_step
313
+ delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
314
+ delay_idx = jax.lax.stop_gradient(delay_idx)
315
+
316
+ elif self.delay_method == _DELAY_CONCAT:
317
+ delay_idx = delay_step
318
+
319
+ else:
320
+ raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
321
+
322
+ # the delay index
323
+ if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
324
+ raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
325
+ indices = (delay_idx,) + indices
326
+
327
+ # the delay data
328
+ return jax.tree.map(lambda a: a[indices], self.history.value)
329
+
330
+ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
331
+ """
332
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
333
+
334
+ Parameters
335
+ ----------
336
+ delay_time: float
337
+ Retrieve the data at the given time.
338
+ indices: tuple
339
+ The indices to slice the data.
340
+
341
+ Returns
342
+ -------
343
+ delay_data: The delay data at the given delay step.
344
+
345
+ """
346
+ assert self.history is not None, 'The delay history is not initialized.'
347
+ assert delay_time is not None, 'The delay time should be given.'
348
+
349
+ current_time = environ.get(environ.T, desc='The current time.')
350
+ dt = environ.get_dt()
351
+
352
+ if environ.get(environ.JIT_ERROR_CHECK, False):
353
+ def _check_delay(t_now, t_delay):
354
+ raise ValueError(f'The request delay time should be within '
355
+ f'[{t_now - self.max_time - dt}, {t_now}], '
356
+ f'but we got {t_delay}')
357
+
358
+ jit_error_if(
359
+ jnp.logical_or(delay_time > current_time,
360
+ delay_time < current_time - self.max_time - dt),
361
+ _check_delay,
362
+ current_time,
363
+ delay_time
364
+ )
365
+
366
+ diff = current_time - delay_time
367
+ float_time_step = diff / dt
368
+
369
+ if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
370
+ data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
371
+ data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
372
+ t_diff = float_time_step - jnp.floor(float_time_step)
373
+ return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
374
+
375
+ elif self.interp_method == _INTERP_ROUND: # "round" interpolation
376
+ return self.retrieve_at_step(
377
+ jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
378
+ *indices
379
+ )
380
+
381
+ else: # raise error
382
+ raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
383
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
384
+
385
+ def update(self, current: PyTree) -> None:
386
+ """
387
+ Update delay variable with the new data.
388
+ """
389
+ assert self.history is not None, 'The delay history is not initialized.'
390
+
391
+ # update the delay data at the rotation index
392
+ if self.delay_method == _DELAY_ROTATE:
393
+ i = environ.get(environ.I)
394
+ idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
395
+ idx = jax.lax.stop_gradient(idx)
396
+ self.history.value = jax.tree.map(
397
+ lambda hist, cur: hist.at[idx].set(cur),
398
+ self.history.value,
399
+ current
400
+ )
401
+ # update the delay data at the first position
402
+ elif self.delay_method == _DELAY_CONCAT:
403
+ current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
404
+ if self.max_length > 1:
405
+ self.history.value = jax.tree.map(
406
+ lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
407
+ self.history.value,
408
+ current
409
+ )
410
+ else:
411
+ self.history.value = current
412
+
413
+ else:
414
+ raise ValueError(f'Unknown updating method "{self.delay_method}"')
415
+
416
+
417
+ class StateWithDelay(Delay):
418
+ """
419
+ A ``State`` type that defines the state in a differential equation.
420
+ """
421
+
422
+ __module__ = 'brainstate.nn'
423
+
424
+ state: State # state
425
+
426
+ def __init__(self, target: Node, item: str):
427
+ super().__init__(None)
428
+
429
+ self._target = target
430
+ self._target_term = item
431
+
432
+ @property
433
+ def state(self) -> State:
434
+ r = getattr(self._target, self._target_term)
435
+ if not isinstance(r, State):
436
+ raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
437
+ return r
438
+
439
+ @call_order(3)
440
+ def init_state(self, *args, **kwargs):
441
+ """
442
+ State initialization function.
443
+ """
444
+ state = self.state
445
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
446
+ super().init_state(*args, **kwargs)
447
+
448
+ def update(self, *args) -> None:
449
+ """
450
+ Update the delay variable with the new data.
451
+ """
452
+ value = self.state.value
453
+ return super().update(value)
@@ -0,0 +1,161 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate.mixin import BindCondData
24
+ from brainstate.nn._module import Module
25
+ from brainstate.typing import ArrayLike
26
+
27
+ __all__ = [
28
+ 'SynOut', 'COBA', 'CUBA', 'MgBlock',
29
+ ]
30
+
31
+
32
+ class SynOut(Module, BindCondData):
33
+ """
34
+ Base class for synaptic outputs.
35
+
36
+ :py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
37
+ """
38
+
39
+ __module__ = 'brainstate.nn'
40
+
41
+ def __init__(self, ):
42
+ super().__init__()
43
+ self._conductance = None
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ if self._conductance is None:
47
+ raise ValueError(f'Please first pack conductance data at the current step using '
48
+ f'".{BindCondData.bind_cond.__name__}(data)". {self}')
49
+ ret = self.update(self._conductance, *args, **kwargs)
50
+ return ret
51
+
52
+
53
+ class COBA(SynOut):
54
+ r"""
55
+ Conductance-based synaptic output.
56
+
57
+ Given the synaptic conductance, the model output the post-synaptic current with
58
+
59
+ .. math::
60
+
61
+ I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
62
+
63
+ Parameters
64
+ ----------
65
+ E: ArrayLike
66
+ The reversal potential.
67
+
68
+ See Also
69
+ --------
70
+ CUBA
71
+ """
72
+ __module__ = 'brainstate.nn'
73
+
74
+ def __init__(self, E: ArrayLike):
75
+ super().__init__()
76
+
77
+ self.E = E
78
+
79
+ def update(self, conductance, potential):
80
+ return conductance * (self.E - potential)
81
+
82
+
83
+ class CUBA(SynOut):
84
+ r"""Current-based synaptic output.
85
+
86
+ Given the conductance, this model outputs the post-synaptic current with a identity function:
87
+
88
+ .. math::
89
+
90
+ I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
91
+
92
+ Parameters
93
+ ----------
94
+ scale: ArrayLike
95
+ The scaling factor for the conductance. Default 1. [mV]
96
+
97
+ See Also
98
+ --------
99
+ COBA
100
+ """
101
+ __module__ = 'brainstate.nn'
102
+
103
+ def __init__(self, scale: ArrayLike = u.volt):
104
+ super().__init__()
105
+ self.scale = scale
106
+
107
+ def update(self, conductance, potential=None):
108
+ return conductance * self.scale
109
+
110
+
111
+ class MgBlock(SynOut):
112
+ r"""Synaptic output based on Magnesium blocking.
113
+
114
+ Given the synaptic conductance, the model output the post-synaptic current with
115
+
116
+ .. math::
117
+
118
+ I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
119
+
120
+ where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
121
+
122
+ .. math::
123
+
124
+ g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
125
+
126
+ Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
127
+
128
+ Parameters
129
+ ----------
130
+ E: ArrayLike
131
+ The reversal potential for the synaptic current. [mV]
132
+ alpha: ArrayLike
133
+ Binding constant. Default 0.062
134
+ beta: ArrayLike
135
+ Unbinding constant. Default 3.57
136
+ cc_Mg: ArrayLike
137
+ Concentration of Magnesium ion. Default 1.2 [mM].
138
+ V_offset: ArrayLike
139
+ The offset potential. Default 0. [mV]
140
+ """
141
+ __module__ = 'brainstate.nn'
142
+
143
+ def __init__(
144
+ self,
145
+ E: ArrayLike = 0.,
146
+ cc_Mg: ArrayLike = 1.2,
147
+ alpha: ArrayLike = 0.062,
148
+ beta: ArrayLike = 3.57,
149
+ V_offset: ArrayLike = 0.,
150
+ ):
151
+ super().__init__()
152
+
153
+ self.E = E
154
+ self.V_offset = V_offset
155
+ self.cc_Mg = cc_Mg
156
+ self.alpha = alpha
157
+ self.beta = beta
158
+
159
+ def update(self, conductance, potential):
160
+ norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
161
+ return conductance * (self.E - potential) / norm
@@ -0,0 +1,58 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ import brainstate as bst
25
+
26
+
27
+ class TestSynOutModels(unittest.TestCase):
28
+ def setUp(self):
29
+ self.conductance = jnp.array([0.5, 1.0, 1.5])
30
+ self.potential = jnp.array([-70.0, -65.0, -60.0])
31
+ self.E = jnp.array([-70.0])
32
+ self.alpha = jnp.array([0.062])
33
+ self.beta = jnp.array([3.57])
34
+ self.cc_Mg = jnp.array([1.2])
35
+ self.V_offset = jnp.array([0.0])
36
+
37
+ def test_COBA(self):
38
+ model = bst.nn.COBA(E=self.E)
39
+ output = model.update(self.conductance, self.potential)
40
+ expected_output = self.conductance * (self.E - self.potential)
41
+ np.testing.assert_array_almost_equal(output, expected_output)
42
+
43
+ def test_CUBA(self):
44
+ model = bst.nn.CUBA()
45
+ output = model.update(self.conductance)
46
+ expected_output = self.conductance * model.scale
47
+ self.assertTrue(u.math.allclose(output, expected_output))
48
+
49
+ def test_MgBlock(self):
50
+ model = bst.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
51
+ output = model.update(self.conductance, self.potential)
52
+ norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
53
+ expected_output = self.conductance * (self.E - self.potential) / norm
54
+ np.testing.assert_array_almost_equal(output, expected_output)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ unittest.main()
@@ -0,0 +1,22 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ._dropout import *
17
+ from ._dropout import __all__ as dropout_all
18
+ from ._elementwise import *
19
+ from ._elementwise import __all__ as elementwise_all
20
+
21
+ __all__ = dropout_all + elementwise_all
22
+ del dropout_all, elementwise_all