brainstate 0.0.1.1.post20240708__py2.py3-none-any.whl → 0.0.1.1.post20240802__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +22 -4
- brainstate/init/_generic.py +1 -1
- brainstate/mixin.py +8 -8
- brainstate/mixin_test.py +2 -0
- brainstate/nn/_elementwise.py +2 -2
- brainstate/random.py +27 -14
- brainstate/transform/_control.py +2 -2
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/METADATA +2 -2
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/RECORD +12 -12
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
import contextlib
|
17
17
|
import threading
|
18
|
-
from typing import Any, Tuple, Dict, List, Callable
|
18
|
+
from typing import Any, Tuple, Dict, List, Callable, Optional
|
19
19
|
|
20
20
|
import jax
|
21
21
|
import numpy as np
|
@@ -108,9 +108,9 @@ class State(object):
|
|
108
108
|
value: PyTree. It can be anything as a pyTree.
|
109
109
|
"""
|
110
110
|
__module__ = 'brainstate'
|
111
|
-
__slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree')
|
111
|
+
__slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
|
112
112
|
|
113
|
-
def __init__(self, value: PyTree):
|
113
|
+
def __init__(self, value: PyTree, name: Optional[str] = None):
|
114
114
|
if isinstance(value, State):
|
115
115
|
value = value.value
|
116
116
|
self._value = value
|
@@ -118,6 +118,21 @@ class State(object):
|
|
118
118
|
self._check_tree = False
|
119
119
|
self._level = len(thread_local_stack.stack)
|
120
120
|
self._source_info = source_info_util.current()
|
121
|
+
self._name = name
|
122
|
+
|
123
|
+
@property
|
124
|
+
def name(self) -> Optional[str]:
|
125
|
+
"""
|
126
|
+
The name of the state.
|
127
|
+
"""
|
128
|
+
return self._name
|
129
|
+
|
130
|
+
@name.setter
|
131
|
+
def name(self, name: str) -> None:
|
132
|
+
"""
|
133
|
+
Set the name of the state.
|
134
|
+
"""
|
135
|
+
self._name = name
|
121
136
|
|
122
137
|
@property
|
123
138
|
def value(self) -> PyTree:
|
@@ -210,7 +225,10 @@ class State(object):
|
|
210
225
|
leaves, tree = jax.tree.flatten(self._value)
|
211
226
|
leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
|
212
227
|
tree_info = jax.tree.unflatten(tree, leaves_info)
|
213
|
-
|
228
|
+
if self.name is None:
|
229
|
+
return f'{self.__class__.__name__}({tree_info})'
|
230
|
+
else:
|
231
|
+
return f'{self.__class__.__name__}({self.name}: {tree_info})'
|
214
232
|
|
215
233
|
|
216
234
|
class ShapeDtype:
|
brainstate/init/_generic.py
CHANGED
@@ -83,7 +83,7 @@ def _expand_params_to_match_sizes(params, sizes):
|
|
83
83
|
|
84
84
|
|
85
85
|
def param(
|
86
|
-
parameter: Union[Callable, ArrayLike],
|
86
|
+
parameter: Union[Callable, ArrayLike, State],
|
87
87
|
sizes: Union[int, Sequence[int]],
|
88
88
|
batch_size: Optional[int] = None,
|
89
89
|
allow_none: bool = True,
|
brainstate/mixin.py
CHANGED
@@ -207,7 +207,7 @@ class _JointGenericAlias(_UnionGenericAlias, _root=True):
|
|
207
207
|
|
208
208
|
@_SpecialForm
|
209
209
|
def JointTypes(self, parameters):
|
210
|
-
"""
|
210
|
+
"""Joint types; JointTypes[X, Y] means both X and Y.
|
211
211
|
|
212
212
|
To define a union, use e.g. Union[int, str].
|
213
213
|
|
@@ -216,28 +216,28 @@ def JointTypes(self, parameters):
|
|
216
216
|
- None as an argument is a special case and is replaced by `type(None)`.
|
217
217
|
- Unions of unions are flattened, e.g.::
|
218
218
|
|
219
|
-
|
219
|
+
JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
|
220
220
|
|
221
221
|
- Unions of a single argument vanish, e.g.::
|
222
222
|
|
223
|
-
|
223
|
+
JointTypes[int] == int # The constructor actually returns int
|
224
224
|
|
225
225
|
- Redundant arguments are skipped, e.g.::
|
226
226
|
|
227
|
-
|
227
|
+
JointTypes[int, str, int] == JointTypes[int, str]
|
228
228
|
|
229
229
|
- When comparing unions, the argument order is ignored, e.g.::
|
230
230
|
|
231
|
-
|
231
|
+
JointTypes[int, str] == JointTypes[str, int]
|
232
232
|
|
233
|
-
- You cannot subclass or instantiate a
|
234
|
-
- You can use Optional[X] as a shorthand for
|
233
|
+
- You cannot subclass or instantiate a JointTypes.
|
234
|
+
- You can use Optional[X] as a shorthand for JointTypes[X, None].
|
235
235
|
"""
|
236
236
|
if parameters == ():
|
237
237
|
raise TypeError("Cannot take a Joint of no types.")
|
238
238
|
if not isinstance(parameters, tuple):
|
239
239
|
parameters = (parameters,)
|
240
|
-
msg = "
|
240
|
+
msg = "JointTypes[arg, ...]: each arg must be a type."
|
241
241
|
parameters = tuple(_type_check(p, msg) for p in parameters)
|
242
242
|
parameters = _remove_dups_flatten(parameters)
|
243
243
|
if len(parameters) == 1:
|
brainstate/mixin_test.py
CHANGED
brainstate/nn/_elementwise.py
CHANGED
@@ -1139,13 +1139,13 @@ class Dropout(Module, ElementWiseBlock):
|
|
1139
1139
|
name: Optional[str] = None
|
1140
1140
|
) -> None:
|
1141
1141
|
super().__init__(mode=mode, name=name)
|
1142
|
-
assert 0. <= prob
|
1142
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
1143
1143
|
self.prob = prob
|
1144
1144
|
|
1145
1145
|
def __call__(self, x):
|
1146
1146
|
dtype = bu.math.get_dtype(x)
|
1147
1147
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1148
|
-
if fit_phase
|
1148
|
+
if fit_phase and self.prob < 1.:
|
1149
1149
|
keep_mask = random.bernoulli(self.prob, x.shape)
|
1150
1150
|
return jnp.where(keep_mask,
|
1151
1151
|
jnp.asarray(x / self.prob, dtype=dtype),
|
brainstate/random.py
CHANGED
@@ -1167,23 +1167,32 @@ def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
|
|
1167
1167
|
return RandomState(seed_or_key)
|
1168
1168
|
|
1169
1169
|
|
1170
|
-
def seed(
|
1170
|
+
def seed(seed_or_key: int = None):
|
1171
1171
|
"""Sets a new random seed.
|
1172
1172
|
|
1173
1173
|
Parameters
|
1174
1174
|
----------
|
1175
|
-
|
1176
|
-
The random seed.
|
1175
|
+
seed_or_key: int, optional
|
1176
|
+
The random seed (an integer) or jax random key.
|
1177
1177
|
"""
|
1178
1178
|
with jax.ensure_compile_time_eval():
|
1179
|
-
if
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1179
|
+
if seed_or_key is None:
|
1180
|
+
seed_or_key = np.random.randint(0, 100000)
|
1181
|
+
|
1182
|
+
# numpy random seed
|
1183
|
+
if np.size(seed_or_key) == 1: # seed
|
1184
|
+
np.random.seed(seed_or_key)
|
1185
|
+
elif np.size(seed_or_key) == 2: # jax random key
|
1186
|
+
np.random.seed(seed_or_key[0])
|
1187
|
+
else:
|
1188
|
+
raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
|
1189
|
+
|
1190
|
+
# jax random seed
|
1191
|
+
DEFAULT.seed(seed_or_key)
|
1183
1192
|
|
1184
1193
|
|
1185
1194
|
@contextmanager
|
1186
|
-
def seed_context(
|
1195
|
+
def seed_context(seed_or_key: SeedOrKey):
|
1187
1196
|
"""
|
1188
1197
|
A context manager that sets the random seed for the duration of the block.
|
1189
1198
|
|
@@ -1206,16 +1215,19 @@ def seed_context(seed: int):
|
|
1206
1215
|
The context manager does not only set the seed for the AX random state, but also for the numpy random state.
|
1207
1216
|
|
1208
1217
|
Args:
|
1209
|
-
|
1218
|
+
seed_or_key: The seed (an integer) or jax random key.
|
1210
1219
|
|
1211
|
-
Returns:
|
1212
|
-
The random state.
|
1213
1220
|
"""
|
1214
1221
|
old_jrand_key = DEFAULT.value
|
1215
1222
|
old_np_state = np.random.get_state()
|
1216
1223
|
try:
|
1217
|
-
np.
|
1218
|
-
|
1224
|
+
if np.size(seed_or_key) == 1: # seed
|
1225
|
+
np.random.seed(seed_or_key)
|
1226
|
+
elif np.size(seed_or_key) == 2: # jax random key
|
1227
|
+
np.random.seed(seed_or_key[0])
|
1228
|
+
else:
|
1229
|
+
raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
|
1230
|
+
DEFAULT.seed(seed_or_key)
|
1219
1231
|
yield
|
1220
1232
|
finally:
|
1221
1233
|
np.random.set_state(old_np_state)
|
@@ -1223,7 +1235,8 @@ def seed_context(seed: int):
|
|
1223
1235
|
|
1224
1236
|
|
1225
1237
|
def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
|
1226
|
-
r"""
|
1238
|
+
r"""
|
1239
|
+
Random values in a given shape.
|
1227
1240
|
|
1228
1241
|
.. note::
|
1229
1242
|
This is a convenience function for users porting code from Matlab,
|
brainstate/transform/_control.py
CHANGED
@@ -25,7 +25,7 @@ import jax.numpy as jnp
|
|
25
25
|
import numpy as np
|
26
26
|
|
27
27
|
from brainstate._utils import set_module_as
|
28
|
-
from ._jit_error import jit_error
|
28
|
+
from ._jit_error import jit_error, remove_vmap
|
29
29
|
from ._make_jaxpr import StatefulFunction, _assign_state_values
|
30
30
|
from ._progress_bar import ProgressBar
|
31
31
|
|
@@ -347,7 +347,7 @@ def _wrap_fun_with_pbar(fun, pbar_runner):
|
|
347
347
|
def new_fun(new_carry, inputs):
|
348
348
|
i, old_carry = new_carry
|
349
349
|
old_carry, old_outputs = fun(old_carry, inputs)
|
350
|
-
pbar_runner(i)
|
350
|
+
pbar_runner(remove_vmap(i, op='none'))
|
351
351
|
return (i + 1, old_carry), old_outputs
|
352
352
|
|
353
353
|
return new_fun
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/METADATA
RENAMED
@@ -1,9 +1,9 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.1.1.
|
3
|
+
Version: 0.0.1.1.post20240802
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
|
-
Author:
|
6
|
+
Author: BDP
|
7
7
|
Author-email: BrainPy Team <chao.brain@qq.com>
|
8
8
|
License: Apache-2.0 license
|
9
9
|
Project-URL: homepage, http://github.com/brainpy
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/RECORD
RENAMED
@@ -2,13 +2,13 @@ brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
|
|
2
2
|
brainstate/_module.py,sha256=UjhfmY26VHQ-kF6U4l68AslGU6vhfD-dR7gF-4Io5ic,52520
|
3
3
|
brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
|
4
4
|
brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
|
5
|
-
brainstate/_state.py,sha256=
|
5
|
+
brainstate/_state.py,sha256=t4lEikvxTfeL2TW0chLUvsQuuRoJSO-iXylUydl1i7k,12057
|
6
6
|
brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
|
7
7
|
brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
8
8
|
brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
|
9
|
-
brainstate/mixin.py,sha256=
|
10
|
-
brainstate/mixin_test.py,sha256
|
11
|
-
brainstate/random.py,sha256=
|
9
|
+
brainstate/mixin.py,sha256=2f2toMUmgJIiovX1wi8OIBM8sbWH6s9Usa1ixL9J4tg,10747
|
10
|
+
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
|
+
brainstate/random.py,sha256=UbXfC0nrxk5FOsld0rCxBp2bIaeQHH5bj-NWQTR8bbQ,188447
|
12
12
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
13
13
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
14
14
|
brainstate/typing.py,sha256=tZyXpopNtp1UvqbntnmlUP6OBTM_ywyV9QBQ6Gu3IdU,2232
|
@@ -20,14 +20,14 @@ brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNil
|
|
20
20
|
brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
|
21
21
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
22
22
|
brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
|
23
|
-
brainstate/init/_generic.py,sha256=
|
23
|
+
brainstate/init/_generic.py,sha256=G5WHSQs3M9htREIZ_OXqj1ffSF_BdYdPTaMqe9AKj-k,7315
|
24
24
|
brainstate/init/_random_inits.py,sha256=ycxH9WyKvPhRbk9PkamCLlc9aX5YokREBfSCbrbFQQ4,16021
|
25
25
|
brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
|
26
26
|
brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
|
27
27
|
brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
|
28
28
|
brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
|
29
29
|
brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
|
30
|
-
brainstate/nn/_elementwise.py,sha256=
|
30
|
+
brainstate/nn/_elementwise.py,sha256=Br2yd1kdr06iWGSvpoebWWO6suXFDiF8PQv_hOX9kZQ,43599
|
31
31
|
brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
|
32
32
|
brainstate/nn/_misc.py,sha256=Xc4U4NLmvfnKdBNDayFrRBPAy3p0beS6T9C59rIDP00,3790
|
33
33
|
brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
|
@@ -50,7 +50,7 @@ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aa
|
|
50
50
|
brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
|
51
51
|
brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
|
52
52
|
brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
|
53
|
-
brainstate/transform/_control.py,sha256=
|
53
|
+
brainstate/transform/_control.py,sha256=0NFUGLIenqKuBhBiTmY0YgCrl2GI1ZbuWMW0DSOolpE,26874
|
54
54
|
brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
|
55
55
|
brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
|
56
56
|
brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
|
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
|
|
59
59
|
brainstate/transform/_make_jaxpr.py,sha256=vMPruKfp5Ugv8RL-9wGfQdSumLZdLtThZvv3sU9MDjE,30426
|
60
60
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
61
61
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate-0.0.1.1.
|
63
|
-
brainstate-0.0.1.1.
|
64
|
-
brainstate-0.0.1.1.
|
65
|
-
brainstate-0.0.1.1.
|
66
|
-
brainstate-0.0.1.1.
|
62
|
+
brainstate-0.0.1.1.post20240802.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
63
|
+
brainstate-0.0.1.1.post20240802.dist-info/METADATA,sha256=LYHl7AJ94js5MhJVghX5mJ1KB_VA4kWVc1bCmZ0O3GY,3807
|
64
|
+
brainstate-0.0.1.1.post20240802.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
65
|
+
brainstate-0.0.1.1.post20240802.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
66
|
+
brainstate-0.0.1.1.post20240802.dist-info/RECORD,,
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/LICENSE
RENAMED
File without changes
|
{brainstate-0.0.1.1.post20240708.dist-info → brainstate-0.0.1.1.post20240802.dist-info}/WHEEL
RENAMED
File without changes
|
File without changes
|