brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from collections.abc import Callable, Sequence
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+
22
+ from brainstate._compatible_import import to_concrete_aval, Tracer
23
+ from brainstate._utils import set_module_as
24
+ from ._error_if import jit_error_if
25
+ from ._make_jaxpr import StatefulFunction
26
+ from ._util import wrap_single_fun_in_multi_branches
27
+
28
+ __all__ = [
29
+ 'cond', 'switch', 'ifelse',
30
+ ]
31
+
32
+
33
+ @set_module_as('brainstate.transform')
34
+ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
35
+ """
36
+ Conditionally apply ``true_fun`` or ``false_fun``.
37
+
38
+ Parameters
39
+ ----------
40
+ pred : bool or array-like
41
+ Boolean scalar selecting which branch to execute. Numeric inputs are
42
+ treated as ``True`` when non-zero.
43
+ true_fun : Callable
44
+ Function that receives ``*operands`` when ``pred`` is ``True``.
45
+ false_fun : Callable
46
+ Function that receives ``*operands`` when ``pred`` is ``False``.
47
+ *operands : Any
48
+ Operands forwarded to either branch. May be any pytree of arrays,
49
+ scalars, or nested containers thereof.
50
+
51
+ Returns
52
+ -------
53
+ Any
54
+ Value returned by the selected branch with the same pytree structure
55
+ as produced by ``true_fun`` or ``false_fun``.
56
+
57
+ Notes
58
+ -----
59
+ Provided the arguments are correctly typed, :func:`cond` has semantics
60
+ that match the following Python implementation, where ``pred`` must be a
61
+ scalar:
62
+
63
+ .. code-block:: python
64
+
65
+ >>> def cond(pred, true_fun, false_fun, *operands):
66
+ ... if pred:
67
+ ... return true_fun(*operands)
68
+ ... return false_fun(*operands)
69
+
70
+ In contrast with :func:`jax.lax.select`, using :func:`cond` indicates that only
71
+ one branch runs (subject to compiler rewrites and optimizations). When
72
+ transformed with :func:`~jax.vmap` over a batch of predicates, :func:`cond` is
73
+ converted to :func:`~jax.lax.select`.
74
+
75
+ Examples
76
+ --------
77
+ .. code-block:: python
78
+
79
+ >>> import brainstate
80
+ >>>
81
+ >>> def branch_true(x):
82
+ ... return x + 1
83
+ >>>
84
+ >>> def branch_false(x):
85
+ ... return x - 1
86
+ >>>
87
+ >>> brainstate.transform.cond(True, branch_true, branch_false, 3)
88
+ """
89
+ if not (callable(true_fun) and callable(false_fun)):
90
+ raise TypeError("true_fun and false_fun arguments should be callable.")
91
+
92
+ if pred is None:
93
+ raise TypeError("cond predicate is None")
94
+ if isinstance(pred, Sequence) or np.ndim(pred) != 0:
95
+ raise TypeError(f"Pred must be a scalar, got {pred} of " +
96
+ (f"type {type(pred)}" if isinstance(pred, Sequence) else f"shape {np.shape(pred)}."))
97
+
98
+ # check pred
99
+ try:
100
+ pred_dtype = jax.dtypes.result_type(pred)
101
+ except TypeError as err:
102
+ raise TypeError("Pred type must be either boolean or number, got {}.".format(pred)) from err
103
+ if pred_dtype.kind != 'b':
104
+ if pred_dtype.kind in 'iuf':
105
+ pred = pred != 0
106
+ else:
107
+ raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
108
+
109
+ # not jit
110
+ if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(pred), Tracer):
111
+ if pred:
112
+ return true_fun(*operands)
113
+ else:
114
+ return false_fun(*operands)
115
+
116
+ # evaluate jaxpr
117
+ stateful_true = StatefulFunction(true_fun, name='cond:true').make_jaxpr(*operands)
118
+ stateful_false = StatefulFunction(false_fun, name='conda:false').make_jaxpr(*operands)
119
+
120
+ # state trace and state values
121
+ state_trace = (stateful_true.get_state_trace(*operands) +
122
+ stateful_false.get_state_trace(*operands))
123
+ read_state_vals = state_trace.get_read_state_values(True)
124
+ write_state_vals = state_trace.get_write_state_values(True)
125
+
126
+ # wrap the functions
127
+ true_fun = wrap_single_fun_in_multi_branches(
128
+ stateful_true, state_trace, read_state_vals, True, stateful_true.get_arg_cache_key(*operands)
129
+ )
130
+ false_fun = wrap_single_fun_in_multi_branches(
131
+ stateful_false, state_trace, read_state_vals, True, stateful_false.get_arg_cache_key(*operands)
132
+ )
133
+
134
+ # cond
135
+ write_state_vals, out = jax.lax.cond(pred, true_fun, false_fun, write_state_vals, *operands)
136
+
137
+ # assign the written state values and restore the read state values
138
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
139
+ return out
140
+
141
+
142
+ @set_module_as('brainstate.transform')
143
+ def switch(index, branches: Sequence[Callable], *operands):
144
+ """
145
+ Apply exactly one branch from ``branches`` based on ``index``.
146
+
147
+ Parameters
148
+ ----------
149
+ index : int or array-like
150
+ Scalar integer specifying which branch to execute.
151
+ branches : Sequence[Callable]
152
+ Sequence of callables; each receives ``*operands``.
153
+ *operands : Any
154
+ Operands forwarded to the selected branch. May be any pytree of arrays,
155
+ scalars, or nested containers thereof.
156
+
157
+ Returns
158
+ -------
159
+ Any
160
+ Value returned by the selected branch with the same pytree structure
161
+ as the selected callable.
162
+
163
+ Notes
164
+ -----
165
+ If ``index`` is out of bounds, it is clamped to ``[0, len(branches) - 1]``.
166
+ Conceptually, :func:`switch` behaves like:
167
+
168
+ .. code-block:: python
169
+
170
+ >>> def switch(index, branches, *operands):
171
+ ... safe_index = clamp(0, index, len(branches) - 1)
172
+ ... return branches[safe_index](*operands)
173
+
174
+ Internally this wraps XLA's `Conditional <https://www.tensorflow.org/xla/operation_semantics#conditional>`_
175
+ operator. When transformed with :func:`~jax.vmap` over a batch of predicates,
176
+ :func:`switch` is converted to :func:`~jax.lax.select`.
177
+
178
+ Examples
179
+ --------
180
+ .. code-block:: python
181
+
182
+ >>> import brainstate
183
+ >>>
184
+ >>> branches = (
185
+ ... lambda x: x - 1,
186
+ ... lambda x: x,
187
+ ... lambda x: x + 1,
188
+ ... )
189
+ >>>
190
+ >>> brainstate.transform.switch(2, branches, 3)
191
+ """
192
+ # check branches
193
+ if not all(callable(branch) for branch in branches):
194
+ raise TypeError("branches argument should be a sequence of callables.")
195
+
196
+ # check index
197
+ if len(np.shape(index)) != 0:
198
+ raise TypeError(f"Branch index must be scalar, got {index} of shape {np.shape(index)}.")
199
+ try:
200
+ index_dtype = jax.dtypes.result_type(index)
201
+ except TypeError as err:
202
+ msg = f"Index type must be an integer, got {index}."
203
+ raise TypeError(msg) from err
204
+ if index_dtype.kind not in 'iu':
205
+ raise TypeError(f"Index type must be an integer, got {index} as {index_dtype}")
206
+
207
+ # format branches
208
+ branches = tuple(branches)
209
+ if len(branches) == 0:
210
+ raise ValueError("Empty branch sequence")
211
+ elif len(branches) == 1:
212
+ return branches[0](*operands)
213
+
214
+ # format index
215
+ index = jax.lax.convert_element_type(index, np.int32)
216
+ lo = np.array(0, np.int32)
217
+ hi = np.array(len(branches) - 1, np.int32)
218
+ index = jax.lax.clamp(lo, index, hi)
219
+
220
+ # not jit
221
+ if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
222
+ return branches[int(index)](*operands)
223
+
224
+ # evaluate jaxpr
225
+ wrapped_branches = [StatefulFunction(branch, name='switch').make_jaxpr(*operands) for branch in branches]
226
+
227
+ # wrap the functions
228
+ state_trace = (wrapped_branches[0].get_state_trace(*operands) +
229
+ wrapped_branches[1].get_state_trace(*operands))
230
+ state_trace.merge(*[wrapped_branch.get_state_trace(*operands)
231
+ for wrapped_branch in wrapped_branches[2:]])
232
+ read_state_vals = state_trace.get_read_state_values(True)
233
+ write_state_vals = state_trace.get_write_state_values(True)
234
+ branches = [
235
+ wrap_single_fun_in_multi_branches(
236
+ wrapped_branch, state_trace, read_state_vals, True, wrapped_branch.get_arg_cache_key(*operands)
237
+ )
238
+ for wrapped_branch in wrapped_branches
239
+ ]
240
+
241
+ # switch
242
+ write_state_vals, out = jax.lax.switch(index, branches, write_state_vals, *operands)
243
+
244
+ # write back state values or restore them
245
+ state_trace.assign_state_vals_v2(read_state_vals, write_state_vals)
246
+ return out
247
+
248
+
249
+ @set_module_as('brainstate.transform')
250
+ def ifelse(conditions, branches, *operands, check_cond: bool = True):
251
+ """
252
+ Represent multi-way ``if``/``elif``/``else`` control flow.
253
+
254
+ Parameters
255
+ ----------
256
+ conditions : Sequence[bool] or Array
257
+ Sequence of mutually exclusive boolean predicates. When ``check_cond`` is
258
+ ``True``, exactly one entry must evaluate to ``True``.
259
+ branches : Sequence[Callable]
260
+ Sequence of branch callables evaluated lazily. Must have the same length as
261
+ ``conditions``, contain at least two callables, and each branch receives
262
+ ``*operands`` when selected.
263
+ *operands : Any
264
+ Operands forwarded to the selected branch as positional arguments.
265
+ check_cond : bool, default=True
266
+ Whether to verify that exactly one condition evaluates to ``True``.
267
+
268
+ Returns
269
+ -------
270
+ Any
271
+ Value produced by the branch corresponding to the active condition.
272
+
273
+ Notes
274
+ -----
275
+ When ``check_cond`` is ``True``, exactly one condition must evaluate to ``True``.
276
+ A common pattern is to make the final condition ``True`` to encode a default
277
+ branch.
278
+
279
+ Examples
280
+ --------
281
+ .. code-block:: python
282
+
283
+ >>> import brainstate
284
+ >>>
285
+ >>> def describe(a):
286
+ ... return brainstate.transform.ifelse(
287
+ ... conditions=[a > 5, a > 0, True],
288
+ ... branches=[
289
+ ... lambda: "greater than five",
290
+ ... lambda: "positive",
291
+ ... lambda: "non-positive",
292
+ ... ],
293
+ ... )
294
+ >>>
295
+ >>> describe(7)
296
+ >>> describe(-1)
297
+ """
298
+ # check branches
299
+ if not all(callable(branch) for branch in branches):
300
+ raise TypeError("branches argument should be a sequence of callables.")
301
+
302
+ # format branches
303
+ branches = tuple(branches)
304
+ if len(branches) == 0:
305
+ raise ValueError("Empty branch sequence")
306
+ elif len(branches) == 1:
307
+ return branches[0](*operands)
308
+ if len(conditions) != len(branches):
309
+ raise ValueError("The number of conditions should be equal to the number of branches.")
310
+
311
+ # format index
312
+ conditions = jnp.asarray(conditions, np.int32)
313
+ if check_cond:
314
+ jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
315
+ index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
316
+ return switch(index, branches, *operands)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -172,10 +172,10 @@ class TestIfElse(unittest.TestCase):
172
172
  a >= 5 and a < 10,
