brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from jax import vmap
22
+ from jax.lax import psum, pmean, pmax
23
+
24
+ import brainstate
25
+ import brainstate.transform
26
+ from brainstate._error import BatchAxisError
27
+
28
+
29
+
30
+ class TestMap(unittest.TestCase):
31
+ def test_map(self):
32
+ for dim in [(10,), (10, 10), (10, 10, 10)]:
33
+ x = brainstate.random.rand(*dim)
34
+ r1 = brainstate.transform.map(lambda a: a + 1, x, batch_size=None)
35
+ r2 = brainstate.transform.map(lambda a: a + 1, x, batch_size=2)
36
+ r3 = brainstate.transform.map(lambda a: a + 1, x, batch_size=4)
37
+ r4 = brainstate.transform.map(lambda a: a + 1, x, batch_size=5)
38
+ true_r = x + 1
39
+
40
+ self.assertTrue(jnp.allclose(r1, true_r))
41
+ self.assertTrue(jnp.allclose(r2, true_r))
42
+ self.assertTrue(jnp.allclose(r3, true_r))
43
+ self.assertTrue(jnp.allclose(r4, true_r))
44
+
45
+
46
+ class TestAxisName:
47
+ def test1(self):
48
+ def compute_stats_with_axis_name(x):
49
+ """Compute statistics using named axis operations"""
50
+ # Sum across the named axis 'batch'
51
+ total_sum = psum(x, axis_name='batch')
52
+
53
+ # Mean across the named axis 'batch'
54
+ mean_val = pmean(x, axis_name='batch')
55
+
56
+ # Max across the named axis 'batch'
57
+ max_val = pmax(x, axis_name='batch')
58
+
59
+ return {
60
+ 'sum': total_sum,
61
+ 'mean': mean_val,
62
+ 'max': max_val,
63
+ 'original': x
64
+ }
65
+
66
+ batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
67
+ print("Input batch data:", batch_data)
68
+
69
+ # vmap with axis name 'batch'
70
+ vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
71
+ result_jax = vectorized_stats_jax(batch_data)
72
+
73
+ # vmap with axis name 'batch'
74
+ vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
75
+ result = vectorized_stats(batch_data)
76
+
77
+ # vmap with axis name 'batch'
78
+ vectorized_stats_v2 = brainstate.transform.jit(
79
+ brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
80
+ )
81
+ result_v2 = vectorized_stats_v2(batch_data)
82
+
83
+ for key in result_jax.keys():
84
+ print(f" {key}: {result_jax[key]}")
85
+ assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
86
+ assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
87
+
88
+ def test_nested_vmap(self):
89
+ def nested_computation(x):
90
+ """Computation with multiple named axes"""
91
+ # Sum over 'inner' axis, then mean over 'outer' axis
92
+ inner_sum = psum(x, axis_name='inner')
93
+ outer_mean = pmean(inner_sum, axis_name='outer')
94
+ return outer_mean
95
+
96
+ # Create 2D batch data
97
+ data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
98
+ print("Input 2D data shape:", data_2d.shape)
99
+ print("Input 2D data:\n", data_2d)
100
+
101
+ # Nested vmap: first over inner dimension, then outer dimension
102
+ inner_vmap = vmap(nested_computation, axis_name='inner')
103
+ nested_vmap = vmap(inner_vmap, axis_name='outer')
104
+
105
+ result_2d = nested_vmap(data_2d)
106
+ print("Result after nested vmap:", result_2d)
107
+
108
+ inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
109
+ nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
110
+ result_2d_bst = nested_vmap_bst(data_2d)
111
+ print("Result after nested vmap:", result_2d_bst)
112
+
113
+ assert jnp.allclose(result_2d, result_2d_bst)
114
+
115
+ def _gradient_averaging_simulation_bst(self):
116
+ def loss_function(params, x, y):
117
+ """Simple quadratic loss"""
118
+ pred = params * x
119
+ return (pred - y) ** 2
120
+
121
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
122
+ """Compute gradients and average them across the batch"""
123
+ # Compute per-sample gradients
124
+ grad_fn = jax.grad(loss_function, argnums=0)
125
+ per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
126
+
127
+ # Average gradients across batch using named axis
128
+ def average_grads(grads):
129
+ return pmean(grads, axis_name='batch')
130
+
131
+ # Apply averaging with named axis
132
+ averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
133
+ return averaged_grads
134
+
135
+ # Example data
136
+ params = 2.0
137
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
138
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
139
+
140
+ print("Parameters:", params)
141
+ print("Batch X:", batch_x)
142
+ print("Batch Y:", batch_y)
143
+
144
+ # Compute individual gradients first
145
+ grad_fn = jax.grad(loss_function, argnums=0)
146
+ individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
147
+ print("Individual gradients:", individual_grads)
148
+
149
+ # Now compute averaged gradients using axis names
150
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
151
+ print("Averaged gradients:", averaged_grads)
152
+
153
+ return individual_grads, averaged_grads
154
+
155
+ def _gradient_averaging_simulation_jax(self):
156
+ def loss_function(params, x, y):
157
+ """Simple quadratic loss"""
158
+ pred = params * x
159
+ return (pred - y) ** 2
160
+
161
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
162
+ """Compute gradients and average them across the batch"""
163
+ # Compute per-sample gradients
164
+ grad_fn = jax.grad(loss_function, argnums=0)
165
+ per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
166
+
167
+ # Average gradients across batch using named axis
168
+ def average_grads(grads):
169
+ return pmean(grads, axis_name='batch')
170
+
171
+ # Apply averaging with named axis
172
+ averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
173
+ return averaged_grads
174
+
175
+ # Example data
176
+ params = 2.0
177
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
178
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
179
+
180
+ print("Parameters:", params)
181
+ print("Batch X:", batch_x)
182
+ print("Batch Y:", batch_y)
183
+
184
+ # Compute individual gradients first
185
+ grad_fn = jax.grad(loss_function, argnums=0)
186
+ individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
187
+ print("Individual gradients:", individual_grads)
188
+
189
+ # Now compute averaged gradients using axis names
190
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
191
+ print("Averaged gradients:", averaged_grads)
192
+
193
+ return individual_grads, averaged_grads
194
+
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -41,42 +41,95 @@ class ProgressBar(object):
41
41
 
