brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -35,7 +35,7 @@ import jax
35
35
 
36
36
  from brainstate._state import State
37
37
  from brainstate._utils import set_module_as
38
- from brainstate.compile._make_jaxpr import StatefulFunction
38
+ from brainstate.transform._make_jaxpr import StatefulFunction
39
39
  from brainstate.typing import PyTree, Missing
40
40
  from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
41
41
 
@@ -153,16 +153,90 @@ class GradientTransform(PrettyRepr):
153
153
  It allows for flexible configuration of gradient computation with respect to specified states
154
154
  and function arguments.
155
155
 
156
- Attributes:
157
- target (Callable): The function to be transformed.
158
- stateful_target (StatefulFunction): A wrapper around the target function for state management.
159
- raw_argnums (Optional[Union[int, Sequence[int]]]): The original argnums specified by the user.
160
- true_argnums (Union[int, Tuple[int, ...]]): The adjusted argnums used internally.
161
- return_value (bool): Whether to return the function's value along with gradients.
162
- has_aux (bool): Whether the function returns auxiliary data.
156
+ Parameters
157
+ ----------
158
+ target : callable
159
+ The function to be transformed.
160
+ transform : callable
161
+ The transformation function to apply.
162
+ grad_states : State, sequence of State, or dict of State, optional
163
+ States to compute gradients for.
164
+ argnums : int or sequence of int, optional
165
+ Indices of arguments to differentiate with respect to.
166
+ return_value : bool, default False
167
+ Whether to return the function's value along with gradients.
168
+ has_aux : bool, default False
169
+ Whether the function returns auxiliary data.
170
+ transform_params : dict, optional
171
+ Additional parameters for the transformation function.
172
+ check_states : bool, default True
173
+ Whether to check that all grad_states are found in the function.
174
+
175
+ Attributes
176
+ ----------
177
+ target : callable
178
+ The function to be transformed.
179
+ stateful_target : StatefulFunction
180
+ A wrapper around the target function for state management.
181
+ raw_argnums : int, sequence of int, or None
182
+ The original argnums specified by the user.
183
+ true_argnums : int or tuple of int
184
+ The adjusted argnums used internally.
185
+ return_value : bool
186
+ Whether to return the function's value along with gradients.
187
+ has_aux : bool
188
+ Whether the function returns auxiliary data.
189
+
190
+ Examples
191
+ --------
192
+ Basic gradient computation with states:
193
+
194
+ .. code-block:: python
195
+
196
+ >>> import brainstate
197
+ >>> import jax.numpy as jnp
198
+ >>>
199
+ >>> # Create states
200
+ >>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
201
+ >>> bias = brainstate.State(jnp.array([0.5, -0.5]))
202
+ >>>
203
+ >>> def loss_fn(x):
204
+ ... y = x @ weight.value + bias.value
205
+ ... return jnp.sum(y ** 2)
206
+ >>>
207
+ >>> # Create gradient transform
208
+ >>> grad_transform = brainstate.transform.GradientTransform(
209
+ ... target=loss_fn,
210
+ ... transform=jax.grad,
211
+ ... grad_states=[weight, bias]
212
+ ... )
213
+ >>>
214
+ >>> # Compute gradients
215
+ >>> x = jnp.array([1.0, 2.0])
216
+ >>> grads = grad_transform(x)
217
+
218
+ With function arguments and auxiliary data:
219
+
220
+ .. code-block:: python
221
+
222
+ >>> def loss_fn_with_aux(x, scale):
223
+ ... y = x @ weight.value + bias.value
224
+ ... loss = jnp.sum((y * scale) ** 2)
225
+ ... return loss, {"predictions": y, "scale": scale}
226
+ >>>
227
+ >>> grad_transform = brainstate.transform.GradientTransform(
228
+ ... target=loss_fn_with_aux,
229
+ ... transform=jax.grad,
230
+ ... grad_states=[weight, bias],
231
+ ... argnums=[0, 1], # gradient w.r.t x and scale
232
+ ... has_aux=True,
233
+ ... return_value=True
234
+ ... )
235
+ >>>
236
+ >>> grads, loss_value, aux_data = grad_transform(x, 2.0)
163
237
  """
164
238
 
165
- __module__ = "brainstate.augment"
239
+ __module__ = "brainstate.transform"
166
240
 
167
241
  def __init__(
168
242
  self,
@@ -178,17 +252,29 @@ class GradientTransform(PrettyRepr):
178
252
  """
179
253
  Initialize a ``GradientTransform`` instance.
