brainstate 0.1.0.post20250216__py2.py3-none-any.whl → 0.1.0.post20250217__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dynamics/_dynamics_base.py +7 -1
- brainstate/nn/_module.py +8 -3
- brainstate/nn/_utils.py +89 -0
- brainstate/surrogate.py +0 -3
- {brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/RECORD +11 -10
- {brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt +0 -0
brainstate/nn/__init__.py
CHANGED
@@ -31,6 +31,8 @@ from ._interaction import *
|
|
31
31
|
from ._interaction import __all__ as interaction_all
|
32
32
|
from ._module import *
|
33
33
|
from ._module import __all__ as module_all
|
34
|
+
from ._utils import *
|
35
|
+
from ._utils import __all__ as utils_all
|
34
36
|
|
35
37
|
__all__ = (
|
36
38
|
['metrics']
|
@@ -42,6 +44,7 @@ __all__ = (
|
|
42
44
|
+ module_all
|
43
45
|
+ exp_euler_all
|
44
46
|
+ interaction_all
|
47
|
+
+ utils_all
|
45
48
|
)
|
46
49
|
|
47
50
|
del (
|
@@ -53,4 +56,5 @@ del (
|
|
53
56
|
module_all,
|
54
57
|
exp_euler_all,
|
55
58
|
interaction_all,
|
59
|
+
utils_all,
|
56
60
|
)
|
brainstate/nn/_common.py
CHANGED
@@ -34,9 +34,10 @@ For handling the delays:
|
|
34
34
|
"""
|
35
35
|
from __future__ import annotations
|
36
36
|
|
37
|
+
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
38
|
+
|
37
39
|
import brainunit as u
|
38
40
|
import numpy as np
|
39
|
-
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
40
41
|
|
41
42
|
from brainstate import environ
|
42
43
|
from brainstate._state import State
|
@@ -156,6 +157,11 @@ class Dynamics(Module):
|
|
156
157
|
# in-/out- size of neuron population
|
157
158
|
self.out_size = self.in_size
|
158
159
|
|
160
|
+
def __pretty_repr_item__(self, name, value):
|
161
|
+
if name in ['_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
|
162
|
+
return None if value is None else (name[1:], value) # skip the first `_`
|
163
|
+
return super().__pretty_repr_item__(name, value)
|
164
|
+
|
159
165
|
@property
|
160
166
|
def varshape(self):
|
161
167
|
"""The shape of variables in the neuron group."""
|
brainstate/nn/_module.py
CHANGED
@@ -27,10 +27,11 @@ The basic classes include:
|
|
27
27
|
"""
|
28
28
|
from __future__ import annotations
|
29
29
|
|
30
|
-
import numpy as np
|
31
30
|
import warnings
|
32
31
|
from typing import Sequence, Optional, Tuple, Union, TYPE_CHECKING, Callable
|
33
32
|
|
33
|
+
import numpy as np
|
34
|
+
|
34
35
|
from brainstate._state import State
|
35
36
|
from brainstate.graph import Node, states, nodes, flatten
|
36
37
|
from brainstate.mixin import ParamDescriber, ParamDesc
|
@@ -112,7 +113,11 @@ class Module(Node, ParamDesc):
|
|
112
113
|
"""
|
113
114
|
The function to specify the updating rule.
|
114
115
|
"""
|
115
|
-
raise NotImplementedError(
|
116
|
+
raise NotImplementedError(
|
117
|
+
f'Subclass of {self.__class__.__name__} must implement "update" function. \n'
|
118
|
+
f'This instance is: \n'
|
119
|
+
f'{self}'
|
120
|
+
)
|
116
121
|
|
117
122
|
def __call__(self, *args, **kwargs):
|
118
123
|
return self.update(*args, **kwargs)
|
@@ -227,7 +232,7 @@ class Module(Node, ParamDesc):
|
|
227
232
|
|
228
233
|
def __pretty_repr_item__(self, name, value):
|
229
234
|
if name in ['_in_size', '_out_size', '_name']:
|
230
|
-
return
|
235
|
+
return None if value is None else (name[1:], value) # skip the first `_`
|
231
236
|
return name, value
|
232
237
|
|
233
238
|
|
brainstate/nn/_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Union, Tuple
|
19
|
+
|
20
|
+
from brainstate._state import ParamState
|
21
|
+
from brainstate.util import PrettyTable
|
22
|
+
from ._module import Module
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"count_parameters",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
def _format_parameter_count(num_params, precision=2):
|
30
|
+
if num_params < 1000:
|
31
|
+
return str(num_params)
|
32
|
+
|
33
|
+
suffixes = ['', 'K', 'M', 'B', 'T', 'P', 'E']
|
34
|
+
magnitude = 0
|
35
|
+
while abs(num_params) >= 1000:
|
36
|
+
magnitude += 1
|
37
|
+
num_params /= 1000.0
|
38
|
+
|
39
|
+
format_string = '{:.' + str(precision) + 'f}{}'
|
40
|
+
formatted_value = format_string.format(num_params, suffixes[magnitude])
|
41
|
+
|
42
|
+
# 检查是否接近 1000,如果是,尝试使用更大的基数
|
43
|
+
if magnitude < len(suffixes) - 1 and num_params >= 1000 * (1 - 10 ** (-precision)):
|
44
|
+
magnitude += 1
|
45
|
+
num_params /= 1000.0
|
46
|
+
formatted_value = format_string.format(num_params, suffixes[magnitude])
|
47
|
+
|
48
|
+
return formatted_value
|
49
|
+
|
50
|
+
|
51
|
+
def count_parameters(
|
52
|
+
module: Module,
|
53
|
+
precision: int = 2,
|
54
|
+
return_table: bool = False,
|
55
|
+
) -> Union[Tuple[PrettyTable, int], int]:
|
56
|
+
"""
|
57
|
+
Count and display the number of trainable parameters in a neural network model.
|
58
|
+
|
59
|
+
This function iterates through all the parameters of the given model,
|
60
|
+
counts the number of parameters for each module, and displays them in a table.
|
61
|
+
It also calculates and returns the total number of trainable parameters.
|
62
|
+
|
63
|
+
Parameters:
|
64
|
+
-----------
|
65
|
+
model : bst.nn.Module
|
66
|
+
The neural network model for which to count parameters.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
--------
|
70
|
+
int
|
71
|
+
The total number of trainable parameters in the model.
|
72
|
+
|
73
|
+
Prints:
|
74
|
+
-------
|
75
|
+
A pretty-formatted table showing the number of parameters for each module,
|
76
|
+
followed by the total number of trainable parameters.
|
77
|
+
"""
|
78
|
+
assert isinstance(module, Module), "Input must be a neural network module" # noqa: E501
|
79
|
+
table = PrettyTable(["Modules", "Parameters"])
|
80
|
+
total_params = 0
|
81
|
+
for name, parameter in module.states(ParamState).items():
|
82
|
+
param = parameter.numel()
|
83
|
+
table.add_row([name, _format_parameter_count(param, precision=precision)])
|
84
|
+
total_params += param
|
85
|
+
table.add_row(["Total", _format_parameter_count(total_params, precision=precision)])
|
86
|
+
print(table)
|
87
|
+
if return_table:
|
88
|
+
return table, total_params
|
89
|
+
return total_params
|
brainstate/surrogate.py
CHANGED
@@ -132,9 +132,6 @@ class Surrogate(PrettyObject):
|
|
132
132
|
dx = self.surrogate_grad(x)
|
133
133
|
return heaviside_p.bind(x, dx)[0]
|
134
134
|
|
135
|
-
def __repr__(self):
|
136
|
-
return f'{self.__class__.__name__}()'
|
137
|
-
|
138
135
|
def surrogate_fun(self, x) -> jax.Array:
|
139
136
|
"""The surrogate function."""
|
140
137
|
raise NotImplementedError
|
{brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250217
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -6,7 +6,7 @@ brainstate/environ.py,sha256=PllYYZKqany3G7NzIwoUPplLAePbyza6kJGXTPgJK-c,17698
|
|
6
6
|
brainstate/environ_test.py,sha256=khZ_-SUJL6rQCgndeYV98ruUIHGTwFDtITrOs_olmuo,2043
|
7
7
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
8
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
|
-
brainstate/surrogate.py,sha256=
|
9
|
+
brainstate/surrogate.py,sha256=wWYw-TxaFxHVneXuHjWD1UtTcOTk3XRSnhRtUkt_Hb8,53580
|
10
10
|
brainstate/transform.py,sha256=vZWzO4F7qTsXL_SiVQPlTz0l9b_hRo9D-igETfgCTy0,758
|
11
11
|
brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
|
12
12
|
brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
|
@@ -53,14 +53,15 @@ brainstate/init/_random_inits.py,sha256=gyy9481ju7VEi-SFbSRU5iBACaHnf4wjI0596FNu
|
|
53
53
|
brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caCpHfNtIRk,5197
|
54
54
|
brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
|
55
55
|
brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
|
56
|
-
brainstate/nn/__init__.py,sha256=
|
56
|
+
brainstate/nn/__init__.py,sha256=ar1hDUYbSO6oadMpbuS9FWZvZB_iyFzM8CwMK-RNDzM,1823
|
57
57
|
brainstate/nn/_collective_ops.py,sha256=yQNBnh-XVEFnTg-Ga14mHOCGtGxiTkL9MYKdNjJF1BI,17535
|
58
58
|
brainstate/nn/_collective_ops_test.py,sha256=yW7NNYsGFglFRFkqVlpGSY6WLnU-h8GlK6wCmG5jtRc,1189
|
59
|
-
brainstate/nn/_common.py,sha256=
|
59
|
+
brainstate/nn/_common.py,sha256=XQw0i0sH3Y_qUwHSMC7G9VQnDj-RuuTh1Ul-xRIPxxc,7136
|
60
60
|
brainstate/nn/_exp_euler.py,sha256=cRgPNcjMs2C9x_8JabtYz5hm_FwqbiJ_U1VfRHYIlrE,3519
|
61
61
|
brainstate/nn/_exp_euler_test.py,sha256=kvPf009DMYtla2uedKVKrPTHDyMTBepjlfsk5vDHqhI,1240
|
62
|
-
brainstate/nn/_module.py,sha256=
|
62
|
+
brainstate/nn/_module.py,sha256=vrukVI0ylbymzilh9BZtb-d9dnsBsykqanUNTx9Eb6Y,12844
|
63
63
|
brainstate/nn/_module_test.py,sha256=UrVA85fo0KVFN9ApPkxkRcvtXEskWOXPzZIBa4JSFo0,8891
|
64
|
+
brainstate/nn/_utils.py,sha256=epfELIy1COgdS9z5be-fmbFhagNugcIHpw4ww-HlkSY,3123
|
64
65
|
brainstate/nn/metrics.py,sha256=p7eVwd5y8r0N5rMws-zOS_KaZCLOMdrXyQvLnoJeq1w,14736
|
65
66
|
brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
|
66
67
|
brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=mcDxVZlk56NAEkR6xcE74hOZ9up8Rua4SvKEeAhJKU4,10925
|
@@ -74,7 +75,7 @@ brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=EYUajj50PL1V_yuQIANs6sTWfSDrXK
|
|
74
75
|
brainstate/nn/_dyn_impl/_readout.py,sha256=iXqOahtDaLgMuMYXdMT-0scMPLHK1-fYeBb8SfEShY0,4368
|
75
76
|
brainstate/nn/_dyn_impl/_readout_test.py,sha256=QYXoWlTXVwJoIVqAXm5UYF5bjHUMkY4bKqWWyXzXF10,2107
|
76
77
|
brainstate/nn/_dynamics/__init__.py,sha256=j1HSWu01wf5-KjSaNhBC9utVGDALOhUsFPrLPcPPDsM,1208
|
77
|
-
brainstate/nn/_dynamics/_dynamics_base.py,sha256=
|
78
|
+
brainstate/nn/_dynamics/_dynamics_base.py,sha256=rLjTgA_826EfD1OZ-NoEoO11EBNJOpY_Fq2YfDKXRe4,22288
|
78
79
|
brainstate/nn/_dynamics/_dynamics_base_test.py,sha256=Sk6sSqJK_yesI-6Fb_x7gqMsoc0-RUU9GsGZu2jVsxU,2719
|
79
80
|
brainstate/nn/_dynamics/_projection_base.py,sha256=jYe3WdBMgz2TJkcxPWEkyK7OA4IR1ChISd2GTfM6U2o,13528
|
80
81
|
brainstate/nn/_dynamics/_state_delay.py,sha256=qKF1YelGXeBlImhIBdZIC0CAg2dV0o_gkaxS-R1N3qE,16905
|
@@ -120,8 +121,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
|
|
120
121
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
121
122
|
brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
|
122
123
|
brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
|
123
|
-
brainstate-0.1.0.
|
124
|
-
brainstate-0.1.0.
|
125
|
-
brainstate-0.1.0.
|
126
|
-
brainstate-0.1.0.
|
127
|
-
brainstate-0.1.0.
|
124
|
+
brainstate-0.1.0.post20250217.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
125
|
+
brainstate-0.1.0.post20250217.dist-info/METADATA,sha256=eBHppB4DxysakCjcinzG4vVSxI_eS08l7LkSepmGGXI,3585
|
126
|
+
brainstate-0.1.0.post20250217.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
127
|
+
brainstate-0.1.0.post20250217.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
128
|
+
brainstate-0.1.0.post20250217.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250216.dist-info → brainstate-0.1.0.post20250217.dist-info}/top_level.txt
RENAMED
File without changes
|