42
42
  The message displayed in the progress bar can be customized by the following two methods:
43
43
 
44
- 1. By passing a string to the `desc` argument. For example:
44
+ 1. By passing a string to the `desc` argument.
45
+ 2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
46
+ function should take a dictionary as input and return a dictionary. The returned dictionary
47
+ will be used to format the string.
48
+
49
+ In the second case, ``"i"`` denotes the iteration number and other keys can be computed from the
50
+ loop outputs and carry values.
51
+
52
+ Parameters
53
+ ----------
54
+ freq : int, optional
55
+ The frequency at which to print the progress bar. If not specified, the progress
56
+ bar will be printed every 5% of the total iterations.
57
+ count : int, optional
58
+ The number of times to print the progress bar. If not specified, the progress
59
+ bar will be printed every 5% of the total iterations. Cannot be used together with `freq`.
60
+ desc : str or tuple, optional
61
+ A description of the progress bar. If not specified, a default message will be
62
+ displayed. Can be either a string or a tuple of (format_string, format_function).
63
+ **kwargs
64
+ Additional keyword arguments to pass to the progress bar.
65
+
66
+ Examples
67
+ --------
68
+ Basic usage with default description:
45
69
 
46
70
  .. code-block:: python
47
71
 
48
- ProgressBar(desc="Running 1000 iterations")
72
+ >>> import brainstate
73
+ >>> import jax.numpy as jnp
74
+ >>>
75
+ >>> def loop_fn(x):
76
+ ... return x ** 2
77
+ >>>
78
+ >>> xs = jnp.arange(100)
79
+ >>> pbar = brainstate.transform.ProgressBar()
80
+ >>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
49
81
 
50
- 2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
51
- function should take a dictionary as input and return a dictionary. The returned dictionary
52
- will be used to format the string. For example:
82
+ With custom description string:
83
+
84
+ .. code-block:: python
85
+
86
+ >>> pbar = brainstate.transform.ProgressBar(desc="Running 1000 iterations")
87
+ >>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
88
+
89
+ With frequency control:
53
90
 
54
91
  .. code-block:: python
55
92
 
56
- a = brainstate.State(1.)
57
- def loop_fn(x):
58
- a.value = x.value + 1.
59
- return jnp.sum(x ** 2)
93
+ >>> # Update every 10 iterations
94
+ >>> pbar = brainstate.transform.ProgressBar(freq=10)
95
+ >>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
96
+ >>>
97
+ >>> # Update exactly 20 times during execution
98
+ >>> pbar = brainstate.transform.ProgressBar(count=20)
99
+ >>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
60
100
 
61
- pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
62
- lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
101
+ With dynamic description based on loop variables:
63
102
 