180
254
 
181
- Args:
182
- target (Callable): The function to be transformed.
183
- transform (TransformFn): The transformation function to apply.
184
- grad_states (Optional[Union[State, Sequence[State], Dict[str, State]]]): States to compute gradients for.
185
- argnums (Optional[Union[int, Sequence[int]]]): Indices of arguments to differentiate with respect to.
186
- return_value (bool): Whether to return the function's value along with gradients.
187
- has_aux (bool): Whether the function returns auxiliary data.
188
- transform_params (Optional[Dict[str, Any]]): Additional parameters for the transformation function.
189
-
190
- Raises:
191
- TypeError: If any grad_states are not State instances.
255
+ Parameters
256
+ ----------
257
+ target : callable
258
+ The function to be transformed.
259
+ transform : callable
260
+ The transformation function to apply.
261
+ grad_states : State, sequence of State, or dict of State, optional
262
+ States to compute gradients for.
263
+ argnums : int or sequence of int, optional
264
+ Indices of arguments to differentiate with respect to.
265
+ return_value : bool, default False
266
+ Whether to return the function's value along with gradients.
267
+ has_aux : bool, default False
268
+ Whether the function returns auxiliary data.
269
+ transform_params : dict, optional
270
+ Additional parameters for the transformation function.
271
+ check_states : bool, default True
272
+ Whether to check that all grad_states are found in the function.
273
+
274
+ Raises
275
+ ------
276
+ TypeError
277
+ If any grad_states are not State instances.
192
278
  """
193
279
  # gradient variables
194
280
  if isinstance(grad_states, dict):
@@ -221,7 +307,7 @@ class GradientTransform(PrettyRepr):
221
307
  # target
222
308
  assert callable(target), "The target should be a callable object."
223
309
  self.target = target
224
- self.stateful_target = StatefulFunction(target, name='gradient')
310
+ self.stateful_target = StatefulFunction(target, name='gradient', return_only_write=False)
225
311
 
226
312
  # transform
227
313
  grad_setting = dict() if transform_params is None else transform_params
@@ -307,8 +393,7 @@ class GradientTransform(PrettyRepr):
307
393
  Returns:
308
394
  Tuple: A tuple containing updated state values and the function output.
309
395
  """
310
- cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
311
- state_trace = self.stateful_target.get_state_trace(cache)
396
+ state_trace = self.stateful_target.get_state_trace(*args, **kwargs, compile_if_miss=True)
312
397
  state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
313
398
  state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
314
399
  return state_vals, out
@@ -403,12 +488,18 @@ class GradientTransform(PrettyRepr):
403
488
  """
404
489
  Compute gradients by calling the transformed function.
405
490
 
406
- Args:
407
- *args: Positional arguments to pass to the target function.
408
- **kwargs: Keyword arguments to pass to the target function.
409
-
410
- Returns:
411
- Union[Gradient, Tuple]: The computed gradients, potentially including function value and/or auxiliary data.
491
+ Parameters
492
+ ----------
493
+ *args
494
+ Positional arguments to pass to the target function.
495
+ **kwargs
496
+ Keyword arguments to pass to the target function.
497
+
498
+ Returns
499
+ -------
500
+ Gradient or tuple
501
+ The computed gradients, potentially including function value and/or auxiliary data.
502
+ The exact return structure depends on the settings of return_value and has_aux.
412
503
  """
413
504
 
414
505
  # TODO: support jax.disable_jit()
@@ -418,79 +509,135 @@ class GradientTransform(PrettyRepr):
418
509
  cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
419
510
 
420
511
  # apply the gradient transformation
421
- state_trace = self.stateful_target.get_state_trace(cache)
512
+ state_trace = self.stateful_target.get_state_trace_by_cache(cache)
422
513
  rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
423
514
 
424
515
  # analyze and return the results
425
516
  return self._return(rets, state_trace)
426
517
 
427
518
 
