brainstate 0.0.1.1.post20240708__py2.py3-none-any.whl → 0.0.1.1.post20240802__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_state.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  import contextlib
17
17
  import threading
18
- from typing import Any, Tuple, Dict, List, Callable
18
+ from typing import Any, Tuple, Dict, List, Callable, Optional
19
19
 
20
20
  import jax
21
21
  import numpy as np
@@ -108,9 +108,9 @@ class State(object):
108
108
  value: PyTree. It can be anything as a pyTree.
109
109
  """
110
110
  __module__ = 'brainstate'
111
- __slots__ = ('_value', '_tree', '_level', '_source_info', '_check_tree')
111
+ __slots__ = ('_value', '_name', '_tree', '_level', '_source_info', '_check_tree')
112
112
 
113
- def __init__(self, value: PyTree):
113
+ def __init__(self, value: PyTree, name: Optional[str] = None):
114
114
  if isinstance(value, State):
115
115
  value = value.value
116
116
  self._value = value
@@ -118,6 +118,21 @@ class State(object):
118
118
  self._check_tree = False
119
119
  self._level = len(thread_local_stack.stack)
120
120
  self._source_info = source_info_util.current()
121
+ self._name = name
122
+
123
+ @property
124
+ def name(self) -> Optional[str]:
125
+ """
126
+ The name of the state.
127
+ """
128
+ return self._name
129
+
130
+ @name.setter
131
+ def name(self, name: str) -> None:
132
+ """
133
+ Set the name of the state.
134
+ """
135
+ self._name = name
121
136
 
122
137
  @property
123
138
  def value(self) -> PyTree:
@@ -210,7 +225,10 @@ class State(object):
210
225
  leaves, tree = jax.tree.flatten(self._value)
211
226
  leaves_info = [ShapeDtype(leaf.shape, leaf.dtype) for leaf in leaves]
212
227
  tree_info = jax.tree.unflatten(tree, leaves_info)
213
- return f'{self.__class__.__name__}({tree_info})'
228
+ if self.name is None:
229
+ return f'{self.__class__.__name__}({tree_info})'
230
+ else:
231
+ return f'{self.__class__.__name__}({self.name}: {tree_info})'
214
232
 
215
233
 
216
234
  class ShapeDtype:
@@ -83,7 +83,7 @@ def _expand_params_to_match_sizes(params, sizes):
83
83
 
84
84
 
85
85
  def param(
86
- parameter: Union[Callable, ArrayLike],
86
+ parameter: Union[Callable, ArrayLike, State],
87
87
  sizes: Union[int, Sequence[int]],
88
88
  batch_size: Optional[int] = None,
89
89
  allow_none: bool = True,
brainstate/mixin.py CHANGED
@@ -207,7 +207,7 @@ class _JointGenericAlias(_UnionGenericAlias, _root=True):
207
207
 
208
208
  @_SpecialForm
209
209
  def JointTypes(self, parameters):
210
- """All of types; AllOfTypes[X, Y] means both X and Y.
210
+ """Joint types; JointTypes[X, Y] means both X and Y.
211
211
 
212
212
  To define a union, use e.g. Union[int, str].
213
213
 
@@ -216,28 +216,28 @@ def JointTypes(self, parameters):
216
216
  - None as an argument is a special case and is replaced by `type(None)`.
217
217
  - Unions of unions are flattened, e.g.::
218
218
 
219
- AllOfTypes[AllOfTypes[int, str], float] == AllOfTypes[int, str, float]
219
+ JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
220
220
 
221
221
  - Unions of a single argument vanish, e.g.::
222
222
 
223
- AllOfTypes[int] == int # The constructor actually returns int
223
+ JointTypes[int] == int # The constructor actually returns int
224
224
 
225
225
  - Redundant arguments are skipped, e.g.::
226
226
 
227
- AllOfTypes[int, str, int] == AllOfTypes[int, str]
227
+ JointTypes[int, str, int] == JointTypes[int, str]
228
228
 
229
229
  - When comparing unions, the argument order is ignored, e.g.::
230
230
 
231
- AllOfTypes[int, str] == AllOfTypes[str, int]
231
+ JointTypes[int, str] == JointTypes[str, int]
232
232
 
233
- - You cannot subclass or instantiate a AllOfTypes.
234
- - You can use Optional[X] as a shorthand for AllOfTypes[X, None].
233
+ - You cannot subclass or instantiate a JointTypes.
234
+ - You can use Optional[X] as a shorthand for JointTypes[X, None].
235
235
  """
