brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. brainstate/_state.py +853 -90
  2. brainstate/_state_test.py +1 -3
  3. brainstate/augment/__init__.py +2 -2
  4. brainstate/augment/_autograd.py +257 -115
  5. brainstate/augment/_autograd_test.py +2 -3
  6. brainstate/augment/_eval_shape.py +3 -4
  7. brainstate/augment/_mapping.py +582 -62
  8. brainstate/augment/_mapping_test.py +114 -30
  9. brainstate/augment/_random.py +61 -7
  10. brainstate/compile/_ad_checkpoint.py +2 -3
  11. brainstate/compile/_conditions.py +4 -5
  12. brainstate/compile/_conditions_test.py +1 -2
  13. brainstate/compile/_error_if.py +1 -2
  14. brainstate/compile/_error_if_test.py +1 -2
  15. brainstate/compile/_jit.py +23 -16
  16. brainstate/compile/_jit_test.py +1 -2
  17. brainstate/compile/_loop_collect_return.py +18 -10
  18. brainstate/compile/_loop_collect_return_test.py +1 -1
  19. brainstate/compile/_loop_no_collection.py +5 -5
  20. brainstate/compile/_make_jaxpr.py +23 -21
  21. brainstate/compile/_make_jaxpr_test.py +1 -2
  22. brainstate/compile/_progress_bar.py +1 -2
  23. brainstate/compile/_unvmap.py +1 -0
  24. brainstate/compile/_util.py +4 -2
  25. brainstate/environ.py +4 -4
  26. brainstate/environ_test.py +1 -2
  27. brainstate/functional/_activations.py +1 -2
  28. brainstate/functional/_activations_test.py +1 -1
  29. brainstate/functional/_normalization.py +1 -2
  30. brainstate/functional/_others.py +1 -2
  31. brainstate/functional/_spikes.py +136 -20
  32. brainstate/graph/_graph_node.py +2 -43
  33. brainstate/graph/_graph_operation.py +4 -20
  34. brainstate/graph/_graph_operation_test.py +3 -4
  35. brainstate/init/_base.py +1 -2
  36. brainstate/init/_generic.py +1 -2
  37. brainstate/nn/__init__.py +8 -0
  38. brainstate/nn/_collective_ops.py +351 -48
  39. brainstate/nn/_collective_ops_test.py +36 -0
  40. brainstate/nn/_common.py +193 -0
  41. brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
  42. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
  43. brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
  44. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
  45. brainstate/nn/_dyn_impl/_inputs.py +1 -2
  46. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
  47. brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
  48. brainstate/nn/_dyn_impl/_readout.py +2 -3
  49. brainstate/nn/_dyn_impl/_readout_test.py +1 -2
  50. brainstate/nn/_dynamics/_dynamics_base.py +6 -1
  51. brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
  52. brainstate/nn/_dynamics/_state_delay.py +3 -3
  53. brainstate/nn/_dynamics/_synouts_test.py +1 -2
  54. brainstate/nn/_elementwise/_dropout.py +6 -7
  55. brainstate/nn/_elementwise/_dropout_test.py +1 -2
  56. brainstate/nn/_elementwise/_elementwise.py +1 -2
  57. brainstate/nn/_exp_euler.py +1 -2
  58. brainstate/nn/_exp_euler_test.py +1 -2
  59. brainstate/nn/_interaction/_conv.py +1 -2
  60. brainstate/nn/_interaction/_conv_test.py +1 -0
  61. brainstate/nn/_interaction/_linear.py +1 -2
  62. brainstate/nn/_interaction/_linear_test.py +1 -2
  63. brainstate/nn/_interaction/_normalizations.py +1 -2
  64. brainstate/nn/_interaction/_poolings.py +3 -4
  65. brainstate/nn/_module.py +68 -19
  66. brainstate/nn/_module_test.py +1 -2
  67. brainstate/nn/_utils.py +89 -0
  68. brainstate/nn/metrics.py +3 -4
  69. brainstate/optim/_lr_scheduler.py +1 -2
  70. brainstate/optim/_lr_scheduler_test.py +2 -3
  71. brainstate/optim/_optax_optimizer_test.py +1 -2
  72. brainstate/optim/_sgd_optimizer.py +2 -3
  73. brainstate/random/_rand_funs.py +1 -2
  74. brainstate/random/_rand_funs_test.py +2 -3
  75. brainstate/random/_rand_seed.py +2 -3
  76. brainstate/random/_rand_seed_test.py +1 -2
  77. brainstate/random/_rand_state.py +3 -4
  78. brainstate/surrogate.py +5 -5
  79. brainstate/transform.py +0 -3
  80. brainstate/typing.py +28 -25
  81. brainstate/util/__init__.py +9 -7
  82. brainstate/util/_caller.py +1 -2
  83. brainstate/util/_error.py +27 -0
  84. brainstate/util/_others.py +60 -15
  85. brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
  86. brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
  87. brainstate/util/_pretty_repr.py +1 -2
  88. brainstate/util/_pretty_table.py +2900 -0
  89. brainstate/util/_struct.py +11 -11
  90. brainstate/util/filter.py +472 -0
  91. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA +2 -2
  92. brainstate-0.1.0.post20250217.dist-info/RECORD +128 -0
  93. brainstate/util/_filter.py +0 -178
  94. brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
  95. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/LICENSE +0 -0
  96. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/WHEEL +0 -0
  97. {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import unittest
21
-
22
20
  import brainunit as u
23
21
  import jax
24
22
  import jax.numpy as jnp
23
+ import unittest
25
24
 
26
25
  import brainstate as bst
27
26
  from brainstate.nn import IF, LIF, ALIF
@@ -17,9 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Optional
21
-
22
20
  import brainunit as u
21
+ from typing import Optional
23
22
 
24
23
  from brainstate import init, environ
25
24
  from brainstate._state import ShortTermState, HiddenState
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import brainunit as u
21
19
  import jax.numpy as jnp
22
20
  import pytest
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
  from brainstate.nn import Expon, STP, STD
@@ -14,11 +14,10 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Union, Optional, Sequence, Callable
18
-
19
17
  import brainunit as u
20
18
  import jax
21
19
  import numpy as np
20
+ from typing import Union, Optional, Sequence, Callable
22
21
 
23
22
  from brainstate import environ, init, random
24
23
  from brainstate._state import ShortTermState
@@ -17,9 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Callable, Union
21
-
22
20
  import jax.numpy as jnp
21
+ from typing import Callable, Union
23
22
 
24
23
  from brainstate import random, init, functional
25
24
  from brainstate._state import HiddenState, ParamState
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -17,11 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import numbers
21
- from typing import Callable
22
-
23
20
  import brainunit as u
24
21
  import jax
22
+ import numbers
23
+ from typing import Callable
25
24
 
26
25
  from brainstate import environ, init, surrogate
27
26
  from brainstate._state import HiddenState, ParamState
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -157,6 +157,11 @@ class Dynamics(Module):
157
157
  # in-/out- size of neuron population
158
158
  self.out_size = self.in_size
159
159
 
160
+ def __pretty_repr_item__(self, name, value):
161
+ if name in ['_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
162
+ return None if value is None else (name[1:], value) # skip the first `_`
163
+ return super().__pretty_repr_item__(name, value)
164
+
160
165
  @property
161
166
  def varshape(self):
162
167
  """The shape of variables in the neuron group."""
@@ -420,7 +425,7 @@ class Dynamics(Module):
420
425
  else:
421
426
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
422
427
 
423
- def __leaf_fn__(self, name, value):
428
+ def __pretty_repr_item__(self, name, value):
424
429
  if name in ['_in_size', '_out_size', '_name', '_mode',
425
430
  '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
426
431
  return (name, value) if value is None else (name[1:], value) # skip the first `_`
@@ -17,9 +17,8 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import unittest
21
-
22
20
  import numpy as np
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
 
@@ -16,14 +16,14 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
- import numbers
20
- from functools import partial
21
- from typing import Optional, Dict, Callable, Union, Sequence
22
19
 
23
20
  import brainunit as u
24
21
  import jax
25
22
  import jax.numpy as jnp
23
+ import numbers
26
24
  import numpy as np
25
+ from functools import partial
26
+ from typing import Optional, Dict, Callable, Union, Sequence
27
27
 
28
28
  from brainstate import environ
29
29
  from brainstate._state import ShortTermState, State
@@ -15,11 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import brainunit as u
21
19
  import jax.numpy as jnp
22
20
  import numpy as np
21
+ import unittest
23
22
 
24
23
  import brainstate as bst
25
24
 
@@ -16,11 +16,10 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from functools import partial
20
- from typing import Optional, Sequence
21
-
22
19
  import brainunit as u
23
20
  import jax.numpy as jnp
21
+ from functools import partial
22
+ from typing import Optional, Sequence
24
23
 
25
24
  from brainstate import random, environ, init
26
25
  from brainstate._state import ShortTermState
@@ -72,11 +71,11 @@ class Dropout(ElementWiseBlock):
72
71
  for dim in self.broadcast_dims:
73
72
  broadcast_shape[dim] = 1
74
73
  keep_mask = random.bernoulli(self.prob, broadcast_shape)
75
- keep_mask = jnp.broadcast_to(keep_mask, x.shape)
76
- return jnp.where(
74
+ keep_mask = u.math.broadcast_to(keep_mask, x.shape)
75
+ return u.math.where(
77
76
  keep_mask,
78
- jnp.asarray(x / self.prob, dtype=dtype),
79
- jnp.asarray(0., dtype=dtype)
77
+ u.math.asarray(x / self.prob, dtype=dtype),
78
+ u.math.asarray(0., dtype=dtype)
80
79
  )
81
80
  else:
82
81
  return x
@@ -14,9 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import unittest
18
-
19
17
  import numpy as np
18
+ import unittest
20
19
 
21
20
  import brainstate as bst
22
21
 
@@ -17,11 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Optional
21
-
22
20
  import brainunit as u
23
21
  import jax.numpy as jnp
24
22
  import jax.typing
23
+ from typing import Optional
25
24
 
26
25
  from brainstate import random, functional as F
27
26
  from brainstate._state import ParamState
@@ -16,10 +16,9 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Callable
20
-
21
19
  import brainunit as u
22
20
  import jax.numpy as jnp
21
+ from typing import Callable
23
22
 
24
23
  from brainstate import environ, random
25
24
  from brainstate.augment import vector_grad
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import brainunit as u
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -18,10 +18,9 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import collections.abc
21
- from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
22
-
23
21
  import jax
24
22
  import jax.numpy as jnp
23
+ from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
25
24
 
26
25
  from brainstate import init, functional
27
26
  from brainstate._state import ParamState
@@ -235,5 +235,6 @@ class TestConvTranspose3d(parameterized.TestCase):
235
235
  y = conv_transpose_module(x)
236
236
  print(y.shape)
237
237
 
238
+
238
239
  if __name__ == '__main__':
239
240
  absltest.main()
@@ -17,10 +17,9 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Callable, Union, Optional
21
-
22
20
  import brainunit as u
23
21
  import jax.numpy as jnp
22
+ from typing import Callable, Union, Optional
24
23
 
25
24
  from brainstate import init, functional
26
25
  from brainstate._state import ParamState
@@ -16,9 +16,8 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import unittest
20
-
21
19
  import brainunit as u
20
+ import unittest
22
21
  from absl.testing import parameterized
23
22
 
24
23
  import brainstate as bst
@@ -17,10 +17,9 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Callable, Union, Sequence, Optional, Any, Dict
21
-
22
20
  import jax
23
21
  import jax.numpy as jnp
22
+ from typing import Callable, Union, Sequence, Optional, Any
24
23
 
25
24
  from brainstate import environ, init
26
25
  from brainstate._state import ParamState, BatchState
@@ -17,14 +17,13 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import functools
21
- from typing import Sequence, Optional
22
- from typing import Union, Tuple, Callable, List
23
-
24
20
  import brainunit as u
21
+ import functools
25
22
  import jax
26
23
  import jax.numpy as jnp
27
24
  import numpy as np
25
+ from typing import Sequence, Optional
26
+ from typing import Union, Tuple, Callable, List
28
27
 
29
28
  from brainstate import environ
30
29
  from brainstate.nn._module import Module
brainstate/nn/_module.py CHANGED
@@ -28,7 +28,7 @@ The basic classes include:
28
28
  from __future__ import annotations
29
29
 
30
30
  import warnings
31
- from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING
31
+ from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
32
32
 
33
33
  import numpy as np
34
34
 
@@ -36,7 +36,7 @@ from brainstate._state import State
36
36
  from brainstate.graph import Node, states, nodes, flatten
37
37
  from brainstate.mixin import ParamDescriber, ParamDesc
38
38
  from brainstate.typing import PathParts
39
- from brainstate.util import FlattedDict, NestedDict
39
+ from brainstate.util import FlattedDict, NestedDict, BrainStateError
40
40
 
41
41
  # maximum integer
42
42
  max_int = np.iinfo(np.int32).max
@@ -113,7 +113,11 @@ class Module(Node, ParamDesc):
113
113
  """
114
114
  The function to specify the updating rule.
115
115
  """
116
- raise NotImplementedError(f'Subclass of {self.__class__.__name__} must implement "update" function.')
116
+ raise NotImplementedError(
117
+ f'Subclass of {self.__class__.__name__} must implement "update" function. \n'
118
+ f'This instance is: \n'
119
+ f'{self}'
120
+ )
117
121
 
118
122
  def __call__(self, *args, **kwargs):
119
123
  return self.update(*args, **kwargs)
@@ -226,9 +230,9 @@ class Module(Node, ParamDesc):
226
230
  """
227
231
  pass
228
232
 
229
- def __leaf_fn__(self, name, value):
233
+ def __pretty_repr_item__(self, name, value):
230
234
  if name in ['_in_size', '_out_size', '_name']:
231
- return (name, value) if value is None else (name[1:], value) # skip the first `_`
235
+ return None if value is None else (name[1:], value) # skip the first `_`
232
236
  return name, value
233
237
 
234
238
 
@@ -288,7 +292,7 @@ class Sequential(Module):
288
292
  in_size = first.out_size
289
293
  self.layers.append(first)
290
294
  for module in layers:
291
- module, in_size = _format_module(module, in_size)
295
+ module, in_size = self._format_module(module, in_size)
292
296
  self.layers.append(module)
293
297
 
294
298
  # the input and output shape
@@ -301,7 +305,14 @@ class Sequential(Module):
301
305
  """Update function of a sequential model.
302
306
  """
303
307
  for m in self.layers:
304
- x = m(x)
308
+ try:
309
+ x = m(x)
310
+ except Exception as e:
311
+ raise BrainStateError(
312
+ f'The module \n'
313
+ f'{m}\n'
314
+ f'failed to update with input {x}\n'
315
+ ) from e
305
316
  return x
306
317
 
307
318
  def __getitem__(self, key: Union[int, slice]):
@@ -314,16 +325,54 @@ class Sequential(Module):
314
325
  else:
315
326
  raise KeyError(f'Unknown type of key: {type(key)}')
316
327
 
328
+ def append(self, layer: Callable):
329
+ """
330
+ Append a layer to the sequential model.
331
+
332
+ This method adds a new layer to the end of the sequential model. The layer can be
333
+ either a Module instance, an ElementWiseBlock instance, or a callable function. If the
334
+ layer is a callable function, it will be wrapped in an ElementWiseBlock instance.
317
335
 
318
- def _format_module(module, in_size):
319
- if isinstance(module, ParamDescriber):
320
- module = module(in_size=in_size)
321
- assert isinstance(module, Module), 'The module should be an instance of Module.'
322
- out_size = module.out_size
323
- elif isinstance(module, ElementWiseBlock):
324
- out_size = in_size
325
- elif isinstance(module, Module):
326
- out_size = module.out_size
327
- else:
328
- raise TypeError(f"Unsupported type {type(module)}. ")
329
- return module, out_size
336
+ Parameters:
337
+ ----------
338
+ layer : Callable
339
+ The layer to be appended to the sequential model. It can be a Module instance,
340
+ an ElementWiseBlock instance, or a callable function.
341
+
342
+ Raises:
343
+ -------
344
+ ValueError
345
+ If the sequential model is empty and the first layer is a callable function.
346
+
347
+ Returns:
348
+ --------
349
+ None
350
+ The method does not return any value. It modifies the sequential model by adding
351
+ the new layer to the end.
352
+ """
353
+ if len(self.layers) == 0:
354
+ raise ValueError('The first layer should be a module, not a function.')
355
+ module, in_size = self._format_module(layer, self.out_size)
356
+ self.layers.append(module)
357
+ self.out_size = in_size
358
+
359
+ def _format_module(self, module, in_size):
360
+ if isinstance(module, ParamDescriber):
361
+ if in_size is None:
362
+ raise ValueError(
363
+ 'The input size should be specified. '
364
+ f'Please set the in_size attribute of the previous module: \n'
365
+ f'{self.layers[-1]}'
366
+ )
367
+ module = module(in_size=in_size)
368
+ assert isinstance(module, Module), 'The module should be an instance of Module.'
369
+ out_size = module.out_size
370
+ elif isinstance(module, ElementWiseBlock):
371
+ out_size = in_size
372
+ elif isinstance(module, Module):
373
+ out_size = module.out_size
374
+ elif callable(module):
375
+ out_size = in_size
376
+ else:
377
+ raise TypeError(f"Unsupported type {type(module)}. ")
378
+ return module, out_size
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
21
19
  import jaxlib.xla_extension
20
+ import unittest
22
21
 
23
22
  import brainstate as bst
24
23
 
@@ -0,0 +1,89 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Union, Tuple
19
+
20
+ from brainstate._state import ParamState
21
+ from brainstate.util import PrettyTable
22
+ from ._module import Module
23
+
24
+ __all__ = [
25
+ "count_parameters",
26
+ ]
27
+
28
+
29
+ def _format_parameter_count(num_params, precision=2):
30
+ if num_params < 1000:
31
+ return str(num_params)
32
+
33
+ suffixes = ['', 'K', 'M', 'B', 'T', 'P', 'E']
34
+ magnitude = 0
35
+ while abs(num_params) >= 1000:
36
+ magnitude += 1
37
+ num_params /= 1000.0
38
+
39
+ format_string = '{:.' + str(precision) + 'f}{}'
40
+ formatted_value = format_string.format(num_params, suffixes[magnitude])
41
+
42
+ # 检查是否接近 1000,如果是,尝试使用更大的基数
43
+ if magnitude < len(suffixes) - 1 and num_params >= 1000 * (1 - 10 ** (-precision)):
44
+ magnitude += 1
45
+ num_params /= 1000.0
46
+ formatted_value = format_string.format(num_params, suffixes[magnitude])
47
+
48
+ return formatted_value
49
+
50
+
51
+ def count_parameters(
52
+ module: Module,
53
+ precision: int = 2,
54
+ return_table: bool = False,
55
+ ) -> Union[Tuple[PrettyTable, int], int]:
56
+ """
57
+ Count and display the number of trainable parameters in a neural network model.
58
+
59
+ This function iterates through all the parameters of the given model,
60
+ counts the number of parameters for each module, and displays them in a table.
61
+ It also calculates and returns the total number of trainable parameters.
62
+
63
+ Parameters:
64
+ -----------
65
+ model : bst.nn.Module
66
+ The neural network model for which to count parameters.
67
+
68
+ Returns:
69
+ --------
70
+ int
71
+ The total number of trainable parameters in the model.
72
+
73
+ Prints:
74
+ -------
75
+ A pretty-formatted table showing the number of parameters for each module,
76
+ followed by the total number of trainable parameters.
77
+ """
78
+ assert isinstance(module, Module), "Input must be a neural network module" # noqa: E501
79
+ table = PrettyTable(["Modules", "Parameters"])
80
+ total_params = 0
81
+ for name, parameter in module.states(ParamState).items():
82
+ param = parameter.numel()
83
+ table.add_row([name, _format_parameter_count(param, precision=precision)])
84
+ total_params += param
85
+ table.add_row(["Total", _format_parameter_count(total_params, precision=precision)])
86
+ print(table)
87
+ if return_table:
88
+ return table, total_params
89
+ return total_params
brainstate/nn/metrics.py CHANGED
@@ -16,13 +16,12 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import typing as tp
20
- from dataclasses import dataclass
21
- from functools import partial
22
-
23
19
  import jax
24
20
  import jax.numpy as jnp
25
21
  import numpy as np
22
+ import typing as tp
23
+ from dataclasses import dataclass
24
+ from functools import partial
26
25
 
27
26
  from brainstate._state import State
28
27
 
@@ -16,11 +16,10 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- from typing import Sequence, Union
20
-
21
19
  import jax
22
20
  import jax.numpy as jnp
23
21
  import numpy as np
22
+ from typing import Sequence, Union
24
23
 
25
24
  from brainstate import environ
26
25
  from brainstate._state import State, LongTermState
@@ -15,9 +15,8 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax.numpy as jnp
19
+ import unittest
21
20
 
22
21
  import brainstate as bst
23
22
 
@@ -37,7 +36,7 @@ class TestMultiStepLR(unittest.TestCase):
37
36
  self.assertTrue(jnp.allclose(r, 0.0001))
38
37
 
39
38
  def test2(self):
40
- lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
39
+ lr = bst.compile.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
41
40
  for i in range(40):
42
41
  r = lr(i)
43
42
  if i < 10:
@@ -15,10 +15,9 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import unittest
19
-
20
18
  import jax
21
19
  import optax
20
+ import unittest
22
21
 
23
22
  import brainstate as bst
24
23
 
@@ -16,12 +16,11 @@
16
16
  # -*- coding: utf-8 -*-
17
17
  from __future__ import annotations
18
18
 
19
- import functools
20
- from typing import Union, Dict, Optional, Tuple, Any, TypeVar
21
-
22
19
  import brainunit as u
20
+ import functools
23
21
  import jax
24
22
  import jax.numpy as jnp
23
+ from typing import Union, Dict, Optional, Tuple, Any, TypeVar
25
24
 
26
25
  from brainstate import environ
27
26
  from brainstate._state import State, LongTermState, StateDictManager
@@ -17,9 +17,8 @@
17
17
  # -*- coding: utf-8 -*-
18
18
  from __future__ import annotations
19
19
 
20
- from typing import Optional
21
-
22
20
  import numpy as np
21
+ from typing import Optional
23
22
 
24
23
  from brainstate.typing import DTypeLike, Size, SeedOrKey
25
24
  from ._rand_state import RandomState, DEFAULT
@@ -15,13 +15,12 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import platform
19
- import unittest
20
-
21
18
  import jax.numpy as jnp
22
19
  import jax.random as jr
23
20
  import numpy as np
21
+ import platform
24
22
  import pytest
23
+ import unittest
25
24
 
26
25
  import brainstate as bst
27
26