brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from typing import Callable, Tuple, Union
18
+
19
+ import jax
20
+
21
+ from brainstate._utils import set_module_as
22
+ from brainstate.typing import Missing
23
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple
24
+
25
+ __all__ = [
26
+ 'checkpoint',
27
+ 'remat'
28
+ ]
29
+
30
+
31
+ @set_module_as('brainstate.transform')
32
+ def checkpoint(
33
+ fun: Callable = Missing(),
34
+ *,
35
+ prevent_cse: bool = True,
36
+ policy: Callable[..., bool] | None = None,
37
+ static_argnums: int | Tuple[int, ...] = (),
38
+ ) -> Union[Callable, Callable[[Callable], Callable]]:
39
+ """Make ``fun`` recompute internal linearization points when differentiated.
40
+
41
+ This decorator wraps :func:`jax.checkpoint` (also exposed as :func:`jax.remat`) to
42
+ rematerialize intermediate values during reverse-mode automatic differentiation.
43
+ It allows trading additional computation for reduced peak memory when evaluating
44
+ functions with :func:`jax.grad`, :func:`jax.vjp`, or :func:`jax.linearize`.
45
+
46
+ Parameters
47
+ ----------
48
+ fun : Callable, optional
49
+ Function whose autodiff evaluation strategy should use rematerialization.
50
+ Positional and keyword arguments may be arrays, scalars, or arbitrarily
51
+ nested Python containers of those types.
52
+ prevent_cse : bool, default True
53
+ Whether to prevent common-subexpression-elimination (CSE) optimizations in
54
+ the generated HLO. Disabling CSE is usually necessary under
55
+ :func:`jax.jit`/:func:`jax.pmap` so that rematerialization is not optimized
56
+ away. Set to ``False`` when decorating code inside control-flow primitives
57
+ (for example, :func:`jax.lax.scan`) where CSE is already handled safely.
58
+ policy : Callable[..., bool], optional
59
+ Callable drawn from :mod:`jax.checkpoint_policies` that decides which
60
+ primitive outputs may be saved as residuals instead of being recomputed. The
61
+ callable receives type-level information about a primitive application and
62
+ returns ``True`` when the corresponding value can be cached.
63
+ static_argnums : int or tuple of int, optional
64
+ Indices of arguments to treat as static during tracing. Marking arguments as
65
+ static can avoid :class:`jax.errors.ConcretizationTypeError` at the expense
66
+ of additional retracing when those arguments change.
67
+
68
+ Returns
69
+ -------
70
+ callable
71
+ A function with the same input/output behaviour as ``fun``. When
72
+ differentiated, it rematerializes intermediate linearization points instead
73
+ of storing them, reducing memory pressure at the cost of extra computation.
74
+
75
+ Notes
76
+ -----
77
+ Reverse-mode autodiff normally stores all linearization points during the
78
+ forward pass so that they can be reused during the backward pass. This storage
79
+ can dominate memory usage, particularly on accelerators where memory accesses
80
+ are expensive. Applying ``checkpoint`` causes those values to be recomputed on
81
+ the backward pass from the saved inputs instead of being cached.
82
+
83
+ The decorator can be composed recursively to express sophisticated
84
+ rematerialization strategies. For functions with data-dependent Python control
85
+ flow, specify ``static_argnums`` (and, if needed,
86
+ :func:`jax.ensure_compile_time_eval`) so that branching conditions are evaluated
87
+ at trace time.
88
+
89
+ Examples
90
+ --------
91
+ Use :func:`jax.checkpoint` to trade computation for memory:
92
+
93
+ .. code-block:: python
94
+
95
+ >>> import brainstate
96
+ >>> import jax.numpy as jnp
97
+
98
+ >>> @brainstate.transform.checkpoint
99
+ ... def g(x):
100
+ ... y = jnp.sin(x)
101
+ ... z = jnp.sin(y)
102
+ ... return z
103
+
104
+ >>> value, grad = jax.value_and_grad(g)(2.0)
105
+
106
+ Compose checkpoints recursively to control the rematerialization granularity:
107
+
108
+ .. code-block:: python
109
+
110
+ >>> import jax
111
+
112
+ >>> def recursive_checkpoint(funs):
113
+ ... if len(funs) == 1:
114
+ ... return funs[0]
115
+ ... if len(funs) == 2:
116
+ ... f1, f2 = funs
117
+ ... return lambda x: f1(f2(x))
118
+ ... f1 = recursive_checkpoint(funs[: len(funs) // 2])
119
+ ... f2 = recursive_checkpoint(funs[len(funs) // 2 :])
120
+ ... return lambda x: f1(jax.checkpoint(f2)(x))
121
+
122
+ When control flow depends on argument values, mark the relevant arguments as
123
+ static:
124
+
125
+ .. code-block:: python
126
+
127
+ >>> from functools import partial
128
+ >>> import jax
129
+ >>> import brainstate
130
+
131
+ >>> @brainstate.transform.checkpoint(static_argnums=(1,))
132
+ ... def foo(x, is_training):
133
+ ... if is_training:
134
+ ... ...
135
+ ... else:
136
+ ... ...
137
+
138
+ >>> @brainstate.transform.checkpoint(static_argnums=(1,))
139
+ ... def foo_with_eval(x, y):
140
+ ... with jax.ensure_compile_time_eval():
141
+ ... y_pos = y > 0
142
+ ... if y_pos:
143
+ ... ...
144
+ ... else:
145
+ ... ...
146
+
147
+ As an alternative to ``static_argnums``, compute values that drive control flow
148
+ outside the decorated function and close over them in the JAX-traced callable.
149
+ """
150
+ if isinstance(fun, Missing):
151
+ return lambda f: checkpoint(f, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums)
152
+
153
+ static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
154
+ fun = StatefulFunction(fun, static_argnums=static_argnums, name='checkpoint')
155
+ checkpointed_fun = jax.checkpoint(
156
+ fun.jaxpr_call,
157
+ prevent_cse=prevent_cse,
158
+ policy=policy,
159
+ static_argnums=tuple(i + 1 for i in static_argnums)
160
+ )
161
+
162
+ @functools.wraps(fun.fun)
163
+ def remat_fun(*args, **params):
164
+ # compile the function and get the state trace
165
+ state_trace = fun.get_state_trace(*args, **params, compile_if_miss=True)
166
+ read_state_vals = state_trace.get_read_state_values(True)
167
+ # call the checkpointed function
168
+ write_state_vals, outs = checkpointed_fun(state_trace.get_state_values(), *args, **params)
169
+ # write the state values back to the states
170
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
171
+ return outs
172
+
173
+ return remat_fun
174
+
175
+
176
+ remat = checkpoint
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.