brainstate 0.1.0.post20250216__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/nn/__init__.py CHANGED
@@ -31,6 +31,8 @@ from ._interaction import *
31
31
  from ._interaction import __all__ as interaction_all
32
32
  from ._module import *
33
33
  from ._module import __all__ as module_all
34
+ from ._utils import *
35
+ from ._utils import __all__ as utils_all
34
36
 
35
37
  __all__ = (
36
38
  ['metrics']
@@ -42,6 +44,7 @@ __all__ = (
42
44
  + module_all
43
45
  + exp_euler_all
44
46
  + interaction_all
47
+ + utils_all
45
48
  )
46
49
 
47
50
  del (
@@ -53,4 +56,5 @@ del (
53
56
  module_all,
54
57
  exp_euler_all,
55
58
  interaction_all,
59
+ utils_all,
56
60
  )
brainstate/nn/_common.py CHANGED
@@ -174,7 +174,6 @@ class Vmap(Module):
174
174
  axis_size=axis_size,
175
175
  )
176
176
  def vmap_run(*args, **kwargs):
177
- vmap_states_ = vmap_states
178
177
  return module(*args, **kwargs)
179
178
 
180
179
  # vmapped module
@@ -34,9 +34,10 @@ For handling the delays:
34
34
  """
35
35
  from __future__ import annotations
36
36
 
37
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
38
+
37
39
  import brainunit as u
38
40
  import numpy as np
39
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
40
41
 
41
42
  from brainstate import environ
42
43
  from brainstate._state import State
@@ -156,6 +157,11 @@ class Dynamics(Module):
156
157
  # in-/out- size of neuron population
157
158
  self.out_size = self.in_size
158
159
 
160
+ def __pretty_repr_item__(self, name, value):
161
+ if name in ['_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
162
+ return None if value is None else (name[1:], value) # skip the first `_`
163
+ return super().__pretty_repr_item__(name, value)
164
+
159
165
  @property
160
166
  def varshape(self):
161
167
  """The shape of variables in the neuron group."""
brainstate/nn/_module.py CHANGED
@@ -27,10 +27,11 @@ The basic classes include:
27
27
  """
28
28
  from __future__ import annotations
29
29
 
30
- import numpy as np
31
30
  import warnings
32
31
  from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
33
32
 
33
+ import numpy as np
34
+
34
35
  from brainstate._state import State
35
36
  from brainstate.graph import Node, states, nodes, flatten
36
37
  from brainstate.mixin import ParamDescriber, ParamDesc
@@ -112,7 +113,11 @@ class Module(Node, ParamDesc):
112
113
  """
113
114
  The function to specify the updating rule.
114
115
  """
115
- raise NotImplementedError(f'Subclass of {self.__class__.__name__} must implement "update" function.')
116
+ raise NotImplementedError(
117
+ f'Subclass of {self.__class__.__name__} must implement "update" function. \n'
118
+ f'This instance is: \n'
119
+ f'{self}'
120
+ )
116
121
 
117
122
  def __call__(self, *args, **kwargs):
118
123
  return self.update(*args, **kwargs)
@@ -227,7 +232,7 @@ class Module(Node, ParamDesc):
227
232
 
228
233
  def __pretty_repr_item__(self, name, value):
229
234
  if name in ['_in_size', '_out_size', '_name']:
230
- return (name, value) if value is None else (name[1:], value) # skip the first `_`
235
+ return None if value is None else (name[1:], value) # skip the first `_`
231
236
  return name, value
232
237
 
233
238
 
@@ -0,0 +1,89 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Union, Tuple
19
+
20
+ from brainstate._state import ParamState
21
+ from brainstate.util import PrettyTable
22
+ from ._module import Module
23
+
24
+ __all__ = [
25
+ "count_parameters",
26
+ ]
27
+
28
+
29
+ def _format_parameter_count(num_params, precision=2):
30
+ if num_params < 1000:
31
+ return str(num_params)
32
+
33
+ suffixes = ['', 'K', 'M', 'B', 'T', 'P', 'E']
34
+ magnitude = 0
35
+ while abs(num_params) >= 1000:
36
+ magnitude += 1
37
+ num_params /= 1000.0
38
+
39
+ format_string = '{:.' + str(precision) + 'f}{}'
40
+ formatted_value = format_string.format(num_params, suffixes[magnitude])
41
+
42
+ # 检查是否接近 1000,如果是,尝试使用更大的基数
43
+ if magnitude < len(suffixes) - 1 and num_params >= 1000 * (1 - 10 ** (-precision)):
44
+ magnitude += 1
45
+ num_params /= 1000.0
46
+ formatted_value = format_string.format(num_params, suffixes[magnitude])
47
+
48
+ return formatted_value
49
+
50
+
51
+ def count_parameters(
52
+ module: Module,
53
+ precision: int = 2,
54
+ return_table: bool = False,
55
+ ) -> Union[Tuple[PrettyTable, int], int]:
56
+ """
57
+ Count and display the number of trainable parameters in a neural network model.
58
+
59
+ This function iterates through all the parameters of the given model,
60
+ counts the number of parameters for each module, and displays them in a table.
61
+ It also calculates and returns the total number of trainable parameters.
62
+
63
+ Parameters:
64
+ -----------
65
+ model : bst.nn.Module
66
+ The neural network model for which to count parameters.
67
+
68
+ Returns:
69
+ --------
70
+ int
71
+ The total number of trainable parameters in the model.
72
+
73
+ Prints:
74
+ -------
75
+ A pretty-formatted table showing the number of parameters for each module,
76
+ followed by the total number of trainable parameters.
77
+ """
78
+ assert isinstance(module, Module), "Input must be a neural network module" # noqa: E501
79
+ table = PrettyTable(["Modules", "Parameters"])
80
+ total_params = 0
81
+ for name, parameter in module.states(ParamState).items():
82
+ param = parameter.numel()
83
+ table.add_row([name, _format_parameter_count(param, precision=precision)])
84
+ total_params += param
85
+ table.add_row(["Total", _format_parameter_count(total_params, precision=precision)])
86
+ print(table)
87
+ if return_table:
88
+ return table, total_params
89
+ return total_params
brainstate/surrogate.py CHANGED
@@ -132,9 +132,6 @@ class Surrogate(PrettyObject):
132
132
  dx = self.surrogate_grad(x)
133
133
  return heaviside_p.bind(x, dx)[0]
134
134
 
135
- def __repr__(self):
136
- return f'{self.__class__.__name__}()'
137
-
138
135
  def surrogate_fun(self, x) -> jax.Array:
139
136
  """The surrogate function."""
140
137
  raise NotImplementedError
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250216
3
+ Version: 0.1.0.post20250217
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -6,7 +6,7 @@ brainstate/environ.py,sha256=PllYYZKqany3G7NzIwoUPplLAePbyza6kJGXTPgJK-c,17698
6
6
  brainstate/environ_test.py,sha256=khZ_-SUJL6rQCgndeYV98ruUIHGTwFDtITrOs_olmuo,2043
7
7
  brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
8
8
  brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
9
- brainstate/surrogate.py,sha256=OiI7l51s-yjxOdPjTAwWyBljDq9hRTadjZGpCF8zVkc,53654
9
+ brainstate/surrogate.py,sha256=wWYw-TxaFxHVneXuHjWD1UtTcOTk3XRSnhRtUkt_Hb8,53580
10
10
  brainstate/transform.py,sha256=vZWzO4F7qTsXL_SiVQPlTz0l9b_hRo9D-igETfgCTy0,758
11
11
  brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
12
12
  brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
@@ -53,14 +53,15 @@ brainstate/init/_random_inits.py,sha256=gyy9481ju7VEi-SFbSRU5iBACaHnf4wjI0596FNu
53
53
  brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caCpHfNtIRk,5197
54
54
  brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
55
55
  brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
56
- brainstate/nn/__init__.py,sha256=Eb9q19tR29RAAvpjAQWubxfubzMzCY2gAHVpi6zSBjI,1725
56
+ brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
57
57
  brainstate/nn/_collective_ops.py,sha256=yQNBnh-XVEFnTg-Ga14mHOCGtGxiTkL9MYKdNjJF1BI,17535
58
58
  brainstate/nn/_collective_ops_test.py,sha256=yW7NNYsGFglFRFkqVlpGSY6WLnU-h8GlK6wCmG5jtRc,1189
59
- brainstate/nn/_common.py,sha256=Decmt8uwhIEJ5ODpl8gwUXSjSTYQcP9n0i7p3kt7eIo,7176
59
+ brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
60
60
  brainstate/nn/_exp_euler.py,sha256=cRgPNcjMs2C9x_8JabtYz5hm_FwqbiJ_U1VfRHYIlrE,3519
61
61
  brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
62
- brainstate/nn/_module.py,sha256=wJvgwztwjIQT_1o4o-DSjT081nEpfLDfmGWvg_-K5Ro,12764
62
+ brainstate/nn/_module.py,sha256=vrukVI0ylbymzilh9BZtb-d9dnsBsykqanUNTx9Eb6Y,12844
63
63
  brainstate/nn/_module_test.py,sha256=UrVA85fo0KVFN9ApPkxkRcvtXEskWOXPzZIBa4JSFo0,8891
64
+ brainstate/nn/_utils.py,sha256=epfELIy1COgdS9z5be-fmbFhagNugcIHpw4ww-HlkSY,3123
64
65
  brainstate/nn/metrics.py,sha256=p7eVwd5y8r0N5rMws-zOS_KaZCLOMdrXyQvLnoJeq1w,14736
65
66
  brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
66
67
  brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=mcDxVZlk56NAEkR6xcE74hOZ9up8Rua4SvKEeAhJKU4,10925
@@ -74,7 +75,7 @@ brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=EYUajj50PL1V_yuQIANs6sTWfSDrXK
74
75
  brainstate/nn/_dyn_impl/_readout.py,sha256=iXqOahtDaLgMuMYXdMT-0scMPLHK1-fYeBb8SfEShY0,4368
75
76
  brainstate/nn/_dyn_impl/_readout_test.py,sha256=QYXoWlTXVwJoIVqAXm5UYF5bjHUMkY4bKqWWyXzXF10,2107
76
77
  brainstate/nn/_dynamics/__init__.py,sha256=j1HSWu01wf5-KjSaNhBC9utVGDALOhUsFPrLPcPPDsM,1208
77
- brainstate/nn/_dynamics/_dynamics_base.py,sha256=XADbPoPjNR_Uwx4oSwoA-qMuYwFR3gJRCnR4_gHtq4w,21994
78
+ brainstate/nn/_dynamics/_dynamics_base.py,sha256=rLjTgA_826EfD1OZ-NoEoO11EBNJOpY_Fq2YfDKXRe4,22288
78
79
  brainstate/nn/_dynamics/_dynamics_base_test.py,sha256=Sk6sSqJK_yesI-6Fb_x7gqMsoc0-RUU9GsGZu2jVsxU,2719
79
80
  brainstate/nn/_dynamics/_projection_base.py,sha256=jYe3WdBMgz2TJkcxPWEkyK7OA4IR1ChISd2GTfM6U2o,13528
80
81
  brainstate/nn/_dynamics/_state_delay.py,sha256=qKF1YelGXeBlImhIBdZIC0CAg2dV0o_gkaxS-R1N3qE,16905
@@ -120,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
120
121
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
121
122
  brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
122
123
  brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
123
- brainstate-0.1.0.post20250216.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
124
- brainstate-0.1.0.post20250216.dist-info/METADATA,sha256=pLMfvR_kUK1DsX9tcUEvX1jCX2zdWj8YUx6YrFF7iug,3585
125
- brainstate-0.1.0.post20250216.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
126
- brainstate-0.1.0.post20250216.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
127
- brainstate-0.1.0.post20250216.dist-info/RECORD,,
124
+ brainstate-0.1.0.post20250217.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
125
+ brainstate-0.1.0.post20250217.dist-info/METADATA,sha256=eBHppB4DxysakCjcinzG4vVSxI_eS08l7LkSepmGGXI,3585
126
+ brainstate-0.1.0.post20250217.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
127
+ brainstate-0.1.0.post20250217.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
128
+ brainstate-0.1.0.post20250217.dist-info/RECORD,,