brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -35,7 +35,7 @@ import jax
|
|
35
35
|
|
36
36
|
from brainstate._state import State
|
37
37
|
from brainstate._utils import set_module_as
|
38
|
-
from brainstate.
|
38
|
+
from brainstate.transform._make_jaxpr import StatefulFunction
|
39
39
|
from brainstate.typing import PyTree, Missing
|
40
40
|
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
|
41
41
|
|
@@ -153,16 +153,90 @@ class GradientTransform(PrettyRepr):
|
|
153
153
|
It allows for flexible configuration of gradient computation with respect to specified states
|
154
154
|
and function arguments.
|
155
155
|
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
156
|
+
Parameters
|
157
|
+
----------
|
158
|
+
target : callable
|
159
|
+
The function to be transformed.
|
160
|
+
transform : callable
|
161
|
+
The transformation function to apply.
|
162
|
+
grad_states : State, sequence of State, or dict of State, optional
|
163
|
+
States to compute gradients for.
|
164
|
+
argnums : int or sequence of int, optional
|
165
|
+
Indices of arguments to differentiate with respect to.
|
166
|
+
return_value : bool, default False
|
167
|
+
Whether to return the function's value along with gradients.
|
168
|
+
has_aux : bool, default False
|
169
|
+
Whether the function returns auxiliary data.
|
170
|
+
transform_params : dict, optional
|
171
|
+
Additional parameters for the transformation function.
|
172
|
+
check_states : bool, default True
|
173
|
+
Whether to check that all grad_states are found in the function.
|
174
|
+
|
175
|
+
Attributes
|
176
|
+
----------
|
177
|
+
target : callable
|
178
|
+
The function to be transformed.
|
179
|
+
stateful_target : StatefulFunction
|
180
|
+
A wrapper around the target function for state management.
|
181
|
+
raw_argnums : int, sequence of int, or None
|
182
|
+
The original argnums specified by the user.
|
183
|
+
true_argnums : int or tuple of int
|
184
|
+
The adjusted argnums used internally.
|
185
|
+
return_value : bool
|
186
|
+
Whether to return the function's value along with gradients.
|
187
|
+
has_aux : bool
|
188
|
+
Whether the function returns auxiliary data.
|
189
|
+
|
190
|
+
Examples
|
191
|
+
--------
|
192
|
+
Basic gradient computation with states:
|
193
|
+
|
194
|
+
.. code-block:: python
|
195
|
+
|
196
|
+
>>> import brainstate
|
197
|
+
>>> import jax.numpy as jnp
|
198
|
+
>>>
|
199
|
+
>>> # Create states
|
200
|
+
>>> weight = brainstate.State(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
|
201
|
+
>>> bias = brainstate.State(jnp.array([0.5, -0.5]))
|
202
|
+
>>>
|
203
|
+
>>> def loss_fn(x):
|
204
|
+
... y = x @ weight.value + bias.value
|
205
|
+
... return jnp.sum(y ** 2)
|
206
|
+
>>>
|
207
|
+
>>> # Create gradient transform
|
208
|
+
>>> grad_transform = brainstate.transform.GradientTransform(
|
209
|
+
... target=loss_fn,
|
210
|
+
... transform=jax.grad,
|
211
|
+
... grad_states=[weight, bias]
|
212
|
+
... )
|
213
|
+
>>>
|
214
|
+
>>> # Compute gradients
|
215
|
+
>>> x = jnp.array([1.0, 2.0])
|
216
|
+
>>> grads = grad_transform(x)
|
217
|
+
|
218
|
+
With function arguments and auxiliary data:
|
219
|
+
|
220
|
+
.. code-block:: python
|
221
|
+
|
222
|
+
>>> def loss_fn_with_aux(x, scale):
|
223
|
+
... y = x @ weight.value + bias.value
|
224
|
+
... loss = jnp.sum((y * scale) ** 2)
|
225
|
+
... return loss, {"predictions": y, "scale": scale}
|
226
|
+
>>>
|
227
|
+
>>> grad_transform = brainstate.transform.GradientTransform(
|
228
|
+
... target=loss_fn_with_aux,
|
229
|
+
... transform=jax.grad,
|
230
|
+
... grad_states=[weight, bias],
|
231
|
+
... argnums=[0, 1], # gradient w.r.t x and scale
|
232
|
+
... has_aux=True,
|
233
|
+
... return_value=True
|
234
|
+
... )
|
235
|
+
>>>
|
236
|
+
>>> grads, loss_value, aux_data = grad_transform(x, 2.0)
|
163
237
|
"""
|
164
238
|
|
165
|
-
__module__ = "brainstate.
|
239
|
+
__module__ = "brainstate.transform"
|
166
240
|
|
167
241
|
def __init__(
|
168
242
|
self,
|
@@ -178,17 +252,29 @@ class GradientTransform(PrettyRepr):
|
|
178
252
|
"""
|
179
253
|
Initialize a ``GradientTransform`` instance.
|
180
254
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
255
|
+
Parameters
|
256
|
+
----------
|
257
|
+
target : callable
|
258
|
+
The function to be transformed.
|
259
|
+
transform : callable
|
260
|
+
The transformation function to apply.
|
261
|
+
grad_states : State, sequence of State, or dict of State, optional
|
262
|
+
States to compute gradients for.
|
263
|
+
argnums : int or sequence of int, optional
|
264
|
+
Indices of arguments to differentiate with respect to.
|
265
|
+
return_value : bool, default False
|
266
|
+
Whether to return the function's value along with gradients.
|
267
|
+
has_aux : bool, default False
|
268
|
+
Whether the function returns auxiliary data.
|
269
|
+
transform_params : dict, optional
|
270
|
+
Additional parameters for the transformation function.
|
271
|
+
check_states : bool, default True
|
272
|
+
Whether to check that all grad_states are found in the function.
|
273
|
+
|
274
|
+
Raises
|
275
|
+
------
|
276
|
+
TypeError
|
277
|
+
If any grad_states are not State instances.
|
192
278
|
"""
|
193
279
|
# gradient variables
|
194
280
|
if isinstance(grad_states, dict):
|
@@ -221,7 +307,7 @@ class GradientTransform(PrettyRepr):
|
|
221
307
|
# target
|
222
308
|
assert callable(target), "The target should be a callable object."
|
223
309
|
self.target = target
|
224
|
-
self.stateful_target = StatefulFunction(target, name='gradient')
|
310
|
+
self.stateful_target = StatefulFunction(target, name='gradient', return_only_write=False)
|
225
311
|
|
226
312
|
# transform
|
227
313
|
grad_setting = dict() if transform_params is None else transform_params
|
@@ -307,8 +393,7 @@ class GradientTransform(PrettyRepr):
|
|
307
393
|
Returns:
|
308
394
|
Tuple: A tuple containing updated state values and the function output.
|
309
395
|
"""
|
310
|
-
|
311
|
-
state_trace = self.stateful_target.get_state_trace(cache)
|
396
|
+
state_trace = self.stateful_target.get_state_trace(*args, **kwargs, compile_if_miss=True)
|
312
397
|
state_vals = self._merge_state_vals(grad_vals, other_vals, state_trace)
|
313
398
|
state_vals, out = self.stateful_target.jaxpr_call(state_vals, *args, **kwargs)
|
314
399
|
return state_vals, out
|
@@ -403,12 +488,18 @@ class GradientTransform(PrettyRepr):
|
|
403
488
|
"""
|
404
489
|
Compute gradients by calling the transformed function.
|
405
490
|
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
491
|
+
Parameters
|
492
|
+
----------
|
493
|
+
*args
|
494
|
+
Positional arguments to pass to the target function.
|
495
|
+
**kwargs
|
496
|
+
Keyword arguments to pass to the target function.
|
497
|
+
|
498
|
+
Returns
|
499
|
+
-------
|
500
|
+
Gradient or tuple
|
501
|
+
The computed gradients, potentially including function value and/or auxiliary data.
|
502
|
+
The exact return structure depends on the settings of return_value and has_aux.
|
412
503
|
"""
|
413
504
|
|
414
505
|
# TODO: support jax.disable_jit()
|
@@ -418,79 +509,135 @@ class GradientTransform(PrettyRepr):
|
|
418
509
|
cache = self.stateful_target.get_arg_cache_key(*args, **kwargs)
|
419
510
|
|
420
511
|
# apply the gradient transformation
|
421
|
-
state_trace = self.stateful_target.
|
512
|
+
state_trace = self.stateful_target.get_state_trace_by_cache(cache)
|
422
513
|
rets = self._transform(*self._split_state_vals(state_trace), *args, **kwargs)
|
423
514
|
|
424
515
|
# analyze and return the results
|
425
516
|
return self._return(rets, state_trace)
|
426
517
|
|
427
518
|
|
428
|
-
|
519
|
+
@set_module_as("brainstate.transform")
|
520
|
+
def grad(
|
521
|
+
fun: Callable = Missing(),
|
522
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
523
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
524
|
+
holomorphic: Optional[bool] = False,
|
525
|
+
allow_int: Optional[bool] = False,
|
526
|
+
has_aux: Optional[bool] = None,
|
527
|
+
return_value: Optional[bool] = False,
|
528
|
+
unit_aware: bool = False,
|
529
|
+
check_states: bool = True,
|
530
|
+
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
531
|
+
"""
|
532
|
+
Compute the gradient of a scalar-valued function with respect to its arguments.
|
533
|
+
|
429
534
|
|
430
535
|
1. When ``grad_states`` is None
|
536
|
+
|
431
537
|
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
432
538
|
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
433
539
|
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
434
540
|
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
435
541
|
2. When ``grad_states`` is not None and ``argnums`` is None
|
542
|
+
|
436
543
|
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
437
544
|
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
438
545
|
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
439
546
|
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
440
547
|
3. When ``grad_states`` is not None and ``argnums`` is not None
|
548
|
+
|
441
549
|
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
442
550
|
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
443
551
|
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
444
552
|
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
445
553
|
|
446
|
-
'''
|
447
|
-
|
448
|
-
|
449
|
-
@set_module_as("brainstate.augment")
|
450
|
-
def grad(
|
451
|
-
fun: Callable = Missing(),
|
452
|
-
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
453
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
454
|
-
holomorphic: Optional[bool] = False,
|
455
|
-
allow_int: Optional[bool] = False,
|
456
|
-
has_aux: Optional[bool] = None,
|
457
|
-
return_value: Optional[bool] = False,
|
458
|
-
unit_aware: bool = False,
|
459
|
-
check_states: bool = True,
|
460
|
-
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
461
|
-
"""
|
462
|
-
Compute the gradient of a scalar-valued function with respect to its arguments.
|
463
554
|
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
the function returns a pair where the first element is the gradient and the
|
490
|
-
second element is the auxiliary data. If `return_value` is True, the function
|
491
|
-
returns a pair where the first element is the gradient and the second element
|
492
|
-
is the value of the function.
|
555
|
+
Parameters
|
556
|
+
----------
|
557
|
+
fun : callable, optional
|
558
|
+
The scalar-valued function to be differentiated.
|
559
|
+
grad_states : State, sequence of State, or dict of State, optional
|
560
|
+
The variables in fun to take their gradients.
|
561
|
+
argnums : int or sequence of int, optional
|
562
|
+
Specifies which positional argument(s) to differentiate with respect to.
|
563
|
+
holomorphic : bool, default False
|
564
|
+
Whether fun is promised to be holomorphic.
|
565
|
+
allow_int : bool, default False
|
566
|
+
Whether to allow differentiating with respect to
|
567
|
+
integer valued inputs. The gradient of an integer input will have a trivial
|
568
|
+
vector-space dtype (float0).
|
569
|
+
has_aux : bool, optional
|
570
|
+
Indicates whether fun returns a pair where the
|
571
|
+
first element is considered the output of the mathematical function to be
|
572
|
+
differentiated and the second element is auxiliary data.
|
573
|
+
return_value : bool, default False
|
574
|
+
Indicates whether to return the value of the
|
575
|
+
function along with the gradient.
|
576
|
+
unit_aware : bool, default False
|
577
|
+
Whether to return the gradient in the unit-aware mode.
|
578
|
+
check_states : bool, default True
|
579
|
+
Whether to check that all grad_states are found in the function.
|
493
580
|
|
581
|
+
Returns
|
582
|
+
-------
|
583
|
+
GradientTransform or callable
|
584
|
+
A function which computes the gradient of fun. The function takes the same
|
585
|
+
arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
|
586
|
+
the function returns a pair where the first element is the gradient and the
|
587
|
+
second element is the auxiliary data. If `return_value` is True, the function
|
588
|
+
returns a pair where the first element is the gradient and the second element
|
589
|
+
is the value of the function.
|
590
|
+
|
591
|
+
Examples
|
592
|
+
--------
|
593
|
+
Basic gradient computation:
|
594
|
+
|
595
|
+
.. code-block:: python
|
596
|
+
|
597
|
+
>>> import brainstate
|
598
|
+
>>> import jax.numpy as jnp
|
599
|
+
>>>
|
600
|
+
>>> # Simple function gradient
|
601
|
+
>>> def f(x):
|
602
|
+
... return jnp.sum(x ** 2)
|
603
|
+
>>>
|
604
|
+
>>> grad_f = brainstate.transform.grad(f)
|
605
|
+
>>> x = jnp.array([1.0, 2.0, 3.0])
|
606
|
+
>>> gradient = grad_f(x)
|
607
|
+
|
608
|
+
Gradient with respect to states:
|
609
|
+
|
610
|
+
.. code-block:: python
|
611
|
+
|
612
|
+
>>> # Create states
|
613
|
+
>>> weight = brainstate.State(jnp.array([1.0, 2.0]))
|
614
|
+
>>> bias = brainstate.State(jnp.array([0.5]))
|
615
|
+
>>>
|
616
|
+
>>> def loss_fn(x):
|
617
|
+
... prediction = jnp.dot(x, weight.value) + bias.value
|
618
|
+
... return prediction ** 2
|
619
|
+
>>>
|
620
|
+
>>> # Compute gradients with respect to states
|
621
|
+
>>> grad_fn = brainstate.transform.grad(loss_fn, grad_states=[weight, bias])
|
622
|
+
>>> x = jnp.array([1.0, 2.0])
|
623
|
+
>>> state_grads = grad_fn(x)
|
624
|
+
|
625
|
+
With auxiliary data and return value:
|
626
|
+
|
627
|
+
.. code-block:: python
|
628
|
+
|
629
|
+
>>> def loss_with_aux(x):
|
630
|
+
... prediction = jnp.dot(x, weight.value) + bias.value
|
631
|
+
... loss = prediction ** 2
|
632
|
+
... return loss, {"prediction": prediction}
|
633
|
+
>>>
|
634
|
+
>>> grad_fn = brainstate.transform.grad(
|
635
|
+
... loss_with_aux,
|
636
|
+
... grad_states=[weight, bias],
|
637
|
+
... has_aux=True,
|
638
|
+
... return_value=True
|
639
|
+
... )
|
640
|
+
>>> grads, loss_value, aux_data = grad_fn(x)
|
494
641
|
"""
|
495
642
|
if isinstance(fun, Missing):
|
496
643
|
def transform(fun) -> GradientTransform:
|
@@ -519,10 +666,7 @@ def grad(
|
|
519
666
|
)
|
520
667
|
|
521
668
|
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
@set_module_as("brainstate.augment")
|
669
|
+
@set_module_as("brainstate.transform")
|
526
670
|
def vector_grad(
|
527
671
|
func: Callable = Missing(),
|
528
672
|
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
@@ -532,34 +676,91 @@ def vector_grad(
|
|
532
676
|
unit_aware: bool = False,
|
533
677
|
check_states: bool = True,
|
534
678
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
535
|
-
"""
|
679
|
+
"""
|
680
|
+
Take vector-valued gradients for function ``func``.
|
536
681
|
|
537
|
-
Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
|
682
|
+
Same as :py:func:`grad`, :py:func:`jacrev`, and :py:func:`jacfwd`,
|
538
683
|
the returns in this function are different for different argument settings.
|
539
684
|
|
540
|
-
|
685
|
+
|
686
|
+
1. When ``grad_states`` is None
|
687
|
+
|
688
|
+
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
689
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
690
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
691
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
692
|
+
2. When ``grad_states`` is not None and ``argnums`` is None
|
693
|
+
|
694
|
+
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
695
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
696
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
697
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
698
|
+
3. When ``grad_states`` is not None and ``argnums`` is not None
|
699
|
+
|
700
|
+
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
701
|
+
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
702
|
+
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
703
|
+
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
704
|
+
|
541
705
|
|
542
706
|
Parameters
|
543
707
|
----------
|
544
|
-
func:
|
708
|
+
func : callable, optional
|
545
709
|
Function whose gradient is to be computed.
|
546
|
-
grad_states :
|
710
|
+
grad_states : State, sequence of State, or dict of State, optional
|
547
711
|
The variables in ``func`` to take their gradients.
|
548
|
-
|
712
|
+
argnums : int or sequence of int, optional
|
713
|
+
Specifies which positional argument(s) to differentiate with respect to.
|
714
|
+
return_value : bool, default False
|
715
|
+
Whether to return the loss value.
|
716
|
+
has_aux : bool, optional
|
549
717
|
Indicates whether ``fun`` returns a pair where the
|
550
718
|
first element is considered the output of the mathematical function to be
|
551
|
-
differentiated and the second element is auxiliary data.
|
552
|
-
|
553
|
-
Whether return the
|
554
|
-
|
555
|
-
|
556
|
-
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
557
|
-
mode. Default False.
|
719
|
+
differentiated and the second element is auxiliary data.
|
720
|
+
unit_aware : bool, default False
|
721
|
+
Whether to return the gradient in the unit-aware mode.
|
722
|
+
check_states : bool, default True
|
723
|
+
Whether to check that all grad_states are found in the function.
|
558
724
|
|
559
725
|
Returns
|
560
726
|
-------
|
561
|
-
|
727
|
+
GradientTransform or callable
|
562
728
|
The vector gradient function.
|
729
|
+
|
730
|
+
Examples
|
731
|
+
--------
|
732
|
+
Basic vector gradient computation:
|
733
|
+
|
734
|
+
.. code-block:: python
|
735
|
+
|
736
|
+
>>> import brainstate
|
737
|
+
>>> import jax.numpy as jnp
|
738
|
+
>>>
|
739
|
+
>>> # Vector-valued function
|
740
|
+
>>> def f(x):
|
741
|
+
... return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])
|
742
|
+
>>>
|
743
|
+
>>> vector_grad_f = brainstate.transform.vector_grad(f)
|
744
|
+
>>> x = jnp.array([2.0, 3.0])
|
745
|
+
>>> gradients = vector_grad_f(x) # Shape: (3, 2)
|
746
|
+
|
747
|
+
With states:
|
748
|
+
|
749
|
+
.. code-block:: python
|
750
|
+
|
751
|
+
>>> params = brainstate.State(jnp.array([1.0, 2.0]))
|
752
|
+
>>>
|
753
|
+
>>> def model(x):
|
754
|
+
... return jnp.array([
|
755
|
+
... x * params.value[0],
|
756
|
+
... x**2 * params.value[1]
|
757
|
+
... ])
|
758
|
+
>>>
|
759
|
+
>>> vector_grad_fn = brainstate.transform.vector_grad(
|
760
|
+
... model, grad_states=[params]
|
761
|
+
... )
|
762
|
+
>>> x = 3.0
|
763
|
+
>>> param_grads = vector_grad_fn(x)
|
563
764
|
"""
|
564
765
|
|
565
766
|
if isinstance(func, Missing):
|
@@ -588,10 +789,7 @@ def vector_grad(
|
|
588
789
|
)
|
589
790
|
|
590
791
|
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
@set_module_as("brainstate.augment")
|
792
|
+
@set_module_as("brainstate.transform")
|
595
793
|
def jacrev(
|
596
794
|
fun: Callable,
|
597
795
|
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
@@ -610,7 +808,26 @@ def jacrev(
|
|
610
808
|
computation on functions and class functions. Moreover, it supports returning
|
611
809
|
value ("return_value") and returning auxiliary data ("has_aux").
|
612
810
|
|
613
|
-
|
811
|
+
|
812
|
+
1. When ``grad_states`` is None
|
813
|
+
|
814
|
+
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
815
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
816
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
817
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
818
|
+
2. When ``grad_states`` is not None and ``argnums`` is None
|
819
|
+
|
820
|
+
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
821
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
822
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
823
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
824
|
+
3. When ``grad_states`` is not None and ``argnums`` is not None
|
825
|
+
|
826
|
+
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
827
|
+
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
828
|
+
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
829
|
+
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
830
|
+
|
614
831
|
|
615
832
|
|
616
833
|
Parameters
|
@@ -657,12 +874,10 @@ def jacrev(
|
|
657
874
|
)
|
658
875
|
|
659
876
|
|
660
|
-
jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
|
661
|
-
|
662
877
|
jacobian = jacrev
|
663
878
|
|
664
879
|
|
665
|
-
@set_module_as("brainstate.
|
880
|
+
@set_module_as("brainstate.transform")
|
666
881
|
def jacfwd(
|
667
882
|
func: Callable,
|
668
883
|
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
@@ -679,7 +894,26 @@ def jacfwd(
|
|
679
894
|
computation on functions and class functions. Moreover, it supports returning
|
680
895
|
value ("return_value") and returning auxiliary data ("has_aux").
|
681
896
|
|
682
|
-
|
897
|
+
|
898
|
+
1. When ``grad_states`` is None
|
899
|
+
|
900
|
+
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
901
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
902
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
903
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
904
|
+
2. When ``grad_states`` is not None and ``argnums`` is None
|
905
|
+
|
906
|
+
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
907
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
908
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
909
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
910
|
+
3. When ``grad_states`` is not None and ``argnums`` is not None
|
911
|
+
|
912
|
+
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
913
|
+
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
914
|
+
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
915
|
+
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
916
|
+
|
683
917
|
|
684
918
|
Parameters
|
685
919
|
----------
|
@@ -717,10 +951,7 @@ def jacfwd(
|
|
717
951
|
)
|
718
952
|
|
719
953
|
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
@set_module_as("brainstate.augment")
|
954
|
+
@set_module_as("brainstate.transform")
|
724
955
|
def hessian(
|
725
956
|
func: Callable,
|
726
957
|
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
@@ -734,7 +965,26 @@ def hessian(
|
|
734
965
|
"""
|
735
966
|
Hessian of ``func`` as a dense array.
|
736
967
|
|
737
|
-
|
968
|
+
|
969
|
+
1. When ``grad_states`` is None
|
970
|
+
|
971
|
+
- ``has_aux=False`` + ``return_value=False`` => ``arg_grads``.
|
972
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(arg_grads, aux_data)``.
|
973
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(arg_grads, loss_value)``.
|
974
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(arg_grads, loss_value, aux_data)``.
|
975
|
+
2. When ``grad_states`` is not None and ``argnums`` is None
|
976
|
+
|
977
|
+
- ``has_aux=False`` + ``return_value=False`` => ``var_grads``.
|
978
|
+
- ``has_aux=True`` + ``return_value=False`` => ``(var_grads, aux_data)``.
|
979
|
+
- ``has_aux=False`` + ``return_value=True`` => ``(var_grads, loss_value)``.
|
980
|
+
- ``has_aux=True`` + ``return_value=True`` => ``(var_grads, loss_value, aux_data)``.
|
981
|
+
3. When ``grad_states`` is not None and ``argnums`` is not None
|
982
|
+
|
983
|
+
- ``has_aux=False`` + ``return_value=False`` => ``(var_grads, arg_grads)``.
|
984
|
+
- ``has_aux=True`` + ``return_value=False`` => ``((var_grads, arg_grads), aux_data)``.
|
985
|
+
- ``has_aux=False`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value)``.
|
986
|
+
- ``has_aux=True`` + ``return_value=True`` => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
987
|
+
|
738
988
|
|
739
989
|
Parameters
|
740
990
|
----------
|
@@ -773,6 +1023,3 @@ def hessian(
|
|
773
1023
|
transform_params=dict(holomorphic=holomorphic),
|
774
1024
|
check_states=check_states
|
775
1025
|
)
|
776
|
-
|
777
|
-
|
778
|
-
hessian.__doc__ = hessian.__doc__ % _doc_of_return
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -24,7 +24,7 @@ import jax.numpy as jnp
|
|
24
24
|
import pytest
|
25
25
|
|
26
26
|
import brainstate
|
27
|
-
from brainstate.
|
27
|
+
from brainstate.transform._autograd import _jacfwd
|
28
28
|
|
29
29
|
|
30
30
|
class TestPureFuncGrad(unittest.TestCase):
|