428
- _doc_of_return = '''
519
+ @set_module_as("brainstate.transform")
520
+ def grad(
521
+ fun: Callable = Missing(),
522
+ grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
523
+ argnums: Optional[Union[int, Sequence[int]]] = None,
524
+ holomorphic: Optional[bool] = False,
525
+ allow_int: Optional[bool] = False,
526
+ has_aux: Optional[bool] = None,
527
+ return_value: Optional[bool] = False,
528
+ unit_aware: bool = False,
529
+ check_states: bool = True,
530
+ ) -> GradientTransform | Callable[[Callable], GradientTransform]:
531
+ """
532
+ Compute the gradient of a scalar-valued function with respect to its arguments.
533
+
429
534
 
430
535
  1. When ``grad_states`` is None
536
+
431
537
  - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
432
538
  - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
433
539
  - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
434
540
  - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
435
541
  2. When ``grad_states`` is not None and ``argnums`` is None
542
+
436
543
  - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
437
544
  - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
438
545
  - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
439
546
  - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
440
547
  3. When ``grad_states`` is not None and ``argnums`` is not None
548
+
441
549
  - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
442
550
  - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
443
551
  - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
444
552
  - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
445
553
 
446
- '''
447
-
448
-
449
- @set_module_as("brainstate.augment")
450
- def grad(
451
- fun: Callable = Missing(),
452
- grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
453
- argnums: Optional[Union[int, Sequence[int]]] = None,
454
- holomorphic: Optional[bool] = False,
455
- allow_int: Optional[bool] = False,
456
- has_aux: Optional[bool] = None,
457
- return_value: Optional[bool] = False,
458
- unit_aware: bool = False,
459
- check_states: bool = True,
460
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
461
- """
462
- Compute the gradient of a scalar-valued function with respect to its arguments.
463
554
 
464
- %s
465
-
466
- Args:
467
- fun: callable. the scalar-valued function to be differentiated.
468
- allow_int: (bool) optional. Whether to allow differentiating with respect to
469
- integer valued inputs. The gradient of an integer input will have a trivial
470
- vector-space dtype (float0). Default False.
471
- holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
472
- Default False.
473
- grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
474
- in fun to take their gradients.
475
- fun: the scalar-valued function to be differentiated.
476
- argnums: (int or tuple of ints) optional. Specifies which positional
477
- argument(s) to differentiate with respect to.
478
- has_aux: (bool) optional. Indicates whether fun returns a pair where the
479
- first element is considered the output of the mathematical function to be
480
- differentiated and the second element is auxiliary data. Default False.
481
- return_value: (bool) optional. Indicates whether to return the value of the
482
- function along with the gradient. Default False.
483
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
484
- mode. Default False.
485
-
486
- Returns:
487
- A function which computes the gradient of fun. The function takes the same
488
- arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
489
- the function returns a pair where the first element is the gradient and the
490
- second element is the auxiliary data. If `return_value` is True, the function
491
- returns a pair where the first element is the gradient and the second element
492
- is the value of the function.
555
+ Parameters
556
+ ----------
557
+ fun : callable, optional
558
+ The scalar-valued function to be differentiated.
559
+ grad_states : State, sequence of State, or dict of State, optional
560
+ The variables in fun to take their gradients.
561
+ argnums : int or sequence of int, optional
562
+ Specifies which positional argument(s) to differentiate with respect to.
563
+ holomorphic : bool, default False
564
+ Whether fun is promised to be holomorphic.
565
+ allow_int : bool, default False
566
+ Whether to allow differentiating with respect to
567
+ integer valued inputs. The gradient of an integer input will have a trivial
568
+ vector-space dtype (float0).
569
+ has_aux : bool, optional
570
+ Indicates whether fun returns a pair where the
571
+ first element is considered the output of the mathematical function to be
572
+ differentiated and the second element is auxiliary data.
573
+ return_value : bool, default False
574
+ Indicates whether to return the value of the
575
+ function along with the gradient.
576
+ unit_aware : bool, default False
577
+ Whether to return the gradient in the unit-aware mode.
578
+ check_states : bool, default True
579
+ Whether to check that all grad_states are found in the function.
493
580
 
581
+ Returns
582
+ -------
583
+ GradientTransform or callable
584
+ A function which computes the gradient of fun. The function takes the same
585
+ arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
586
+ the function returns a pair where the first element is the gradient and the
587
+ second element is the auxiliary data. If `return_value` is True, the function
588
+ returns a pair where the first element is the gradient and the second element
589
+ is the value of the function.
590
+
591
+ Examples
592
+ --------
593
+ Basic gradient computation:
594
+
595
+ .. code-block:: python
596
+
597
+ >>> import brainstate
598
+ >>> import jax.numpy as jnp
599
+ >>>
600
+ >>> # Simple function gradient
601
+ >>> def f(x):
602
+ ... return jnp.sum(x ** 2)
603
+ >>>
604
+ >>> grad_f = brainstate.transform.grad(f)
605
+ >>> x = jnp.array([1.0, 2.0, 3.0])
606
+ >>> gradient = grad_f(x)
607
+
608
+ Gradient with respect to states:
609
+
610
+ .. code-block:: python
611
+
612
+ >>> # Create states
613
+ >>> weight = brainstate.State(jnp.array([1.0, 2.0]))
614
+ >>> bias = brainstate.State(jnp.array([0.5]))
615
+ >>>
616
+ >>> def loss_fn(x):
617
+ ... prediction = jnp.dot(x, weight.value) + bias.value
618
+ ... return prediction ** 2
619
+ >>>
620
+ >>> # Compute gradients with respect to states
621
+ >>> grad_fn = brainstate.transform.grad(loss_fn, grad_states=[weight, bias])
622
+ >>> x = jnp.array([1.0, 2.0])
623
+ >>> state_grads = grad_fn(x)
624
+
625
+ With auxiliary data and return value:
626
+
627
+ .. code-block:: python
628
+
629
+ >>> def loss_with_aux(x):
630
+ ... prediction = jnp.dot(x, weight.value) + bias.value
631
+ ... loss = prediction ** 2
632
+ ... return loss, {"prediction": prediction}
633
+ >>>
634
+ >>> grad_fn = brainstate.transform.grad(
635
+ ... loss_with_aux,
636
+ ... grad_states=[weight, bias],
637
+ ... has_aux=True,
638
+ ... return_value=True
639
+ ... )
640
+ >>> grads, loss_value, aux_data = grad_fn(x)
494
641
  """
495
642
  if isinstance(fun, Missing):
496
643
  def transform(fun) -> GradientTransform:
@@ -519,10 +666,7 @@ def grad(
519
666
  )
520
667
 
521
668
 
522
- grad.__doc__ = grad.__doc__ % _doc_of_return
523
-
524
-
525
- @set_module_as("brainstate.augment")
669
+ @set_module_as("brainstate.transform")
526
670
  def vector_grad(
527
671
  func: Callable = Missing(),
528
672
  grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
@@ -532,34 +676,91 @@ def vector_grad(
532
676
  unit_aware: bool = False,
533
677
  check_states: bool = True,
534
678
  ) -> GradientTransform | Callable[[Callable], GradientTransform]:
535
- """Take vector-valued gradients for function ``func``.
679
+ """
680
+ Take vector-valued gradients for function ``func``.
536
681
 
537
- Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
682
+ Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
538
683
  the returns in this function are different for different argument settings.
539
684
 
540
- %s
685
+
686
+ 1. When ``grad_states`` is None
687
+
688
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
689
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
690
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
691
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
692
+ 2. When ``grad_states`` is not None and ``argnums`` is None
693
+
694
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
695
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
696
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
697
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
698
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
699
+
700
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
701
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
702
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
703
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
704
+
541
705
 
542
706
  Parameters
543
707
  ----------
544
- func: Callable
708
+ func : callable, optional
545
709
  Function whose gradient is to be computed.
546
- grad_states : optional, ArrayType, sequence of ArrayType, dict
710
+ grad_states : State, sequence of State, or dict of State, optional
547
711
  The variables in ``func`` to take their gradients.
548
- has_aux: optional, bool
712
+ argnums : int or sequence of int, optional
713
+ Specifies which positional argument(s) to differentiate with respect to.
714
+ return_value : bool, default False
715
+ Whether to return the loss value.
716
+ has_aux : bool, optional
549
717
  Indicates whether ``fun`` returns a pair where the
550
718
  first element is considered the output of the mathematical function to be
551
- differentiated and the second element is auxiliary data. Default False.
552
- return_value : bool
553
- Whether return the loss value.
554
- argnums: Optional, integer or sequence of integers. Specifies which
555
- positional argument(s) to differentiate with respect to (default ``0``).
556
- unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
557
- mode. Default False.
719
+ differentiated and the second element is auxiliary data.
720
+ unit_aware : bool, default False
721
+ Whether to return the gradient in the unit-aware mode.
722
+ check_states : bool, default True
723
+ Whether to check that all grad_states are found in the function.
558
724
 
559
725
  Returns
560
726
  -------
561
- func : GradientTransform
727
+ GradientTransform or callable
562
728
  The vector gradient function.
729
+
730
+ Examples
731
+ --------
732
+ Basic vector gradient computation:
733
+
734
+ .. code-block:: python
735
+
736
+ >>> import brainstate
737
+ >>> import jax.numpy as jnp
738
+ >>>
739
+ >>> # Vector-valued function
740
+ >>> def f(x):
741
+ ... return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])
742
+ >>>
743
+ >>> vector_grad_f = brainstate.transform.vector_grad(f)
744
+ >>> x = jnp.array([2.0, 3.0])
745
+ >>> gradients = vector_grad_f(x) # Shape: (3, 2)
746
+
747
+ With states:
748
+
749
+ .. code-block:: python
750
+
751
+ >>> params = brainstate.State(jnp.array([1.0, 2.0]))
752
+ >>>
753
+ >>> def model(x):
754
+ ... return jnp.array([
755
+ ... x * params.value[0],
756
+ ... x**2 * params.value[1]
757
+ ... ])
758
+ >>>
759
+ >>> vector_grad_fn = brainstate.transform.vector_grad(
760
+ ... model, grad_states=[params]
761
+ ... )
762
+ >>> x = 3.0
763
+ >>> param_grads = vector_grad_fn(x)
563
764
  """
564
765
 
565
766
  if isinstance(func, Missing):
@@ -588,10 +789,7 @@ def vector_grad(
588
789
  )
589
790
 
590
791
 
591
- vector_grad.__doc__ = vector_grad.__doc__ % _doc_of_return
592
-
593
-
594
- @set_module_as("brainstate.augment")
792
+ @set_module_as("brainstate.transform")
595
793
  def jacrev(
596
794
  fun: Callable,
597
795
  grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
@@ -610,7 +808,26 @@ def jacrev(
610
808
  computation on functions and class functions. Moreover, it supports returning
611
809
  value ("return_value") and returning auxiliary data ("has_aux").
612
810
 
613
- %s
811
+
812
+ 1. When ``grad_states`` is None
813
+
814
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
815
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
816
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
817
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
818
+ 2. When ``grad_states`` is not None and ``argnums`` is None
819
+
820
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
821
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
822
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
823
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
824
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
825
+
826
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
827
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
828
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
829
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
830
+
614
831
 
615
832
 
616
833
  Parameters
@@ -657,12 +874,10 @@ def jacrev(
657
874
  )
658
875
 
659
876
 
660
- jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
661
-
662
877
  jacobian = jacrev
663
878
 
664
879
 
665
- @set_module_as("brainstate.augment")
880
+ @set_module_as("brainstate.transform")
666
881
  def jacfwd(
667
882
  func: Callable,
668
883
  grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
@@ -679,7 +894,26 @@ def jacfwd(
679
894
  computation on functions and class functions. Moreover, it supports returning
680
895
  value ("return_value") and returning auxiliary data ("has_aux").
681
896
 
682
- %s
897
+
898
+ 1. When ``grad_states`` is None
899
+
900
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
901
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
902
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
903
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
904
+ 2. When ``grad_states`` is not None and ``argnums`` is None
905
+
906
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
907
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
908
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
909
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
910
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
911
+
912
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
913
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
914
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
915
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
916
+
683
917
 
684
918
  Parameters
685
919
  ----------
@@ -717,10 +951,7 @@ def jacfwd(
717
951
  )
718
952
 
719
953
 
720
- jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
721
-
722
-
723
- @set_module_as("brainstate.augment")
954
+ @set_module_as("brainstate.transform")
724
955
  def hessian(
725
956
  func: Callable,
726
957
  grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
@@ -734,7 +965,26 @@ def hessian(
734
965
  """
735
966
  Hessian of ``func`` as a dense array.
736
967
 
737
- %s
968
+
969
+ 1. When ``grad_states`` is None
970
+
971
+ - ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
972
+ - ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
973
+ - ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
974
+ - ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
975
+ 2. When ``grad_states`` is not None and ``argnums`` is None
976
+
977
+ - ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
978
+ - ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
979
+ - ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
980
+ - ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
981
+ 3. When ``grad_states`` is not None and ``argnums`` is not None
982
+
983
+ - ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
984
+ - ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
985
+ - ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
986
+ - ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
987
+
738
988
 
739
989
  Parameters
740
990
  ----------
@@ -773,6 +1023,3 @@ def hessian(
773
1023
  transform_params=dict(holomorphic=holomorphic),
774
1024
  check_states=check_states
775
1025
  )
776
-
777
-
778
- hessian.__doc__ = hessian.__doc__ % _doc_of_return
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ import jax.numpy as jnp
24
24
  import pytest
25
25
 
26
26
  import brainstate
27
- from brainstate.augment._autograd import _jacfwd
27
+ from brainstate.transform._autograd import _jacfwd
28
28
 
29
29
 
30
30
  class TestPureFuncGrad(unittest.TestCase):