64
- brainstate.compile.for_loop(loop_fn, xs, pbar=pbar)
103
+ .. code-block:: python
65
104
 
66
- In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
67
- the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
105
+ >>> state = brainstate.State(1.0)
106
+ >>>
107
+ >>> def loop_fn(x):
108
+ ... state.value += x
109
+ ... loss = jnp.sum(x ** 2)
110
+ ... return loss
111
+ >>>
112
+ >>> def format_desc(data):
113
+ ... return {"i": data["i"], "loss": data["y"], "state": data["carry"]}
114
+ >>>
115
+ >>> pbar = brainstate.transform.ProgressBar(
116
+ ... desc=("Iteration {i}, loss = {loss:.4f}, state = {state:.2f}", format_desc)
117
+ ... )
118
+ >>> results = brainstate.transform.for_loop(loop_fn, xs, pbar=pbar)
119
+
120
+ With scan function:
68
121
 
122
+ .. code-block:: python
69
123
 
70
- Args:
71
- freq: The frequency at which to print the progress bar. If not specified, the progress
72
- bar will be printed every 5% of the total iterations.
73
- count: The number of times to print the progress bar. If not specified, the progress
74
- bar will be printed every 5% of the total iterations.
75
- desc: A description of the progress bar. If not specified, a default message will be
76
- displayed.
77
- kwargs: Additional keyword arguments to pass to the progress bar.
124
+ >>> def scan_fn(carry, x):
125
+ ... new_carry = carry + x
126
+ ... return new_carry, new_carry ** 2
127
+ >>>
128
+ >>> init_carry = 0.0
129
+ >>> pbar = brainstate.transform.ProgressBar(freq=5)
130
+ >>> final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs, pbar=pbar)
78
131
  """
79
- __module__ = "brainstate.compile"
132
+ __module__ = "brainstate.transform"
80
133
 
81
134
  def __init__(
82
135
  self,
@@ -141,7 +194,7 @@ class ProgressBar(object):
141
194
 
142
195
 
143
196
  class ProgressBarRunner(object):
144
- __module__ = "brainstate.compile"
197
+ __module__ = "brainstate.transform"
145
198
 
146
199
  def __init__(
147
200
  self,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
16
16
  import functools
17
17
  from typing import Callable, Sequence, Union
18
18
 
19
+ from brainstate._utils import set_module_as
19
20
  from brainstate.random import DEFAULT, RandomState
20
21
  from brainstate.typing import Missing
21
22
  from brainstate.util import PrettyObject
@@ -27,47 +28,66 @@ __all__ = [
27
28
 
28
29
  class RngRestore(PrettyObject):
29
30
  """
30
- Backup and restore the random state of a sequence of RandomState instances.
31
+ Manage backing up and restoring multiple random states.
31
32
 
32
- This class provides functionality to save the current state of multiple
33
- RandomState instances and later restore them to their saved states.
33
+ Parameters
34
+ ----------
35
+ rngs : Sequence[RandomState]
36
+ Sequence of :class:`~brainstate.random.RandomState` instances whose
37
+ states should be captured and restored.
38
+
39
+ Attributes
40
+ ----------
41
+ rngs : Sequence[RandomState]
42
+ Managed random-state instances.
43
+ rng_keys : list
44
+ Cached keys captured by :meth:`backup` until :meth:`restore` runs.
34
45
 
35
- Attributes:
36
- rngs (Sequence[RandomState]): A sequence of RandomState instances to manage.
37
- rng_keys (list): A list to store the backed up random keys.
46
+ Examples
47
+ --------
48
+ .. code-block:: python
49
+
50
+ >>> import brainstate
51
+ >>>
52
+ >>> rng = brainstate.random.RandomState(0)
53
+ >>> restorer = brainstate.transform.RngRestore([rng])
54
+ >>> restorer.backup()
55
+ >>> _ = rng.random()
56
+ >>> restorer.restore()
38
57
  """
58
+ __module__ = 'brainstate.transform'
39
59
 
40
60
  def __init__(self, rngs: Sequence[RandomState]):
41
61
  """
42
- Initialize the RngRestore instance.
62
+ Initialize a restorer for the provided random states.
43
63
 
44
- Args:
45
- rngs (Sequence[RandomState]): A sequence of RandomState instances
46
- whose states will be managed.
64
+ Parameters
65
+ ----------
66
+ rngs : Sequence[RandomState]
67
+ Random states that will be backed up and restored.
47
68
  """
48
69
  self.rngs: Sequence[RandomState] = rngs
49
70
  self.rng_keys = []
50
71
 
51
72
  def backup(self):
52
73
  """
