brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,497 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import copy
19
+ import functools
20
+ import gc
21
+ import threading
22
+ import types
23
+ from collections.abc import Iterable
24
+ from typing import Any, Callable, Tuple, Union, Dict
25
+
26
+ import jax
27
+ from jax.lib import xla_bridge
28
+
29
+ from brainstate._utils import set_module_as
30
+
31
+ __all__ = [
32
+ 'clear_buffer_memory',
33
+ 'not_instance_eval',
34
+ 'is_instance_eval',
35
+ 'DictManager',
36
+ 'DotDict',
37
+ ]
38
+
39
+
40
+ class NameContext(threading.local):
41
+ def __init__(self):
42
+ self.typed_names: Dict[str, int] = {}
43
+
44
+
45
+ NAME = NameContext()
46
+
47
+
48
+ def get_unique_name(type_: str):
49
+ """Get the unique name for the given object type."""
50
+ if type_ not in NAME.typed_names:
51
+ NAME.typed_names[type_] = 0
52
+ name = f'{type_}{NAME.typed_names[type_]}'
53
+ NAME.typed_names[type_] += 1
54
+ return name
55
+
56
+
57
+ @jax.tree_util.register_pytree_node_class
58
+ class DictManager(dict):
59
+ """
60
+ DictManager, for collecting all pytree used in the program.
61
+
62
+ :py:class:`~.DictManager` supports all features of python dict.
63
+ """
64
+ __module__ = 'brainstate.util'
65
+ _val_id_to_key: dict
66
+
67
+ def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
68
+ """
69
+ Get a new stack with the subset of keys.
70
+ """
71
+ gather = type(self)()
72
+ if isinstance(sep, types.FunctionType):
73
+ for k, v in self.items():
74
+ if sep(v):
75
+ gather[k] = v
76
+ return gather
77
+ else:
78
+ for k, v in self.items():
79
+ if isinstance(v, sep):
80
+ gather[k] = v
81
+ return gather
82
+
83
+ def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
84
+ """
85
+ Get a new stack with the subset of keys.
86
+ """
87
+ gather = type(self)()
88
+ for k, v in self.items():
89
+ if not isinstance(v, sep):
90
+ gather[k] = v
91
+ return gather
92
+
93
+ def add_unique_key(self, key: Any, val: Any):
94
+ """
95
+ Add a new element and check if the value is same or not.
96
+ """
97
+ self._check_elem(val)
98
+ if key in self:
99
+ if id(val) != id(self[key]):
100
+ raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
101
+ else:
102
+ self[key] = val
103
+
104
+ def add_unique_value(self, key: Any, val: Any) -> bool:
105
+ """
106
+ Add a new element and check if the val is unique.
107
+
108
+ Parameters:
109
+ key: The key of the element.
110
+ val: The value of the element
111
+
112
+ Returns:
113
+ bool: True if the value is unique, False otherwise.
114
+ """
115
+ self._check_elem(val)
116
+ if not hasattr(self, '_val_id_to_key'):
117
+ self._val_id_to_key = {id(v): k for k, v in self.items()}
118
+ if id(val) not in self._val_id_to_key:
119
+ self._val_id_to_key[id(val)] = key
120
+ self[key] = val
121
+ return True
122
+ else:
123
+ return False
124
+
125
+ def unique(self) -> 'DictManager':
126
+ """
127
+ Get a new type of collections with unique values.
128
+
129
+ If one value is assigned to two or more keys,
130
+ then only one pair of (key, value) will be returned.
131
+ """
132
+ gather = type(self)()
133
+ seen = set()
134
+ for k, v in self.items():
135
+ if id(v) not in seen:
136
+ seen.add(id(v))
137
+ gather[k] = v
138
+ return gather
139
+
140
+ def unique_(self):
141
+ """
142
+ Get a new type of collections with unique values.
143
+
144
+ If one value is assigned to two or more keys,
145
+ then only one pair of (key, value) will be returned.
146
+ """
147
+ seen = set()
148
+ for k in tuple(self.keys()):
149
+ v = self[k]
150
+ if id(v) not in seen:
151
+ seen.add(id(v))
152
+ else:
153
+ self.pop(k)
154
+ return self
155
+
156
+ def assign(self, *args) -> None:
157
+ """
158
+ Assign the value for each element according to the given ``data``.
159
+ """
160
+ for arg in args:
161
+ assert isinstance(arg, dict), 'Must be an instance of dict.'
162
+ for k, v in arg.items():
163
+ self[k] = v
164
+
165
+ def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
166
+ """
167
+ Split the stack into subsets of stack by the given types.
168
+ """
169
+ filters = (first, *others)
170
+ results = tuple(type(self)() for _ in range(len(filters) + 1))
171
+ for k, v in self.items():
172
+ for i, filt in enumerate(filters):
173
+ if isinstance(v, filt):
174
+ results[i][k] = v
175
+ break
176
+ else:
177
+ results[-1][k] = v
178
+ return results
179
+
180
+ def pop_by_keys(self, keys: Iterable):
181
+ """
182
+ Pop the elements by the keys.
183
+ """
184
+ for k in tuple(self.keys()):
185
+ if k in keys:
186
+ self.pop(k)
187
+
188
+ def pop_by_values(self, values: Iterable, by: str = 'id'):
189
+ """
190
+ Pop the elements by the values.
191
+
192
+ Args:
193
+ values: The value ids.
194
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
195
+ """
196
+ if by == 'id':
197
+ value_ids = {id(v) for v in values}
198
+ for k in tuple(self.keys()):
199
+ if id(self[k]) in value_ids:
200
+ self.pop(k)
201
+ elif by == 'value':
202
+ for k in tuple(self.keys()):
203
+ if self[k] in values:
204
+ self.pop(k)
205
+ else:
206
+ raise ValueError(f'Unsupported method: {by}')
207
+
208
+ def difference_by_keys(self, keys: Iterable):
209
+ """
210
+ Get the difference of the stack by the keys.
211
+ """
212
+ return type(self)({k: v for k, v in self.items() if k not in keys})
213
+
214
+ def difference_by_values(self, values: Iterable, by: str = 'id'):
215
+ """
216
+ Get the difference of the stack by the values.
217
+
218
+ Args:
219
+ values: The value ids.
220
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
221
+ """
222
+ if by == 'id':
223
+ value_ids = {id(v) for v in values}
224
+ return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
225
+ elif by == 'value':
226
+ return type(self)({k: v for k, v in self.items() if v not in values})
227
+ else:
228
+ raise ValueError(f'Unsupported method: {by}')
229
+
230
+ def intersection_by_keys(self, keys: Iterable):
231
+ """
232
+ Get the intersection of the stack by the keys.
233
+ """
234
+ return type(self)({k: v for k, v in self.items() if k in keys})
235
+
236
+ def intersection_by_values(self, values: Iterable, by: str = 'id'):
237
+ """
238
+ Get the intersection of the stack by the values.
239
+
240
+ Args:
241
+ values: The value ids.
242
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
243
+ """
244
+ if by == 'id':
245
+ value_ids = {id(v) for v in values}
246
+ return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
247
+ elif by == 'value':
248
+ return type(self)({k: v for k, v in self.items() if v in values})
249
+ else:
250
+ raise ValueError(f'Unsupported method: {by}')
251
+
252
+ def union_by_value_ids(self, other: dict):
253
+ """
254
+ Union the stack by the value ids.
255
+
256
+ Args:
257
+ other:
258
+
259
+ Returns:
260
+
261
+ """
262
+
263
+ def __add__(self, other: dict):
264
+ """
265
+ Compose other instance of dict.
266
+ """
267
+ new_dict = type(self)(self)
268
+ new_dict.update(other)
269
+ return new_dict
270
+
271
+ def tree_flatten(self):
272
+ return tuple(self.values()), tuple(self.keys())
273
+
274
+ @classmethod
275
+ def tree_unflatten(cls, keys, values):
276
+ return cls(jax.util.safe_zip(keys, values))
277
+
278
+ def _check_elem(self, elem: Any):
279
+ raise NotImplementedError
280
+
281
+ def to_dict(self):
282
+ """
283
+ Convert the stack to a dict.
284
+
285
+ Returns
286
+ -------
287
+ dict
288
+ The dict object.
289
+ """
290
+ return dict(self)
291
+
292
+ def __copy__(self):
293
+ return type(self)(self)
294
+
295
+
296
+ @set_module_as('brainstate.util')
297
+ def clear_buffer_memory(
298
+ platform: str = None,
299
+ array: bool = True,
300
+ compilation: bool = False,
301
+ ):
302
+ """Clear all on-device buffers.
303
+
304
+ This function will be very useful when you call models in a Python loop,
305
+ because it can clear all cached arrays, and clear device memory.
306
+
307
+ .. warning::
308
+
309
+ This operation may cause errors when you use a deleted buffer.
310
+ Therefore, regenerate data always.
311
+
312
+ Parameters
313
+ ----------
314
+ platform: str
315
+ The device to clear its memory.
316
+ array: bool
317
+ Clear all buffer array. Default is True.
318
+ compilation: bool
319
+ Clear compilation cache. Default is False.
320
+
321
+ """
322
+ if array:
323
+ for buf in xla_bridge.get_backend(platform).live_buffers():
324
+ buf.delete()
325
+ if compilation:
326
+ jax.clear_caches()
327
+ gc.collect()
328
+
329
+
330
+ @jax.tree_util.register_pytree_node_class
331
+ class DotDict(dict):
332
+ """Python dictionaries with advanced dot notation access.
333
+
334
+ For example:
335
+
336
+ >>> d = DotDict({'a': 10, 'b': 20})
337
+ >>> d.a
338
+ 10
339
+ >>> d['a']
340
+ 10
341
+ >>> d.c # this will raise a KeyError
342
+ KeyError: 'c'
343
+ >>> d.c = 30 # but you can assign a value to a non-existing item
344
+ >>> d.c
345
+ 30
346
+ """
347
+
348
+ __module__ = 'brainstate.util'
349
+
350
+ def __init__(self, *args, **kwargs):
351
+ object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
352
+ object.__setattr__(self, '__key', kwargs.pop('__key', None))
353
+ for arg in args:
354
+ if not arg:
355
+ continue
356
+ elif isinstance(arg, dict):
357
+ for key, val in arg.items():
358
+ self[key] = self._hook(val)
359
+ elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
360
+ self[arg[0]] = self._hook(arg[1])
361
+ else:
362
+ for key, val in iter(arg):
363
+ self[key] = self._hook(val)
364
+
365
+ for key, val in kwargs.items():
366
+ self[key] = self._hook(val)
367
+
368
+ def __setattr__(self, name, value):
369
+ if hasattr(self.__class__, name):
370
+ raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
371
+ else:
372
+ self[name] = value
373
+
374
+ def __setitem__(self, name, value):
375
+ super(DotDict, self).__setitem__(name, value)
376
+ try:
377
+ p = object.__getattribute__(self, '__parent')
378
+ key = object.__getattribute__(self, '__key')
379
+ except AttributeError:
380
+ p = None
381
+ key = None
382
+ if p is not None:
383
+ p[key] = self
384
+ object.__delattr__(self, '__parent')
385
+ object.__delattr__(self, '__key')
386
+
387
+ @classmethod
388
+ def _hook(cls, item):
389
+ if isinstance(item, dict):
390
+ return cls(item)
391
+ elif isinstance(item, (list, tuple)):
392
+ return type(item)(cls._hook(elem) for elem in item)
393
+ return item
394
+
395
+ def __getattr__(self, item):
396
+ return self.__getitem__(item)
397
+
398
+ def __delattr__(self, name):
399
+ del self[name]
400
+
401
+ def copy(self):
402
+ return copy.copy(self)
403
+
404
+ def deepcopy(self):
405
+ return copy.deepcopy(self)
406
+
407
+ def __deepcopy__(self, memo):
408
+ other = self.__class__()
409
+ memo[id(self)] = other
410
+ for key, value in self.items():
411
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
412
+ return other
413
+
414
+ def to_dict(self):
415
+ base = {}
416
+ for key, value in self.items():
417
+ if isinstance(value, type(self)):
418
+ base[key] = value.to_dict()
419
+ elif isinstance(value, (list, tuple)):
420
+ base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
421
+ for item in value)
422
+ else:
423
+ base[key] = value
424
+ return base
425
+
426
+ def update(self, *args, **kwargs):
427
+ other = {}
428
+ if args:
429
+ if len(args) > 1:
430
+ raise TypeError()
431
+ other.update(args[0])
432
+ other.update(kwargs)
433
+ for k, v in other.items():
434
+ if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
435
+ self[k] = v
436
+ else:
437
+ self[k].update(v)
438
+
439
+ def __getnewargs__(self):
440
+ return tuple(self.items())
441
+
442
+ def __getstate__(self):
443
+ return self
444
+
445
+ def __setstate__(self, state):
446
+ self.update(state)
447
+
448
+ def setdefault(self, key, default=None):
449
+ if key in self:
450
+ return self[key]
451
+ else:
452
+ self[key] = default
453
+ return default
454
+
455
+ def tree_flatten(self):
456
+ return tuple(self.values()), tuple(self.keys())
457
+
458
+ @classmethod
459
+ def tree_unflatten(cls, keys, values):
460
+ return cls(jax.util.safe_zip(keys, values))
461
+
462
+
463
+ def _is_not_instance(x, cls):
464
+ return not isinstance(x, cls)
465
+
466
+
467
+ def _is_instance(x, cls):
468
+ return isinstance(x, cls)
469
+
470
+
471
+ @set_module_as('brainstate.util')
472
+ def not_instance_eval(*cls):
473
+ """
474
+ Create a partial function to evaluate if the input is not an instance of the given class.
475
+
476
+ Args:
477
+ *cls: The classes to check.
478
+
479
+ Returns:
480
+ The partial function.
481
+
482
+ """
483
+ return functools.partial(_is_not_instance, cls=cls)
484
+
485
+
486
+ @set_module_as('brainstate.util')
487
+ def is_instance_eval(*cls):
488
+ """
489
+ Create a partial function to evaluate if the input is an instance of the given class.
490
+
491
+ Args:
492
+ *cls: The classes to check.
493
+
494
+ Returns:
495
+ The partial function.
496
+ """
497
+ return functools.partial(_is_instance, cls=cls)
@@ -0,0 +1,208 @@
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from __future__ import annotations
19
+
20
+ import dataclasses
21
+ import threading
22
+ from abc import ABC, abstractmethod
23
+ from functools import partial
24
+ from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
25
+
26
+ __all__ = [
27
+ 'PrettyType',
28
+ 'PrettyAttr',
29
+ 'PrettyRepr',
30
+ 'PrettyMapping',
31
+ 'MappingReprMixin',
32
+ ]
33
+
34
+ A = TypeVar('A')
35
+ B = TypeVar('B')
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class PrettyType:
40
+ """
41
+ Configuration for pretty representation of objects.
42
+ """
43
+ type: Union[str, type]
44
+ start: str = '('
45
+ end: str = ')'
46
+ value_sep: str = '='
47
+ elem_indent: str = ' '
48
+ empty_repr: str = ''
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class PrettyAttr:
53
+ """
54
+ Configuration for pretty representation of attributes.
55
+ """
56
+ key: str
57
+ value: Union[str, Any]
58
+ start: str = ''
59
+ end: str = ''
60
+
61
+
62
+ class PrettyRepr(ABC):
63
+ """
64
+ Interface for pretty representation of objects.
65
+
66
+ Example::
67
+
68
+ >>> class MyObject(PrettyRepr):
69
+ >>> def __pretty_repr__(self):
70
+ >>> yield PrettyType(type='MyObject', start='{', end='}')
71
+ >>> yield PrettyAttr('key', self.key)
72
+ >>> yield PrettyAttr('value', self.value)
73
+
74
+ """
75
+ __slots__ = ()
76
+
77
+ @abstractmethod
78
+ def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
79
+ raise NotImplementedError
80
+
81
+ def __repr__(self) -> str:
82
+ # repr the individual object with the pretty representation
83
+ return get_repr(self)
84
+
85
+
86
+ def _repr_elem(obj: PrettyType, elem: Any) -> str:
87
+ if not isinstance(elem, PrettyAttr):
88
+ raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
89
+
90
+ value = elem.value if isinstance(elem.value, str) else repr(elem.value)
91
+ value = value.replace('\n', '\n' + obj.elem_indent)
92
+
93
+ return f'{obj.elem_indent}{elem.start}{elem.key}{obj.value_sep}{value}{elem.end}'
94
+
95
+
96
+ def get_repr(obj: PrettyRepr) -> str:
97
+ """
98
+ Get the pretty representation of an object.
99
+ """
100
+ if not isinstance(obj, PrettyRepr):
101
+ raise TypeError(f'Object {obj!r} is not representable')
102
+
103
+ iterator = obj.__pretty_repr__()
104
+ obj_repr = next(iterator)
105
+
106
+ # repr object
107
+ if not isinstance(obj_repr, PrettyType):
108
+ raise TypeError(f'First item must be PrettyType, got {type(obj_repr).__name__}')
109
+
110
+ # repr attributes
111
+ elem_reprs = tuple(map(partial(_repr_elem, obj_repr), iterator))
112
+ elems = ',\n'.join(elem_reprs)
113
+ if elems:
114
+ elems = '\n' + elems + '\n'
115
+ else:
116
+ elems = obj_repr.empty_repr
117
+
118
+ # repr object type
119
+ type_repr = obj_repr.type if isinstance(obj_repr.type, str) else obj_repr.type.__name__
120
+
121
+ # return repr
122
+ return f'{type_repr}{obj_repr.start}{elems}{obj_repr.end}'
123
+
124
+
125
+ class MappingReprMixin(Mapping[A, B]):
126
+ """
127
+ Mapping mixin for pretty representation.
128
+ """
129
+
130
+ def __pretty_repr__(self):
131
+ yield PrettyType(type='', value_sep=': ', start='{', end='}')
132
+
133
+ for key, value in self.items():
134
+ yield PrettyAttr(repr(key), value)
135
+
136
+
137
+ @dataclasses.dataclass(repr=False)
138
+ class PrettyMapping(PrettyRepr):
139
+ """
140
+ Pretty representation of a mapping.
141
+ """
142
+ mapping: Mapping
143
+
144
+ def __pretty_repr__(self):
145
+ yield PrettyType(type='', value_sep=': ', start='{', end='}')
146
+
147
+ for key, value in self.mapping.items():
148
+ yield PrettyAttr(repr(key), value)
149
+
150
+
151
+ @dataclasses.dataclass
152
+ class PrettyReprContext(threading.local):
153
+ # seen_modules_repr: set[int] | None = None
154
+ seen_modules_repr: dict[int, Any] | None = None
155
+
156
+
157
+ CONTEXT = PrettyReprContext()
158
+
159
+
160
+ def _default_repr_object(node):
161
+ yield PrettyType(type=type(node))
162
+
163
+
164
+ def _default_repr_attr(node):
165
+ for name, value in vars(node).items():
166
+ if name.startswith('_'):
167
+ continue
168
+ yield PrettyAttr(name, repr(value))
169
+
170
+
171
+ def pretty_repr_avoid_duplicate(
172
+ node,
173
+ repr_object: Optional[Callable] = None,
174
+ repr_attr: Optional[Callable] = None
175
+ ):
176
+ """
177
+ Pretty representation of an object avoiding duplicate representations.
178
+ """
179
+ if repr_object is None:
180
+ repr_object = _default_repr_object
181
+ if repr_attr is None:
182
+ repr_attr = _default_repr_attr
183
+
184
+ if CONTEXT.seen_modules_repr is None:
185
+ # CONTEXT.seen_modules_repr = set()
186
+ CONTEXT.seen_modules_repr = dict()
187
+ clear_seen = True
188
+ else:
189
+ clear_seen = False
190
+
191
+ # Avoid infinite recursion
192
+ if id(node) in CONTEXT.seen_modules_repr:
193
+ yield PrettyType(type=type(node), empty_repr='...')
194
+ return
195
+
196
+ # repr object
197
+ yield from repr_object(node)
198
+
199
+ # Add to seen modules
200
+ # CONTEXT.seen_modules_repr.add(id(node))
201
+ CONTEXT.seen_modules_repr[id(node)] = node
202
+
203
+ try:
204
+ # repr attributes
205
+ yield from repr_attr(node)
206
+ finally:
207
+ if clear_seen:
208
+ CONTEXT.seen_modules_repr = None