236
236
  if parameters == ():
237
237
  raise TypeError("Cannot take a Joint of no types.")
238
238
  if not isinstance(parameters, tuple):
239
239
  parameters = (parameters,)
240
- msg = "AllOfTypes[arg, ...]: each arg must be a type."
240
+ msg = "JointTypes[arg, ...]: each arg must be a type."
241
241
  parameters = tuple(_type_check(p, msg) for p in parameters)
242
242
  parameters = _remove_dups_flatten(parameters)
243
243
  if len(parameters) == 1:
brainstate/mixin_test.py CHANGED
@@ -30,6 +30,8 @@ class TestMixin(unittest.TestCase):
30
30
  self.assertTrue(bc.mixin.Training)
31
31
 
32
32
 
33
+
34
+
33
35
  class TestMode(unittest.TestCase):
34
36
  def test_JointMode(self):
35
37
  a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
@@ -1139,13 +1139,13 @@ class Dropout(Module, ElementWiseBlock):
1139
1139
  name: Optional[str] = None
1140
1140
  ) -> None:
1141
1141
  super().__init__(mode=mode, name=name)
1142
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
1142
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
1143
1143
  self.prob = prob
1144
1144
 
1145
1145
  def __call__(self, x):
1146
1146
  dtype = bu.math.get_dtype(x)
1147
1147
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1148
- if fit_phase:
1148
+ if fit_phase and self.prob < 1.:
1149
1149
  keep_mask = random.bernoulli(self.prob, x.shape)
