brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,585 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- import inspect
19
- from functools import partial, wraps
20
- from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple
21
-
22
- import jax
23
- from jax import numpy as jnp
24
- from jax._src.api import _vjp
25
- from jax.api_util import argnums_partial
26
- from jax.extend import linear_util
27
-
28
- from brainstate._state import State, StateTrace, StateDictManager
29
- from brainstate._utils import set_module_as
30
-
31
- __all__ = [
32
- 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
33
- ]
34
-
35
-
36
- def _isgeneratorfunction(fun):
37
- # re-implemented here because of https://bugs.python.org/issue33261
38
- while inspect.ismethod(fun):
39
- fun = fun.__func__
40
- while isinstance(fun, partial):
41
- fun = fun.func
42
- return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
43
-
44
-
45
- def _check_callable(fun):
46
- # In Python 3.10+, the only thing stopping us from supporting staticmethods
47
- # is that we can't take weak references to them, which the C++ JIT requires.
48
- if isinstance(fun, staticmethod):
49
- raise TypeError(f"staticmethod arguments are not supported, got {fun}")
50
- if not callable(fun):
51
- raise TypeError(f"Expected a callable value, got {fun}")
52
- if _isgeneratorfunction(fun):
53
- raise TypeError(f"Expected a function, got a generator function: {fun}")
54
-
55
-
56
- def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
57
- """
58
- Compute the gradient of a vector with respect to the input.
59
- """
60
- _check_callable(func)
61
-
62
- @wraps(func)
63
- def grad_fun(*args, **kwargs):
64
- f = linear_util.wrap_init(func, kwargs)
65
- f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
66
- if has_aux:
67
- y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
68
- else:
69
- y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
70
- leaves, tree = jax.tree.flatten(y)
71
- tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
72
- grads = vjp_fn(tangents)
73
- if isinstance(argnums, int):
74
- grads = grads[0]
75
- if has_aux:
76
- return (grads, y, aux) if return_value else (grads, aux)
77
- else:
78
- return (grads, y) if return_value else grads
79
-
80
- return grad_fun
81
-
82
-
83
- def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
84
- @wraps(fun)
85
- def fun_wrapped(*args, **kwargs):
86
- if has_aux:
87
- y, aux = fun(*args, **kwargs)
88
- if return_value:
89
- return y, (y, aux)
90
- else:
91
- return y, aux
92
- else:
93
- y = fun(*args, **kwargs)
94
- if return_value:
95
- return y, y
96
- else:
97
- return y, None
98
-
99
- transform = jax.jacrev(fun_wrapped, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True)
100
-
101
- @wraps(fun)
102
- def jacfun(*args, **kwargs):
103
- jac, aux = transform(*args, **kwargs)
104
- if return_value:
105
- return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
106
- else:
107
- return (jac, aux) if has_aux else jac
108
-
109
- return jacfun
110
-
111
-
112
- def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
113
- @wraps(fun)
114
- def fun_wrapped(*args, **kwargs):
115
- if has_aux:
116
- y, aux = fun(*args, **kwargs)
117
- if return_value:
118
- return y, (y, aux)
119
- else:
120
- return y, aux
121
- else:
122
- y = fun(*args, **kwargs)
123
- if return_value:
124
- return y, y
125
- else:
126
- return y, None
127
-
128
- transform = jax.jacfwd(fun_wrapped, argnums=argnums, holomorphic=holomorphic, has_aux=True)
129
-
130
- @wraps(fun)
131
- def jacfun(*args, **kwargs):
132
- jac, aux = transform(*args, **kwargs)
133
- if return_value:
134
- return (jac, aux[0], aux[1]) if has_aux else (jac, aux)
135
- else:
136
- return (jac, aux) if has_aux else jac
137
-
138
- return jacfun
139
-
140
-
141
- class GradientTransform(object):
142
- """
143
- Automatic Differentiation Transformations for the ``State`` system.
144
- """
145
- __module__ = "brainstate.transform"
146
-
147
- def __init__(
148
- self,
149
- target: Callable,
150
- transform: Callable,
151
- grad_vars: Any,
152
- argnums: Optional[Union[int, Sequence[int]]],
153
- return_value: bool,
154
- has_aux: bool,
155
- transform_params: Optional[Dict[str, Any]] = None,
156
- ):
157
- # gradient variables
158
- if isinstance(grad_vars, StateDictManager):
159
- grad_vars = {k: v for k, v in grad_vars.items()}
160
- self._grad_vars, self._grad_tree = jax.tree.flatten(grad_vars)
161
- if any(not isinstance(v, State) for v in self._grad_vars):
162
- raise TypeError("All grad_vars must be State instances.")
163
-
164
- # parameters
165
- if argnums is None and len(self._grad_vars) == 0:
166
- argnums = 0
167
- if argnums is None:
168
- assert len(self._grad_vars) > 0
169
- _argnums = 0
170
- elif isinstance(argnums, int):
171
- _argnums = (0, argnums + 1) if len(self._grad_vars) > 0 else (argnums + 1)
172
- else:
173
- assert isinstance(argnums, (tuple, list))
174
- _argnums = tuple(a + 1 for a in argnums)
175
- if len(self._grad_vars) > 0:
176
- _argnums = (0,) + _argnums
177
- self._nonvar_argnums = argnums
178
- self._argnums = _argnums
179
- self._return_value = return_value
180
- self._has_aux = has_aux
181
-
182
- # target
183
- self.target = target
184
-
185
- # transform
186
- self._states_to_be_written: Tuple[State, ...] = None
187
- _grad_setting = dict() if transform_params is None else transform_params
188
- if self._has_aux:
189
- self._transform = transform(self._fun_with_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
190
- else:
191
- self._transform = transform(self._fun_without_aux, argnums=self._argnums, has_aux=True, **_grad_setting)
192
-
193
- def __repr__(self):
194
- name = self.__class__.__name__
195
- format_ref = (f'{name}(target={self.target}, \n' +
196
- f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
197
- f'{" " * len(name)} num_of_dyn_vars={len(self._states_to_be_written)})')
198
- return format_ref
199
-
200
- def __call_target(self, *args, **kwargs):
201
- if self._states_to_be_written is None:
202
- with StateTrace() as stack:
203
- output = self.target(*args, **kwargs)
204
- grad_ids = set([id(v) for v in self._grad_vars])
205
- self._states_to_be_written = tuple(st for st, ty in zip(stack.states, stack.types)
206
- if ty == 'write' and id(st) not in grad_ids)
207
- else:
208
- output = self.target(*args, **kwargs)
209
- return output
210
-
211
- def _fun_with_aux(self, grad_values: tuple, *args, **kwargs):
212
- for v, d in zip(self._grad_vars, grad_values):
213
- v._value = d
214
- # Users should return the auxiliary data like::
215
- # >>> # 1. example of return one data
216
- # >>> return scalar_loss, data
217
- # >>> # 2. example of return multiple data
218
- # >>> return scalar_loss, (data1, data2, ...)
219
- outs = self.__call_target(*args, **kwargs)
220
- # outputs: [0] is the value for gradient,
221
- # [1] is other values for return
222
- assert self._states_to_be_written is not None, "The states to be written should be collected."
223
- return outs[0], (outs, [v.value for v in self._grad_vars], [v.value for v in self._states_to_be_written])
224
-
225
- def _fun_without_aux(self, grad_values: tuple, *args, **kwargs):
226
- for v, d in zip(self._grad_vars, grad_values):
227
- v._value = d
228
- # Users should return the scalar value like this::
229
- # >>> return scalar_loss
230
- out = self.__call_target(*args, **kwargs)
231
- assert self._states_to_be_written is not None, "The states to be written should be collected."
232
- return out, (out, [v.value for v in self._grad_vars], [v.value for v in self._states_to_be_written])
233
-
234
- def __return(self, rets):
235
- grads, (outputs, new_grad_vals, new_dyn_vals) = rets
236
- for i, val in enumerate(new_grad_vals):
237
- self._grad_vars[i].value = val
238
- for i, val in enumerate(new_dyn_vals):
239
- self._states_to_be_written[i].value = val
240
-
241
- # check returned grads
242
- if len(self._grad_vars) > 0:
243
- if self._nonvar_argnums is None:
244
- grads = self._grad_tree.unflatten(grads)
245
- else:
246
- var_grads = self._grad_tree.unflatten(grads[0])
247
- arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
248
- grads = (var_grads, arg_grads)
249
-
250
- # check returned value
251
- if self._return_value:
252
- # check aux
253
- if self._has_aux:
254
- return grads, outputs[0], outputs[1]
255
- else:
256
- return grads, outputs
257
- else:
258
- # check aux
259
- if self._has_aux:
260
- return grads, outputs[1]
261
- else:
262
- return grads
263
-
264
- def __call__(self, *args, **kwargs):
265
- rets = self._transform([v.value for v in self._grad_vars], *args, **kwargs)
266
- return self.__return(rets)
267
-
268
-
269
- @set_module_as("brainstate.transform")
270
- def grad(
271
- fun: Optional[Callable] = None,
272
- grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
273
- argnums: Optional[Union[int, Sequence[int]]] = None,
274
- holomorphic: Optional[bool] = False,
275
- allow_int: Optional[bool] = False,
276
- reduce_axes: Optional[Sequence[str]] = (),
277
- has_aux: Optional[bool] = None,
278
- return_value: Optional[bool] = False,
279
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
280
- """
281
- Compute the gradient of a scalar-valued function with respect to its arguments.
282
-
283
- Args:
284
- reduce_axes:
285
- allow_int:
286
- holomorphic:
287
- grad_vars:
288
- fun: the scalar-valued function to be differentiated.
289
- argnums: (int or tuple of ints) optional. Specifies which positional
290
- argument(s) to differentiate with respect to.
291
- has_aux: (bool) optional. Indicates whether fun returns a pair where the
292
- first element is considered the output of the mathematical function to be
293
- differentiated and the second element is auxiliary data. Default False.
294
- return_value: (bool) optional. Indicates whether to return the value of the
295
- function along with the gradient. Default False.
296
-
297
- Returns:
298
- A function which computes the gradient of fun. The function takes the same
299
- arguments as `fun`, but returns the gradient instead. If `has_aux` is True,
300
- the function returns a pair where the first element is the gradient and the
301
- second element is the auxiliary data. If `return_value` is True, the function
302
- returns a pair where the first element is the gradient and the second element
303
- is the value of the function.
304
-
305
- """
306
- if fun is None:
307
- def transform(fun) -> GradientTransform:
308
- return GradientTransform(target=fun,
309
- transform=jax.grad,
310
- grad_vars=grad_vars,
311
- argnums=argnums,
312
- return_value=return_value,
313
- has_aux=False if has_aux is None else has_aux,
314
- transform_params=dict(holomorphic=holomorphic,
315
- allow_int=allow_int,
316
- reduce_axes=reduce_axes))
317
-
318
- return transform
319
-
320
- return GradientTransform(target=fun,
321
- transform=jax.grad,
322
- grad_vars=grad_vars,
323
- argnums=argnums,
324
- return_value=return_value,
325
- has_aux=False if has_aux is None else has_aux,
326
- transform_params=dict(holomorphic=holomorphic,
327
- allow_int=allow_int,
328
- reduce_axes=reduce_axes))
329
-
330
-
331
- @set_module_as("brainstate.transform")
332
- def vector_grad(
333
- func: Optional[Callable] = None,
334
- grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
335
- argnums: Optional[Union[int, Sequence[int]]] = None,
336
- return_value: bool = False,
337
- has_aux: Optional[bool] = None,
338
- ) -> GradientTransform | Callable[[Callable], GradientTransform]:
339
- """Take vector-valued gradients for function ``func``.
340
-
341
- Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_,
342
- `brainpy.math.jacrev <./brainpy.math.autograd.jacrev.html>`_ and
343
- `brainpy.math.jacfwd <./brainpy.math.autograd.jacfwd.html>`_,
344
- the returns in this function are different for different argument settings.
345
-
346
- 1. When "grad_vars" is None
347
- - "has_aux=False" + "return_value=False" => ``arg_grads``.
348
- - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
349
- - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
350
- - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
351
- 2. When "grad_vars" is not None and "argnums" is None
352
- - "has_aux=False" + "return_value=False" => ``var_grads``.
353
- - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
354
- - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
355
- - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
356
- 3. When "grad_vars" is not None and "argnums" is not None
357
- - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
358
- - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
359
- - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
360
- - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
361
-
362
-
363
- Parameters
364
- ----------
365
- func: Callable
366
- Function whose gradient is to be computed.
367
- grad_vars : optional, ArrayType, sequence of ArrayType, dict
368
- The variables in ``func`` to take their gradients.
369
- has_aux: optional, bool
370
- Indicates whether ``fun`` returns a pair where the
371
- first element is considered the output of the mathematical function to be
372
- differentiated and the second element is auxiliary data. Default False.
373
- return_value : bool
374
- Whether return the loss value.
375
- argnums: Optional, integer or sequence of integers. Specifies which
376
- positional argument(s) to differentiate with respect to (default ``0``).
377
-
378
- Returns
379
- -------
380
- func : GradientTransform
381
- The vector gradient function.
382
- """
383
-
384
- if func is None:
385
- def transform(fun) -> GradientTransform:
386
- return GradientTransform(target=fun,
387
- transform=functional_vector_grad,
388
- grad_vars=grad_vars,
389
- argnums=argnums,
390
- return_value=return_value,
391
- has_aux=False if has_aux is None else has_aux)
392
-
393
- return transform
394
-
395
- else:
396
- return GradientTransform(target=func,
397
- transform=functional_vector_grad,
398
- grad_vars=grad_vars,
399
- argnums=argnums,
400
- return_value=return_value,
401
- has_aux=False if has_aux is None else has_aux)
402
-
403
-
404
- @set_module_as("brainstate.transform")
405
- def jacrev(
406
- func: Callable,
407
- grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
408
- argnums: Optional[Union[int, Sequence[int]]] = None,
409
- has_aux: Optional[bool] = None,
410
- return_value: bool = False,
411
- holomorphic: bool = False,
412
- allow_int: bool = False,
413
- ) -> GradientTransform:
414
- """
415
- Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
416
-
417
- This function extends the JAX official ``jacrev`` to make automatic jacobian
418
- computation on functions and class functions. Moreover, it supports returning
419
- value ("return_value") and returning auxiliary data ("has_aux").
420
-
421
- Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are
422
- different for different argument settings in ``brainpy.math.jacrev``.
423
-
424
- 1. When "grad_vars" is None
425
- - "has_aux=False" + "return_value=False" => ``arg_grads``.
426
- - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
427
- - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
428
- - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
429
- 2. When "grad_vars" is not None and "argnums" is None
430
- - "has_aux=False" + "return_value=False" => ``var_grads``.
431
- - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
432
- - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
433
- - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
434
- 3. When "grad_vars" is not None and "argnums" is not None
435
- - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
436
- - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
437
- - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
438
- - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
439
-
440
- Parameters
441
- ----------
442
- func: Function whose Jacobian is to be computed.
443
- grad_vars : optional, ArrayType, sequence of ArrayType, dict
444
- The variables in ``func`` to take their gradients.
445
- has_aux: optional, bool
446
- Indicates whether ``fun`` returns a pair where the
447
- first element is considered the output of the mathematical function to be
448
- differentiated and the second element is auxiliary data. Default False.
449
- return_value : bool
450
- Whether return the loss value.
451
- argnums: Optional, integer or sequence of integers.
452
- Specifies which
453
- positional argument(s) to differentiate with respect to (default ``0``).
454
- holomorphic: Optional, bool.
455
- Indicates whether ``fun`` is promised to be
456
- holomorphic. Default False.
457
- allow_int: Optional, bool.
458
- Whether to allow differentiating with
459
- respect to integer valued inputs. The gradient of an integer input will
460
- have a trivial vector-space dtype (float0). Default False.
461
-
462
- Returns
463
- -------
464
- fun: GradientTransform
465
- The transformed object.
466
- """
467
- return GradientTransform(target=func,
468
- transform=_jacrev,
469
- grad_vars=grad_vars,
470
- argnums=argnums,
471
- return_value=return_value,
472
- has_aux=False if has_aux is None else has_aux,
473
- transform_params=dict(holomorphic=holomorphic,
474
- allow_int=allow_int))
475
-
476
-
477
- jacobian = jacrev
478
-
479
-
480
- @set_module_as("brainstate.transform")
481
- def jacfwd(
482
- func: Callable,
483
- grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
484
- argnums: Optional[Union[int, Sequence[int]]] = None,
485
- has_aux: Optional[bool] = None,
486
- return_value: bool = False,
487
- holomorphic: bool = False,
488
- ) -> GradientTransform:
489
- """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
490
-
491
- This function extends the JAX official ``jacfwd`` to make automatic jacobian
492
- computation on functions and class functions. Moreover, it supports returning
493
- value ("return_value") and returning auxiliary data ("has_aux").
494
-
495
- Same as `brainpy.math.grad <./brainpy.math.autograd.grad.html>`_, the returns are
496
- different for different argument settings in ``brainpy.math.jacfwd``.
497
-
498
- 1. When "grad_vars" is None
499
- - "has_aux=False" + "return_value=False" => ``arg_grads``.
500
- - "has_aux=True" + "return_value=False" => ``(arg_grads, aux_data)``.
501
- - "has_aux=False" + "return_value=True" => ``(arg_grads, loss_value)``.
502
- - "has_aux=True" + "return_value=True" => ``(arg_grads, loss_value, aux_data)``.
503
- 2. When "grad_vars" is not None and "argnums" is None
504
- - "has_aux=False" + "return_value=False" => ``var_grads``.
505
- - "has_aux=True" + "return_value=False" => ``(var_grads, aux_data)``.
506
- - "has_aux=False" + "return_value=True" => ``(var_grads, loss_value)``.
507
- - "has_aux=True" + "return_value=True" => ``(var_grads, loss_value, aux_data)``.
508
- 3. When "grad_vars" is not None and "argnums" is not None
509
- - "has_aux=False" + "return_value=False" => ``(var_grads, arg_grads)``.
510
- - "has_aux=True" + "return_value=False" => ``((var_grads, arg_grads), aux_data)``.
511
- - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``.
512
- - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``.
513
-
514
- Parameters
515
- ----------
516
- func: Function whose Jacobian is to be computed.
517
- grad_vars : optional, ArrayType, sequence of ArrayType, dict
518
- The variables in ``func`` to take their gradients.
519
- has_aux: optional, bool
520
- Indicates whether ``fun`` returns a pair where the
521
- first element is considered the output of the mathematical function to be
522
- differentiated and the second element is auxiliary data. Default False.
523
- return_value : bool
524
- Whether return the loss value.
525
- argnums: Optional, integer or sequence of integers. Specifies which
526
- positional argument(s) to differentiate with respect to (default ``0``).
527
- holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
528
- holomorphic. Default False.
529
-
530
- Returns
531
- -------
532
- obj: GradientTransform
533
- The transformed object.
534
- """
535
-
536
- return GradientTransform(target=func,
537
- transform=_jacfwd,
538
- grad_vars=grad_vars,
539
- argnums=argnums,
540
- return_value=return_value,
541
- has_aux=False if has_aux is None else has_aux,
542
- transform_params=dict(holomorphic=holomorphic))
543
-
544
-
545
- @set_module_as("brainstate.transform")
546
- def hessian(
547
- func: Callable,
548
- grad_vars: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
549
- argnums: Optional[Union[int, Sequence[int]]] = None,
550
- return_value: bool = False,
551
- holomorphic: bool = False,
552
- ) -> GradientTransform:
553
- """Hessian of ``func`` as a dense array.
554
-
555
- Parameters
556
- ----------
557
- func : callable
558
- Function whose Hessian is to be computed. Its arguments at positions
559
- specified by ``argnums`` should be arrays, scalars, or standard Python
560
- containers thereof. It should return arrays, scalars, or standard Python
561
- containers thereof.
562
- grad_vars : optional, ArrayCollector, sequence of ArrayType
563
- The variables required to compute their gradients.
564
- argnums: Optional, integer or sequence of integers
565
- Specifies which positional argument(s) to differentiate with respect to (default ``0``).
566
- holomorphic : bool
567
- Indicates whether ``fun`` is promised to be holomorphic. Default False.
568
- return_value : bool
569
- Whether return the hessian values.
570
-
571
- Returns
572
- -------
573
- obj: ObjectTransform
574
- The transformed object.
575
- """
576
- raise NotImplementedError("The hessian computation is not supported yet.")
577
-
578
- # return jacfwd(jacrev(func,
579
- # grad_vars=grad_vars,
580
- # argnums=argnums,
581
- # holomorphic=holomorphic),
582
- # grad_vars=grad_vars,
583
- # argnums=argnums,
584
- # holomorphic=holomorphic,
585
- # return_value=return_value)