brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/util.py DELETED
@@ -1,746 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import copy
17
- import functools
18
- import gc
19
- import types
20
- import warnings
21
- from collections.abc import Iterable
22
- from typing import Any, Callable, Tuple, Union, Sequence
23
-
24
- import jax
25
- from jax.lib import xla_bridge
26
-
27
- from ._utils import set_module_as
28
-
29
- __all__ = [
30
- 'unique_name',
31
- 'clear_buffer_memory',
32
- 'not_instance_eval',
33
- 'is_instance_eval',
34
- 'DictManager',
35
- 'MemScaling',
36
- 'IdMemScaling',
37
- 'DotDict',
38
- ]
39
-
40
- _name2id = dict()
41
- _typed_names = {}
42
-
43
-
44
- @set_module_as('brainstate.util')
45
- def check_name_uniqueness(name, obj):
46
- """Check the uniqueness of the name for the object type."""
47
- if not name.isidentifier():
48
- raise ValueError(
49
- f'"{name}" isn\'t a valid identifier '
50
- f'according to Python language definition. '
51
- f'Please choose another name.'
52
- )
53
- if name in _name2id:
54
- if _name2id[name] != id(obj):
55
- raise ValueError(
56
- f'In BrainPy, each object should have a unique name. '
57
- f'However, we detect that {obj} has a used name "{name}". \n'
58
- f'If you try to run multiple trials, you may need \n\n'
59
- f'>>> brainpy.brainpy_object.clear_name_cache() \n\n'
60
- f'to clear all cached names. '
61
- )
62
- else:
63
- _name2id[name] = id(obj)
64
-
65
-
66
- def get_unique_name(type_: str):
67
- """Get the unique name for the given object type."""
68
- if type_ not in _typed_names:
69
- _typed_names[type_] = 0
70
- name = f'{type_}{_typed_names[type_]}'
71
- _typed_names[type_] += 1
72
- return name
73
-
74
-
75
- @set_module_as('brainstate.util')
76
- def unique_name(name=None, self=None):
77
- """Get the unique name for this object.
78
-
79
- Parameters
80
- ----------
81
- name : str, optional
82
- The expected name. If None, the default unique name will be returned.
83
- Otherwise, the provided name will be checked to guarantee its uniqueness.
84
- self : str, optional
85
- The name of this class, used for object naming.
86
-
87
- Returns
88
- -------
89
- name : str
90
- The unique name for this object.
91
- """
92
- if name is None:
93
- assert self is not None, 'If name is None, self should be provided.'
94
- return get_unique_name(type_=self.__class__.__name__)
95
- else:
96
- check_name_uniqueness(name=name, obj=self)
97
- return name
98
-
99
-
100
- @set_module_as('brainstate.util')
101
- def clear_name_cache(ignore_warn: bool = True):
102
- """Clear the cached names."""
103
- _name2id.clear()
104
- _typed_names.clear()
105
- if not ignore_warn:
106
- warnings.warn(f'All named models and their ids are cleared.', UserWarning)
107
-
108
-
109
- @jax.tree_util.register_pytree_node_class
110
- class DictManager(dict):
111
- """
112
- DictManager, for collecting all pytree used in the program.
113
-
114
- :py:class:`~.DictManager` supports all features of python dict.
115
- """
116
- __module__ = 'brainstate.util'
117
-
118
- def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
119
- """
120
- Get a new stack with the subset of keys.
121
- """
122
- gather = type(self)()
123
- if isinstance(sep, types.FunctionType):
124
- for k, v in self.items():
125
- if sep(v):
126
- gather[k] = v
127
- return gather
128
- else:
129
- for k, v in self.items():
130
- if isinstance(v, sep):
131
- gather[k] = v
132
- return gather
133
-
134
- def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
135
- """
136
- Get a new stack with the subset of keys.
137
- """
138
- gather = type(self)()
139
- for k, v in self.items():
140
- if not isinstance(v, sep):
141
- gather[k] = v
142
- return gather
143
-
144
- def add_unique_elem(self, key: Any, var: Any):
145
- """Add a new element."""
146
- self._check_elem(var)
147
- if key in self:
148
- if id(var) != id(self[key]):
149
- raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
150
- else:
151
- self[key] = var
152
-
153
- def unique(self) -> 'DictManager':
154
- """
155
- Get a new type of collections with unique values.
156
-
157
- If one value is assigned to two or more keys,
158
- then only one pair of (key, value) will be returned.
159
- """
160
- gather = type(self)()
161
- seen = set()
162
- for k, v in self.items():
163
- if id(v) not in seen:
164
- seen.add(id(v))
165
- gather[k] = v
166
- return gather
167
-
168
- def assign(self, *args) -> None:
169
- """
170
- Assign the value for each element according to the given ``data``.
171
- """
172
- for arg in args:
173
- assert isinstance(arg, dict), 'Must be an instance of dict.'
174
- for k, v in arg.items():
175
- self[k] = v
176
-
177
- def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
178
- """
179
- Split the stack into subsets of stack by the given types.
180
- """
181
- filters = (first, *others)
182
- results = tuple(type(self)() for _ in range(len(filters) + 1))
183
- for k, v in self.items():
184
- for i, filt in enumerate(filters):
185
- if isinstance(v, filt):
186
- results[i][k] = v
187
- break
188
- else:
189
- results[-1][k] = v
190
- return results
191
-
192
- def pop_by_keys(self, keys: Iterable):
193
- """
194
- Pop the elements by the keys.
195
- """
196
- for k in tuple(self.keys()):
197
- if k in keys:
198
- self.pop(k)
199
-
200
- def pop_by_values(self, values: Iterable, by: str = 'id'):
201
- """
202
- Pop the elements by the values.
203
-
204
- Args:
205
- values: The value ids.
206
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
207
- """
208
- if by == 'id':
209
- value_ids = {id(v) for v in values}
210
- for k in tuple(self.keys()):
211
- if id(self[k]) in value_ids:
212
- self.pop(k)
213
- elif by == 'value':
214
- for k in tuple(self.keys()):
215
- if self[k] in values:
216
- self.pop(k)
217
- else:
218
- raise ValueError(f'Unsupported method: {by}')
219
-
220
- def difference_by_keys(self, keys: Iterable):
221
- """
222
- Get the difference of the stack by the keys.
223
- """
224
- return type(self)({k: v for k, v in self.items() if k not in keys})
225
-
226
- def difference_by_values(self, values: Iterable, by: str = 'id'):
227
- """
228
- Get the difference of the stack by the values.
229
-
230
- Args:
231
- values: The value ids.
232
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
233
- """
234
- if by == 'id':
235
- value_ids = {id(v) for v in values}
236
- return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
237
- elif by == 'value':
238
- return type(self)({k: v for k, v in self.items() if v not in values})
239
- else:
240
- raise ValueError(f'Unsupported method: {by}')
241
-
242
- def intersection_by_keys(self, keys: Iterable):
243
- """
244
- Get the intersection of the stack by the keys.
245
- """
246
- return type(self)({k: v for k, v in self.items() if k in keys})
247
-
248
- def intersection_by_values(self, values: Iterable, by: str = 'id'):
249
- """
250
- Get the intersection of the stack by the values.
251
-
252
- Args:
253
- values: The value ids.
254
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
255
- """
256
- if by == 'id':
257
- value_ids = {id(v) for v in values}
258
- return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
259
- elif by == 'value':
260
- return type(self)({k: v for k, v in self.items() if v in values})
261
- else:
262
- raise ValueError(f'Unsupported method: {by}')
263
-
264
- def union_by_value_ids(self, other: dict):
265
- """
266
- Union the stack by the value ids.
267
-
268
- Args:
269
- other:
270
-
271
- Returns:
272
-
273
- """
274
-
275
- def __add__(self, other: dict):
276
- """
277
- Compose other instance of dict.
278
- """
279
- new_dict = type(self)(self)
280
- new_dict.update(other)
281
- return new_dict
282
-
283
- def tree_flatten(self):
284
- return tuple(self.values()), tuple(self.keys())
285
-
286
- @classmethod
287
- def tree_unflatten(cls, keys, values):
288
- return cls(jax.util.safe_zip(keys, values))
289
-
290
- def _check_elem(self, elem: Any):
291
- raise NotImplementedError
292
-
293
- def to_dict(self):
294
- """
295
- Convert the stack to a dict.
296
-
297
- Returns
298
- -------
299
- dict
300
- The dict object.
301
- """
302
- return dict(self)
303
-
304
- def __copy__(self):
305
- return type(self)(self)
306
-
307
-
308
- @set_module_as('brainstate.util')
309
- def clear_buffer_memory(
310
- platform: str = None,
311
- array: bool = True,
312
- compilation: bool = False,
313
- ):
314
- """Clear all on-device buffers.
315
-
316
- This function will be very useful when you call models in a Python loop,
317
- because it can clear all cached arrays, and clear device memory.
318
-
319
- .. warning::
320
-
321
- This operation may cause errors when you use a deleted buffer.
322
- Therefore, regenerate data always.
323
-
324
- Parameters
325
- ----------
326
- platform: str
327
- The device to clear its memory.
328
- array: bool
329
- Clear all buffer array. Default is True.
330
- compilation: bool
331
- Clear compilation cache. Default is False.
332
-
333
- """
334
- if array:
335
- for buf in xla_bridge.get_backend(platform).live_buffers():
336
- buf.delete()
337
- if compilation:
338
- jax.clear_caches()
339
- gc.collect()
340
-
341
-
342
- class MemScaling(object):
343
- """
344
- The scaling object for membrane potential.
345
-
346
- The scaling object is used to transform the membrane potential range to a
347
- standard range. The scaling object can be used to transform the membrane
348
- potential to a standard range, and transform the standard range to the
349
- membrane potential.
350
-
351
- """
352
- __module__ = 'brainstate.util'
353
-
354
- def __init__(self, scale, bias):
355
- self._scale = scale
356
- self._bias = bias
357
-
358
- @classmethod
359
- def transform(
360
- cls,
361
- oring_range: Sequence[Union[float, int]],
362
- target_range: Sequence[Union[float, int]] = (0., 1.)
363
- ) -> 'MemScaling':
364
- """Transform the membrane potential range to a ``Scaling`` instance.
365
-
366
- Args:
367
- oring_range: [V_min, V_max]
368
- target_range: [scaled_V_min, scaled_V_max]
369
-
370
- Returns:
371
- The instanced scaling object.
372
- """
373
- V_min, V_max = oring_range
374
- scaled_V_min, scaled_V_max = target_range
375
- scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)
376
- bias = scaled_V_min * scale - V_min
377
- return cls(scale=scale, bias=bias)
378
-
379
- def scale_offset(self, x, bias=None, scale=None):
380
- """
381
- Transform the membrane potential to the standard range.
382
-
383
- Parameters
384
- ----------
385
- x : array_like
386
- The membrane potential.
387
- bias : float, optional
388
- The bias of the scaling object. If None, the default bias will be used.
389
- scale : float, optional
390
- The scale of the scaling object. If None, the default scale will be used.
391
-
392
- Returns
393
- -------
394
- x : array_like
395
- The standard range of the membrane potential.
396
- """
397
- if bias is None:
398
- bias = self._bias
399
- if scale is None:
400
- scale = self._scale
401
- return (x + bias) / scale
402
-
403
- def scale(self, x, scale=None):
404
- """
405
- Transform the membrane potential to the standard range.
406
-
407
- Parameters
408
- ----------
409
- x : array_like
410
- The membrane potential.
411
- scale : float, optional
412
- The scale of the scaling object. If None, the default scale will be used.
413
-
414
- Returns
415
- -------
416
- x : array_like
417
- The standard range of the membrane potential.
418
- """
419
- if scale is None:
420
- scale = self._scale
421
- return x / scale
422
-
423
- def offset(self, x, bias=None):
424
- """
425
- Transform the membrane potential to the standard range.
426
-
427
- Parameters
428
- ----------
429
- x : array_like
430
- The membrane potential.
431
- bias : float, optional
432
- The bias of the scaling object. If None, the default bias will be used.
433
-
434
- Returns
435
- -------
436
- x : array_like
437
- The standard range of the membrane potential.
438
- """
439
- if bias is None:
440
- bias = self._bias
441
- return x + bias
442
-
443
- def rev_scale(self, x, scale=None):
444
- """
445
- Reversely transform the standard range to the original membrane potential.
446
-
447
- Parameters
448
- ----------
449
- x : array_like
450
- The standard range of the membrane potential.
451
- scale : float, optional
452
- The scale of the scaling object. If None, the default scale will be used.
453
-
454
- Returns
455
- -------
456
- x : array_like
457
- The original membrane potential.
458
- """
459
- if scale is None:
460
- scale = self._scale
461
- return x * scale
462
-
463
- def rev_offset(self, x, bias=None):
464
- """
465
- Reversely transform the standard range to the original membrane potential.
466
-
467
- Parameters
468
- ----------
469
- x : array_like
470
- The standard range of the membrane potential.
471
- bias : float, optional
472
- The bias of the scaling object. If None, the default bias will be used.
473
-
474
- Returns
475
- -------
476
- x : array_like
477
- The original membrane potential.
478
- """
479
- if bias is None:
480
- bias = self._bias
481
- return x - bias
482
-
483
- def rev_scale_offset(self, x, bias=None, scale=None):
484
- """
485
- Reversely transform the standard range to the original membrane potential.
486
-
487
- Parameters
488
- ----------
489
- x : array_like
490
- The standard range of the membrane potential.
491
- bias : float, optional
492
- The bias of the scaling object. If None, the default bias will be used.
493
- scale : float, optional
494
- The scale of the scaling object. If None, the default scale will be used.
495
-
496
- Returns
497
- -------
498
- x : array_like
499
- The original membrane potential.
500
- """
501
- if bias is None:
502
- bias = self._bias
503
- if scale is None:
504
- scale = self._scale
505
- return x * scale - bias
506
-
507
- def clone(self):
508
- """
509
- Clone the scaling object.
510
-
511
- Returns
512
- -------
513
- scaling : MemScaling
514
- The cloned scaling object.
515
- """
516
- return MemScaling(bias=self._bias, scale=self._scale)
517
-
518
-
519
- class IdMemScaling(MemScaling):
520
- """
521
- The identity scaling object.
522
-
523
- The identity scaling object is used to transform the membrane potential to
524
- the standard range, and reversely transform the standard range to the
525
- membrane potential.
526
-
527
- """
528
- __module__ = 'brainstate.util'
529
-
530
- def __init__(self):
531
- super().__init__(scale=1., bias=0.)
532
-
533
- def scale_offset(self, x, bias=None, scale=None):
534
- """
535
- Transform the membrane potential to the standard range.
536
- """
537
- return x
538
-
539
- def scale(self, x, scale=None):
540
- """
541
- Transform the membrane potential to the standard range.
542
- """
543
- return x
544
-
545
- def offset(self, x, bias=None):
546
- """
547
- Transform the membrane potential to the standard range.
548
- """
549
- return x
550
-
551
- def rev_scale(self, x, scale=None):
552
- """
553
- Reversely transform the standard range to the original membrane potential.
554
-
555
- """
556
- return x
557
-
558
- def rev_offset(self, x, bias=None):
559
- """
560
- Reversely transform the standard range to the original membrane potential.
561
-
562
-
563
- """
564
- return x
565
-
566
- def rev_scale_offset(self, x, bias=None, scale=None):
567
- """
568
- Reversely transform the standard range to the original membrane potential.
569
- """
570
- return x
571
-
572
- def clone(self):
573
- """
574
- Clone the scaling object.
575
- """
576
- return IdMemScaling()
577
-
578
-
579
- @jax.tree_util.register_pytree_node_class
580
- class DotDict(dict):
581
- """Python dictionaries with advanced dot notation access.
582
-
583
- For example:
584
-
585
- >>> d = DotDict({'a': 10, 'b': 20})
586
- >>> d.a
587
- 10
588
- >>> d['a']
589
- 10
590
- >>> d.c # this will raise a KeyError
591
- KeyError: 'c'
592
- >>> d.c = 30 # but you can assign a value to a non-existing item
593
- >>> d.c
594
- 30
595
- """
596
-
597
- __module__ = 'brainstate.util'
598
-
599
- def __init__(self, *args, **kwargs):
600
- object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
601
- object.__setattr__(self, '__key', kwargs.pop('__key', None))
602
- for arg in args:
603
- if not arg:
604
- continue
605
- elif isinstance(arg, dict):
606
- for key, val in arg.items():
607
- self[key] = self._hook(val)
608
- elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
609
- self[arg[0]] = self._hook(arg[1])
610
- else:
611
- for key, val in iter(arg):
612
- self[key] = self._hook(val)
613
-
614
- for key, val in kwargs.items():
615
- self[key] = self._hook(val)
616
-
617
- def __setattr__(self, name, value):
618
- if hasattr(self.__class__, name):
619
- raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
620
- else:
621
- self[name] = value
622
-
623
- def __setitem__(self, name, value):
624
- super(DotDict, self).__setitem__(name, value)
625
- try:
626
- p = object.__getattribute__(self, '__parent')
627
- key = object.__getattribute__(self, '__key')
628
- except AttributeError:
629
- p = None
630
- key = None
631
- if p is not None:
632
- p[key] = self
633
- object.__delattr__(self, '__parent')
634
- object.__delattr__(self, '__key')
635
-
636
- @classmethod
637
- def _hook(cls, item):
638
- if isinstance(item, dict):
639
- return cls(item)
640
- elif isinstance(item, (list, tuple)):
641
- return type(item)(cls._hook(elem) for elem in item)
642
- return item
643
-
644
- def __getattr__(self, item):
645
- return self.__getitem__(item)
646
-
647
- def __delattr__(self, name):
648
- del self[name]
649
-
650
- def copy(self):
651
- return copy.copy(self)
652
-
653
- def deepcopy(self):
654
- return copy.deepcopy(self)
655
-
656
- def __deepcopy__(self, memo):
657
- other = self.__class__()
658
- memo[id(self)] = other
659
- for key, value in self.items():
660
- other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
661
- return other
662
-
663
- def to_dict(self):
664
- base = {}
665
- for key, value in self.items():
666
- if isinstance(value, type(self)):
667
- base[key] = value.to_dict()
668
- elif isinstance(value, (list, tuple)):
669
- base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
670
- for item in value)
671
- else:
672
- base[key] = value
673
- return base
674
-
675
- def update(self, *args, **kwargs):
676
- other = {}
677
- if args:
678
- if len(args) > 1:
679
- raise TypeError()
680
- other.update(args[0])
681
- other.update(kwargs)
682
- for k, v in other.items():
683
- if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
684
- self[k] = v
685
- else:
686
- self[k].update(v)
687
-
688
- def __getnewargs__(self):
689
- return tuple(self.items())
690
-
691
- def __getstate__(self):
692
- return self
693
-
694
- def __setstate__(self, state):
695
- self.update(state)
696
-
697
- def setdefault(self, key, default=None):
698
- if key in self:
699
- return self[key]
700
- else:
701
- self[key] = default
702
- return default
703
-
704
- def tree_flatten(self):
705
- return tuple(self.values()), tuple(self.keys())
706
-
707
- @classmethod
708
- def tree_unflatten(cls, keys, values):
709
- return cls(jax.util.safe_zip(keys, values))
710
-
711
-
712
- def _is_not_instance(x, cls):
713
- return not isinstance(x, cls)
714
-
715
-
716
- def _is_instance(x, cls):
717
- return isinstance(x, cls)
718
-
719
-
720
- @set_module_as('brainstate.util')
721
- def not_instance_eval(*cls):
722
- """
723
- Create a partial function to evaluate if the input is not an instance of the given class.
724
-
725
- Args:
726
- *cls: The classes to check.
727
-
728
- Returns:
729
- The partial function.
730
-
731
- """
732
- return functools.partial(_is_not_instance, cls=cls)
733
-
734
-
735
- @set_module_as('brainstate.util')
736
- def is_instance_eval(*cls):
737
- """
738
- Create a partial function to evaluate if the input is an instance of the given class.
739
-
740
- Args:
741
- *cls: The classes to check.
742
-
743
- Returns:
744
- The partial function.
745
- """
746
- return functools.partial(_is_instance, cls=cls)