brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/util.py DELETED
@@ -1,746 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import copy
17
- import functools
18
- import gc
19
- import types
20
- import warnings
21
- from collections.abc import Iterable
22
- from typing import Any, Callable, Tuple, Union, Sequence
23
-
24
- import jax
25
- from jax.lib import xla_bridge
26
-
27
- from ._utils import set_module_as
28
-
29
- __all__ = [
30
- 'unique_name',
31
- 'clear_buffer_memory',
32
- 'not_instance_eval',
33
- 'is_instance_eval',
34
- 'DictManager',
35
- 'MemScaling',
36
- 'IdMemScaling',
37
- 'DotDict',
38
- ]
39
-
40
- _name2id = dict()
41
- _typed_names = {}
42
-
43
-
44
- @set_module_as('brainstate.util')
45
- def check_name_uniqueness(name, obj):
46
- """Check the uniqueness of the name for the object type."""
47
- if not name.isidentifier():
48
- raise ValueError(
49
- f'"{name}" isn\'t a valid identifier '
50
- f'according to Python language definition. '
51
- f'Please choose another name.'
52
- )
53
- if name in _name2id:
54
- if _name2id[name] != id(obj):
55
- raise ValueError(
56
- f'In BrainPy, each object should have a unique name. '
57
- f'However, we detect that {obj} has a used name "{name}". \n'
58
- f'If you try to run multiple trials, you may need \n\n'
59
- f'>>> brainpy.brainpy_object.clear_name_cache() \n\n'
60
- f'to clear all cached names. '
61
- )
62
- else:
63
- _name2id[name] = id(obj)
64
-
65
-
66
- def get_unique_name(type_: str):
67
- """Get the unique name for the given object type."""
68
- if type_ not in _typed_names:
69
- _typed_names[type_] = 0
70
- name = f'{type_}{_typed_names[type_]}'
71
- _typed_names[type_] += 1
72
- return name
73
-
74
-
75
- @set_module_as('brainstate.util')
76
- def unique_name(name=None, self=None):
77
- """Get the unique name for this object.
78
-
79
- Parameters
80
- ----------
81
- name : str, optional
82
- The expected name. If None, the default unique name will be returned.
83
- Otherwise, the provided name will be checked to guarantee its uniqueness.
84
- self : str, optional
85
- The name of this class, used for object naming.
86
-
87
- Returns
88
- -------
89
- name : str
90
- The unique name for this object.
91
- """
92
- if name is None:
93
- assert self is not None, 'If name is None, self should be provided.'
94
- return get_unique_name(type_=self.__class__.__name__)
95
- else:
96
- check_name_uniqueness(name=name, obj=self)
97
- return name
98
-
99
-
100
- @set_module_as('brainstate.util')
101
- def clear_name_cache(ignore_warn: bool = True):
102
- """Clear the cached names."""
103
- _name2id.clear()
104
- _typed_names.clear()
105
- if not ignore_warn:
106
- warnings.warn(f'All named models and their ids are cleared.', UserWarning)
107
-
108
-
109
- @jax.tree_util.register_pytree_node_class
110
- class DictManager(dict):
111
- """
112
- DictManager, for collecting all pytree used in the program.
113
-
114
- :py:class:`~.DictManager` supports all features of python dict.
115
- """
116
- __module__ = 'brainstate.util'
117
-
118
- def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
119
- """
120
- Get a new stack with the subset of keys.
121
- """
122
- gather = type(self)()
123
- if isinstance(sep, types.FunctionType):
124
- for k, v in self.items():
125
- if sep(v):
126
- gather[k] = v
127
- return gather
128
- else:
129
- for k, v in self.items():
130
- if isinstance(v, sep):
131
- gather[k] = v
132
- return gather
133
-
134
- def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
135
- """
136
- Get a new stack with the subset of keys.
137
- """
138
- gather = type(self)()
139
- for k, v in self.items():
140
- if not isinstance(v, sep):
141
- gather[k] = v
142
- return gather
143
-
144
- def add_unique_elem(self, key: Any, var: Any):
145
- """Add a new element."""
146
- self._check_elem(var)
147
- if key in self:
148
- if id(var) != id(self[key]):
149
- raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
150
- else:
151
- self[key] = var
152
-
153
- def unique(self) -> 'DictManager':
154
- """
155
- Get a new type of collections with unique values.
156
-
157
- If one value is assigned to two or more keys,
158
- then only one pair of (key, value) will be returned.
159
- """
160
- gather = type(self)()
161
- seen = set()
162
- for k, v in self.items():
163
- if id(v) not in seen:
164
- seen.add(id(v))
165
- gather[k] = v
166
- return gather
167
-
168
- def assign(self, *args) -> None:
169
- """
170
- Assign the value for each element according to the given ``data``.
171
- """
172
- for arg in args:
173
- assert isinstance(arg, dict), 'Must be an instance of dict.'
174
- for k, v in arg.items():
175
- self[k] = v
176
-
177
- def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
178
- """
179
- Split the stack into subsets of stack by the given types.
180
- """
181
- filters = (first, *others)
182
- results = tuple(type(self)() for _ in range(len(filters) + 1))
183
- for k, v in self.items():
184
- for i, filt in enumerate(filters):
185
- if isinstance(v, filt):
186
- results[i][k] = v
187
- break
188
- else:
189
- results[-1][k] = v
190
- return results
191
-
192
- def pop_by_keys(self, keys: Iterable):
193
- """
194
- Pop the elements by the keys.
195
- """
196
- for k in tuple(self.keys()):
197
- if k in keys:
198
- self.pop(k)
199
-
200
- def pop_by_values(self, values: Iterable, by: str = 'id'):
201
- """
202
- Pop the elements by the values.
203
-
204
- Args:
205
- values: The value ids.
206
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
207
- """
208
- if by == 'id':
209
- value_ids = {id(v) for v in values}
210
- for k in tuple(self.keys()):
211
- if id(self[k]) in value_ids:
212
- self.pop(k)
213
- elif by == 'value':
214
- for k in tuple(self.keys()):
215
- if self[k] in values:
216
- self.pop(k)
217
- else:
218
- raise ValueError(f'Unsupported method: {by}')
219
-
220
- def difference_by_keys(self, keys: Iterable):
221
- """
222
- Get the difference of the stack by the keys.
223
- """
224
- return type(self)({k: v for k, v in self.items() if k not in keys})
225
-
226
- def difference_by_values(self, values: Iterable, by: str = 'id'):
227
- """
228
- Get the difference of the stack by the values.
229
-
230
- Args:
231
- values: The value ids.
232
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
233
- """
234
- if by == 'id':
235
- value_ids = {id(v) for v in values}
236
- return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
237
- elif by == 'value':
238
- return type(self)({k: v for k, v in self.items() if v not in values})
239
- else:
240
- raise ValueError(f'Unsupported method: {by}')
241
-
242
- def intersection_by_keys(self, keys: Iterable):
243
- """
244
- Get the intersection of the stack by the keys.
245
- """
246
- return type(self)({k: v for k, v in self.items() if k in keys})
247
-
248
- def intersection_by_values(self, values: Iterable, by: str = 'id'):
249
- """
250
- Get the intersection of the stack by the values.
251
-
252
- Args:
253
- values: The value ids.
254
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
255
- """
256
- if by == 'id':
257
- value_ids = {id(v) for v in values}
258
- return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
259
- elif by == 'value':
260
- return type(self)({k: v for k, v in self.items() if v in values})
261
- else:
262
- raise ValueError(f'Unsupported method: {by}')
263
-
264
- def union_by_value_ids(self, other: dict):
265
- """
266
- Union the stack by the value ids.
267
-
268
- Args:
269
- other:
270
-
271
- Returns:
272
-
273
- """
274
-
275
- def __add__(self, other: dict):
276
- """
277
- Compose other instance of dict.
278
- """
279
- new_dict = type(self)(self)
280
- new_dict.update(other)
281
- return new_dict
282
-
283
- def tree_flatten(self):
284
- return tuple(self.values()), tuple(self.keys())
285
-
286
- @classmethod
287
- def tree_unflatten(cls, keys, values):
288
- return cls(jax.util.safe_zip(keys, values))
289
-
290
- def _check_elem(self, elem: Any):
291
- raise NotImplementedError
292
-
293
- def to_dict(self):
294
- """
295
- Convert the stack to a dict.
296
-
297
- Returns
298
- -------
299
- dict
300
- The dict object.
301
- """
302
- return dict(self)
303
-
304
- def __copy__(self):
305
- return type(self)(self)
306
-
307
-
308
- @set_module_as('brainstate.util')
309
- def clear_buffer_memory(
310
- platform: str = None,
311
- array: bool = True,
312
- compilation: bool = False,
313
- ):
314
- """Clear all on-device buffers.
315
-
316
- This function will be very useful when you call models in a Python loop,
317
- because it can clear all cached arrays, and clear device memory.
318
-
319
- .. warning::
320
-
321
- This operation may cause errors when you use a deleted buffer.
322
- Therefore, regenerate data always.
323
-
324
- Parameters
325
- ----------
326
- platform: str
327
- The device to clear its memory.
328
- array: bool
329
- Clear all buffer array. Default is True.
330
- compilation: bool
331
- Clear compilation cache. Default is False.
332
-
333
- """
334
- if array:
335
- for buf in xla_bridge.get_backend(platform).live_buffers():
336
- buf.delete()
337
- if compilation:
338
- jax.clear_caches()
339
- gc.collect()
340
-
341
-
342
- class MemScaling(object):
343
- """
344
- The scaling object for membrane potential.
345
-
346
- The scaling object is used to transform the membrane potential range to a
347
- standard range. The scaling object can be used to transform the membrane
348
- potential to a standard range, and transform the standard range to the
349
- membrane potential.
350
-
351
- """
352
- __module__ = 'brainstate.util'
353
-
354
- def __init__(self, scale, bias):
355
- self._scale = scale
356
- self._bias = bias
357
-
358
- @classmethod
359
- def transform(
360
- cls,
361
- oring_range: Sequence[Union[float, int]],
362
- target_range: Sequence[Union[float, int]] = (0., 1.)
363
- ) -> 'MemScaling':
364
- """Transform the membrane potential range to a ``Scaling`` instance.
365
-
366
- Args:
367
- oring_range: [V_min, V_max]
368
- target_range: [scaled_V_min, scaled_V_max]
369
-
370
- Returns:
371
- The instanced scaling object.
372
- """
373
- V_min, V_max = oring_range
374
- scaled_V_min, scaled_V_max = target_range
375
- scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)
376
- bias = scaled_V_min * scale - V_min
377
- return cls(scale=scale, bias=bias)
378
-
379
- def scale_offset(self, x, bias=None, scale=None):
380
- """
381
- Transform the membrane potential to the standard range.
382
-
383
- Parameters
384
- ----------
385
- x : array_like
386
- The membrane potential.
387
- bias : float, optional
388
- The bias of the scaling object. If None, the default bias will be used.
389
- scale : float, optional
390
- The scale of the scaling object. If None, the default scale will be used.
391
-
392
- Returns
393
- -------
394
- x : array_like
395
- The standard range of the membrane potential.
396
- """
397
- if bias is None:
398
- bias = self._bias
399
- if scale is None:
400
- scale = self._scale
401
- return (x + bias) / scale
402
-
403
- def scale(self, x, scale=None):
404
- """
405
- Transform the membrane potential to the standard range.
406
-
407
- Parameters
408
- ----------
409
- x : array_like
410
- The membrane potential.
411
- scale : float, optional
412
- The scale of the scaling object. If None, the default scale will be used.
413
-
414
- Returns
415
- -------
416
- x : array_like
417
- The standard range of the membrane potential.
418
- """
419
- if scale is None:
420
- scale = self._scale
421
- return x / scale
422
-
423
- def offset(self, x, bias=None):
424
- """
425
- Transform the membrane potential to the standard range.
426
-
427
- Parameters
428
- ----------
429
- x : array_like
430
- The membrane potential.
431
- bias : float, optional
432
- The bias of the scaling object. If None, the default bias will be used.
433
-
434
- Returns
435
- -------
436
- x : array_like
437
- The standard range of the membrane potential.
438
- """
439
- if bias is None:
440
- bias = self._bias
441
- return x + bias
442
-
443
- def rev_scale(self, x, scale=None):
444
- """
445
- Reversely transform the standard range to the original membrane potential.
446
-
447
- Parameters
448
- ----------
449
- x : array_like
450
- The standard range of the membrane potential.
451
- scale : float, optional
452
- The scale of the scaling object. If None, the default scale will be used.
453
-
454
- Returns
455
- -------
456
- x : array_like
457
- The original membrane potential.
458
- """
459
- if scale is None:
460
- scale = self._scale
461
- return x * scale
462
-
463
- def rev_offset(self, x, bias=None):
464
- """
465
- Reversely transform the standard range to the original membrane potential.
466
-
467
- Parameters
468
- ----------
469
- x : array_like
470
- The standard range of the membrane potential.
471
- bias : float, optional
472
- The bias of the scaling object. If None, the default bias will be used.
473
-
474
- Returns
475
- -------
476
- x : array_like
477
- The original membrane potential.
478
- """
479
- if bias is None:
480
- bias = self._bias
481
- return x - bias
482
-
483
- def rev_scale_offset(self, x, bias=None, scale=None):
484
- """
485
- Reversely transform the standard range to the original membrane potential.
486
-
487
- Parameters
488
- ----------
489
- x : array_like
490
- The standard range of the membrane potential.
491
- bias : float, optional
492
- The bias of the scaling object. If None, the default bias will be used.
493
- scale : float, optional
494
- The scale of the scaling object. If None, the default scale will be used.
495
-
496
- Returns
497
- -------
498
- x : array_like
499
- The original membrane potential.
500
- """
501
- if bias is None:
502
- bias = self._bias
503
- if scale is None:
504
- scale = self._scale
505
- return x * scale - bias
506
-
507
- def clone(self):
508
- """
509
- Clone the scaling object.
510
-
511
- Returns
512
- -------
513
- scaling : MemScaling
514
- The cloned scaling object.
515
- """
516
- return MemScaling(bias=self._bias, scale=self._scale)
517
-
518
-
519
- class IdMemScaling(MemScaling):
520
- """
521
- The identity scaling object.
522
-
523
- The identity scaling object is used to transform the membrane potential to
524
- the standard range, and reversely transform the standard range to the
525
- membrane potential.
526
-
527
- """
528
- __module__ = 'brainstate.util'
529
-
530
- def __init__(self):
531
- super().__init__(scale=1., bias=0.)
532
-
533
- def scale_offset(self, x, bias=None, scale=None):
534
- """
535
- Transform the membrane potential to the standard range.
536
- """
537
- return x
538
-
539
- def scale(self, x, scale=None):
540
- """
541
- Transform the membrane potential to the standard range.
542
- """
543
- return x
544
-
545
- def offset(self, x, bias=None):
546
- """
547
- Transform the membrane potential to the standard range.
548
- """
549
- return x
550
-
551
- def rev_scale(self, x, scale=None):
552
- """
553
- Reversely transform the standard range to the original membrane potential.
554
-
555
- """
556
- return x
557
-
558
- def rev_offset(self, x, bias=None):
559
- """
560
- Reversely transform the standard range to the original membrane potential.
561
-
562
-
563
- """
564
- return x
565
-
566
- def rev_scale_offset(self, x, bias=None, scale=None):
567
- """
568
- Reversely transform the standard range to the original membrane potential.
569
- """
570
- return x
571
-
572
- def clone(self):
573
- """
574
- Clone the scaling object.
575
- """
576
- return IdMemScaling()
577
-
578
-
579
- @jax.tree_util.register_pytree_node_class
580
- class DotDict(dict):
581
- """Python dictionaries with advanced dot notation access.
582
-
583
- For example:
584
-
585
- >>> d = DotDict({'a': 10, 'b': 20})
586
- >>> d.a
587
- 10
588
- >>> d['a']
589
- 10
590
- >>> d.c # this will raise a KeyError
591
- KeyError: 'c'
592
- >>> d.c = 30 # but you can assign a value to a non-existing item
593
- >>> d.c
594
- 30
595
- """
596
-
597
- __module__ = 'brainstate.util'
598
-
599
- def __init__(self, *args, **kwargs):
600
- object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
601
- object.__setattr__(self, '__key', kwargs.pop('__key', None))
602
- for arg in args:
603
- if not arg:
604
- continue
605
- elif isinstance(arg, dict):
606
- for key, val in arg.items():
607
- self[key] = self._hook(val)
608
- elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
609
- self[arg[0]] = self._hook(arg[1])
610
- else:
611
- for key, val in iter(arg):
612
- self[key] = self._hook(val)
613
-
614
- for key, val in kwargs.items():
615
- self[key] = self._hook(val)
616
-
617
- def __setattr__(self, name, value):
618
- if hasattr(self.__class__, name):
619
- raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
620
- else:
621
- self[name] = value
622
-
623
- def __setitem__(self, name, value):
624
- super(DotDict, self).__setitem__(name, value)
625
- try:
626
- p = object.__getattribute__(self, '__parent')
627
- key = object.__getattribute__(self, '__key')
628
- except AttributeError:
629
- p = None
630
- key = None
631
- if p is not None:
632
- p[key] = self
633
- object.__delattr__(self, '__parent')
634
- object.__delattr__(self, '__key')
635
-
636
- @classmethod
637
- def _hook(cls, item):
638
- if isinstance(item, dict):
639
- return cls(item)
640
- elif isinstance(item, (list, tuple)):
641
- return type(item)(cls._hook(elem) for elem in item)
642
- return item
643
-
644
- def __getattr__(self, item):
645
- return self.__getitem__(item)
646
-
647
- def __delattr__(self, name):
648
- del self[name]
649
-
650
- def copy(self):
651
- return copy.copy(self)
652
-
653
- def deepcopy(self):
654
- return copy.deepcopy(self)
655
-
656
- def __deepcopy__(self, memo):
657
- other = self.__class__()
658
- memo[id(self)] = other
659
- for key, value in self.items():
660
- other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
661
- return other
662
-
663
- def to_dict(self):
664
- base = {}
665
- for key, value in self.items():
666
- if isinstance(value, type(self)):
667
- base[key] = value.to_dict()
668
- elif isinstance(value, (list, tuple)):
669
- base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
670
- for item in value)
671
- else:
672
- base[key] = value
673
- return base
674
-
675
- def update(self, *args, **kwargs):
676
- other = {}
677
- if args:
678
- if len(args) > 1:
679
- raise TypeError()
680
- other.update(args[0])
681
- other.update(kwargs)
682
- for k, v in other.items():
683
- if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
684
- self[k] = v
685
- else:
686
- self[k].update(v)
687
-
688
- def __getnewargs__(self):
689
- return tuple(self.items())
690
-
691
- def __getstate__(self):
692
- return self
693
-
694
- def __setstate__(self, state):
695
- self.update(state)
696
-
697
- def setdefault(self, key, default=None):
698
- if key in self:
699
- return self[key]
700
- else:
701
- self[key] = default
702
- return default
703
-
704
- def tree_flatten(self):
705
- return tuple(self.values()), tuple(self.keys())
706
-
707
- @classmethod
708
- def tree_unflatten(cls, keys, values):
709
- return cls(jax.util.safe_zip(keys, values))
710
-
711
-
712
- def _is_not_instance(x, cls):
713
- return not isinstance(x, cls)
714
-
715
-
716
- def _is_instance(x, cls):
717
- return isinstance(x, cls)
718
-
719
-
720
- @set_module_as('brainstate.util')
721
- def not_instance_eval(*cls):
722
- """
723
- Create a partial function to evaluate if the input is not an instance of the given class.
724
-
725
- Args:
726
- *cls: The classes to check.
727
-
728
- Returns:
729
- The partial function.
730
-
731
- """
732
- return functools.partial(_is_not_instance, cls=cls)
733
-
734
-
735
- @set_module_as('brainstate.util')
736
- def is_instance_eval(*cls):
737
- """
738
- Create a partial function to evaluate if the input is an instance of the given class.
739
-
740
- Args:
741
- *cls: The classes to check.
742
-
743
- Returns:
744
- The partial function.
745
- """
746
- return functools.partial(_is_instance, cls=cls)