brainstate 0.1.0.post20250211__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. brainstate/_state.py +875 -93
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +4 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +194 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +2 -3
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +63 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/metrics.py +3 -4
  68. brainstate/optim/_lr_scheduler.py +1 -2
  69. brainstate/optim/_lr_scheduler_test.py +2 -3
  70. brainstate/optim/_optax_optimizer_test.py +1 -2
  71. brainstate/optim/_sgd_optimizer.py +2 -3
  72. brainstate/random/_rand_funs.py +1 -2
  73. brainstate/random/_rand_funs_test.py +2 -3
  74. brainstate/random/_rand_seed.py +2 -3
  75. brainstate/random/_rand_seed_test.py +1 -2
  76. brainstate/random/_rand_state.py +3 -4
  77. brainstate/surrogate.py +183 -35
  78. brainstate/transform.py +0 -3
  79. brainstate/typing.py +28 -25
  80. brainstate/util/__init__.py +9 -7
  81. brainstate/util/_caller.py +1 -2
  82. brainstate/util/_error.py +27 -0
  83. brainstate/util/_others.py +60 -15
  84. brainstate/util/{_dict.py → _pretty_pytree.py} +108 -29
  85. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  86. brainstate/util/_pretty_repr.py +128 -10
  87. brainstate/util/_pretty_table.py +2900 -0
  88. brainstate/util/_struct.py +11 -11
  89. brainstate/util/filter.py +472 -0
  90. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
  91. brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
  92. brainstate/util/_filter.py +0 -178
  93. brainstate-0.1.0.post20250211.dist-info/RECORD +0 -124
  94. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
  95. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
  96. {brainstate-0.1.0.post20250211.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py CHANGED
@@ -21,6 +21,8 @@ import jax.numpy as jnp
21
21
  import jax.scipy as sci
22
22
  from jax.interpreters import batching, ad, mlir
23
23
 
24
+ from brainstate.util._pretty_pytree import PrettyObject
25
+
24
26
  if jax.__version_info__ < (0, 4, 38):
25
27
  from jax.core import Primitive
26
28
  else:
@@ -77,7 +79,10 @@ def _heaviside_imp(x, dx):
77
79
 
78
80
 
79
81
  def _heaviside_batching(args, axes):
80
- return heaviside_p.bind(*args), [axes[0]]
82
+ x, dx = args
83
+ if axes[0] != axes[1]:
84
+ dx = batching.moveaxis(dx, axes[1], axes[0])
85
+ return heaviside_p.bind(x, dx), tuple([axes[0]])
81
86
 
82
87
 
83
88
  def _heaviside_jvp(primals, tangents):
@@ -97,7 +102,7 @@ ad.primitive_jvps[heaviside_p] = _heaviside_jvp
97
102
  mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
98
103
 
99
104
 
100
- class Surrogate(object):
105
+ class Surrogate(PrettyObject):
101
106
  """The base surrograte gradient function.
102
107
 
103
108
  To customize a surrogate gradient function, you can inherit this class and
@@ -142,9 +147,20 @@ class Surrogate(object):
142
147
  class Sigmoid(Surrogate):
143
148
  """Spike function with the sigmoid-shaped surrogate gradient.
144
149
 
150
+ This class implements a spiking neuron activation with a sigmoid-shaped
151
+ surrogate gradient for backpropagation. It can be used in spiking neural
152
+ networks to approximate the non-differentiable step function during training.
153
+
154
+ Parameters
155
+ ----------
156
+ alpha : float, optional
157
+ A parameter controlling the steepness of the sigmoid curve in the
158
+ surrogate gradient. Higher values make the transition sharper.
159
+ Default is 4.0.
160
+
145
161
  See Also
146
162
  --------
147
- sigmoid
163
+ sigmoid : Function version of this class.
148
164
 
149
165
  """
150
166
 
@@ -153,9 +169,33 @@ class Sigmoid(Surrogate):
153
169
  self.alpha = alpha
154
170
 
155
171
  def surrogate_fun(self, x):
172
+ """Compute the surrogate function.
173
+
174
+ Parameters
175
+ ----------
176
+ x : jax.Array
177
+ The input array.
178
+
179
+ Returns
180
+ -------
181
+ jax.Array
182
+ The output of the surrogate function.
183
+ """
156
184
  return sci.special.expit(self.alpha * x)
157
185
 
158
186
  def surrogate_grad(self, x):
187
+ """Compute the gradient of the surrogate function.
188
+
189
+ Parameters
190
+ ----------
191
+ x : jax.Array
192
+ The input array.
193
+
194
+ Returns
195
+ -------
196
+ jax.Array
197
+ The gradient of the surrogate function.
198
+ """
159
199
  sgax = sci.special.expit(x * self.alpha)
160
200
  dx = (1. - sgax) * sgax * self.alpha
161
201
  return dx
@@ -171,7 +211,12 @@ def sigmoid(
171
211
  x: jax.Array,
172
212
  alpha: float = 4.,
173
213
  ):
174
- r"""Spike function with the sigmoid-shaped surrogate gradient.
214
+ r"""
215
+ Compute a spike function with a sigmoid-shaped surrogate gradient.
216
+
217
+ This function implements a spiking neuron activation with a sigmoid-shaped
218
+ surrogate gradient for backpropagation. It can be used in spiking neural
219
+ networks to approximate the non-differentiable step function during training.
175
220
 
176
221
  If `origin=False`, return the forward function:
177
222
 
@@ -210,16 +255,28 @@ def sigmoid(
210
255
 
211
256
  Parameters
212
257
  ----------
213
- x: jax.Array, Array
214
- The input data.
215
- alpha: float
216
- Parameter to control smoothness of gradient
217
-
258
+ x : jax.Array
259
+ The input array representing the neuron's membrane potential.
260
+ alpha : float, optional
261
+ A parameter controlling the steepness of the sigmoid curve in the
262
+ surrogate gradient. Higher values make the transition sharper.
263
+ Default is 4.0.
218
264
 
219
265
  Returns
220
266
  -------
221
- out: jax.Array
222
- The spiking state.
267
+ jax.Array
268
+ An array of the same shape as the input, containing binary values (0 or 1)
269
+ representing the spiking state of each neuron.
270
+
271
+ Notes
272
+ -----
273
+ The forward pass uses a step function (1 for x >= 0, 0 for x < 0),
274
+ while the backward pass uses a sigmoid-shaped surrogate gradient for
275
+ smooth optimization.
276
+
277
+ The surrogate gradient is defined as:
278
+ g'(x) = alpha * (1 - sigmoid(alpha * x)) * sigmoid(alpha * x)
279
+
223
280
  """
224
281
  return Sigmoid(alpha=alpha)(x)
225
282
 
@@ -238,11 +295,15 @@ class PiecewiseQuadratic(Surrogate):
238
295
  self.alpha = alpha
239
296
 
240
297
  def surrogate_fun(self, x):
241
- z = jnp.where(x < -1 / self.alpha,
242
- 0.,
243
- jnp.where(x > 1 / self.alpha,
244
- 1.,
245
- (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
298
+ z = jnp.where(
299
+ x < -1 / self.alpha,
300
+ 0.,
301
+ jnp.where(
302
+ x > 1 / self.alpha,
303
+ 1.,
304
+ (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5
305
+ )
306
+ )
246
307
  return z
247
308
 
248
309
  def surrogate_grad(self, x):
@@ -260,7 +321,12 @@ def piecewise_quadratic(
260
321
  x: jax.Array,
261
322
  alpha: float = 1.,
262
323
  ):
263
- r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
324
+ r"""
325
+ Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
326
+
327
+ This function implements a surrogate gradient method for spiking neural networks
328
+ using a piecewise quadratic function. It provides a differentiable approximation
329
+ of the step function used in the forward pass of spiking neurons.
264
330
 
265
331
  If `origin=False`, computes the forward function:
266
332
 
@@ -306,18 +372,29 @@ def piecewise_quadratic(
306
372
  >>> plt.legend()
307
373
  >>> plt.show()
308
374
 
309
- Parameters
375
+ Parameters
310
376
  ----------
311
- x: jax.Array, Array
312
- The input data.
313
- alpha: float
314
- Parameter to control smoothness of gradient
315
-
377
+ x : jax.Array
378
+ The input array representing the neuron's membrane potential.
379
+ alpha : float, optional
380
+ A parameter controlling the steepness of the surrogate gradient.
381
+ Higher values result in a steeper gradient. Default is 1.0.
316
382
 
317
383
  Returns
318
384
  -------
319
- out: jax.Array
320
- The spiking state.
385
+ jax.Array
386
+ An array of the same shape as the input, containing binary values (0 or 1)
387
+ representing the spiking state of each neuron.
388
+
389
+ Notes
390
+ -----
391
+ The function uses different computations for forward and backward passes:
392
+ - Forward: Step function (1 for x >= 0, 0 for x < 0)
393
+ - Backward: Piecewise quadratic function for smooth gradient
394
+
395
+ The surrogate gradient is defined as:
396
+ g'(x) = 0 if |x| > 1/alpha
397
+ -alpha^2|x| + alpha if |x| <= 1/alpha
321
398
 
322
399
  References
323
400
  ----------
@@ -331,11 +408,22 @@ def piecewise_quadratic(
331
408
 
332
409
 
333
410
  class PiecewiseExp(Surrogate):
334
- """Judge spiking state with a piecewise exponential function.
411
+ """
412
+ Judge spiking state with a piecewise exponential function.
413
+
414
+ This class implements a surrogate gradient method for spiking neural networks
415
+ using a piecewise exponential function. It provides a differentiable approximation
416
+ of the step function used in the forward pass of spiking neurons.
417
+
418
+ Parameters
419
+ ----------
420
+ alpha : float, optional
421
+ A parameter controlling the steepness of the surrogate gradient.
422
+ Higher values result in a steeper gradient. Default is 1.0.
335
423
 
336
424
  See Also
337
425
  --------
338
- piecewise_exp
426
+ piecewise_exp : Function version of this class.
339
427
  """
340
428
 
341
429
  def __init__(self, alpha: float = 1.):
@@ -343,16 +431,62 @@ class PiecewiseExp(Surrogate):
343
431
  self.alpha = alpha
344
432
 
345
433
  def surrogate_grad(self, x):
434
+ """
435
+ Compute the surrogate gradient.
436
+
437
+ Parameters
438
+ ----------
439
+ x : jax.Array
440
+ The input array.
441
+
442
+ Returns
443
+ -------
444
+ jax.Array
445
+ The surrogate gradient.
446
+ """
346
447
  dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
347
448
  return dx
348
449
 
349
450
  def surrogate_fun(self, x):
350
- return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
451
+ """
452
+ Compute the surrogate function.
453
+
454
+ Parameters
455
+ ----------
456
+ x : jax.Array
457
+ The input array.
458
+
459
+ Returns
460
+ -------
461
+ jax.Array
462
+ The output of the surrogate function.
463
+ """
464
+ return jnp.where(
465
+ x < 0,
466
+ jnp.exp(self.alpha * x) / 2,
467
+ 1 - jnp.exp(-self.alpha * x) / 2
468
+ )
351
469
 
352
470
  def __repr__(self):
471
+ """
472
+ Return a string representation of the PiecewiseExp instance.
473
+
474
+ Returns
475
+ -------
476
+ str
477
+ A string representation of the instance.
478
+ """
353
479
  return f'{self.__class__.__name__}(alpha={self.alpha})'
354
480
 
355
481
  def __hash__(self):
482
+ """
483
+ Compute a hash value for the PiecewiseExp instance.
484
+
485
+ Returns
486
+ -------
487
+ int
488
+ A hash value for the instance.
489
+ """
356
490
  return hash((self.__class__, self.alpha))
357
491
 
358
492
 
@@ -363,6 +497,10 @@ def piecewise_exp(
363
497
  ):
364
498
  r"""Judge spiking state with a piecewise exponential function [1]_.
365
499
 
500
+ This function implements a surrogate gradient method for spiking neural networks
501
+ using a piecewise exponential function. It provides a differentiable approximation
502
+ of the step function used in the forward pass of spiking neurons.
503
+
366
504
  If `origin=False`, computes the forward function:
367
505
 
368
506
  .. math::
@@ -403,16 +541,26 @@ def piecewise_exp(
403
541
 
404
542
  Parameters
405
543
  ----------
406
- x: jax.Array, Array
407
- The input data.
408
- alpha: float
409
- Parameter to control smoothness of gradient
410
-
544
+ x : jax.Array
545
+ The input array representing the neuron's membrane potential.
546
+ alpha : float, optional
547
+ A parameter controlling the steepness of the surrogate gradient.
548
+ Higher values result in a steeper gradient. Default is 1.0.
411
549
 
412
550
  Returns
413
551
  -------
414
- out: jax.Array
415
- The spiking state.
552
+ jax.Array
553
+ An array of the same shape as the input, containing binary values (0 or 1)
554
+ representing the spiking state of each neuron.
555
+
556
+ Notes
557
+ -----
558
+ The function uses different computations for forward and backward passes:
559
+ - Forward: Step function (1 for x >= 0, 0 for x < 0)
560
+ - Backward: Piecewise exponential function for smooth gradient
561
+
562
+ The surrogate gradient is defined as:
563
+ g'(x) = (alpha / 2) * exp(-alpha * |x|)
416
564
 
417
565
  References
418
566
  ----------
brainstate/transform.py CHANGED
@@ -15,6 +15,3 @@
15
15
 
16
16
  # alias for compilation and augmentation functions
17
17
 
18
- from .compile import *
19
- from .augment import *
20
-
brainstate/typing.py CHANGED
@@ -16,13 +16,17 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import builtins
19
+
20
+ import brainunit as u
19
21
  import functools as ft
20
22
  import importlib
21
23
  import inspect
22
-
23
- import brainunit as u
24
24
  import jax
25
25
  import numpy as np
26
+ from typing import (
27
+ Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
28
+ runtime_checkable, TYPE_CHECKING, Generic, Sequence
29
+ )
26
30
 
27
31
  tp = importlib.import_module("typing")
28
32
 
@@ -41,35 +45,35 @@ __all__ = [
41
45
  'Missing',
42
46
  ]
43
47
 
44
- K = tp.TypeVar('K')
48
+ K = TypeVar('K')
45
49
 
46
50
 
47
- @tp.runtime_checkable
48
- class Key(tp.Hashable, tp.Protocol):
51
+ @runtime_checkable
52
+ class Key(Hashable, Protocol):
49
53
  def __lt__(self: K, value: K, /) -> bool:
50
54
  ...
51
55
 
52
56
 
53
- Ellipsis = builtins.ellipsis if tp.TYPE_CHECKING else tp.Any
57
+ Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
54
58
 
55
- PathParts = tp.Tuple[Key, ...]
56
- Predicate = tp.Callable[[PathParts, tp.Any], bool]
57
- FilterLiteral = tp.Union[type, str, Predicate, bool, Ellipsis, None]
58
- Filter = tp.Union[FilterLiteral, tp.Tuple['Filter', ...], tp.List['Filter']]
59
+ PathParts = Tuple[Key, ...]
60
+ Predicate = Callable[[PathParts, Any], bool]
61
+ FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
62
+ Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
59
63
 
60
- _T = tp.TypeVar("_T")
64
+ _T = TypeVar("_T")
61
65
 
62
- _Annotation = tp.TypeVar("_Annotation")
66
+ _Annotation = TypeVar("_Annotation")
63
67
 
64
68
 
65
- class _Array(tp.Generic[_Annotation]):
69
+ class _Array(Generic[_Annotation]):
66
70
  pass
67
71
 
68
72
 
69
73
  _Array.__module__ = "builtins"
70
74
 
71
75
 
72
- def _item_to_str(item: tp.Union[str, type, slice]) -> str:
76
+ def _item_to_str(item: Union[str, type, slice]) -> str:
73
77
  if isinstance(item, slice):
74
78
  if item.step is not None:
75
79
  raise NotImplementedError
@@ -83,7 +87,7 @@ def _item_to_str(item: tp.Union[str, type, slice]) -> str:
83
87
 
84
88
 
85
89
  def _maybe_tuple_to_str(
86
- item: tp.Union[str, type, slice, tp.Tuple[tp.Union[str, type, slice], ...]]
90
+ item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
87
91
  ) -> str:
88
92
  if isinstance(item, tuple):
89
93
  if len(item) == 0:
@@ -113,7 +117,7 @@ class Array:
113
117
  Array.__module__ = "builtins"
114
118
 
115
119
 
116
- class _FakePyTree(tp.Generic[_T]):
120
+ class _FakePyTree(Generic[_T]):
117
121
  pass
118
122
 
119
123
 
@@ -255,11 +259,10 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
255
259
  cases, all named pieces must already have been seen and their structures bound.
256
260
  """ # noqa: E501
257
261
 
258
- Size = tp.Union[int, tp.Sequence[int]]
259
- Axes = tp.Union[int, tp.Sequence[int]]
260
- SeedOrKey = tp.Union[int, jax.Array, np.ndarray]
261
- Shape = tp.Sequence[int]
262
-
262
+ Size = Union[int, Sequence[int]]
263
+ Axes = Union[int, Sequence[int]]
264
+ SeedOrKey = Union[int, jax.Array, np.ndarray]
265
+ Shape = Sequence[int]
263
266
 
264
267
  # --- Array --- #
265
268
 
@@ -267,7 +270,7 @@ Shape = tp.Sequence[int]
267
270
  # standard JAX array (i.e. not including future non-standard array types like
268
271
  # KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
269
272
  # accept arbitrary sequences, nor does it accept string data.
270
- ArrayLike = tp.Union[
273
+ ArrayLike = Union[
271
274
  jax.Array, # JAX array type
272
275
  np.ndarray, # NumPy array type
273
276
  np.bool_, np.number, # NumPy scalar types
@@ -281,7 +284,7 @@ ArrayLike = tp.Union[
281
284
  DType = np.dtype
282
285
 
283
286
 
284
- class SupportsDType(tp.Protocol):
287
+ class SupportsDType(Protocol):
285
288
  @property
286
289
  def dtype(self) -> DType: ...
287
290
 
@@ -291,9 +294,9 @@ class SupportsDType(tp.Protocol):
291
294
  # because JAX doesn't support objects or structured dtypes.
292
295
  # Unlike np.typing.DTypeLike, we exclude None, and instead require
293
296
  # explicit annotations when None is acceptable.
294
- DTypeLike = tp.Union[
297
+ DTypeLike = Union[
295
298
  str, # like 'float32', 'int32'
296
- type[tp.Any], # like np.float32, np.int32, float, int
299
+ type[Any], # like np.float32, np.int32, float, int
297
300
  np.dtype, # like np.dtype('float32'), np.dtype('int32')
298
301
  SupportsDType, # like jnp.float32, jnp.int32
299
302
  ]
@@ -13,36 +13,38 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from ._dict import *
17
- from ._dict import __all__ as _mapping_all
16
+ from . import filter
18
17
  from ._error import *
19
18
  from ._error import __all__ as _error_all
20
- from ._filter import *
21
- from ._filter import __all__ as _filter_all
22
19
  from ._others import *
23
20
  from ._others import __all__ as _others_all
21
+ from ._pretty_pytree import *
22
+ from ._pretty_pytree import __all__ as _mapping_all
24
23
  from ._pretty_repr import *
25
24
  from ._pretty_repr import __all__ as _pretty_repr_all
25
+ from ._pretty_table import *
26
+ from ._pretty_table import __all__ as _table_all
26
27
  from ._scaling import *
27
28
  from ._scaling import __all__ as _mem_scale_all
28
29
  from ._struct import *
29
30
  from ._struct import __all__ as _struct_all
30
31
 
31
32
  __all__ = (
32
- _others_all
33
+ ['filter']
34
+ + _others_all
33
35
  + _mem_scale_all
34
- + _filter_all
35
36
  + _pretty_repr_all
36
37
  + _struct_all
37
38
  + _error_all
38
39
  + _mapping_all
40
+ + _table_all
39
41
  )
40
42
  del (
41
43
  _others_all,
42
44
  _mem_scale_all,
43
- _filter_all,
44
45
  _pretty_repr_all,
45
46
  _struct_all,
46
47
  _error_all,
47
48
  _mapping_all,
49
+ _table_all,
48
50
  )
@@ -18,9 +18,8 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import dataclasses
21
- from typing import Any, TypeVar, Protocol, Generic
22
-
23
21
  import jax
22
+ from typing import Any, TypeVar, Protocol, Generic
24
23
 
25
24
  __all__ = [
26
25
  'DelayedAccessor',
brainstate/util/_error.py CHANGED
@@ -21,8 +21,35 @@ __all__ = [
21
21
 
22
22
 
23
23
  class BrainStateError(Exception):
24
+ """
25
+ A custom exception class for BrainState-related errors.
26
+
27
+ This exception is raised when a BrainState-specific error occurs during
28
+ the execution of the program. It serves as a base class for more specific
29
+ BrainState exceptions.
30
+
31
+ Attributes:
32
+ Inherits all attributes from the built-in Exception class.
33
+
34
+ Usage::
35
+
36
+ raise BrainStateError("A BrainState-specific error occurred.")
37
+ """
24
38
  pass
25
39
 
26
40
 
27
41
  class TraceContextError(BrainStateError):
42
+ """
43
+ A custom exception class for trace context-related errors in BrainState.
44
+
45
+ This exception is raised when an error occurs specifically related to
46
+ trace context operations or manipulations within the BrainState framework.
47
+
48
+ Attributes:
49
+ Inherits all attributes from the BrainStateError class.
50
+
51
+ Usage::
52
+
53
+ raise TraceContextError("An error occurred while handling trace context.")
54
+ """
28
55
  pass
@@ -15,20 +15,21 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import gc
19
+
18
20
  import copy
19
21
  import functools
20
- import gc
22
+ import jax
21
23
  import threading
22
24
  import types
23
25
  from collections.abc import Iterable
24
- from typing import Any, Callable, Tuple, Union, Dict
25
-
26
- import jax
27
26
  from jax.lib import xla_bridge
27
+ from typing import Any, Callable, Tuple, Union, Dict
28
28
 
29
29
  from brainstate._utils import set_module_as
30
30
 
31
31
  __all__ = [
32
+ 'split_total',
32
33
  'clear_buffer_memory',
33
34
  'not_instance_eval',
34
35
  'is_instance_eval',
@@ -37,6 +38,61 @@ __all__ = [
37
38
  ]
38
39
 
39
40
 
41
+ def split_total(
42
+ total: int,
43
+ fraction: Union[int, float],
44
+ ) -> int:
45
+ """
46
+ Calculate the number of epochs for simulation based on a total and a fraction.
47
+
48
+ This function determines the number of epochs to simulate given a total number
49
+ of epochs and either a fraction or a specific number of epochs to run.
50
+
51
+ Parameters:
52
+ -----------
53
+ total : int
54
+ The total number of epochs. Must be a positive integer.
55
+ fraction : Union[int, float]
56
+ If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
57
+ If ``int``: The specific number of epochs to run, must not exceed the total.
58
+
59
+ Returns:
60
+ --------
61
+ int
62
+ The calculated number of epochs to simulate.
63
+
64
+ Raises:
65
+ -------
66
+ ValueError
67
+ If total is not positive, fraction is negative, or if fraction as float is > 1
68
+ or as int is > total.
69
+ AssertionError
70
+ If total is not an integer.
71
+ """
72
+ assert isinstance(total, int), "Length must be an integer."
73
+ if total <= 0:
74
+ raise ValueError("'total' must be a positive integer.")
75
+ if fraction < 0:
76
+ raise ValueError("'fraction' value cannot be negative.")
77
+
78
+ if isinstance(fraction, float):
79
+ if fraction < 0:
80
+ raise ValueError("'fraction' value cannot be negative.")
81
+ if fraction > 1:
82
+ raise ValueError("'fraction' value cannot be greater than 1.")
83
+ return int(total * fraction)
84
+
85
+ elif isinstance(fraction, int):
86
+ if fraction < 0:
87
+ raise ValueError("'fraction' value cannot be negative.")
88
+ if fraction > total:
89
+ raise ValueError("'fraction' value cannot be greater than total.")
90
+ return fraction
91
+
92
+ else:
93
+ raise ValueError("'fraction' must be an integer or float.")
94
+
95
+
40
96
  class NameContext(threading.local):
41
97
  def __init__(self):
42
98
  self.typed_names: Dict[str, int] = {}
@@ -249,17 +305,6 @@ class DictManager(dict):
249
305
  else:
250
306
  raise ValueError(f'Unsupported method: {by}')
251
307
 
252
- def union_by_value_ids(self, other: dict):
253
- """
254
- Union the stack by the value ids.
255
-
256
- Args:
257
- other:
258
-
259
- Returns:
260
-
261
- """
262
-
263
308
  def __add__(self, other: dict):
264
309
  """
265
310
  Compose other instance of dict.