53
- Backup the current random key of the RandomState instances.
74
+ Cache the current key for each managed random state.
54
75
 
55
- This method saves the current value (state) of each RandomState
56
- instance in the rngs sequence.
76
+ Notes
77
+ -----
78
+ The cached keys persist until :meth:`restore` is called, after which the
79
+ internal cache is cleared.
57
80
  """
58
81
  self.rng_keys = [rng.value for rng in self.rngs]
59
82
 
60
83
  def restore(self):
61
84
  """
62
- Restore the random key of the RandomState instances.
85
+ Restore each random state to the cached key.
63
86
 
64
- This method restores each RandomState instance to its previously
65
- saved state. It raises an error if the number of saved keys doesn't
66
- match the number of RandomState instances.
67
-
68
- Raises:
69
- ValueError: If the number of saved random keys does not match
70
- the number of RandomState instances.
87
+ Raises
88
+ ------
89
+ ValueError
90
+ Raised when the number of stored keys does not match ``rngs``.
71
91
  """
72
92
  if len(self.rng_keys) != len(self.rngs):
73
93
  raise ValueError('The number of random keys does not match the number of random states.')
@@ -95,50 +115,50 @@ def _rng_backup(
95
115
  return wrapper
96
116
 
97
117
 
118
+ @set_module_as('brainstate.transform')
98
119
  def restore_rngs(
99
120
  fn: Callable = Missing(),
100
121
  rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
101
122
  ) -> Callable:
102
123
  """
103
- Decorator to backup and restore the random state before and after a function call.
104
-
105
- This function can be used as a decorator or called directly. It ensures that the
106
- random state of the specified RandomState instances is preserved across function calls,
107
- which is useful for maintaining reproducibility in stochastic operations.
124
+ Decorate a function so specified random states are restored after execution.
108
125
 
109
126
  Parameters
110
127
  ----------
111
128
  fn : Callable, optional
112
- The function to be wrapped. If not provided, the decorator can be used
113
- with parameters.
129
+ Function to wrap. When omitted, :func:`restore_rngs` returns a decorator
130
+ preconfigured with ``rngs``.
114
131
  rngs : Union[RandomState, Sequence[RandomState]], optional
115
- The random state(s) to be backed up and restored. This can be a single
116
- RandomState instance or a sequence of RandomState instances. If not provided,
117
- the default RandomState instance will be used.
132
+ Random states whose keys should be backed up before running ``fn`` and
133
+ restored afterwards. Defaults to :data:`brainstate.random.DEFAULT`.
118
134
 
119
135
  Returns
120
136
  -------
121
137
  Callable
122
- If `fn` is provided, returns the wrapped function that will backup the
123
- random state before execution and restore it afterwards.
124
- If `fn` is not provided, returns a partial function that can be used as
125
- a decorator with the specified `rngs`.
138
+ Wrapped callable that restores the random state or a partially applied
139
+ decorator depending on how :func:`restore_rngs` is used.
126
140
 
127
141
  Raises
128
142
  ------
129
143
  AssertionError
130
- If `rngs` is not a RandomState instance or a sequence of RandomState instances.
144
+ If ``rngs`` is neither a :class:`~brainstate.random.RandomState` instance nor
145
+ a sequence of such instances.
131
146
 
132
147
  Examples
133
148
  --------
