brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -22,7 +22,6 @@ from brainstate._utils import set_module_as
22
22
  from ._loop_collect_return import _bounded_while_loop
23
23
  from ._make_jaxpr import StatefulFunction
24
24
  from ._util import wrap_single_fun_in_multi_branches_while_loop as wrap_fn
25
- from ._util import write_back_state_values
26
25
 
27
26
  X = TypeVar('X')
28
27
  Y = TypeVar('Y')
@@ -35,7 +34,7 @@ __all__ = [
35
34
  ]
36
35
 
37
36
 
38
- @set_module_as('brainstate.compile')
37
+ @set_module_as('brainstate.transform')
39
38
  def while_loop(
40
39
  cond_fun: Callable[[T], BooleanNumeric],
41
40
  body_fun: Callable[[T], T],
@@ -50,13 +49,15 @@ def while_loop(
50
49
 
51
50
  while_loop :: (a -> Bool) -> (a -> a) -> a -> a
52
51
 
53
- The semantics of ``while_loop`` are given by this Python implementation::
52
+ The semantics of ``while_loop`` are given by this Python implementation:
54
53
 
55
- def while_loop(cond_fun, body_fun, init_val):
56
- val = init_val
57
- while cond_fun(val):
58
- val = body_fun(val)
59
- return val
54
+ .. code-block:: python
55
+
56
+ >>> def while_loop(cond_fun, body_fun, init_val):
57
+ ... val = init_val
58
+ ... while cond_fun(val):
59
+ ... val = body_fun(val)
60
+ ... return val
60
61
 
61
62
  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
62
63
  to a single WhileOp. That makes it useful for reducing compilation times
@@ -75,16 +76,55 @@ def while_loop(
75
76
  ``while_loop`` is not reverse-mode differentiable because XLA computations
76
77
  require static bounds on memory requirements.
77
78
 
78
- Args:
79
- cond_fun: function of type ``a -> Bool``.
80
- body_fun: function of type ``a -> a``.
81
- init_val: value of type ``a``, a type that can be a scalar, array, or any
79
+ Parameters
80
+ ----------
81
+ cond_fun : callable
82
+ Function of type ``a -> Bool``.
83
+ body_fun : callable
84
+ Function of type ``a -> a``.
85
+ init_val : T
86
+ Value of type ``a``, a type that can be a scalar, array, or any
82
87
  pytree (nested Python tuple/list/dict) thereof, representing the initial
83
88
  loop carry value.
84
89
 
85
- Returns:
86
- The output from the final iteration of body_fun, of type ``a``.
87
-
90
+ Returns
91
+ -------
92
+ T
93
+ The output from the final iteration of body_fun, of type ``a``.
94
+
95
+ Examples
96
+ --------
97
+ Basic while loop operation:
98
+
99
+ .. code-block:: python
100
+
101
+ >>> import brainstate
102
+ >>> import jax.numpy as jnp
103
+ >>>
104
+ >>> def cond_fn(val):
105
+ ... return val < 10
106
+ >>>
107
+ >>> def body_fn(val):
108
+ ... return val + 1
109
+ >>>
110
+ >>> result = brainstate.transform.while_loop(cond_fn, body_fn, 0)
111
+ >>> # result will be 10
112
+
113
+ While loop with array state:
114
+
115
+ .. code-block:: python
116
+
117
+ >>> def cond_fn(state):
118
+ ... return jnp.sum(state) < 100
119
+ >>>
120
+ >>> def body_fn(state):
121
+ ... return state * 1.1
122
+ >>>
123
+ >>> init_state = jnp.array([1.0, 2.0, 3.0])
124
+ >>> final_state = brainstate.transform.while_loop(cond_fn, body_fn, init_state)
125
+
126
+ References
127
+ ----------
88
128
  .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
89
129
  """
90
130
  if not (callable(body_fun) and callable(cond_fun)):
@@ -103,24 +143,28 @@ def while_loop(
103
143
  # evaluate jaxpr
104
144
  stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
105
145
  stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
106
- if len(stateful_cond.get_write_states()) != 0:
146
+ cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
147
+ body_cache_key = stateful_body.get_arg_cache_key(init_val)
148
+ if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
107
149
  raise ValueError("while_loop: cond_fun should not have any write states.")
108
150
 
109
151
  # state trace and state values
110
- state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
152
+ state_trace = (stateful_cond.get_state_trace_by_cache(cond_cache_key) +
153
+ stateful_body.get_state_trace_by_cache(body_cache_key))
111
154
  read_state_vals = state_trace.get_read_state_values(True)
112
155
  write_state_vals = state_trace.get_write_state_values(True)
113
- new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
114
- new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
156
+ new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
157
+ new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
115
158
 
116
159
  # while_loop
117
160
  state_vals, final_val = jax.lax.while_loop(new_cond_fn, new_body_fn, (write_state_vals, init_val))
118
161
 
119
162
  # write back state values or restore them
120
- write_back_state_values(state_trace, read_state_vals, state_vals)
163
+ state_trace.assign_state_vals_v2(read_state_vals, state_vals)
121
164
  return final_val
122
165
 
123
166
 
167
+ @set_module_as('brainstate.transform')
124
168
  def bounded_while_loop(
125
169
  cond_fun: Callable[[T], BooleanNumeric],
126
170
  body_fun: Callable[[T], T],
@@ -138,18 +182,70 @@ def bounded_while_loop(
138
182
  even if the condition function is never false. The function is implemented
139
183
  using a scan operation, so it is reverse-mode differentiable.
140
184
 
141
- Args:
142
- cond_fun: A function of type ``a -> Bool``.
143
- body_fun: A function of type ``a -> a``.
144
- init_val: The initial value of type ``a``.
145
- max_steps: A bound on the maximum number of steps, after which the loop
185
+ Parameters
186
+ ----------
187
+ cond_fun : callable
188
+ A function of type ``a -> Bool``.
189
+ body_fun : callable
190
+ A function of type ``a -> a``.
191
+ init_val : T
192
+ The initial value of type ``a``.
193
+ max_steps : int
194
+ A bound on the maximum number of steps, after which the loop
146
195
  terminates unconditionally.
147
- base: Run time will increase slightly as `base` increases. Compilation time will
196
+ base : int, default 16
197
+ Run time will increase slightly as `base` increases. Compilation time will
148
198
  decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
149
199
  (Which happens as `base` increases.)
150
200
 
151
- Returns:
152
- The final value, as if computed by a `lax.while_loop`.
201
+ Returns
202
+ -------
203
+ T
204
+ The final value, as if computed by a `lax.while_loop`.
205
+
206
+ Examples
207
+ --------
208
+ Basic bounded while loop:
209
+
210
+ .. code-block:: python
211
+
212
+ >>> import brainstate
213
+ >>> import jax.numpy as jnp
214
+ >>>
215
+ >>> def cond_fn(val):
216
+ ... return val < 1000 # This might never be false
217
+ >>>
218
+ >>> def body_fn(val):
219
+ ... return val * 2
220
+ >>>
221
+ >>> # Loop will terminate after at most 10 steps
222
+ >>> result = brainstate.transform.bounded_while_loop(
223
+ ... cond_fn, body_fn, 1, max_steps=10
224
+ ... )
225
+
226
+ Bounded while loop with custom base:
227
+
228
+ .. code-block:: python
229
+
230
+ >>> # Use a smaller base for potentially faster compilation
231
+ >>> result = brainstate.transform.bounded_while_loop(
232
+ ... cond_fn, body_fn, 1, max_steps=100, base=8
233
+ ... )
234
+
235
+ Bounded while loop with array state:
236
+
237
+ .. code-block:: python
238
+
239
+ >>> def cond_fn(state):
240
+ ... return jnp.max(state) < 100
241
+ >>>
242
+ >>> def body_fn(state):
243
+ ... return state + jnp.array([1.0, 2.0, 3.0])
244
+ >>>
245
+ >>> init_state = jnp.array([0.0, 0.0, 0.0])
246
+ >>> final_state = brainstate.transform.bounded_while_loop(
247
+ ... cond_fn, body_fn, init_state, max_steps=50
248
+ ... )
153
249
  """
154
250
 
155
251
  # checking
@@ -162,15 +258,18 @@ def bounded_while_loop(
162
258
  # evaluate jaxpr
163
259
  stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
164
260
  stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
165
- if len(stateful_cond.get_write_states()) != 0:
261
+ cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
262
+ body_cache_key = stateful_body.get_arg_cache_key(init_val)
263
+ if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
166
264
  raise ValueError("while_loop: cond_fun should not have any write states.")
167
265
 
168
266
  # state trace and state values
169
- state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
267
+ state_trace = (stateful_cond.get_state_trace(init_val) +
268
+ stateful_body.get_state_trace(init_val))
170
269
  read_state_vals = state_trace.get_read_state_values(True)
171
270
  write_state_vals = state_trace.get_write_state_values(True)
172
- new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
173
- new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
271
+ new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
272
+ new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
174
273
 
175
274
  # initial value
176
275
  init_val = (write_state_vals, init_val)
@@ -180,5 +279,5 @@ def bounded_while_loop(
180
279
  state_vals, val = _bounded_while_loop(new_cond_fn, new_body_fn, init_val, rounded_max_steps, base, None)
181
280
 
182
281
  # write back state values or restore them
183
- write_back_state_values(state_trace, read_state_vals, state_vals)
282
+ state_trace.assign_state_vals_v2(read_state_vals, state_vals)
184
283
  return val
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
16
16
 
17
17
  from unittest import TestCase
18
18
 
19
- import brainstate
19
+ import brainstate
20
20
 
21
21
 
22
22
  class TestWhileLoop(TestCase):