173
173
  a >= 10],
174
174
  branches=[lambda: 1,
175
- lambda: 2,
176
- lambda: 3,
177
- lambda: 4,
178
- lambda: 5])
175
+ lambda: 2,
176
+ lambda: 3,
177
+ lambda: 4,
178
+ lambda: 5])
179
179
 
180
180
  self.assertTrue(f(3) == 3)
181
181
  self.assertTrue(f(1) == 2)
@@ -189,10 +189,10 @@ class TestIfElse(unittest.TestCase):
189
189
  jnp.logical_and(a <= 2, a > 0),
190
190
  a <= 0],
191
191
  [lambda _: 1,
192
- lambda _: 2,
193
- lambda _: 3,
194
- lambda _: 4,
195
- lambda _: 5, ],
192
+ lambda _: 2,
193
+ lambda _: 3,
194
+ lambda _: 4,
195
+ lambda _: 5, ],
196
196
  a)
197
197
  return jax.vmap(f)(operands)
198
198
 
@@ -212,8 +212,8 @@ class TestIfElse(unittest.TestCase):
212
212
  def F3(x):
213
213
  return brainstate.compile.ifelse((x >= 10, jnp.logical_and(x >= 0, x < 10), x < 0),
214
214
  [lambda x: x,
215
- lambda x: x ** 2,
216
- lambda x: x ** 4, ],
215
+ lambda x: x ** 2,
216
+ lambda x: x ** 4, ],
217
217
  x)