134
- >>> @restore_rngs
135
- ... def my_random_function():
136
- ... return random.random()
137
-
138
- >>> rng = RandomState(42)
139
- >>> @restore_rngs(rngs=rng)
140
- ... def another_random_function():
141
- ... return rng.random()
149
+ .. code-block:: python
150
+
151
+ >>> import brainstate
152
+ >>>
153
+ >>> rng = brainstate.random.RandomState(0)
154
+ >>>
155
+ >>> @brainstate.transform.restore_rngs(rngs=rng)
156
+ ... def sample_pair():
157
+ ... first = rng.random()
158
+ ... second = rng.random()
159
+ ... return first, second
160
+ >>>
161
+ >>> assert sample_pair()[0] == sample_pair()[0]
142
162
  """
143
163
  if isinstance(fn, Missing):
144
164
  return functools.partial(restore_rngs, rngs=rngs)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -27,8 +27,39 @@ __all__ = [
27
27
  ]
28
28
 
29
29
 
30
- @set_module_as('brainstate.augment')
30
+ @set_module_as('brainstate.transform')
31
31
  def unvmap(x, op: str = 'any'):
32
+ """
33
+ Remove a leading vmap dimension by aggregating batched values.
34
+
35
+ Parameters
36
+ ----------
37
+ x : Any
38
+ Value produced inside a :func:`jax.vmap`-transformed function.
39
+ op : {'all', 'any', 'none', 'max'}, default='any'
40
+ Reduction to apply across the vmapped axis. ``'none'`` returns ``x`` without
41
+ reduction, while ``'max'`` computes the maximum element.
42
+
43
+ Returns
44
+ -------
45
+ Any
46
+ Result of applying the requested reduction with vmap metadata removed.
47
+
48
+ Raises
49
+ ------
50
+ ValueError
51
+ If ``op`` is not one of ``'all'``, ``'any'``, ``'none'``, or ``'max'``.
52
+
53
+ Examples
54
+ --------
55
+ .. code-block:: python
56
+
57
+ >>> import jax.numpy as jnp
58
+ >>> import brainstate
59
+ >>>
60
+ >>> xs = jnp.array([[True, False], [True, True]])
61
+ >>> brainstate.transform.unvmap(xs, op='all')
62
+ """
32
63
  if op == 'all':
33
64
  return unvmap_all(x)
34
65
  elif op == 'any':
@@ -47,7 +78,29 @@ unvmap_all_p = Primitive("unvmap_all")
47
78
 
48
79
 
49
80
  def unvmap_all(x):
50
- """As `jnp.all`, but ignores batch dimensions."""
81
+ """
82
+ Evaluate :func:`jax.numpy.all` while ignoring vmapped batch dimensions.
83
+
84
+ Parameters
85
+ ----------
86
+ x : Any
87
+ Input array or pytree produced under :func:`jax.vmap`.
88
+
89
+ Returns
90
+ -------
91
+ jax.Array
92
+ Scalar boolean result of ``jnp.all(x)``.
93
+
94
+ Examples
95
+ --------
96
+ .. code-block:: python
97
+
98
+ >>> import jax.numpy as jnp
99
+ >>> import brainstate
100
+ >>>
101
+ >>> values = jnp.array([[True, False], [True, True]])
102
+ >>> brainstate.transform.unvmap(values, op='all')
103
+ """
51
104
  return unvmap_all_p.bind(x)
52
105
 
53
106
 
@@ -78,7 +131,29 @@ unvmap_any_p = Primitive("unvmap_any")
78
131
 
79
132
 
80
133
  def unvmap_any(x):
81
- """As `jnp.any`, but ignores batch dimensions."""
134
+ """
135
+ Evaluate :func:`jax.numpy.any` while ignoring vmapped batch dimensions.
136
+
137
+ Parameters
138
+ ----------
139
+ x : Any
140
+ Input array or pytree produced under :func:`jax.vmap`.
141
+
142
+ Returns
143
+ -------
144
+ jax.Array
145
+ Scalar boolean result of ``jnp.any(x)``.
146
+
147
+ Examples
148
+ --------
149
+ .. code-block:: python
150
+
151
+ >>> import jax.numpy as jnp
152
+ >>> import brainstate
153
+ >>>
154
+ >>> values = jnp.array([[False, False], [False, True]])
155
+ >>> brainstate.transform.unvmap(values, op='any')
156
+ """
82
157
  return unvmap_any_p.bind(x)
83
158
 
84
159
 
@@ -109,7 +184,29 @@ unvmap_max_p = Primitive("unvmap_max")
109
184
 
110
185
 
111
186
  def unvmap_max(x):
112
- """As `jnp.max`, but ignores batch dimensions."""
187
+ """
188
+ Evaluate :func:`jax.numpy.max` while ignoring vmapped batch dimensions.
189
+
190
+ Parameters
191
+ ----------
192
+ x : Any
193
+ Input array or pytree produced under :func:`jax.vmap`.
194
+
195
+ Returns
196
+ -------
197
+ jax.Array
198
+ Scalar containing the maximum value of ``x`` with the same dtype.
199
+
200
+ Examples
201
+ --------
202
+ .. code-block:: python
203
+
204
+ >>> import jax.numpy as jnp
205
+ >>> import brainstate
206
+ >>>
207
+ >>> values = jnp.array([[1.0, 2.0], [0.5, 3.5]])
208
+ >>> brainstate.transform.unvmap(values, op='max')
209
+ """
113
210
  return unvmap_max_p.bind(x)
114
211
 
115
212