brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +875 -93
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +183 -35
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +128 -10
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py
CHANGED
@@ -21,6 +21,8 @@ import jax.numpy as jnp
|
|
21
21
|
import jax.scipy as sci
|
22
22
|
from jax.interpreters import batching, ad, mlir
|
23
23
|
|
24
|
+
from brainstate.util._pretty_pytree import PrettyObject
|
25
|
+
|
24
26
|
if jax.__version_info__ < (0, 4, 38):
|
25
27
|
from jax.core import Primitive
|
26
28
|
else:
|
@@ -77,7 +79,10 @@ def _heaviside_imp(x, dx):
|
|
77
79
|
|
78
80
|
|
79
81
|
def _heaviside_batching(args, axes):
|
80
|
-
|
82
|
+
x, dx = args
|
83
|
+
if axes[0] != axes[1]:
|
84
|
+
dx = batching.moveaxis(dx, axes[1], axes[0])
|
85
|
+
return heaviside_p.bind(x, dx), tuple([axes[0]])
|
81
86
|
|
82
87
|
|
83
88
|
def _heaviside_jvp(primals, tangents):
|
@@ -97,7 +102,7 @@ ad.primitive_jvps[heaviside_p] = _heaviside_jvp
|
|
97
102
|
mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
|
98
103
|
|
99
104
|
|
100
|
-
class Surrogate(
|
105
|
+
class Surrogate(PrettyObject):
|
101
106
|
"""The base surrograte gradient function.
|
102
107
|
|
103
108
|
To customize a surrogate gradient function, you can inherit this class and
|
@@ -142,9 +147,20 @@ class Surrogate(object):
|
|
142
147
|
class Sigmoid(Surrogate):
|
143
148
|
"""Spike function with the sigmoid-shaped surrogate gradient.
|
144
149
|
|
150
|
+
This class implements a spiking neuron activation with a sigmoid-shaped
|
151
|
+
surrogate gradient for backpropagation. It can be used in spiking neural
|
152
|
+
networks to approximate the non-differentiable step function during training.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
alpha : float, optional
|
157
|
+
A parameter controlling the steepness of the sigmoid curve in the
|
158
|
+
surrogate gradient. Higher values make the transition sharper.
|
159
|
+
Default is 4.0.
|
160
|
+
|
145
161
|
See Also
|
146
162
|
--------
|
147
|
-
sigmoid
|
163
|
+
sigmoid : Function version of this class.
|
148
164
|
|
149
165
|
"""
|
150
166
|
|
@@ -153,9 +169,33 @@ class Sigmoid(Surrogate):
|
|
153
169
|
self.alpha = alpha
|
154
170
|
|
155
171
|
def surrogate_fun(self, x):
|
172
|
+
"""Compute the surrogate function.
|
173
|
+
|
174
|
+
Parameters
|
175
|
+
----------
|
176
|
+
x : jax.Array
|
177
|
+
The input array.
|
178
|
+
|
179
|
+
Returns
|
180
|
+
-------
|
181
|
+
jax.Array
|
182
|
+
The output of the surrogate function.
|
183
|
+
"""
|
156
184
|
return sci.special.expit(self.alpha * x)
|
157
185
|
|
158
186
|
def surrogate_grad(self, x):
|
187
|
+
"""Compute the gradient of the surrogate function.
|
188
|
+
|
189
|
+
Parameters
|
190
|
+
----------
|
191
|
+
x : jax.Array
|
192
|
+
The input array.
|
193
|
+
|
194
|
+
Returns
|
195
|
+
-------
|
196
|
+
jax.Array
|
197
|
+
The gradient of the surrogate function.
|
198
|
+
"""
|
159
199
|
sgax = sci.special.expit(x * self.alpha)
|
160
200
|
dx = (1. - sgax) * sgax * self.alpha
|
161
201
|
return dx
|
@@ -171,7 +211,12 @@ def sigmoid(
|
|
171
211
|
x: jax.Array,
|
172
212
|
alpha: float = 4.,
|
173
213
|
):
|
174
|
-
r"""
|
214
|
+
r"""
|
215
|
+
Compute a spike function with a sigmoid-shaped surrogate gradient.
|
216
|
+
|
217
|
+
This function implements a spiking neuron activation with a sigmoid-shaped
|
218
|
+
surrogate gradient for backpropagation. It can be used in spiking neural
|
219
|
+
networks to approximate the non-differentiable step function during training.
|
175
220
|
|
176
221
|
If `origin=False`, return the forward function:
|
177
222
|
|
@@ -210,16 +255,28 @@ def sigmoid(
|
|
210
255
|
|
211
256
|
Parameters
|
212
257
|
----------
|
213
|
-
x: jax.Array
|
214
|
-
|
215
|
-
alpha: float
|
216
|
-
|
217
|
-
|
258
|
+
x : jax.Array
|
259
|
+
The input array representing the neuron's membrane potential.
|
260
|
+
alpha : float, optional
|
261
|
+
A parameter controlling the steepness of the sigmoid curve in the
|
262
|
+
surrogate gradient. Higher values make the transition sharper.
|
263
|
+
Default is 4.0.
|
218
264
|
|
219
265
|
Returns
|
220
266
|
-------
|
221
|
-
|
222
|
-
|
267
|
+
jax.Array
|
268
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
269
|
+
representing the spiking state of each neuron.
|
270
|
+
|
271
|
+
Notes
|
272
|
+
-----
|
273
|
+
The forward pass uses a step function (1 for x >= 0, 0 for x < 0),
|
274
|
+
while the backward pass uses a sigmoid-shaped surrogate gradient for
|
275
|
+
smooth optimization.
|
276
|
+
|
277
|
+
The surrogate gradient is defined as:
|
278
|
+
g'(x) = alpha * (1 - sigmoid(alpha * x)) * sigmoid(alpha * x)
|
279
|
+
|
223
280
|
"""
|
224
281
|
return Sigmoid(alpha=alpha)(x)
|
225
282
|
|
@@ -238,11 +295,15 @@ class PiecewiseQuadratic(Surrogate):
|
|
238
295
|
self.alpha = alpha
|
239
296
|
|
240
297
|
def surrogate_fun(self, x):
|
241
|
-
z = jnp.where(
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
298
|
+
z = jnp.where(
|
299
|
+
x < -1 / self.alpha,
|
300
|
+
0.,
|
301
|
+
jnp.where(
|
302
|
+
x > 1 / self.alpha,
|
303
|
+
1.,
|
304
|
+
(-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5
|
305
|
+
)
|
306
|
+
)
|
246
307
|
return z
|
247
308
|
|
248
309
|
def surrogate_grad(self, x):
|
@@ -260,7 +321,12 @@ def piecewise_quadratic(
|
|
260
321
|
x: jax.Array,
|
261
322
|
alpha: float = 1.,
|
262
323
|
):
|
263
|
-
r"""
|
324
|
+
r"""
|
325
|
+
Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
|
326
|
+
|
327
|
+
This function implements a surrogate gradient method for spiking neural networks
|
328
|
+
using a piecewise quadratic function. It provides a differentiable approximation
|
329
|
+
of the step function used in the forward pass of spiking neurons.
|
264
330
|
|
265
331
|
If `origin=False`, computes the forward function:
|
266
332
|
|
@@ -306,18 +372,29 @@ def piecewise_quadratic(
|
|
306
372
|
>>> plt.legend()
|
307
373
|
>>> plt.show()
|
308
374
|
|
309
|
-
|
375
|
+
Parameters
|
310
376
|
----------
|
311
|
-
x: jax.Array
|
312
|
-
|
313
|
-
alpha: float
|
314
|
-
|
315
|
-
|
377
|
+
x : jax.Array
|
378
|
+
The input array representing the neuron's membrane potential.
|
379
|
+
alpha : float, optional
|
380
|
+
A parameter controlling the steepness of the surrogate gradient.
|
381
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
316
382
|
|
317
383
|
Returns
|
318
384
|
-------
|
319
|
-
|
320
|
-
|
385
|
+
jax.Array
|
386
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
387
|
+
representing the spiking state of each neuron.
|
388
|
+
|
389
|
+
Notes
|
390
|
+
-----
|
391
|
+
The function uses different computations for forward and backward passes:
|
392
|
+
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
393
|
+
- Backward: Piecewise quadratic function for smooth gradient
|
394
|
+
|
395
|
+
The surrogate gradient is defined as:
|
396
|
+
g'(x) = 0 if |x| > 1/alpha
|
397
|
+
-alpha^2|x| + alpha if |x| <= 1/alpha
|
321
398
|
|
322
399
|
References
|
323
400
|
----------
|
@@ -331,11 +408,22 @@ def piecewise_quadratic(
|
|
331
408
|
|
332
409
|
|
333
410
|
class PiecewiseExp(Surrogate):
|
334
|
-
"""
|
411
|
+
"""
|
412
|
+
Judge spiking state with a piecewise exponential function.
|
413
|
+
|
414
|
+
This class implements a surrogate gradient method for spiking neural networks
|
415
|
+
using a piecewise exponential function. It provides a differentiable approximation
|
416
|
+
of the step function used in the forward pass of spiking neurons.
|
417
|
+
|
418
|
+
Parameters
|
419
|
+
----------
|
420
|
+
alpha : float, optional
|
421
|
+
A parameter controlling the steepness of the surrogate gradient.
|
422
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
335
423
|
|
336
424
|
See Also
|
337
425
|
--------
|
338
|
-
piecewise_exp
|
426
|
+
piecewise_exp : Function version of this class.
|
339
427
|
"""
|
340
428
|
|
341
429
|
def __init__(self, alpha: float = 1.):
|
@@ -343,16 +431,62 @@ class PiecewiseExp(Surrogate):
|
|
343
431
|
self.alpha = alpha
|
344
432
|
|
345
433
|
def surrogate_grad(self, x):
|
434
|
+
"""
|
435
|
+
Compute the surrogate gradient.
|
436
|
+
|
437
|
+
Parameters
|
438
|
+
----------
|
439
|
+
x : jax.Array
|
440
|
+
The input array.
|
441
|
+
|
442
|
+
Returns
|
443
|
+
-------
|
444
|
+
jax.Array
|
445
|
+
The surrogate gradient.
|
446
|
+
"""
|
346
447
|
dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
|
347
448
|
return dx
|
348
449
|
|
349
450
|
def surrogate_fun(self, x):
|
350
|
-
|
451
|
+
"""
|
452
|
+
Compute the surrogate function.
|
453
|
+
|
454
|
+
Parameters
|
455
|
+
----------
|
456
|
+
x : jax.Array
|
457
|
+
The input array.
|
458
|
+
|
459
|
+
Returns
|
460
|
+
-------
|
461
|
+
jax.Array
|
462
|
+
The output of the surrogate function.
|
463
|
+
"""
|
464
|
+
return jnp.where(
|
465
|
+
x < 0,
|
466
|
+
jnp.exp(self.alpha * x) / 2,
|
467
|
+
1 - jnp.exp(-self.alpha * x) / 2
|
468
|
+
)
|
351
469
|
|
352
470
|
def __repr__(self):
|
471
|
+
"""
|
472
|
+
Return a string representation of the PiecewiseExp instance.
|
473
|
+
|
474
|
+
Returns
|
475
|
+
-------
|
476
|
+
str
|
477
|
+
A string representation of the instance.
|
478
|
+
"""
|
353
479
|
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
354
480
|
|
355
481
|
def __hash__(self):
|
482
|
+
"""
|
483
|
+
Compute a hash value for the PiecewiseExp instance.
|
484
|
+
|
485
|
+
Returns
|
486
|
+
-------
|
487
|
+
int
|
488
|
+
A hash value for the instance.
|
489
|
+
"""
|
356
490
|
return hash((self.__class__, self.alpha))
|
357
491
|
|
358
492
|
|
@@ -363,6 +497,10 @@ def piecewise_exp(
|
|
363
497
|
):
|
364
498
|
r"""Judge spiking state with a piecewise exponential function [1]_.
|
365
499
|
|
500
|
+
This function implements a surrogate gradient method for spiking neural networks
|
501
|
+
using a piecewise exponential function. It provides a differentiable approximation
|
502
|
+
of the step function used in the forward pass of spiking neurons.
|
503
|
+
|
366
504
|
If `origin=False`, computes the forward function:
|
367
505
|
|
368
506
|
.. math::
|
@@ -403,16 +541,26 @@ def piecewise_exp(
|
|
403
541
|
|
404
542
|
Parameters
|
405
543
|
----------
|
406
|
-
x: jax.Array
|
407
|
-
|
408
|
-
alpha: float
|
409
|
-
|
410
|
-
|
544
|
+
x : jax.Array
|
545
|
+
The input array representing the neuron's membrane potential.
|
546
|
+
alpha : float, optional
|
547
|
+
A parameter controlling the steepness of the surrogate gradient.
|
548
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
411
549
|
|
412
550
|
Returns
|
413
551
|
-------
|
414
|
-
|
415
|
-
|
552
|
+
jax.Array
|
553
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
554
|
+
representing the spiking state of each neuron.
|
555
|
+
|
556
|
+
Notes
|
557
|
+
-----
|
558
|
+
The function uses different computations for forward and backward passes:
|
559
|
+
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
560
|
+
- Backward: Piecewise exponential function for smooth gradient
|
561
|
+
|
562
|
+
The surrogate gradient is defined as:
|
563
|
+
g'(x) = (alpha / 2) * exp(-alpha * |x|)
|
416
564
|
|
417
565
|
References
|
418
566
|
----------
|
brainstate/transform.py
CHANGED
brainstate/typing.py
CHANGED
@@ -16,13 +16,17 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import builtins
|
19
|
+
|
20
|
+
import brainunit as u
|
19
21
|
import functools as ft
|
20
22
|
import importlib
|
21
23
|
import inspect
|
22
|
-
|
23
|
-
import brainunit as u
|
24
24
|
import jax
|
25
25
|
import numpy as np
|
26
|
+
from typing import (
|
27
|
+
Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
|
28
|
+
runtime_checkable, TYPE_CHECKING, Generic, Sequence
|
29
|
+
)
|
26
30
|
|
27
31
|
tp = importlib.import_module("typing")
|
28
32
|
|
@@ -41,35 +45,35 @@ __all__ = [
|
|
41
45
|
'Missing',
|
42
46
|
]
|
43
47
|
|
44
|
-
K =
|
48
|
+
K = TypeVar('K')
|
45
49
|
|
46
50
|
|
47
|
-
@
|
48
|
-
class Key(
|
51
|
+
@runtime_checkable
|
52
|
+
class Key(Hashable, Protocol):
|
49
53
|
def __lt__(self: K, value: K, /) -> bool:
|
50
54
|
...
|
51
55
|
|
52
56
|
|
53
|
-
Ellipsis = builtins.ellipsis if
|
57
|
+
Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
|
54
58
|
|
55
|
-
PathParts =
|
56
|
-
Predicate =
|
57
|
-
FilterLiteral =
|
58
|
-
Filter =
|
59
|
+
PathParts = Tuple[Key, ...]
|
60
|
+
Predicate = Callable[[PathParts, Any], bool]
|
61
|
+
FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
|
62
|
+
Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
|
59
63
|
|
60
|
-
_T =
|
64
|
+
_T = TypeVar("_T")
|
61
65
|
|
62
|
-
_Annotation =
|
66
|
+
_Annotation = TypeVar("_Annotation")
|
63
67
|
|
64
68
|
|
65
|
-
class _Array(
|
69
|
+
class _Array(Generic[_Annotation]):
|
66
70
|
pass
|
67
71
|
|
68
72
|
|
69
73
|
_Array.__module__ = "builtins"
|
70
74
|
|
71
75
|
|
72
|
-
def _item_to_str(item:
|
76
|
+
def _item_to_str(item: Union[str, type, slice]) -> str:
|
73
77
|
if isinstance(item, slice):
|
74
78
|
if item.step is not None:
|
75
79
|
raise NotImplementedError
|
@@ -83,7 +87,7 @@ def _item_to_str(item: tp.Union[str, type, slice]) -> str:
|
|
83
87
|
|
84
88
|
|
85
89
|
def _maybe_tuple_to_str(
|
86
|
-
item:
|
90
|
+
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
87
91
|
) -> str:
|
88
92
|
if isinstance(item, tuple):
|
89
93
|
if len(item) == 0:
|
@@ -113,7 +117,7 @@ class Array:
|
|
113
117
|
Array.__module__ = "builtins"
|
114
118
|
|
115
119
|
|
116
|
-
class _FakePyTree(
|
120
|
+
class _FakePyTree(Generic[_T]):
|
117
121
|
pass
|
118
122
|
|
119
123
|
|
@@ -255,11 +259,10 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
|
|
255
259
|
cases, all named pieces must already have been seen and their structures bound.
|
256
260
|
""" # noqa: E501
|
257
261
|
|
258
|
-
Size =
|
259
|
-
Axes =
|
260
|
-
SeedOrKey =
|
261
|
-
Shape =
|
262
|
-
|
262
|
+
Size = Union[int, Sequence[int]]
|
263
|
+
Axes = Union[int, Sequence[int]]
|
264
|
+
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
265
|
+
Shape = Sequence[int]
|
263
266
|
|
264
267
|
# --- Array --- #
|
265
268
|
|
@@ -267,7 +270,7 @@ Shape = tp.Sequence[int]
|
|
267
270
|
# standard JAX array (i.e. not including future non-standard array types like
|
268
271
|
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
|
269
272
|
# accept arbitrary sequences, nor does it accept string data.
|
270
|
-
ArrayLike =
|
273
|
+
ArrayLike = Union[
|
271
274
|
jax.Array, # JAX array type
|
272
275
|
np.ndarray, # NumPy array type
|
273
276
|
np.bool_, np.number, # NumPy scalar types
|
@@ -281,7 +284,7 @@ ArrayLike = tp.Union[
|
|
281
284
|
DType = np.dtype
|
282
285
|
|
283
286
|
|
284
|
-
class SupportsDType(
|
287
|
+
class SupportsDType(Protocol):
|
285
288
|
@property
|
286
289
|
def dtype(self) -> DType: ...
|
287
290
|
|
@@ -291,9 +294,9 @@ class SupportsDType(tp.Protocol):
|
|
291
294
|
# because JAX doesn't support objects or structured dtypes.
|
292
295
|
# Unlike np.typing.DTypeLike, we exclude None, and instead require
|
293
296
|
# explicit annotations when None is acceptable.
|
294
|
-
DTypeLike =
|
297
|
+
DTypeLike = Union[
|
295
298
|
str, # like 'float32', 'int32'
|
296
|
-
type[
|
299
|
+
type[Any], # like np.float32, np.int32, float, int
|
297
300
|
np.dtype, # like np.dtype('float32'), np.dtype('int32')
|
298
301
|
SupportsDType, # like jnp.float32, jnp.int32
|
299
302
|
]
|
brainstate/util/__init__.py
CHANGED
@@ -13,36 +13,38 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from .
|
17
|
-
from ._dict import __all__ as _mapping_all
|
16
|
+
from . import filter
|
18
17
|
from ._error import *
|
19
18
|
from ._error import __all__ as _error_all
|
20
|
-
from ._filter import *
|
21
|
-
from ._filter import __all__ as _filter_all
|
22
19
|
from ._others import *
|
23
20
|
from ._others import __all__ as _others_all
|
21
|
+
from ._pretty_pytree import *
|
22
|
+
from ._pretty_pytree import __all__ as _mapping_all
|
24
23
|
from ._pretty_repr import *
|
25
24
|
from ._pretty_repr import __all__ as _pretty_repr_all
|
25
|
+
from ._pretty_table import *
|
26
|
+
from ._pretty_table import __all__ as _table_all
|
26
27
|
from ._scaling import *
|
27
28
|
from ._scaling import __all__ as _mem_scale_all
|
28
29
|
from ._struct import *
|
29
30
|
from ._struct import __all__ as _struct_all
|
30
31
|
|
31
32
|
__all__ = (
|
32
|
-
|
33
|
+
['filter']
|
34
|
+
+ _others_all
|
33
35
|
+ _mem_scale_all
|
34
|
-
+ _filter_all
|
35
36
|
+ _pretty_repr_all
|
36
37
|
+ _struct_all
|
37
38
|
+ _error_all
|
38
39
|
+ _mapping_all
|
40
|
+
+ _table_all
|
39
41
|
)
|
40
42
|
del (
|
41
43
|
_others_all,
|
42
44
|
_mem_scale_all,
|
43
|
-
_filter_all,
|
44
45
|
_pretty_repr_all,
|
45
46
|
_struct_all,
|
46
47
|
_error_all,
|
47
48
|
_mapping_all,
|
49
|
+
_table_all,
|
48
50
|
)
|
brainstate/util/_caller.py
CHANGED
brainstate/util/_error.py
CHANGED
@@ -21,8 +21,35 @@ __all__ = [
|
|
21
21
|
|
22
22
|
|
23
23
|
class BrainStateError(Exception):
|
24
|
+
"""
|
25
|
+
A custom exception class for BrainState-related errors.
|
26
|
+
|
27
|
+
This exception is raised when a BrainState-specific error occurs during
|
28
|
+
the execution of the program. It serves as a base class for more specific
|
29
|
+
BrainState exceptions.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
Inherits all attributes from the built-in Exception class.
|
33
|
+
|
34
|
+
Usage::
|
35
|
+
|
36
|
+
raise BrainStateError("A BrainState-specific error occurred.")
|
37
|
+
"""
|
24
38
|
pass
|
25
39
|
|
26
40
|
|
27
41
|
class TraceContextError(BrainStateError):
|
42
|
+
"""
|
43
|
+
A custom exception class for trace context-related errors in BrainState.
|
44
|
+
|
45
|
+
This exception is raised when an error occurs specifically related to
|
46
|
+
trace context operations or manipulations within the BrainState framework.
|
47
|
+
|
48
|
+
Attributes:
|
49
|
+
Inherits all attributes from the BrainStateError class.
|
50
|
+
|
51
|
+
Usage::
|
52
|
+
|
53
|
+
raise TraceContextError("An error occurred while handling trace context.")
|
54
|
+
"""
|
28
55
|
pass
|
brainstate/util/_others.py
CHANGED
@@ -15,20 +15,21 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
import gc
|
19
|
+
|
18
20
|
import copy
|
19
21
|
import functools
|
20
|
-
import
|
22
|
+
import jax
|
21
23
|
import threading
|
22
24
|
import types
|
23
25
|
from collections.abc import Iterable
|
24
|
-
from typing import Any, Callable, Tuple, Union, Dict
|
25
|
-
|
26
|
-
import jax
|
27
26
|
from jax.lib import xla_bridge
|
27
|
+
from typing import Any, Callable, Tuple, Union, Dict
|
28
28
|
|
29
29
|
from brainstate._utils import set_module_as
|
30
30
|
|
31
31
|
__all__ = [
|
32
|
+
'split_total',
|
32
33
|
'clear_buffer_memory',
|
33
34
|
'not_instance_eval',
|
34
35
|
'is_instance_eval',
|
@@ -37,6 +38,61 @@ __all__ = [
|
|
37
38
|
]
|
38
39
|
|
39
40
|
|
41
|
+
def split_total(
|
42
|
+
total: int,
|
43
|
+
fraction: Union[int, float],
|
44
|
+
) -> int:
|
45
|
+
"""
|
46
|
+
Calculate the number of epochs for simulation based on a total and a fraction.
|
47
|
+
|
48
|
+
This function determines the number of epochs to simulate given a total number
|
49
|
+
of epochs and either a fraction or a specific number of epochs to run.
|
50
|
+
|
51
|
+
Parameters:
|
52
|
+
-----------
|
53
|
+
total : int
|
54
|
+
The total number of epochs. Must be a positive integer.
|
55
|
+
fraction : Union[int, float]
|
56
|
+
If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
|
57
|
+
If ``int``: The specific number of epochs to run, must not exceed the total.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
--------
|
61
|
+
int
|
62
|
+
The calculated number of epochs to simulate.
|
63
|
+
|
64
|
+
Raises:
|
65
|
+
-------
|
66
|
+
ValueError
|
67
|
+
If total is not positive, fraction is negative, or if fraction as float is > 1
|
68
|
+
or as int is > total.
|
69
|
+
AssertionError
|
70
|
+
If total is not an integer.
|
71
|
+
"""
|
72
|
+
assert isinstance(total, int), "Length must be an integer."
|
73
|
+
if total <= 0:
|
74
|
+
raise ValueError("'total' must be a positive integer.")
|
75
|
+
if fraction < 0:
|
76
|
+
raise ValueError("'fraction' value cannot be negative.")
|
77
|
+
|
78
|
+
if isinstance(fraction, float):
|
79
|
+
if fraction < 0:
|
80
|
+
raise ValueError("'fraction' value cannot be negative.")
|
81
|
+
if fraction > 1:
|
82
|
+
raise ValueError("'fraction' value cannot be greater than 1.")
|
83
|
+
return int(total * fraction)
|
84
|
+
|
85
|
+
elif isinstance(fraction, int):
|
86
|
+
if fraction < 0:
|
87
|
+
raise ValueError("'fraction' value cannot be negative.")
|
88
|
+
if fraction > total:
|
89
|
+
raise ValueError("'fraction' value cannot be greater than total.")
|
90
|
+
return fraction
|
91
|
+
|
92
|
+
else:
|
93
|
+
raise ValueError("'fraction' must be an integer or float.")
|
94
|
+
|
95
|
+
|
40
96
|
class NameContext(threading.local):
|
41
97
|
def __init__(self):
|
42
98
|
self.typed_names: Dict[str, int] = {}
|
@@ -249,17 +305,6 @@ class DictManager(dict):
|
|
249
305
|
else:
|
250
306
|
raise ValueError(f'Unsupported method: {by}')
|
251
307
|
|
252
|
-
def union_by_value_ids(self, other: dict):
|
253
|
-
"""
|
254
|
-
Union the stack by the value ids.
|
255
|
-
|
256
|
-
Args:
|
257
|
-
other:
|
258
|
-
|
259
|
-
Returns:
|
260
|
-
|
261
|
-
"""
|
262
|
-
|
263
308
|
def __add__(self, other: dict):
|
264
309
|
"""
|
265
310
|
Compose other instance of dict.
|