brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,151 +1,171 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import functools
17
- from typing import Callable, Sequence, Union
18
-
19
- from brainstate.random import DEFAULT, RandomState
20
- from brainstate.typing import Missing
21
- from brainstate.util import PrettyObject
22
-
23
- __all__ = [
24
- 'restore_rngs'
25
- ]
26
-
27
-
28
- class RngRestore(PrettyObject):
29
- """
30
- Backup and restore the random state of a sequence of RandomState instances.
31
-
32
- This class provides functionality to save the current state of multiple
33
- RandomState instances and later restore them to their saved states.
34
-
35
- Attributes:
36
- rngs (Sequence[RandomState]): A sequence of RandomState instances to manage.
37
- rng_keys (list): A list to store the backed up random keys.
38
- """
39
-
40
- def __init__(self, rngs: Sequence[RandomState]):
41
- """
42
- Initialize the RngRestore instance.
43
-
44
- Args:
45
- rngs (Sequence[RandomState]): A sequence of RandomState instances
46
- whose states will be managed.
47
- """
48
- self.rngs: Sequence[RandomState] = rngs
49
- self.rng_keys = []
50
-
51
- def backup(self):
52
- """
53
- Backup the current random key of the RandomState instances.
54
-
55
- This method saves the current value (state) of each RandomState
56
- instance in the rngs sequence.
57
- """
58
- self.rng_keys = [rng.value for rng in self.rngs]
59
-
60
- def restore(self):
61
- """
62
- Restore the random key of the RandomState instances.
63
-
64
- This method restores each RandomState instance to its previously
65
- saved state. It raises an error if the number of saved keys doesn't
66
- match the number of RandomState instances.
67
-
68
- Raises:
69
- ValueError: If the number of saved random keys does not match
70
- the number of RandomState instances.
71
- """
72
- if len(self.rng_keys) != len(self.rngs):
73
- raise ValueError('The number of random keys does not match the number of random states.')
74
- for rng, key in zip(self.rngs, self.rng_keys):
75
- rng.restore_value(key)
76
- self.rng_keys.clear()
77
-
78
-
79
- def _rng_backup(
80
- fn: Callable,
81
- rngs: Union[RandomState, Sequence[RandomState]]
82
- ) -> Callable:
83
- rng_restorer = RngRestore(rngs)
84
-
85
- @functools.wraps(fn)
86
- def wrapper(*args, **kwargs):
87
- # backup the random state
88
- rng_restorer.backup()
89
- # call the function
90
- out = fn(*args, **kwargs)
91
- # restore the random state
92
- rng_restorer.restore()
93
- return out
94
-
95
- return wrapper
96
-
97
-
98
- def restore_rngs(
99
- fn: Callable = Missing(),
100
- rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
101
- ) -> Callable:
102
- """
103
- Decorator to backup and restore the random state before and after a function call.
104
-
105
- This function can be used as a decorator or called directly. It ensures that the
106
- random state of the specified RandomState instances is preserved across function calls,
107
- which is useful for maintaining reproducibility in stochastic operations.
108
-
109
- Parameters
110
- ----------
111
- fn : Callable, optional
112
- The function to be wrapped. If not provided, the decorator can be used
113
- with parameters.
114
- rngs : Union[RandomState, Sequence[RandomState]], optional
115
- The random state(s) to be backed up and restored. This can be a single
116
- RandomState instance or a sequence of RandomState instances. If not provided,
117
- the default RandomState instance will be used.
118
-
119
- Returns
120
- -------
121
- Callable
122
- If `fn` is provided, returns the wrapped function that will backup the
123
- random state before execution and restore it afterwards.
124
- If `fn` is not provided, returns a partial function that can be used as
125
- a decorator with the specified `rngs`.
126
-
127
- Raises
128
- ------
129
- AssertionError
130
- If `rngs` is not a RandomState instance or a sequence of RandomState instances.
131
-
132
- Examples
133
- --------
134
- >>> @restore_rngs
135
- ... def my_random_function():
136
- ... return random.random()
137
-
138
- >>> rng = RandomState(42)
139
- >>> @restore_rngs(rngs=rng)
140
- ... def another_random_function():
141
- ... return rng.random()
142
- """
143
- if isinstance(fn, Missing):
144
- return functools.partial(restore_rngs, rngs=rngs)
145
-
146
- if isinstance(rngs, RandomState):
147
- rngs = [rngs]
148
- assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
149
- for rng in rngs:
150
- assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
151
- return _rng_backup(fn, rngs=rngs)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from typing import Callable, Sequence, Union
18
+
19
+ from brainstate._utils import set_module_as
20
+ from brainstate.random import DEFAULT, RandomState
21
+ from brainstate.typing import Missing
22
+ from brainstate.util import PrettyObject
23
+
24
+ __all__ = [
25
+ 'restore_rngs'
26
+ ]
27
+
28
+
29
+ class RngRestore(PrettyObject):
30
+ """
31
+ Manage backing up and restoring multiple random states.
32
+
33
+ Parameters
34
+ ----------
35
+ rngs : Sequence[RandomState]
36
+ Sequence of :class:`~brainstate.random.RandomState` instances whose
37
+ states should be captured and restored.
38
+
39
+ Attributes
40
+ ----------
41
+ rngs : Sequence[RandomState]
42
+ Managed random-state instances.
43
+ rng_keys : list
44
+ Cached keys captured by :meth:`backup` until :meth:`restore` runs.
45
+
46
+ Examples
47
+ --------
48
+ .. code-block:: python
49
+
50
+ >>> import brainstate
51
+ >>>
52
+ >>> rng = brainstate.random.RandomState(0)
53
+ >>> restorer = brainstate.transform.RngRestore([rng])
54
+ >>> restorer.backup()
55
+ >>> _ = rng.random()
56
+ >>> restorer.restore()
57
+ """
58
+ __module__ = 'brainstate.transform'
59
+
60
+ def __init__(self, rngs: Sequence[RandomState]):
61
+ """
62
+ Initialize a restorer for the provided random states.
63
+
64
+ Parameters
65
+ ----------
66
+ rngs : Sequence[RandomState]
67
+ Random states that will be backed up and restored.
68
+ """
69
+ self.rngs: Sequence[RandomState] = rngs
70
+ self.rng_keys = []
71
+
72
+ def backup(self):
73
+ """
74
+ Cache the current key for each managed random state.
75
+
76
+ Notes
77
+ -----
78
+ The cached keys persist until :meth:`restore` is called, after which the
79
+ internal cache is cleared.
80
+ """
81
+ self.rng_keys = [rng.value for rng in self.rngs]
82
+
83
+ def restore(self):
84
+ """
85
+ Restore each random state to the cached key.
86
+
87
+ Raises
88
+ ------
89
+ ValueError
90
+ Raised when the number of stored keys does not match ``rngs``.
91
+ """
92
+ if len(self.rng_keys) != len(self.rngs):
93
+ raise ValueError('The number of random keys does not match the number of random states.')
94
+ for rng, key in zip(self.rngs, self.rng_keys):
95
+ rng.restore_value(key)
96
+ self.rng_keys.clear()
97
+
98
+
99
+ def _rng_backup(
100
+ fn: Callable,
101
+ rngs: Union[RandomState, Sequence[RandomState]]
102
+ ) -> Callable:
103
+ rng_restorer = RngRestore(rngs)
104
+
105
+ @functools.wraps(fn)
106
+ def wrapper(*args, **kwargs):
107
+ # backup the random state
108
+ rng_restorer.backup()
109
+ # call the function
110
+ out = fn(*args, **kwargs)
111
+ # restore the random state
112
+ rng_restorer.restore()
113
+ return out
114
+
115
+ return wrapper
116
+
117
+
118
+ @set_module_as('brainstate.transform')
119
+ def restore_rngs(
120
+ fn: Callable = Missing(),
121
+ rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
122
+ ) -> Callable:
123
+ """
124
+ Decorate a function so specified random states are restored after execution.
125
+
126
+ Parameters
127
+ ----------
128
+ fn : Callable, optional
129
+ Function to wrap. When omitted, :func:`restore_rngs` returns a decorator
130
+ preconfigured with ``rngs``.
131
+ rngs : Union[RandomState, Sequence[RandomState]], optional
132
+ Random states whose keys should be backed up before running ``fn`` and
133
+ restored afterwards. Defaults to :data:`brainstate.random.DEFAULT`.
134
+
135
+ Returns
136
+ -------
137
+ Callable
138
+ Wrapped callable that restores the random state or a partially applied
139
+ decorator depending on how :func:`restore_rngs` is used.
140
+
141
+ Raises
142
+ ------
143
+ AssertionError
144
+ If ``rngs`` is neither a :class:`~brainstate.random.RandomState` instance nor
145
+ a sequence of such instances.
146
+
147
+ Examples
148
+ --------
149
+ .. code-block:: python
150
+
151
+ >>> import brainstate
152
+ >>>
153
+ >>> rng = brainstate.random.RandomState(0)
154
+ >>>
155
+ >>> @brainstate.transform.restore_rngs(rngs=rng)
156
+ ... def sample_pair():
157
+ ... first = rng.random()
158
+ ... second = rng.random()
159
+ ... return first, second
160
+ >>>
161
+ >>> assert sample_pair()[0] == sample_pair()[0]
162
+ """
163
+ if isinstance(fn, Missing):
164
+ return functools.partial(restore_rngs, rngs=rngs)
165
+
166
+ if isinstance(rngs, RandomState):
167
+ rngs = [rngs]
168
+ assert isinstance(rngs, Sequence), 'rngs must be a RandomState or a sequence of RandomState instances.'
169
+ for rng in rngs:
170
+ assert isinstance(rng, RandomState), 'rngs must be a RandomState or a sequence of RandomState instances.'
171
+ return _rng_backup(fn, rngs=rngs)