brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,7 +24,6 @@ from brainstate._compatible_import import Device
24
24
  from brainstate._utils import set_module_as
25
25
  from brainstate.typing import Missing
26
26
  from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
27
- from ._util import write_back_state_values
28
27
 
29
28
  __all__ = ['jit']
30
29
 
@@ -33,6 +32,8 @@ class JittedFunction(Callable):
33
32
  """
34
33
  A wrapped version of ``fun``, set up for just-in-time compilation.
35
34
  """
35
+ __module__ = 'brainstate.transform'
36
+
36
37
  origin_fun: Callable # the original function
37
38
  stateful_fun: StatefulFunction # the stateful function for extracting states
38
39
  jitted_fun: jax.stages.Wrapped # the jitted function
@@ -67,8 +68,8 @@ def _get_jitted_fun(
67
68
  static_argnums=static_argnums,
68
69
  static_argnames=static_argnames,
69
70
  abstracted_axes=abstracted_axes,
70
- cache_type='jit',
71
- name='jit'
71
+ name='jit',
72
+ return_only_write=True
72
73
  )
73
74
  jit_fun = jax.jit(
74
75
  fun.jaxpr_call,
@@ -92,14 +93,14 @@ def _get_jitted_fun(
92
93
  return fun.fun(*args, **params)
93
94
 
94
95
  # compile the function and get the state trace
95
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
96
+ state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
96
97
  read_state_vals = state_trace.get_read_state_values(True)
97
98
 
98
99
  # call the jitted function
99
100
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
100
101
 
101
102
  # write the state values back to the states
102
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
103
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
103
104
  return outs
104
105
 
105
106
  def clear_cache():
@@ -139,7 +140,7 @@ def _get_jitted_fun(
139
140
  A ``Lowered`` instance representing the lowering.
140
141
  """
141
142
  # compile the function and get the state trace
142
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
143
+ state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
143
144
  read_state_vals = state_trace.get_read_state_values(replace_writen=True)
144
145
  write_state_vals = state_trace.get_write_state_values(replace_read=True)
145
146
 
@@ -147,7 +148,7 @@ def _get_jitted_fun(
147
148
  ret = jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
148
149
 
149
150
  # write the state values back to the states
150
- write_back_state_values(state_trace, read_state_vals, write_state_vals)
151
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
151
152
  return ret
152
153
 
153
154
  jitted_fun: JittedFunction
@@ -176,7 +177,7 @@ def _get_jitted_fun(
176
177
  return jitted_fun
177
178
 
178
179
 
179
- @set_module_as('brainstate.compile')
180
+ @set_module_as('brainstate.transform')
180
181
  def jit(
181
182
  fun: Callable | Missing = Missing(),
182
183
  in_shardings=sharding_impls.UNSPECIFIED,
@@ -195,9 +196,12 @@ def jit(
195
196
  """
196
197
  Sets up ``fun`` for just-in-time compilation with XLA.
197
198
 
198
- Args:
199
- fun: Function to be jitted.
200
- in_shardings: Pytree of structure matching that of arguments to ``fun``,
199
+ Parameters
200
+ ----------
201
+ fun : callable or Missing, optional
202
+ Function to be jitted.
203
+ in_shardings : pytree, optional
204
+ Pytree of structure matching that of arguments to ``fun``,
201
205
  with all actual arguments replaced by resource assignment specifications.
202
206
  It is also valid to specify a pytree prefix (e.g. one value in place of a
203
207
  whole subtree), in which case the leaves get broadcast to all values in
@@ -208,26 +212,29 @@ def jit(
208
212
  if the sharding cannot be inferred.
209
213
 
210
214
  The valid resource assignment specifications are:
211
- - :py:class:`XLACompatibleSharding`, which will decide how the value
212
- will be partitioned. With this, using a mesh context manager is not
213
- required.
214
- - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
215
- it wants.
216
- For in_shardings, JAX will mark is as replicated but this behavior
217
- can change in the future.
218
- For out_shardings, we will rely on the XLA GSPMD partitioner to
219
- determine the output shardings.
215
+
216
+ - :py:class:`XLACompatibleSharding`, which will decide how the value
217
+ will be partitioned. With this, using a mesh context manager is not
218
+ required.
219
+ - :py:obj:`None`, will give JAX the freedom to choose whatever sharding
220
+ it wants.
221
+ For in_shardings, JAX will mark is as replicated but this behavior
222
+ can change in the future.
223
+ For out_shardings, we will rely on the XLA GSPMD partitioner to
224
+ determine the output shardings.
220
225
 
221
226
  The size of every dimension has to be a multiple of the total number of
222
227
  resources assigned to it. This is similar to pjit's in_shardings.
223
- out_shardings: Like ``in_shardings``, but specifies resource
228
+ out_shardings : pytree, optional
229
+ Like ``in_shardings``, but specifies resource
224
230
  assignment for function outputs. This is similar to pjit's
225
231
  out_shardings.
226
232
 
227
233
  The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
228
234
  will use GSPMD's sharding propagation to figure out what the sharding of the
229
235
  output(s) should be.
230
- static_argnums: An optional int or collection of ints that specify which
236
+ static_argnums : int or sequence of int, optional
237
+ An optional int or collection of ints that specify which
231
238
  positional arguments to treat as static (compile-time constant).
232
239
  Operations that only depend on static arguments will be constant-folded in
233
240
  Python (during tracing), and so the corresponding argument values can be
@@ -248,12 +255,8 @@ def jit(
248
255
  provided, ``inspect.signature`` is not used, and only actual
249
256
  parameters listed in either ``static_argnums`` or ``static_argnames`` will
250
257
  be treated as static.
251
- static_argnames: An optional string or collection of strings specifying
252
- which named arguments are treated as static (compile-time constant).
253
- Operations that only depend on static arguments will be constant-folded in
254
- Python (during tracing), and so the corresponding argument values can be
255
- any Python object.
256
- donate_argnums: Specify which positional argument buffers are "donated" to
258
+ donate_argnums : int or sequence of int, optional
259
+ Specify which positional argument buffers are "donated" to
257
260
  the computation. It is safe to donate argument buffers if you no longer
258
261
  need them once the computation has finished. In some cases XLA can make
259
262
  use of donated buffers to reduce the amount of memory needed to perform a
@@ -274,38 +277,88 @@ def jit(
274
277
 
275
278
  For more details on buffer donation see the
276
279
  `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
277
- donate_argnames: An optional string or collection of strings specifying
280
+ static_argnames : str or sequence of str, optional
281
+ An optional string or collection of strings specifying
282
+ which named arguments are treated as static (compile-time constant).
283
+ Operations that only depend on static arguments will be constant-folded in
284
+ Python (during tracing), and so the corresponding argument values can be
285
+ any Python object.
286
+ donate_argnames : str or iterable of str, optional
287
+ An optional string or collection of strings specifying
278
288
  which named arguments are donated to the computation. See the
279
289
  comment on ``donate_argnums`` for details. If not
280
290
  provided but ``donate_argnums`` is set, the default is based on calling
281
291
  ``inspect.signature(fun)`` to find corresponding named arguments.
282
- keep_unused: If `False` (the default), arguments that JAX determines to be
292
+ keep_unused : bool, default False
293
+ If `False` (the default), arguments that JAX determines to be
283
294
  unused by `fun` *may* be dropped from resulting compiled XLA executables.
284
295
  Such arguments will not be transferred to the device nor provided to the
285
296
  underlying executable. If `True`, unused arguments will not be pruned.
286
- device: This is an experimental feature and the API is likely to change.
297
+ device : Device, optional
298
+ This is an experimental feature and the API is likely to change.
287
299
  Optional, the Device the jitted function will run on. (Available devices
288
300
  can be retrieved via :py:func:`jax.devices`.) The default is inherited
289
301
  from XLA's DeviceAssignment logic and is usually to use
290
302
  ``jax.devices()[0]``.
291
- backend: This is an experimental feature and the API is likely to change.
303
+ backend : str, optional
304
+ This is an experimental feature and the API is likely to change.
292
305
  Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
293
306
  ``'tpu'``.
294
- inline: Specify whether this function should be inlined into enclosing
307
+ inline : bool, default False
308
+ Specify whether this function should be inlined into enclosing
295
309
  jaxprs (rather than being represented as an application of the xla_call
296
310
  primitive with its own subjaxpr). Default False.
297
- abstracted_axes:
298
-
299
- Returns:
300
- A wrapped version of ``fun``, set up for just-in-time compilation.
301
- The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
302
- and has the following attributes and methods:
303
-
304
- - ``stateful_fun``: the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
305
- - ``origin_fun(*args, **kwargs)``: the original function
306
- - ``jitted_fun(*args, **kwargs)``: the jitted function
307
- - ``clear_cache(*args, **kwargs)``: clear the cache of the jitted function
308
-
311
+ abstracted_axes : Any, optional
312
+ Abstracted axes specification.
313
+ **kwargs
314
+ Additional keyword arguments passed to the underlying JAX jit function.
315
+
316
+ Returns
317
+ -------
318
+ JittedFunction or callable
319
+ A wrapped version of ``fun``, set up for just-in-time compilation.
320
+ The returned object is a :py:class:`JittedFunction` that can be called with the same arguments
321
+ and has the following attributes and methods:
322
+
323
+ - ``stateful_fun`` : the stateful function for extracting states, an instance of :py:class:`StatefulFunction`.
324
+ - ``origin_fun(*args, **kwargs)`` : the original function
325
+ - ``jitted_fun(*args, **kwargs)`` : the jitted function
326
+ - ``clear_cache(*args, **kwargs)`` : clear the cache of the jitted function
327
+
328
+ Examples
329
+ --------
330
+ Basic usage with a simple function:
331
+
332
+ .. code-block:: python
333
+
334
+ >>> import brainstate
335
+ >>> import jax.numpy as jnp
336
+ >>>
337
+ >>> @brainstate.transform.jit
338
+ ... def f(x):
339
+ ... return x ** 2
340
+ >>>
341
+ >>> result = f(jnp.array([1, 2, 3]))
342
+
343
+ Using static arguments:
344
+
345
+ .. code-block:: python
346
+
347
+ >>> @brainstate.transform.jit(static_argnums=(1,))
348
+ ... def g(x, n):
349
+ ... return x ** n
350
+ >>>
351
+ >>> result = g(jnp.array([1, 2, 3]), 2)
352
+
353
+ Manual jitting:
354
+
355
+ .. code-block:: python
356
+
357
+ >>> def h(x):
358
+ ... return x * 2
359
+ >>>
360
+ >>> jitted_h = brainstate.transform.jit(h)
361
+ >>> result = jitted_h(jnp.array([1, 2, 3]))
309
362
  """
310
363
 
311
364
  if isinstance(fun, Missing):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -41,12 +41,12 @@ class TestJIT(unittest.TestCase):
41
41
 
42
42
  print(fun1(1.))
43
43
  key = fun1.stateful_fun.get_arg_cache_key(1.)
44
- self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
44
+ self.assertTrue(len(fun1.stateful_fun.get_states_by_cache(key)) == 2)
45
45
 
46
46
  x = bst.random.randn(10)
47
47
  print(fun1(x))
48
48
  key = fun1.stateful_fun.get_arg_cache_key(x)
49
- self.assertTrue(len(fun1.stateful_fun.get_states(key)) == 2)
49
+ self.assertTrue(len(fun1.stateful_fun.get_states_by_cache(key)) == 2)
50
50
 
51
51
  def test_kwargs(self):
52
52
  a = bst.State(bst.random.randn(10))