brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax.numpy as jnp
19
+
20
+ import brainstate as bc
21
+
22
+
23
+ class TestVarDelay(unittest.TestCase):
24
+ def test_delay1(self):
25
+ a = bc.State(bc.random.random(10, 20))
26
+ delay = bc.Delay(a.value)
27
+ delay.register_entry('a', 1.)
28
+ delay.register_entry('b', 2.)
29
+ delay.register_entry('c', None)
30
+
31
+ delay.init_state()
32
+ with self.assertRaises(KeyError):
33
+ delay.register_entry('c', 10.)
34
+ bc.util.clear_buffer_memory()
35
+
36
+ def test_rotation_delay(self):
37
+ rotation_delay = bc.Delay(jnp.ones((1,)))
38
+ t0 = 0.
39
+ t1, n1 = 1., 10
40
+ t2, n2 = 2., 20
41
+
42
+ rotation_delay.register_entry('a', t0)
43
+ rotation_delay.register_entry('b', t1)
44
+ rotation_delay.register_entry('c2', 1.9)
45
+ rotation_delay.register_entry('c', t2)
46
+
47
+ rotation_delay.init_state()
48
+
49
+ print()
50
+ # print(rotation_delay)
51
+ # print(rotation_delay.max_length)
52
+
53
+ for i in range(100):
54
+ bc.environ.set(i=i)
55
+ rotation_delay(jnp.ones((1,)) * i)
56
+ # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
57
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
58
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
59
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
60
+ bc.util.clear_buffer_memory()
61
+
62
+ def test_concat_delay(self):
63
+ rotation_delay = bc.Delay(jnp.ones([1]), method='concat')
64
+ t0 = 0.
65
+ t1, n1 = 1., 10
66
+ t2, n2 = 2., 20
67
+
68
+ rotation_delay.register_entry('a', t0)
69
+ rotation_delay.register_entry('b', t1)
70
+ rotation_delay.register_entry('c', t2)
71
+
72
+ rotation_delay.init_state()
73
+
74
+ print()
75
+ for i in range(100):
76
+ bc.environ.set(i=i)
77
+ rotation_delay(jnp.ones((1,)) * i)
78
+ print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
79
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
80
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
81
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
82
+ bc.util.clear_buffer_memory()
83
+
84
+ def test_rotation_and_concat_delay(self):
85
+ rotation_delay = bc.Delay(jnp.ones((1,)))
86
+ concat_delay = bc.Delay(jnp.ones([1]), method='concat')
87
+ t0 = 0.
88
+ t1, n1 = 1., 10
89
+ t2, n2 = 2., 20
90
+
91
+ rotation_delay.register_entry('a', t0)
92
+ rotation_delay.register_entry('b', t1)
93
+ rotation_delay.register_entry('c', t2)
94
+ concat_delay.register_entry('a', t0)
95
+ concat_delay.register_entry('b', t1)
96
+ concat_delay.register_entry('c', t2)
97
+
98
+ rotation_delay.init_state()
99
+ concat_delay.init_state()
100
+
101
+ print()
102
+ for i in range(100):
103
+ bc.environ.set(i=i)
104
+ new = jnp.ones((1,)) * i
105
+ rotation_delay(new)
106
+ concat_delay(new)
107
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
108
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
109
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
110
+ bc.util.clear_buffer_memory()
111
+
112
+
113
+ class TestModule(unittest.TestCase):
114
+ def test_states(self):
115
+ class A(bc.Module):
116
+ def __init__(self):
117
+ super().__init__()
118
+ self.a = bc.State(bc.random.random(10, 20))
119
+ self.b = bc.State(bc.random.random(10, 20))
120
+
121
+ class B(bc.Module):
122
+ def __init__(self):
123
+ super().__init__()
124
+ self.a = A()
125
+ self.b = bc.State(bc.random.random(10, 20))
126
+
127
+ b = B()
128
+ print()
129
+ print(b.states())
130
+ print(b.states())
131
+ print(b.states(level=0))
132
+ print(b.states(level=0))
133
+
brainstate/_state.py ADDED
@@ -0,0 +1,378 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import contextlib
17
+ import threading
18
+ from typing import Any, Tuple, Dict, List, Callable
19
+
20
+ import jax
21
+ import numpy as np
22
+ from jax.api_util import shaped_abstractify
23
+ from jax.extend import source_info_util
24
+
25
+ from .util import DictManager
26
+
27
+ PyTree = Any
28
+ max_int = np.iinfo(np.int32)
29
+
30
+ __all__ = [
31
+ 'State', 'ShortTermState', 'LongTermState', 'ParamState',
32
+ 'StateDictManager', 'visible_state_dict',
33
+ 'check_state_value_tree',
34
+ ]
35
+
36
+ _pytree_registered_objects = set()
37
+
38
+
39
+ def _register_pytree_cls(cls):
40
+ if cls not in _pytree_registered_objects:
41
+ jax.tree_util.register_pytree_node_class(cls)
42
+ _pytree_registered_objects.add(cls)
43
+
44
+
45
+ # The global state of the state stack is accessed by a thread-local object.
46
+ # This allows concurrent tracing in separate threads; passing traced objects
47
+ # between threads is forbidden.
48
+ class ThreadLocalStack(threading.local):
49
+ def __init__(self):
50
+ self.stack: List[StateTrace] = []
51
+
52
+
53
+ thread_local_stack = ThreadLocalStack()
54
+
55
+ _global_context_to_check_state_tree = [False]
56
+
57
+
58
+ @contextlib.contextmanager
59
+ def check_state_value_tree() -> None:
60
+ """
61
+ The contex manager to check weather the tree structure of the state value keeps consistently.
62
+ """
63
+ try:
64
+ _global_context_to_check_state_tree.append(True)
65
+ yield
66
+ finally:
67
+ _global_context_to_check_state_tree.pop()
68
+
69
+
70
+ class State(object):
71
+ """
72
+ The pointer to specify the dynamical data.
73
+
74
+ To implement a new subclass of :py:class:`~.State`, you only need to inherent this class:
75
+
76
+ Example::
77
+
78
+ class MyState(State):
79
+ pass
80
+
81
+ The typical examples of :py:class:`~.State` subclass are:
82
+
83
+ - :py:class:`~.ShortTermState`: The short-term state, which is used to store the short-term data in the program.
84
+ - :py:class:`~.LongTermState`: The long-term state, which is used to store the long-term data in the program.
85
+ - :py:class:`~.ParamState`: The parameter state, which is used to store the parameters in the program.
86
+ - :py:class:`~.RandomState`: The random generator state, which is used to store the random key in the program.
87
+
88
+ Args:
89
+ value: PyTree. It can be anything as a pyTree.
90
+ """
91
+ __module__ = 'brainstate'
92
+ __slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree')
93
+
94
+ def __init__(self, value: PyTree):
95
+ if isinstance(value, State):
96
+ value = value.value
97
+ self._value = value
98
+ self._tree = jax.tree.structure(value)
99
+ self._check_tree = False
100
+ self._level = len(thread_local_stack.stack)
101
+ self._source_info = source_info_util.current()
102
+
103
+ @property
104
+ def value(self) -> PyTree:
105
+ """
106
+ The data and its value.
107
+ """
108
+ self._check_if_deleted()
109
+
110
+ # read the value by the stack (>= level)
111
+ trace: StateTrace
112
+ for trace in thread_local_stack.stack[self._level:]:
113
+ trace.read_its_value(self)
114
+ # return the value
115
+ return self._value
116
+
117
+ @value.setter
118
+ def value(self, v) -> None:
119
+ """
120
+ Set the value of the state.
121
+
122
+ Args:
123
+ v: The value.
124
+ """
125
+ # value checking
126
+ v = v.value if isinstance(v, State) else v
127
+ self._check_value(v)
128
+ # write the value by the stack (>= level)
129
+ trace: StateTrace
130
+ for trace in thread_local_stack.stack[self._level:]:
131
+ trace.write_its_value(self)
132
+ # set the value
133
+ self._value = v
134
+
135
+ def _check_value(self, v):
136
+ if self._check_tree or _global_context_to_check_state_tree[-1]:
137
+ in_tree = jax.tree_util.tree_structure(v)
138
+ if in_tree != self._tree:
139
+ self._raise_error_with_source_info(
140
+ ValueError(f'The given value {in_tree} does not '
141
+ f'match with the origin tree structure '
142
+ f'{self._tree}.')
143
+ )
144
+
145
+ def _raise_error_with_source_info(self, error: Exception):
146
+ name_stack = source_info_util.current_name_stack() + self.source_info.name_stack
147
+ with source_info_util.user_context(self.source_info.traceback, name_stack=name_stack):
148
+ raise error
149
+
150
+ def _check_if_deleted(self):
151
+ pass
152
+
153
+ @property
154
+ def source_info(self) -> source_info_util.SourceInfo:
155
+ """
156
+ The source information of the state, can be useful to identify
157
+ the source code where the definition of the state.
158
+
159
+ Returns:
160
+ The source information.
161
+ """
162
+ return self._source_info
163
+
164
+ def tree_flatten(self):
165
+ """Flattens this variable.
166
+
167
+ Returns:
168
+ A pair where the first element is a list of leaf values
169
+ and the second element is a treedef representing the
170
+ structure of the flattened tree.
171
+ """
172
+ return (self._value,), (self._level,)
173
+
174
+ @classmethod
175
+ def tree_unflatten(cls, aux_data, flat_contents):
176
+ """Reconstructs a variable from the aux_data and the leaves.
177
+
178
+ Args:
179
+ aux_data:
180
+ flat_contents:
181
+
182
+ Returns:
183
+ The variable.
184
+ """
185
+ (_level,) = aux_data
186
+ self = cls(flat_contents[0])
187
+ self._level = max_int
188
+ return self
189
+
190
+ def __repr__(self):
191
+ leaves, tree = jax.tree.flatten(self._value)
192
+ leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
193
+ tree_info = jax.tree.unflatten(tree, leaves_info)
194
+ return f'{self.__class__.__name__}({tree_info})'
195
+
196
+
197
+ class ShapeDtype:
198
+ def __init__(self, shape, dtype):
199
+ self.shape = shape
200
+ self.dtype = dtype
201
+
202
+ def __repr__(self):
203
+ return f'{self.dtype}{list(self.shape)}'
204
+
205
+
206
+ class ShortTermState(State):
207
+ """
208
+ The short-term state, which is used to store the short-term data in the program.
209
+
210
+ For example, in a training process, the gradients of the model are short-term states.
211
+ """
212
+
213
+ __module__ = 'brainstate'
214
+
215
+
216
+ class LongTermState(State):
217
+ """
218
+ The long-term state, which is used to store the long-term data in the program.
219
+
220
+ For example, in a training process, the weights of the model are long-term states.
221
+
222
+ """
223
+
224
+ __module__ = 'brainstate'
225
+
226
+
227
+ class ParamState(LongTermState):
228
+ """
229
+ The parameter state, which is used to store the trainable parameters in the model.
230
+ """
231
+ __module__ = 'brainstate'
232
+
233
+
234
+ class StateDictManager(DictManager):
235
+ """
236
+ State stack, for collecting all :py:class:`~.State` used in the program.
237
+
238
+ :py:class:`~.StateDictManager` supports all features of python dict.
239
+ """
240
+
241
+ __module__ = 'brainstate'
242
+
243
+ def assign_values(self, *args: Dict) -> None:
244
+ """
245
+ Assign the value for each element according to the given ``data``.
246
+ """
247
+ for arg in args:
248
+ assert isinstance(arg, dict), 'Must be an instance of dict.'
249
+ for k, v in arg.items():
250
+ self._set_elem(k, v)
251
+
252
+ def split_values(self, *filters: type) -> Tuple[Dict, ...]:
253
+ """
254
+ Split the values into several subsets of stack by the given types.
255
+ """
256
+ results = tuple(DictManager() for _ in range(len(filters) + 1))
257
+ for k, v in self.items():
258
+ for i, filt in enumerate(filters):
259
+ if isinstance(v, filt):
260
+ results[i][k] = v.value
261
+ break
262
+ else:
263
+ results[-1][k] = v.value
264
+ return results
265
+
266
+ def collect_values(self) -> Dict:
267
+ """
268
+ Collect the values by the given types.
269
+ """
270
+ results = DictManager()
271
+ for k, v in self.items():
272
+ results[k] = v.value
273
+ return results
274
+
275
+ def split(self, first: type, *others: type) -> Tuple['StateDictManager', ...]:
276
+ return super().split(first, *others)
277
+
278
+ def to_dict_values(self) -> Dict:
279
+ """
280
+ Convert the values into a dict.
281
+ """
282
+ return {k: v.value for k, v in self.items()}
283
+
284
+ def _check_elem(self, elem):
285
+ assert isinstance(elem, State), f'must be instance of {State}'
286
+
287
+ def _set_elem(self, key: Any, value: Any) -> None:
288
+ self[key].value = value
289
+
290
+
291
+ class visible_state_dict(StateDictManager):
292
+ """
293
+ The state dictionary whose elements are visible to ``.states()`` collection functions.
294
+ """
295
+ pass
296
+
297
+
298
+ class StateTrace(object):
299
+ """
300
+ The state trace, which is used to trace the states automatically.
301
+ """
302
+
303
+ def __init__(self, new_arg: Callable = None):
304
+ self.states: List[State] = []
305
+ self.types: List[str] = []
306
+ self._id2index = dict()
307
+ self._org_values = []
308
+ self._jax_trace_new_arg = new_arg
309
+ self._written_ids = set()
310
+
311
+ def set_new_arg(self, new_arg: Callable) -> None:
312
+ self._jax_trace_new_arg = new_arg
313
+
314
+ def new_arg(self, state: State) -> None:
315
+ if self._jax_trace_new_arg is not None:
316
+ # internal use
317
+ state._value = jax.tree.map(lambda x: self._jax_trace_new_arg(shaped_abstractify(x)), state._value)
318
+
319
+ def __enter__(self) -> 'StateTrace':
320
+ thread_local_stack.stack.append(self)
321
+ return self
322
+
323
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
324
+ thread_local_stack.stack.pop()
325
+
326
+ def read_its_value(self, state: State) -> None:
327
+ """
328
+ Read the value of the state.
329
+
330
+ Args:
331
+ state: The state.
332
+ """
333
+ id_ = id(state)
334
+ if id_ not in self._id2index:
335
+ self._id2index[id_] = len(self.states)
336
+ self.states.append(state)
337
+ self.types.append('read')
338
+ self._org_values.append(state._value) # internal use
339
+ self.new_arg(state)
340
+
341
+ def write_its_value(self, state: State) -> None:
342
+ """
343
+ Write the value of the state.
344
+
345
+ Args:
346
+ state: The state.
347
+ """
348
+ id_ = id(state)
349
+ if id_ not in self._id2index:
350
+ self.read_its_value(state)
351
+ if id_ not in self._written_ids:
352
+ index = self._id2index[id_]
353
+ self.types[index] = 'write'
354
+ self._written_ids.add(id_)
355
+
356
+ def collect_values(self, *categories: str) -> Tuple:
357
+ """
358
+ Collect the values by the given categories.
359
+
360
+ Args:
361
+ *categories: The categories.
362
+
363
+ Returns:
364
+ results: The values.
365
+ """
366
+ results = []
367
+ for st, ty in zip(self.states, self.types):
368
+ if ty in categories:
369
+ results.append(st.value)
370
+ return tuple(results)
371
+
372
+ def recovery_original_values(self) -> None:
373
+ """
374
+ Recovery the original values.
375
+ """
376
+ for st, val in zip(self.states, self._org_values):
377
+ # internal use
378
+ st._value = val
@@ -0,0 +1,41 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import brainstate as bc
20
+
21
+
22
+ class TestStateSourceInfo(unittest.TestCase):
23
+
24
+ def test_state_source_info(self):
25
+ state = bc.State(bc.random.randn(10))
26
+ print(state._source_info)
27
+
28
+
29
+ class TestStateRepr(unittest.TestCase):
30
+
31
+ def test_state_repr(self):
32
+ print()
33
+
34
+ state = bc.State(bc.random.randn(10))
35
+ print(state)
36
+
37
+ state2 = bc.State({'a': bc.random.randn(10), 'b': bc.random.randn(10)})
38
+ print(state2)
39
+
40
+ state3 = bc.State([bc.random.randn(10), bc.random.randn(10)])
41
+ print(state3)
brainstate/_utils.py ADDED
@@ -0,0 +1,21 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ def set_module_as(module: str):
17
+ def wrapper(fun: callable):
18
+ fun.__module__ = module
19
+ return fun
20
+
21
+ return wrapper