brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,585 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import inspect
|
19
|
-
from functools import partial, wraps
|
20
|
-
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple
|
21
|
-
|
22
|
-
import jax
|
23
|
-
from jax import numpy as jnp
|
24
|
-
from jax._src.api import _vjp
|
25
|
-
from jax.api_util import argnums_partial
|
26
|
-
from jax.extend import linear_util
|
27
|
-
|
28
|
-
from brainstate._state import State, StateTrace, StateDictManager
|
29
|
-
from brainstate._utils import set_module_as
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
33
|
-
]
|
34
|
-
|
35
|
-
|
36
|
-
def _isgeneratorfunction(fun):
|
37
|
-
# re-implemented here because of https://bugs.python.org/issue33261
|
38
|
-
while inspect.ismethod(fun):
|
39
|
-
fun = fun.__func__
|
40
|
-
while isinstance(fun, partial):
|
41
|
-
fun = fun.func
|
42
|
-
return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
|
43
|
-
|
44
|
-
|
45
|
-
def _check_callable(fun):
|
46
|
-
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
47
|
-
# is that we can't take weak references to them, which the C++ JIT requires.
|
48
|
-
if isinstance(fun, staticmethod):
|
49
|
-
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
50
|
-
if not callable(fun):
|
51
|
-
raise TypeError(f"Expected a callable value, got {fun}")
|
52
|
-
if _isgeneratorfunction(fun):
|
53
|
-
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
54
|
-
|
55
|
-
|
56
|
-
def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
|
57
|
-
"""
|
58
|
-
Compute the gradient of a vector with respect to the input.
|
59
|
-
"""
|
60
|
-
_check_callable(func)
|
61
|
-
|
62
|
-
@wraps(func)
|
63
|
-
def grad_fun(*args, **kwargs):
|
64
|
-
f = linear_util.wrap_init(func, kwargs)
|
65
|
-
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
|
66
|
-
if has_aux:
|
67
|
-
y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
68
|
-
else:
|
69
|
-
y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
|
70
|
-
leaves, tree = jax.tree.flatten(y)
|
71
|
-
tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
|
72
|
-
grads = vjp_fn(tangents)
|
73
|
-
if isinstance(argnums, int):
|
74
|
-
grads = grads[0]
|
75
|
-
if has_aux:
|
76
|
-
return (grads, y, aux) if return_value else (grads, aux)
|
77
|
-
else:
|
78
|
-
return (grads, y) if return_value else grads
|
79
|
-
|
80
|
-
return grad_fun
|
81
|
-
|
82
|
-
|
83
|
-
def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
|
84
|
-
@wraps(fun)
|
85
|
-
def fun_wrapped(*args, **kwargs):
|
86
|
-
if has_aux:
|
87
|
-
y, aux = fun(*args, **kwargs)
|
88
|
-
if return_value:
|
89
|
-
return y, (y, aux)
|
90
|
-
else:
|
91
|
-
return y, aux
|
92
|
-
else:
|
93
|
-
y = fun(*args, **kwargs)
|
94
|
-
if return_value:
|
95
|
-
return y, y
|
96
|
-
else:
|
97
|
-
return y, None
|
98
|
-
|
99
|
-
transform = jax.jacrev(fun_wrapped, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True)
|
100
|
-
|
101
|
-
@wraps(fun)
|
102
|
-
def jacfun(*args, **kwargs):
|
103
|
-
jac, aux = transform(*args, **kwargs)
|
104
|
-
if return_value:
|
105
|
-
return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
|
106
|
-
else:
|
107
|
-
return (jac, aux) if has_aux else jac
|
108
|
-
|
109
|
-
return jacfun
|
110
|
-
|
111
|
-
|
112
|
-
def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
|
113
|
-
@wraps(fun)
|
114
|
-
def fun_wrapped(*args, **kwargs):
|
115
|
-
if has_aux:
|
116
|
-
y, aux = fun(*args, **kwargs)
|
117
|
-
if return_value:
|
118
|
-
return y, (y, aux)
|
119
|
-
else:
|
120
|
-
return y, aux
|
121
|
-
else:
|
122
|
-
y = fun(*args, **kwargs)
|
123
|
-
if return_value:
|
124
|
-
return y, y
|
125
|
-
else:
|
126
|
-
return y, None
|
127
|
-
|
128
|
-
transform = jax.jacfwd(fun_wrapped, argnums=argnums, holomorphic=holomorphic, has_aux=True)
|
129
|
-
|
130
|
-
@wraps(fun)
|
131
|
-
def jacfun(*args, **kwargs):
|
132
|
-
jac, aux = transform(*args, **kwargs)
|
133
|
-
if return_value:
|
134
|
-
return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
|
135
|
-
else:
|
136
|
-
return (jac, aux) if has_aux else jac
|
137
|
-
|
138
|
-
return jacfun
|
139
|
-
|
140
|
-
|
141
|
-
class GradientTransform(object):
|
142
|
-
"""
|
143
|
-
Automatic Differentiation Transformations for the ``State`` system.
|
144
|
-
"""
|
145
|
-
__module__ = "brainstate.transform"
|
146
|
-
|
147
|
-
def __init__(
|
148
|
-
self,
|
149
|
-
target: Callable,
|
150
|
-
transform: Callable,
|
151
|
-
grad_vars: Any,
|
152
|
-
argnums: Optional[Union[int, Sequence[int]]],
|
153
|
-
return_value: bool,
|
154
|
-
has_aux: bool,
|
155
|
-
transform_params: Optional[Dict[str, Any]] = None,
|
156
|
-
):
|
157
|
-
# gradient variables
|
158
|
-
if isinstance(grad_vars, StateDictManager):
|
159
|
-
grad_vars = {k: v for k, v in grad_vars.items()}
|
160
|
-
self._grad_vars, self._grad_tree = jax.tree.flatten(grad_vars)
|
161
|
-
if any(not isinstance(v, State) for v in self._grad_vars):
|
162
|
-
raise TypeError("All grad_vars must be State instances.")
|
163
|
-
|
164
|
-
# parameters
|
165
|
-
if argnums is None and len(self._grad_vars) == 0:
|
166
|
-
argnums = 0
|
167
|
-
if argnums is None:
|
168
|
-
assert len(self._grad_vars) > 0
|
169
|
-
_argnums = 0
|
170
|
-
elif isinstance(argnums, int):
|
171
|
-
_argnums = (0, argnums + 1) if len(self._grad_vars) > 0 else (argnums + 1)
|
172
|
-
else:
|
173
|
-
assert isinstance(argnums, (tuple, list))
|
174
|
-
_argnums = tuple(a + 1 for a in argnums)
|
175
|
-
if len(self._grad_vars) > 0:
|
176
|
-
_argnums = (0,) + _argnums
|
177
|
-
self._nonvar_argnums = argnums
|
178
|
-
self._argnums = _argnums
|
179
|
-
self._return_value = return_value
|
180
|
-
self._has_aux = has_aux
|
181
|
-
|
182
|
-
# target
|
183
|
-
self.target = target
|
184
|
-
|
185
|
-
# transform
|
186
|
-
self._states_to_be_written: Tuple[State, ...] = None
|
187
|
-
_grad_setting = dict() if transform_params is None else transform_params
|
188
|
-
if self._has_aux:
|
189
|
-
self._transform = transform(self._fun_with_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
|
190
|
-
else:
|
191
|
-
self._transform = transform(self._fun_without_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
|
192
|
-
|
193
|
-
def __repr__(self):
|
194
|
-
name = self.__class__.__name__
|
195
|
-
format_ref = (f'{name}(target={self.target}, \n' +
|
196
|
-
f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
|
197
|
-
f'{" " * len(name)} num_of_dyn_vars={len(self._states_to_be_written)})')
|
198
|
-
return format_ref
|
199
|
-
|
200
|
-
def __call_target(self, *args, **kwargs):
|
201
|
-
if self._states_to_be_written is None:
|
202
|
-
with StateTrace() as stack:
|
203
|
-
output = self.target(*args, **kwargs)
|
204
|
-
grad_ids = set([id(v) for v in self._grad_vars])
|
205
|
-
self._states_to_be_written = tuple(st for st, ty in zip(stack.states, stack.types)
|
206
|
-
if ty == 'write' and id(st) not in grad_ids)
|
207
|
-
else:
|
208
|
-
output = self.target(*args, **kwargs)
|
209
|
-
return output
|
210
|
-
|
211
|
-
def _fun_with_aux(self, grad_values: tuple, *args, **kwargs):
|
212
|
-
for v, d in zip(self._grad_vars, grad_values):
|
213
|
-
v._value = d
|
214
|
-
# Users should return the auxiliary data like::
|
215
|
-
# >>> # 1. example of return one data
|
216
|
-
# >>> return scalar_loss, data
|
217
|
-
# >>> # 2. example of return multiple data
|
218
|
-
# >>> return scalar_loss, (data1, data2, ...)
|
219
|
-
outs = self.__call_target(*args, **kwargs)
|
220
|
-
# outputs: [0] is the value for gradient,
|
221
|
-
# [1] is other values for return
|
222
|
-
assert self._states_to_be_written is not None, "The states to be written should be collected."
|
223
|
-
return outs[0], (outs, [v.value for v in self._grad_vars], [v.value for v in self._states_to_be_written])
|
224
|
-
|
225
|
-
def _fun_without_aux(self, grad_values: tuple, *args, **kwargs):
|
226
|
-
for v, d in zip(self._grad_vars, grad_values):
|
227
|
-
v._value = d
|
228
|
-
# Users should return the scalar value like this::
|
229
|
-
# >>> return scalar_loss
|
230
|
-
out = self.__call_target(*args, **kwargs)
|
231
|
-
assert self._states_to_be_written is not None, "The states to be written should be collected."
|
232
|
-
return out, (out, [v.value for v in self._grad_vars], [v.value for v in self._states_to_be_written])
|
233
|
-
|
234
|
-
def __return(self, rets):
|
235
|
-
grads, (outputs, new_grad_vals, new_dyn_vals) = rets
|
236
|
-
for i, val in enumerate(new_grad_vals):
|
237
|
-
self._grad_vars[i].value = val
|
238
|
-
for i, val in enumerate(new_dyn_vals):
|
239
|
-
self._states_to_be_written[i].value = val
|
240
|
-
|
241
|
-
# check returned grads
|
242
|
-
if len(self._grad_vars) > 0:
|
243
|
-
if self._nonvar_argnums is None:
|
244
|
-
grads = self._grad_tree.unflatten(grads)
|
245
|
-
else:
|
246
|
-
var_grads = self._grad_tree.unflatten(grads[0])
|
247
|
-
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
|
248
|
-
grads = (var_grads, arg_grads)
|
249
|
-
|
250
|
-
# check returned value
|
251
|
-
if self._return_value:
|
252
|
-
# check aux
|
253
|
-
if self._has_aux:
|
254
|
-
return grads, outputs[0], outputs[1]
|
255
|
-
else:
|
256
|
-
return grads, outputs
|
257
|
-
else:
|
258
|
-
# check aux
|
259
|
-
if self._has_aux:
|
260
|
-
return grads, outputs[1]
|
261
|
-
else:
|
262
|
-
return grads
|
263
|
-
|
264
|
-
def __call__(self, *args, **kwargs):
|
265
|
-
rets = self._transform([v.value for v in self._grad_vars], *args, **kwargs)
|
266
|
-
return self.__return(rets)
|
267
|
-
|
268
|
-
|
269
|
-
@set_module_as("brainstate.transform")
|
270
|
-
def grad(
|
271
|
-
fun: Optional[Callable] = None,
|
272
|
-
grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
273
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
274
|
-
holomorphic: Optional[bool] = False,
|
275
|
-
allow_int: Optional[bool] = False,
|
276
|
-
reduce_axes: Optional[Sequence[str]] = (),
|
277
|
-
has_aux: Optional[bool] = None,
|
278
|
-
return_value: Optional[bool] = False,
|
279
|
-
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
280
|
-
"""
|
281
|
-
Compute the gradient of a scalar-valued function with respect to its arguments.
|
282
|
-
|
283
|
-
Args:
|
284
|
-
reduce_axes:
|
285
|
-
allow_int:
|
286
|
-
holomorphic:
|
287
|
-
grad_vars:
|
288
|
-
fun: the scalar-valued function to be differentiated.
|
289
|
-
argnums: (int or tuple of ints) optional. Specifies which positional
|
290
|
-
argument(s) to differentiate with respect to.
|
291
|
-
has_aux: (bool) optional. Indicates whether fun returns a pair where the
|
292
|
-
first element is considered the output of the mathematical function to be
|
293
|
-
differentiated and the second element is auxiliary data. Default False.
|
294
|
-
return_value: (bool) optional. Indicates whether to return the value of the
|
295
|
-
function along with the gradient. Default False.
|
296
|
-
|
297
|
-
Returns:
|
298
|
-
A function which computes the gradient of fun. The function takes the same
|
299
|
-
arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
|
300
|
-
the function returns a pair where the first element is the gradient and the
|
301
|
-
second element is the auxiliary data. If `return_value` is True, the function
|
302
|
-
returns a pair where the first element is the gradient and the second element
|
303
|
-
is the value of the function.
|
304
|
-
|
305
|
-
"""
|
306
|
-
if fun is None:
|
307
|
-
def transform(fun) -> GradientTransform:
|
308
|
-
return GradientTransform(target=fun,
|
309
|
-
transform=jax.grad,
|
310
|
-
grad_vars=grad_vars,
|
311
|
-
argnums=argnums,
|
312
|
-
return_value=return_value,
|
313
|
-
has_aux=False if has_aux is None else has_aux,
|
314
|
-
transform_params=dict(holomorphic=holomorphic,
|
315
|
-
allow_int=allow_int,
|
316
|
-
reduce_axes=reduce_axes))
|
317
|
-
|
318
|
-
return transform
|
319
|
-
|
320
|
-
return GradientTransform(target=fun,
|
321
|
-
transform=jax.grad,
|
322
|
-
grad_vars=grad_vars,
|
323
|
-
argnums=argnums,
|
324
|
-
return_value=return_value,
|
325
|
-
has_aux=False if has_aux is None else has_aux,
|
326
|
-
transform_params=dict(holomorphic=holomorphic,
|
327
|
-
allow_int=allow_int,
|
328
|
-
reduce_axes=reduce_axes))
|
329
|
-
|
330
|
-
|
331
|
-
@set_module_as("brainstate.transform")
|
332
|
-
def vector_grad(
|
333
|
-
func: Optional[Callable] = None,
|
334
|
-
grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
335
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
336
|
-
return_value: bool = False,
|
337
|
-
has_aux: Optional[bool] = None,
|
338
|
-
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
339
|
-
"""Take vector-valued gradients for function ``func``.
|
340
|
-
|
341
|
-
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_,
|
342
|
-
`brainpy.math.jacrev <./brainpy.math.autograd.jacrev.html>`_ and
|
343
|
-
`brainpy.math.jacfwd <./brainpy.math.autograd.jacfwd.html>`_,
|
344
|
-
the returns in this function are different for different argument settings.
|
345
|
-
|
346
|
-
1. When "grad_vars" is None
|
347
|
-
- "has_aux=False" + "return_value=False" => ``arg_grads``.
|
348
|
-
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
|
349
|
-
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
|
350
|
-
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
|
351
|
-
2. When "grad_vars" is not None and "argnums" is None
|
352
|
-
- "has_aux=False" + "return_value=False" => ``var_grads``.
|
353
|
-
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
|
354
|
-
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
|
355
|
-
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
|
356
|
-
3. When "grad_vars" is not None and "argnums" is not None
|
357
|
-
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
|
358
|
-
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
|
359
|
-
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
|
360
|
-
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
361
|
-
|
362
|
-
|
363
|
-
Parameters
|
364
|
-
----------
|
365
|
-
func: Callable
|
366
|
-
Function whose gradient is to be computed.
|
367
|
-
grad_vars : optional, ArrayType, sequence of ArrayType, dict
|
368
|
-
The variables in ``func`` to take their gradients.
|
369
|
-
has_aux: optional, bool
|
370
|
-
Indicates whether ``fun`` returns a pair where the
|
371
|
-
first element is considered the output of the mathematical function to be
|
372
|
-
differentiated and the second element is auxiliary data. Default False.
|
373
|
-
return_value : bool
|
374
|
-
Whether return the loss value.
|
375
|
-
argnums: Optional, integer or sequence of integers. Specifies which
|
376
|
-
positional argument(s) to differentiate with respect to (default ``0``).
|
377
|
-
|
378
|
-
Returns
|
379
|
-
-------
|
380
|
-
func : GradientTransform
|
381
|
-
The vector gradient function.
|
382
|
-
"""
|
383
|
-
|
384
|
-
if func is None:
|
385
|
-
def transform(fun) -> GradientTransform:
|
386
|
-
return GradientTransform(target=fun,
|
387
|
-
transform=functional_vector_grad,
|
388
|
-
grad_vars=grad_vars,
|
389
|
-
argnums=argnums,
|
390
|
-
return_value=return_value,
|
391
|
-
has_aux=False if has_aux is None else has_aux)
|
392
|
-
|
393
|
-
return transform
|
394
|
-
|
395
|
-
else:
|
396
|
-
return GradientTransform(target=func,
|
397
|
-
transform=functional_vector_grad,
|
398
|
-
grad_vars=grad_vars,
|
399
|
-
argnums=argnums,
|
400
|
-
return_value=return_value,
|
401
|
-
has_aux=False if has_aux is None else has_aux)
|
402
|
-
|
403
|
-
|
404
|
-
@set_module_as("brainstate.transform")
|
405
|
-
def jacrev(
|
406
|
-
func: Callable,
|
407
|
-
grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
408
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
409
|
-
has_aux: Optional[bool] = None,
|
410
|
-
return_value: bool = False,
|
411
|
-
holomorphic: bool = False,
|
412
|
-
allow_int: bool = False,
|
413
|
-
) -> GradientTransform:
|
414
|
-
"""
|
415
|
-
Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
|
416
|
-
|
417
|
-
This function extends the JAX official ``jacrev`` to make automatic jacobian
|
418
|
-
computation on functions and class functions. Moreover, it supports returning
|
419
|
-
value ("return_value") and returning auxiliary data ("has_aux").
|
420
|
-
|
421
|
-
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are
|
422
|
-
different for different argument settings in ``brainpy.math.jacrev``.
|
423
|
-
|
424
|
-
1. When "grad_vars" is None
|
425
|
-
- "has_aux=False" + "return_value=False" => ``arg_grads``.
|
426
|
-
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
|
427
|
-
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
|
428
|
-
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
|
429
|
-
2. When "grad_vars" is not None and "argnums" is None
|
430
|
-
- "has_aux=False" + "return_value=False" => ``var_grads``.
|
431
|
-
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
|
432
|
-
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
|
433
|
-
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
|
434
|
-
3. When "grad_vars" is not None and "argnums" is not None
|
435
|
-
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
|
436
|
-
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
|
437
|
-
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
|
438
|
-
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
439
|
-
|
440
|
-
Parameters
|
441
|
-
----------
|
442
|
-
func: Function whose Jacobian is to be computed.
|
443
|
-
grad_vars : optional, ArrayType, sequence of ArrayType, dict
|
444
|
-
The variables in ``func`` to take their gradients.
|
445
|
-
has_aux: optional, bool
|
446
|
-
Indicates whether ``fun`` returns a pair where the
|
447
|
-
first element is considered the output of the mathematical function to be
|
448
|
-
differentiated and the second element is auxiliary data. Default False.
|
449
|
-
return_value : bool
|
450
|
-
Whether return the loss value.
|
451
|
-
argnums: Optional, integer or sequence of integers.
|
452
|
-
Specifies which
|
453
|
-
positional argument(s) to differentiate with respect to (default ``0``).
|
454
|
-
holomorphic: Optional, bool.
|
455
|
-
Indicates whether ``fun`` is promised to be
|
456
|
-
holomorphic. Default False.
|
457
|
-
allow_int: Optional, bool.
|
458
|
-
Whether to allow differentiating with
|
459
|
-
respect to integer valued inputs. The gradient of an integer input will
|
460
|
-
have a trivial vector-space dtype (float0). Default False.
|
461
|
-
|
462
|
-
Returns
|
463
|
-
-------
|
464
|
-
fun: GradientTransform
|
465
|
-
The transformed object.
|
466
|
-
"""
|
467
|
-
return GradientTransform(target=func,
|
468
|
-
transform=_jacrev,
|
469
|
-
grad_vars=grad_vars,
|
470
|
-
argnums=argnums,
|
471
|
-
return_value=return_value,
|
472
|
-
has_aux=False if has_aux is None else has_aux,
|
473
|
-
transform_params=dict(holomorphic=holomorphic,
|
474
|
-
allow_int=allow_int))
|
475
|
-
|
476
|
-
|
477
|
-
jacobian = jacrev
|
478
|
-
|
479
|
-
|
480
|
-
@set_module_as("brainstate.transform")
|
481
|
-
def jacfwd(
|
482
|
-
func: Callable,
|
483
|
-
grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
484
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
485
|
-
has_aux: Optional[bool] = None,
|
486
|
-
return_value: bool = False,
|
487
|
-
holomorphic: bool = False,
|
488
|
-
) -> GradientTransform:
|
489
|
-
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
|
490
|
-
|
491
|
-
This function extends the JAX official ``jacfwd`` to make automatic jacobian
|
492
|
-
computation on functions and class functions. Moreover, it supports returning
|
493
|
-
value ("return_value") and returning auxiliary data ("has_aux").
|
494
|
-
|
495
|
-
Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are
|
496
|
-
different for different argument settings in ``brainpy.math.jacfwd``.
|
497
|
-
|
498
|
-
1. When "grad_vars" is None
|
499
|
-
- "has_aux=False" + "return_value=False" => ``arg_grads``.
|
500
|
-
- "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
|
501
|
-
- "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
|
502
|
-
- "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
|
503
|
-
2. When "grad_vars" is not None and "argnums" is None
|
504
|
-
- "has_aux=False" + "return_value=False" => ``var_grads``.
|
505
|
-
- "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
|
506
|
-
- "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
|
507
|
-
- "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
|
508
|
-
3. When "grad_vars" is not None and "argnums" is not None
|
509
|
-
- "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
|
510
|
-
- "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
|
511
|
-
- "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
|
512
|
-
- "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
|
513
|
-
|
514
|
-
Parameters
|
515
|
-
----------
|
516
|
-
func: Function whose Jacobian is to be computed.
|
517
|
-
grad_vars : optional, ArrayType, sequence of ArrayType, dict
|
518
|
-
The variables in ``func`` to take their gradients.
|
519
|
-
has_aux: optional, bool
|
520
|
-
Indicates whether ``fun`` returns a pair where the
|
521
|
-
first element is considered the output of the mathematical function to be
|
522
|
-
differentiated and the second element is auxiliary data. Default False.
|
523
|
-
return_value : bool
|
524
|
-
Whether return the loss value.
|
525
|
-
argnums: Optional, integer or sequence of integers. Specifies which
|
526
|
-
positional argument(s) to differentiate with respect to (default ``0``).
|
527
|
-
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
528
|
-
holomorphic. Default False.
|
529
|
-
|
530
|
-
Returns
|
531
|
-
-------
|
532
|
-
obj: GradientTransform
|
533
|
-
The transformed object.
|
534
|
-
"""
|
535
|
-
|
536
|
-
return GradientTransform(target=func,
|
537
|
-
transform=_jacfwd,
|
538
|
-
grad_vars=grad_vars,
|
539
|
-
argnums=argnums,
|
540
|
-
return_value=return_value,
|
541
|
-
has_aux=False if has_aux is None else has_aux,
|
542
|
-
transform_params=dict(holomorphic=holomorphic))
|
543
|
-
|
544
|
-
|
545
|
-
@set_module_as("brainstate.transform")
|
546
|
-
def hessian(
|
547
|
-
func: Callable,
|
548
|
-
grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
549
|
-
argnums: Optional[Union[int, Sequence[int]]] = None,
|
550
|
-
return_value: bool = False,
|
551
|
-
holomorphic: bool = False,
|
552
|
-
) -> GradientTransform:
|
553
|
-
"""Hessian of ``func`` as a dense array.
|
554
|
-
|
555
|
-
Parameters
|
556
|
-
----------
|
557
|
-
func : callable
|
558
|
-
Function whose Hessian is to be computed. Its arguments at positions
|
559
|
-
specified by ``argnums`` should be arrays, scalars, or standard Python
|
560
|
-
containers thereof. It should return arrays, scalars, or standard Python
|
561
|
-
containers thereof.
|
562
|
-
grad_vars : optional, ArrayCollector, sequence of ArrayType
|
563
|
-
The variables required to compute their gradients.
|
564
|
-
argnums: Optional, integer or sequence of integers
|
565
|
-
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
|
566
|
-
holomorphic : bool
|
567
|
-
Indicates whether ``fun`` is promised to be holomorphic. Default False.
|
568
|
-
return_value : bool
|
569
|
-
Whether return the hessian values.
|
570
|
-
|
571
|
-
Returns
|
572
|
-
-------
|
573
|
-
obj: ObjectTransform
|
574
|
-
The transformed object.
|
575
|
-
"""
|
576
|
-
raise NotImplementedError("The hessian computation is not supported yet.")
|
577
|
-
|
578
|
-
# return jacfwd(jacrev(func,
|
579
|
-
# grad_vars=grad_vars,
|
580
|
-
# argnums=argnums,
|
581
|
-
# holomorphic=holomorphic),
|
582
|
-
# grad_vars=grad_vars,
|
583
|
-
# argnums=argnums,
|
584
|
-
# holomorphic=holomorphic,
|
585
|
-
# return_value=return_value)
|