1150
1150
  return jnp.where(keep_mask,
1151
1151
  jnp.asarray(x / self.prob, dtype=dtype),
brainstate/random.py CHANGED
@@ -1167,23 +1167,32 @@ def default_rng(seed_or_key=None, clone: bool = True) -> RandomState:
1167
1167
  return RandomState(seed_or_key)
1168
1168
 
1169
1169
 
1170
- def seed(seed: int = None):
1170
+ def seed(seed_or_key: int = None):
1171
1171
  """Sets a new random seed.
1172
1172
 
1173
1173
  Parameters
1174
1174
  ----------
1175
- seed: int, optional
1176
- The random seed.
1175
+ seed_or_key: int, optional
1176
+ The random seed (an integer) or jax random key.
1177
1177
  """
1178
1178
  with jax.ensure_compile_time_eval():
1179
- if seed is None:
1180
- seed = np.random.randint(0, 100000)
1181
- np.random.seed(seed)
1182
- DEFAULT.seed(seed)
1179
+ if seed_or_key is None:
1180
+ seed_or_key = np.random.randint(0, 100000)
1181
+
1182
+ # numpy random seed
1183
+ if np.size(seed_or_key) == 1: # seed
1184
+ np.random.seed(seed_or_key)
1185
+ elif np.size(seed_or_key) == 2: # jax random key
1186
+ np.random.seed(seed_or_key[0])
1187
+ else:
1188
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
1189
+
1190
+ # jax random seed
1191
+ DEFAULT.seed(seed_or_key)
1183
1192
 
1184
1193
 
1185
1194
  @contextmanager
1186
- def seed_context(seed: int):
1195
+ def seed_context(seed_or_key: SeedOrKey):
1187
1196
  """
1188
1197
  A context manager that sets the random seed for the duration of the block.
1189
1198
 
@@ -1206,16 +1215,19 @@ def seed_context(seed: int):
1206
1215
  The context manager does not only set the seed for the AX random state, but also for the numpy random state.
1207
1216
 
1208
1217
  Args:
1209
- seed: The seed (an integer).
1218
+ seed_or_key: The seed (an integer) or jax random key.
1210
1219
 
1211
- Returns:
1212
- The random state.
1213
1220
  """
1214
1221
  old_jrand_key = DEFAULT.value
1215
1222
  old_np_state = np.random.get_state()
1216
1223
  try:
1217
- np.random.seed(seed)
1218
- DEFAULT.seed(seed)
1224
+ if np.size(seed_or_key) == 1: # seed
1225
+ np.random.seed(seed_or_key)
1226
+ elif np.size(seed_or_key) == 2: # jax random key
1227
+ np.random.seed(seed_or_key[0])
1228
+ else:
1229
+ raise ValueError(f"seed_or_key should be an integer or a tuple of two integers.")
1230
+ DEFAULT.seed(seed_or_key)
1219
1231
  yield
1220
1232
  finally:
1221
1233
  np.random.set_state(old_np_state)
@@ -1223,7 +1235,8 @@ def seed_context(seed: int):
1223
1235
 
1224
1236
 
1225
1237
  def rand(*dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
1226
- r"""Random values in a given shape.
1238
+ r"""
1239
+ Random values in a given shape.
1227
1240
 
1228
1241
  .. note::
1229
1242
  This is a convenience function for users porting code from Matlab,
@@ -25,7 +25,7 @@ import jax.numpy as jnp
25
25
  import numpy as np
26
26
 
27
27
  from brainstate._utils import set_module_as
28
- from ._jit_error import jit_error
28
+ from ._jit_error import jit_error, remove_vmap
29
29
  from ._make_jaxpr import StatefulFunction, _assign_state_values
30
30
  from ._progress_bar import ProgressBar
31
31
 
@@ -347,7 +347,7 @@ def _wrap_fun_with_pbar(fun, pbar_runner):
347
347
  def new_fun(new_carry, inputs):
348
348
  i, old_carry = new_carry
349
349
  old_carry, old_outputs = fun(old_carry, inputs)
350
- pbar_runner(i)
350
+ pbar_runner(remove_vmap(i, op='none'))
351
351
  return (i + 1, old_carry), old_outputs
352
352
 
353
353
  return new_fun
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.1.post20240708
3
+ Version: 0.0.1.1.post20240802
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
- Author: BrainPy Team
6
+ Author: BDP
7
7
  Author-email: BrainPy Team <chao.brain@qq.com>
8
8
  License: Apache-2.0 license
9
9
  Project-URL: homepage, http://github.com/brainpy
@@ -2,13 +2,13 @@ brainstate/__init__.py,sha256=oxslZrm6wxtBQDqwJFb2BaAKZFmnp4d_esDkaeuGMWE,1410
2
2
  brainstate/_module.py,sha256=UjhfmY26VHQ-kF6U4l68AslGU6vhfD-dR7gF-4Io5ic,52520
3
3
  brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
4
4
  brainstate/_random_for_unit.py,sha256=eW4NJkX27VCCNWUwAlyt2otkeEthGKOpUoX6XJ6i95Y,1946
5
- brainstate/_state.py,sha256=ykQluBkdKIcQsd7-pU_UlpnByWFof_8TYQl1hGn7HS8,11629
5
+ brainstate/_state.py,sha256=t4lEikvxTfeL2TW0chLUvsQuuRoJSO-iXylUydl1i7k,12057
6
6
  brainstate/_state_test.py,sha256=HDdipndRLhEHWEdTmyT1ayEBkbv6qJKykfCWKI6yJ_E,1253
7
7
  brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
8
  brainstate/environ.py,sha256=LwRwnFaTbv8l7nHRIbSV46WzcN7pGLQFhT_xDUox2yA,10240
9
- brainstate/mixin.py,sha256=f0XovWlyTOQKC63WTtzSbztRSld3pKEHy9-I1gVZWLg,10748
10
- brainstate/mixin_test.py,sha256=qFLw9Bq8TkOMg8M8CcI92BYvFekXkjyCC9lXSEVa8Ck,2919
11
- brainstate/random.py,sha256=OMa4739GbrQpKspEx0TbeBqjA4yvwwJzdtpr-kJZN6s,187841
9
+ brainstate/mixin.py,sha256=2f2toMUmgJIiovX1wi8OIBM8sbWH6s9Usa1ixL9J4tg,10747
10
+ brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
+ brainstate/random.py,sha256=UbXfC0nrxk5FOsld0rCxBp2bIaeQHH5bj-NWQTR8bbQ,188447
12
12
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
13
13
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
14
14
  brainstate/typing.py,sha256=tZyXpopNtp1UvqbntnmlUP6OBTM_ywyV9QBQ6Gu3IdU,2232
@@ -20,14 +20,14 @@ brainstate/functional/_others.py,sha256=ifB-l82y7ZB632yLUJOEcpkRY-yOoiJ0mtDOxNil
20
20
  brainstate/functional/_spikes.py,sha256=70qGvo4B--QtxfJMjLwGmk9pVsf2x2YNEEgjT-il_Jw,2574
21
21
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
22
22
  brainstate/init/_base.py,sha256=jRTmfoUsH_315vW9YMZzyIn2KDAAsv56SplBnvOyBW0,1148
23
- brainstate/init/_generic.py,sha256=yLkmYRWpGHUK_nYx_fWPe74s_Cd-JDdQ6mUqpi4yqcc,7308
23
+ brainstate/init/_generic.py,sha256=G5WHSQs3M9htREIZ_OXqj1ffSF_BdYdPTaMqe9AKj-k,7315
24
24
  brainstate/init/_random_inits.py,sha256=ycxH9WyKvPhRbk9PkamCLlc9aX5YokREBfSCbrbFQQ4,16021
25
25
  brainstate/init/_regular_inits.py,sha256=u77aSM0BkK9VULFJQZ1lIEYA_sJJzEZBTEttBSJ79RI,3090
26
26
  brainstate/nn/__init__.py,sha256=YJHoI8cXKVRS8f2vUl3Zegp5wm0svMz3qo9JmQJiMQk,2162
27
27
  brainstate/nn/_base.py,sha256=lzbZpku3Q2arH6ZaAwjs6bhbV0RcFChxo2UcpnX5t84,8481
28
28
  brainstate/nn/_connections.py,sha256=GSOW2IbpJRHdPyF4nFJ2RPgO8y6SVHT1Gn-pbri9pMk,22970
29
29
  brainstate/nn/_dynamics.py,sha256=OeYYXv1dqjUDcCsRhZo1XS7SP2li1vlH9uhME_PE9v0,13205
30
- brainstate/nn/_elementwise.py,sha256=6BTqSvSnaHhldwB5ol5OV0hPJ5yJ-Jpm4WSrtFKMNoQ,43579
30
+ brainstate/nn/_elementwise.py,sha256=Br2yd1kdr06iWGSvpoebWWO6suXFDiF8PQv_hOX9kZQ,43599
31
31
  brainstate/nn/_embedding.py,sha256=WbgrIaM_14abN8zBDr0xipBOsFc8dXP2m7Z_aRLAfmU,2249
32
32
  brainstate/nn/_misc.py,sha256=Xc4U4NLmvfnKdBNDayFrRBPAy3p0beS6T9C59rIDP00,3790
33
33
  brainstate/nn/_normalizations.py,sha256=9yVDORAEpqEkL9MYSPU4m7C4q8Qj5UNsPh9sKmIt5gQ,14329
@@ -50,7 +50,7 @@ brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aa
50
50
  brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
51
51
  brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
52
52
  brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
53
- brainstate/transform/_control.py,sha256=NWceTIuLlj2uGTdNcqBAXgnaLuChOGgAtIXtFn5vdLU,26837
53
+ brainstate/transform/_control.py,sha256=0NFUGLIenqKuBhBiTmY0YgCrl2GI1ZbuWMW0DSOolpE,26874
54
54
  brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
55
55
  brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
56
56
  brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
@@ -59,8 +59,8 @@ brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJ
59
59
  brainstate/transform/_make_jaxpr.py,sha256=vMPruKfp5Ugv8RL-9wGfQdSumLZdLtThZvv3sU9MDjE,30426
60
60
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
61
61
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.1.1.post20240708.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.1.1.post20240708.dist-info/METADATA,sha256=HN33Hoom47puNsauzxdmXvsmcP5RqrUpeR_2tVf7Y5U,3816
64
- brainstate-0.0.1.1.post20240708.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
- brainstate-0.0.1.1.post20240708.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.1.1.post20240708.dist-info/RECORD,,
62
+ brainstate-0.0.1.1.post20240802.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
+ brainstate-0.0.1.1.post20240802.dist-info/METADATA,sha256=LYHl7AJ94js5MhJVghX5mJ1KB_VA4kWVc1bCmZ0O3GY,3807
64
+ brainstate-0.0.1.1.post20240802.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
65
+ brainstate-0.0.1.1.post20240802.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
+ brainstate-0.0.1.1.post20240802.dist-info/RECORD,,