218
218
 
219
219
  self.assertTrue(jax.grad(F3)(9.0) == 18.)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -43,7 +43,7 @@ def _error_msg(msg, *arg, **kwargs):
43
43
  raise ValueError(msg)
44
44
 
45
45
 
46
- @set_module_as('brainstate.compile')
46
+ @set_module_as('brainstate.transform')
47
47
  def jit_error_if(
48
48
  pred,
49
49
  error: Union[Callable, str],
@@ -53,32 +53,34 @@ def jit_error_if(
53
53
  """
54
54
  Check errors in a jit function.
55
55
 
56
+ Parameters
57
+ ----------
58
+ pred : bool or Array
59
+ The boolean prediction.
60
+ error : callable or str
61
+ The error function, which raise errors, or a string indicating the error message.
62
+ *err_args
63
+ The arguments which passed into the error function.
64
+ **err_kwargs
65
+ The keywords which passed into the error function.
66
+
56
67
  Examples
57
68
  --------
58
-
59
69
  It can give a function which receive arguments that passed from the JIT variables and raise errors.
60
70
 
61
- >>> def error(x):
62
- >>> raise ValueError(f'error {x}')
63
- >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
64
- >>> jit_error_if(x.sum() < 5., error, x)
71
+ .. code-block:: python
65
72
 
66
- Or, it can be a simple string message.
73
+ >>> def error(x):
74
+ ... raise ValueError(f'error {x}')
75
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
76
+ >>> jit_error_if(x.sum() < 5., error, x)
67
77
 
68
- >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
69
- >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
78
+ Or, it can be a simple string message.
70
79
 
80
+ .. code-block:: python
71
81
 
72
- Parameters
73
- ----------
74
- pred: bool, Array
75
- The boolean prediction.
76
- error: callable, str
77
- The error function, which raise errors, or a string indicating the error message.
78
- err_args:
79
- The arguments which passed into `err_f`.
80
- err_kwargs:
81
- The keywords which passed into `err_f`.
82
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
83
+ >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
82
84
  """
83
85
  if isinstance(error, str):
84
86
  error = partial(_error_msg, error)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,145 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import functools
17
+ from typing import Any, TypeVar, Callable, Sequence, Union
18
+
19
+ import jax
20
+
21
+ from brainstate import random
22
+ from brainstate._utils import set_module_as
23
+ from brainstate.graph import Node, flatten, unflatten
24
+ from ._random import restore_rngs
25
+
26
+ __all__ = [
27
+ 'abstract_init',
28
+ ]
29
+
30
+ A = TypeVar('A')
31
+
32
+
33
+ @set_module_as('brainstate.transform')
34
+ def abstract_init(
35
+ fn: Callable[..., A],
36
+ *args: Any,
37
+ rngs: Union[random.RandomState, Sequence[random.RandomState]] = random.DEFAULT,
38
+ **kwargs: Any,
39
+ ) -> A:
40
+ """
41
+ Compute the shape/dtype of ``fn`` without any FLOPs.
42
+
43
+ This function evaluates the shape and dtype of the output of a function without
44
+ actually executing the computational operations. It's particularly useful for
45
+ initializing neural network models to understand their structure and parameter
46
+ shapes without performing expensive computations.
47
+
48
+ Parameters
49
+ ----------
50
+ fn : callable
51
+ The function whose output shape should be evaluated.
52
+ *args
53
+ Positional argument tuple of arrays, scalars, or (nested) standard
54
+ Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
55
+ those types. Since only the ``shape`` and ``dtype`` attributes are
56
+ accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
57
+ that duck-types as ndarrays (note however that duck-typed objects cannot
58
+ be namedtuples because those are treated as standard Python containers).
59
+ rngs : RandomState or sequence of RandomState, default random.DEFAULT
60
+ A :class:`RandomState` or a sequence of :class:`RandomState` objects
61
+ representing the random number generators to use. If not provided, the
62
+ default random number generator will be used.
63
+ **kwargs
64
+ Keyword argument dict of arrays, scalars, or (nested) standard
65
+ Python containers (pytrees) of those types. As in ``args``, array values
66
+ need only be duck-typed to have ``shape`` and ``dtype`` attributes.
67
+
68
+ Returns
69
+ -------
70
+ A
71
+ A nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves,
72
+ representing the structure and shape/dtype information of the function output.
73
+
74
+ Examples
75
+ --------
76
+ Basic usage with neural network initialization:
77
+
78
+ .. code-block:: python
79
+
80
+ >>> import brainstate
81
+ >>> import jax.numpy as jnp
82
+ >>>
83
+ >>> class MLP:
84
+ ... def __init__(self, n_in, n_mid, n_out):
85
+ ... self.dense1 = brainstate.nn.Linear(n_in, n_mid)
86
+ ... self.dense2 = brainstate.nn.Linear(n_mid, n_out)
87
+ >>>
88
+ >>> # Get shape information without actual computation
89
+ >>> model_shape = brainstate.transform.abstract_init(lambda: MLP(1, 2, 3))
90
+
91
+ With function arguments:
92
+
93
+ .. code-block:: python
94
+
95
+ >>> def create_model(input_size, hidden_size, output_size):
96
+ ... return brainstate.nn.Sequential([
97
+ ... brainstate.nn.Linear(input_size, hidden_size),
98
+ ... brainstate.nn.ReLU(),
99
+ ... brainstate.nn.Linear(hidden_size, output_size)
100
+ ... ])
101
+ >>>
102
+ >>> # Abstract initialization with arguments
103
+ >>> model_shape = brainstate.transform.abstract_init(
104
+ ... create_model, 784, 256, 10
105
+ ... )
106
+
107
+ Using custom random number generators:
108
+
109
+ .. code-block:: python
110
+
111
+ >>> import brainstate.random as random
112
+ >>>
113
+ >>> # Create custom RNG
114
+ >>> rng = random.RandomState(42)
115
+ >>>
116
+ >>> def init_with_custom_weights():
117
+ ... return brainstate.nn.Linear(10, 5)
118
+ >>>
119
+ >>> model_shape = brainstate.transform.abstract_init(
120
+ ... init_with_custom_weights, rngs=rng
121
+ ... )
122
+
123
+ Evaluating function with array inputs:
124
+
125
+ .. code-block:: python
126
+
127
+ >>> def model_forward(x):
128
+ ... layer = brainstate.nn.Linear(x.shape[-1], 128)
129
+ ... return layer(x)
130
+ >>>
131
+ >>> # Use ShapeDtypeStruct to represent input without actual data
132
+ >>> input_shape = jax.ShapeDtypeStruct((32, 784), jnp.float32)
133
+ >>> output_shape = brainstate.transform.abstract_init(model_forward, input_shape)
134
+ """
135
+
136
+ @functools.wraps(fn)
137
+ @restore_rngs(rngs=rngs)
138
+ def _eval_shape_fn(*args_, **kwargs_):
139
+ out = fn(*args_, **kwargs_)
140
+ assert isinstance(out, Node), 'The output of the function must be Node'
141
+ graph_def, treefy_states = flatten(out)
142
+ return graph_def, treefy_states
143
+
144
+ graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
145
+ return unflatten(graph_def_, treefy_states_)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.