brainstate 0.0.1.post20240622__py2.py3-none-any.whl → 0.0.1.post20240708__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_module.py CHANGED
@@ -59,7 +59,7 @@ import numpy as np
59
59
  from . import environ
60
60
  from ._state import State, StateDictManager, visible_state_dict
61
61
  from ._utils import set_module_as
62
- from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
62
+ from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
63
  from .transform import jit_error
64
64
  from .util import unique_name, DictManager, get_unique_name
65
65
 
@@ -809,7 +809,6 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
809
809
  keep_size: bool = False,
810
810
  name: Optional[str] = None,
811
811
  mode: Optional[Mode] = None,
812
- method: str = 'exp_auto'
813
812
  ):
814
813
  # size
815
814
  if isinstance(size, (list, tuple)):
@@ -831,9 +830,6 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
831
830
  # number of neurons
832
831
  self.num = np.prod(size)
833
832
 
834
- # integration method
835
- self.method = method
836
-
837
833
  # -- Attribute for "InputProjMixIn" -- #
838
834
  # each instance of "SupportInputProj" should have
839
835
  # "_current_inputs" and "_delta_inputs" attributes
@@ -1213,10 +1209,10 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1213
1209
  assert self.history is not None, 'The delay history is not initialized.'
1214
1210
  assert delay_step is not None, 'The delay step should be given.'
1215
1211
 
1216
- if environ.get(environ.JIT_ERROR_CHECK, True):
1212
+ if environ.get(environ.JIT_ERROR_CHECK, False):
1217
1213
  def _check_delay(delay_len):
1218
1214
  raise ValueError(f'The request delay length should be less than the '
1219
- f'maximum delay {self.max_length}. But we got {delay_len}')
1215
+ f'maximum delay {self.max_length - 1}. But we got {delay_len}')
1220
1216
 
1221
1217
  jit_error(delay_step >= self.max_length, _check_delay, delay_step)
1222
1218
 
@@ -1263,9 +1259,8 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1263
1259
  current_time = environ.get(environ.T, desc='The current time.')
1264
1260
  dt = environ.get_dt()
1265
1261
 
1266
- if environ.get(environ.JIT_ERROR_CHECK, True):
1267
- def _check_delay(args):
1268
- t_now, t_delay = args
1262
+ if environ.get(environ.JIT_ERROR_CHECK, False):
1263
+ def _check_delay(t_now, t_delay):
1269
1264
  raise ValueError(f'The request delay time should be within '
1270
1265
  f'[{t_now - self.max_time - dt}, {t_now}], '
1271
1266
  f'but we got {t_delay}')
@@ -1273,7 +1268,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
1273
1268
  jit_error(jnp.logical_or(delay_time > current_time,
1274
1269
  delay_time < current_time - self.max_time - dt),
1275
1270
  _check_delay,
1276
- (current_time, delay_time))
1271
+ current_time, delay_time)
1277
1272
 
1278
1273
  diff = current_time - delay_time
1279
1274
  float_time_step = diff / dt
@@ -1415,7 +1410,7 @@ class DelayAccess(Module):
1415
1410
  return self.refs['delay'].at(self._delay_entry, *self.indices)
1416
1411
 
1417
1412
 
1418
- def register_delay_of_target(target: AllOfTypes[ExtendedUpdateWithBA, UpdateReturn]):
1413
+ def register_delay_of_target(target: JointTypes[ExtendedUpdateWithBA, UpdateReturn]):
1419
1414
  """Register delay class for the given target.
1420
1415
 
1421
1416
  Args:
@@ -1425,7 +1420,7 @@ def register_delay_of_target(target: AllOfTypes[ExtendedUpdateWithBA, UpdateRetu
1425
1420
  The delay registered for the given target.
1426
1421
  """
1427
1422
  if not target.has_after_update(delay_identifier):
1428
- assert isinstance(target, AllOfTypes[ExtendedUpdateWithBA, UpdateReturn])
1423
+ assert isinstance(target, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
1429
1424
  target.add_after_update(delay_identifier, Delay(target.update_return_info()))
1430
1425
  delay_cls = target.get_after_update(delay_identifier)
1431
1426
  return delay_cls
@@ -86,7 +86,7 @@ class TestDelay(unittest.TestCase):
86
86
  rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
87
  rotation_delay.init_state()
88
88
 
89
- with bst.environ.context(i=0, t=0):
89
+ with bst.environ.context(i=0, t=0, jit_error_check=True):
90
90
  rotation_delay.retrieve_at_time(-2.0)
91
91
  with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
92
92
  rotation_delay.retrieve_at_time(-2.1)
@@ -0,0 +1,48 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import brainunit as bu
16
+ import jax
17
+ import jax.random as jr
18
+
19
+ from .typing import ArrayLike, Size, DTypeLike
20
+
21
+
22
+ def uniform_for_unit(
23
+ key,
24
+ shape: Size = (),
25
+ dtype: DTypeLike = float,
26
+ minval: ArrayLike = 0.,
27
+ maxval: ArrayLike = 1.
28
+ ) -> jax.Array | bu.Quantity:
29
+ if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
30
+ return bu.Quantity(jr.uniform(key, shape, dtype, minval.value, maxval.value), dim=minval.dim)
31
+ elif isinstance(minval, bu.Quantity):
32
+ assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
33
+ minval = minval.value
34
+ elif isinstance(maxval, bu.Quantity):
35
+ assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
36
+ maxval = maxval.value
37
+ return jr.uniform(key, shape, dtype, minval, maxval)
38
+
39
+
40
+ def permutation_for_unit(
41
+ key,
42
+ x: int | ArrayLike,
43
+ axis: int = 0,
44
+ independent: bool = False
45
+ ) -> jax.Array | bu.Quantity:
46
+ if isinstance(x, bu.Quantity):
47
+ return bu.Quantity(jr.permutation(key, x.value, axis, independent=independent), dim=x.dim)
48
+ return jr.permutation(key, x, axis, independent=independent)
brainstate/_state.py CHANGED
@@ -29,7 +29,9 @@ max_int = np.iinfo(np.int32)
29
29
 
30
30
  __all__ = [
31
31
  'State', 'ShortTermState', 'LongTermState', 'ParamState',
32
- 'StateDictManager', 'visible_state_dict',
32
+ 'StateDictManager',
33
+ 'StateTrace',
34
+ 'visible_state_dict',
33
35
  'check_state_value_tree',
34
36
  ]
35
37
 
@@ -141,7 +143,7 @@ class State(object):
141
143
  """
142
144
  # value checking
143
145
  v = v.value if isinstance(v, State) else v
144
- self._check_value(v)
146
+ self._check_value_tree(v)
145
147
  # write the value by the stack (>= level)
146
148
  trace: StateTrace
147
149
  for trace in thread_local_stack.stack[self._level:]:
@@ -149,9 +151,9 @@ class State(object):
149
151
  # set the value
150
152
  self._value = v
151
153
 
152
- def _check_value(self, v):
154
+ def _check_value_tree(self, v):
153
155
  if self._check_tree or _global_context_to_check_state_tree[-1]:
154
- in_tree = jax.tree_util.tree_structure(v)
156
+ in_tree = jax.tree.structure(v)
155
157
  if in_tree != self._tree:
156
158
  self._raise_error_with_source_info(
157
159
  ValueError(f'The given value {in_tree} does not '
@@ -370,12 +372,13 @@ class StateTrace(object):
370
372
  self.types[index] = 'write'
371
373
  self._written_ids.add(id_)
372
374
 
373
- def collect_values(self, *categories: str) -> Tuple:
375
+ def collect_values(self, *categories: str, check_val_tree: bool = False) -> Tuple:
374
376
  """
375
377
  Collect the values by the given categories.
376
378
 
377
379
  Args:
378
380
  *categories: The categories.
381
+ check_val_tree: Whether to check the tree structure of the value.
379
382
 
380
383
  Returns:
381
384
  results: The values.
@@ -383,7 +386,10 @@ class StateTrace(object):
383
386
  results = []
384
387
  for st, ty in zip(self.states, self.types):
385
388
  if ty in categories:
386
- results.append(st.value)
389
+ val = st.value
390
+ if check_val_tree:
391
+ st._check_value_tree(val)
392
+ results.append(val)
387
393
  return tuple(results)
388
394
 
389
395
  def recovery_original_values(self) -> None:
@@ -15,15 +15,15 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- import numbers
19
18
  from typing import Union, Callable, Optional, Sequence
20
19
 
20
+ import brainunit as bu
21
21
  import jax
22
- import jax.numpy as jnp
23
22
  import numpy as np
24
23
 
25
24
  from brainstate._state import State
26
25
  from ._base import to_size
26
+ from ..typing import ArrayLike
27
27
 
28
28
  __all__ = [
29
29
  'param',
@@ -33,11 +33,57 @@ __all__ = [
33
33
 
34
34
 
35
35
  def _is_scalar(x):
36
- return isinstance(x, numbers.Number)
36
+ return bu.math.isscalar(x)
37
+
38
+
39
+ def are_shapes_broadcastable(shape1, shape2):
40
+ """
41
+ Check if two shapes are broadcastable.
42
+
43
+ Parameters:
44
+ - shape1: Tuple[int], the shape of the first array.
45
+ - shape2: Tuple[int], the shape of the second array.
46
+
47
+ Returns:
48
+ - bool: True if shapes are broadcastable, False otherwise.
49
+ """
50
+ # Reverse the shapes to compare from the last dimension
51
+ shape1_reversed = shape1[::-1]
52
+ shape2_reversed = shape2[::-1]
53
+
54
+ # Iterate over the dimensions of the shorter shape
55
+ for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
56
+ # Check if the dimensions are not equal and neither is 1
57
+ if dim1 != dim2 and 1 not in (dim1, dim2):
58
+ return False
59
+
60
+ # If all dimensions are compatible, the shapes are broadcastable
61
+ return True
62
+
63
+
64
+ def _expand_params_to_match_sizes(params, sizes):
65
+ """
66
+ Expand the dimensions of params to match the dimensions of sizes.
67
+
68
+ Parameters:
69
+ - params: jax.Array or np.ndarray, the parameter array to be expanded.
70
+ - sizes: tuple[int] or list[int], the target shape dimensions.
71
+
72
+ Returns:
73
+ - Expanded params with dimensions matching sizes.
74
+ """
75
+ params_dim = params.ndim
76
+ sizes_dim = len(sizes)
77
+ dim_diff = sizes_dim - params_dim
78
+
79
+ # Add new axes to params if it has fewer dimensions than sizes
80
+ for _ in range(dim_diff):
81
+ params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
82
+ return params
37
83
 
38
84
 
39
85
  def param(
40
- param: Union[Callable, np.ndarray, jax.Array, float, int, bool],
86
+ parameter: Union[Callable, ArrayLike],
41
87
  sizes: Union[int, Sequence[int]],
42
88
  batch_size: Optional[int] = None,
43
89
  allow_none: bool = True,
@@ -47,7 +93,7 @@ def param(
47
93
 
48
94
  Parameters
49
95
  ----------
50
- param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool
96
+ parameter: callable, ArrayLike, State
51
97
  The initialization of the parameter.
52
98
  - If it is None, the created parameter will be None.
53
99
  - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
@@ -71,42 +117,55 @@ def param(
71
117
  --------
72
118
  noise, state
73
119
  """
74
- if param is None:
120
+ # Check if the parameter is None
121
+ if parameter is None:
75
122
  if allow_none:
76
123
  return None
77
124
  else:
78
125
  raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
79
126
  f'Callable function, but we got None. ')
80
- sizes = list(to_size(sizes))
81
- if allow_scalar and _is_scalar(param):
82
- return param
83
127
 
84
- if batch_size is not None:
85
- sizes.insert(0, batch_size)
128
+ # Check if the parameter is a scalar value
129
+ if allow_scalar and _is_scalar(parameter):
130
+ return parameter
86
131
 
87
- if callable(param):
88
- return param(sizes)
89
- elif isinstance(param, (np.ndarray, jax.Array)):
90
- param = jnp.asarray(param)
91
- if batch_size is not None:
92
- param = jnp.repeat(jnp.expand_dims(param, axis=0), batch_size, axis=0)
93
- elif isinstance(param, State):
94
- param = param
132
+ # Convert sizes to a tuple
133
+ sizes = tuple(to_size(sizes))
134
+
135
+ # Check if the parameter is a callable function
136
+ if callable(parameter):
95
137
  if batch_size is not None:
96
- param = type(param)(jnp.repeat(jnp.expand_dims(param.value, axis=batch_axis), batch_size, axis=batch_axis))
138
+ sizes = (batch_size,) + sizes
139
+ return parameter(sizes)
140
+ elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
141
+ parameter = parameter
97
142
  else:
98
- raise ValueError(f'Unknown parameter type: {type(param)}')
143
+ raise ValueError(f'Unknown parameter type: {type(parameter)}')
144
+
145
+ # Check if the shape of the parameter matches the given size
146
+ if not are_shapes_broadcastable(parameter.shape, sizes):
147
+ raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
99
148
 
100
- if allow_scalar:
101
- if param.shape == () or param.shape == (1,):
102
- return param
103
- if param.shape != tuple(sizes):
104
- raise ValueError(f'The shape of the parameters should be {sizes}, but we got {param.shape}')
105
- return param
149
+ # Expand the parameter to match the given batch size
150
+ param_value = parameter.value if isinstance(parameter, State) else parameter
151
+ if batch_size is not None:
152
+ if param_value.ndim <= len(sizes):
153
+ # add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
154
+ param_value = _expand_params_to_match_sizes(param_value, sizes)
155
+ param_value = bu.math.repeat(
156
+ bu.math.expand_dims(param_value, axis=0),
157
+ batch_size,
158
+ axis=0
159
+ )
160
+ else:
161
+ if param_value.shape[0] != batch_size:
162
+ raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
163
+ f'does not match with the given batch size {batch_size}')
164
+ return type(parameter)(param_value) if isinstance(parameter, State) else param_value
106
165
 
107
166
 
108
167
  def state(
109
- init: Union[Callable, np.ndarray, jax.Array],
168
+ init: Union[Callable, jax.typing.ArrayLike],
110
169
  sizes: Union[int, Sequence[int]] = None,
111
170
  batch_size: Optional[int] = None,
112
171
  ):
@@ -124,18 +183,24 @@ def state(
124
183
 
125
184
  else:
126
185
  if sizes is not None:
127
- if jnp.shape(init) != sizes:
128
- raise ValueError(f'The shape of "data" {jnp.shape(init)} does not match with "var_shape" {sizes}')
186
+ if bu.math.shape(init) != sizes:
187
+ raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
129
188
  if isinstance(batch_size, int):
130
189
  batch_size = batch_size
131
- data = State(jnp.repeat(jnp.expand_dims(init, axis=0), batch_size, axis=0))
190
+ data = State(
191
+ bu.math.repeat(
192
+ bu.math.expand_dims(init, axis=0),
193
+ batch_size,
194
+ axis=0
195
+ )
196
+ )
132
197
  else:
133
198
  data = State(init)
134
199
  return data
135
200
 
136
201
 
137
202
  def noise(
138
- noises: Optional[Union[int, float, np.ndarray, jax.Array, Callable]],
203
+ noises: Optional[Union[ArrayLike, Callable]],
139
204
  size: Union[int, Sequence[int]],
140
205
  num_vars: int = 1,
141
206
  noise_idx: int = 0,
@@ -17,11 +17,13 @@
17
17
 
18
18
  import math
19
19
 
20
+ import brainunit as bu
20
21
  import jax.numpy as jnp
21
22
  import numpy as np
22
23
 
23
24
  from brainstate import environ, random
24
25
  from ._base import Initializer, to_size
26
+ from ..typing import ArrayLike
25
27
 
26
28
  __all__ = [
27
29
  'Normal',
@@ -260,7 +262,7 @@ class Uniform(Initializer):
260
262
  class VarianceScaling(Initializer):
261
263
  def __init__(
262
264
  self,
263
- scale: float,
265
+ scale: ArrayLike,
264
266
  mode: str,
265
267
  distribution: str,
266
268
  in_axis: int = -2,
@@ -287,7 +289,9 @@ class VarianceScaling(Initializer):
287
289
  denominator = (fan_in + fan_out) / 2
288
290
  else:
289
291
  raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
290
- variance = (self.scale / denominator).astype(self.dtype)
292
+ scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
293
+ dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
294
+ variance = (scale / denominator).astype(self.dtype)
291
295
  if self.distribution == "truncated_normal":
292
296
  stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
293
297
  res = random.truncated_normal(-2, 2, shape, dtype=self.dtype) * stddev
@@ -298,7 +302,7 @@ class VarianceScaling(Initializer):
298
302
  jnp.sqrt(3 * variance).astype(self.dtype))
299
303
  else:
300
304
  raise ValueError("invalid distribution for variance scaling initializer")
301
- return res
305
+ return res if dim == bu.DIMENSIONLESS else res * dim
302
306
 
303
307
  def __repr__(self):
304
308
  name = self.__class__.__name__
@@ -425,7 +429,7 @@ class Orthogonal(Initializer):
425
429
 
426
430
  def __init__(
427
431
  self,
428
- scale: float = 1.,
432
+ scale: ArrayLike = 1.,
429
433
  axis: int = -1,
430
434
  dtype=None
431
435
  ):
@@ -440,6 +444,9 @@ class Orthogonal(Initializer):
440
444
  n_cols = np.prod(shape) // n_rows
441
445
  matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
442
446
  norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
447
+
448
+ scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
449
+ dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
443
450
  q_mat, r_mat = jnp.linalg.qr(norm_dst)
444
451
  # Enforce Q is uniformly distributed
445
452
  q_mat *= jnp.sign(jnp.diag(r_mat))
@@ -447,7 +454,8 @@ class Orthogonal(Initializer):
447
454
  q_mat = q_mat.T
448
455
  q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
449
456
  q_mat = jnp.moveaxis(q_mat, 0, self.axis)
450
- return jnp.asarray(self.scale, dtype=self.dtype) * q_mat
457
+ r = jnp.asarray(scale, dtype=self.dtype) * q_mat
458
+ return r if dim == bu.DIMENSIONLESS else r * dim
451
459
 
452
460
  def __repr__(self):
453
461
  return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
@@ -472,7 +480,9 @@ class DeltaOrthogonal(Initializer):
472
480
  raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
473
481
  if shape[-1] < shape[-2]:
474
482
  raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
475
- ortho_matrix = Orthogonal(scale=self.scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
483
+ scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
484
+ dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
485
+ ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
476
486
  W = jnp.zeros(shape, dtype=self.dtype)
477
487
  if len(shape) == 3:
478
488
  k = shape[0]
@@ -483,7 +493,7 @@ class DeltaOrthogonal(Initializer):
483
493
  else:
484
494
  k1, k2, k3 = shape[:3]
485
495
  W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
486
- return W
496
+ return W if dim == bu.DIMENSIONLESS else W * dim
487
497
 
488
498
  def __repr__(self):
489
499
  return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
@@ -15,7 +15,8 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- import jax.numpy as jnp
18
+
19
+ import brainunit as bu
19
20
 
20
21
  from brainstate import environ
21
22
  from ._base import Initializer, to_size
@@ -39,7 +40,7 @@ class ZeroInit(Initializer):
39
40
 
40
41
  def __call__(self, shape):
41
42
  shape = to_size(shape)
42
- return jnp.zeros(shape, dtype=self.dtype)
43
+ return bu.math.zeros(shape, dtype=self.dtype)
43
44
 
44
45
  def __repr__(self):
45
46
  return f"{self.__class__.__name__}(dtype={self.dtype})"
@@ -59,11 +60,11 @@ class Constant(Initializer):
59
60
  def __init__(self, value=1., dtype=None):
60
61
  super(Constant, self).__init__()
61
62
  self.dtype = dtype or environ.dftype()
62
- self.value = jnp.asarray(value, dtype=self.dtype)
63
+ self.value = bu.math.asarray(value, dtype=self.dtype)
63
64
 
64
65
  def __call__(self, shape):
65
66
  shape = to_size(shape)
66
- return jnp.full(shape, self.value, dtype=self.dtype)
67
+ return bu.math.full(shape, self.value, dtype=self.dtype)
67
68
 
68
69
  def __repr__(self):
69
70
  return f'{self.__class__.__name__}(value={self.value}, dtype={self.dtype})'
@@ -94,15 +95,15 @@ class Identity(Initializer):
94
95
  def __init__(self, value=1., dtype=None):
95
96
  super(Identity, self).__init__()
96
97
  self.dtype = dtype or environ.dftype()
97
- self.value = jnp.asarray(value, dtype=self.dtype)
98
+ self.value = bu.math.asarray(value, dtype=self.dtype)
98
99
 
99
100
  def __call__(self, shape):
100
101
  shape = to_size(shape)
101
102
  if isinstance(shape, (tuple, list)):
102
103
  if len(shape) > 2:
103
104
  raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
104
- r = jnp.eye(shape, dtype=self.dtype)
105
- r = jnp.fill_diagonal(r, self.value)
105
+ r = bu.math.eye(shape, dtype=self.dtype)
106
+ r = bu.math.fill_diagonal(r, self.value)
106
107
  return r
107
108
 
108
109
  def __repr__(self):
brainstate/mixin.py CHANGED
@@ -32,7 +32,7 @@ __all__ = [
32
32
  'UpdateReturn',
33
33
 
34
34
  # types
35
- 'AllOfTypes',
35
+ 'JointTypes',
36
36
  'OneOfTypes',
37
37
 
38
38
  # behavior modes
@@ -206,7 +206,7 @@ class _JointGenericAlias(_UnionGenericAlias, _root=True):
206
206
 
207
207
 
208
208
  @_SpecialForm
209
- def AllOfTypes(self, parameters):
209
+ def JointTypes(self, parameters):
210
210
  """All of types; AllOfTypes[X, Y] means both X and Y.
211
211
 
212
212
  To define a union, use e.g. Union[int, str].
@@ -341,7 +341,7 @@ class JointMode(Mode):
341
341
  """
342
342
  Check whether the mode is exactly the desired mode.
343
343
  """
344
- return AllOfTypes[tuple(self.types)] == cls
344
+ return JointTypes[tuple(self.types)] == cls
345
345
 
346
346
  def __getattr__(self, item):
347
347
  """
brainstate/mixin_test.py CHANGED
@@ -23,7 +23,7 @@ class TestMixin(unittest.TestCase):
23
23
  self.assertTrue(bc.mixin.Mixin)
24
24
  self.assertTrue(bc.mixin.DelayedInit)
25
25
  self.assertTrue(bc.mixin.DelayedInitializer)
26
- self.assertTrue(bc.mixin.AllOfTypes)
26
+ self.assertTrue(bc.mixin.JointTypes)
27
27
  self.assertTrue(bc.mixin.OneOfTypes)
28
28
  self.assertTrue(bc.mixin.Mode)
29
29
  self.assertTrue(bc.mixin.Batching)
@@ -33,29 +33,29 @@ class TestMixin(unittest.TestCase):
33
33
  class TestMode(unittest.TestCase):
34
34
  def test_JointMode(self):
35
35
  a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
36
- self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching, bc.mixin.Training]))
36
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching, bc.mixin.Training]))
37
37
  self.assertTrue(a.has(bc.mixin.Batching))
38
38
  self.assertTrue(a.has(bc.mixin.Training))
39
39
  b = bc.mixin.JointMode(bc.mixin.Batching())
40
- self.assertTrue(b.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching]))
40
+ self.assertTrue(b.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
41
41
  self.assertTrue(b.is_a(bc.mixin.Batching))
42
42
  self.assertTrue(b.has(bc.mixin.Batching))
43
43
 
44
44
  def test_Training(self):
45
45
  a = bc.mixin.Training()
46
46
  self.assertTrue(a.is_a(bc.mixin.Training))
47
- self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Training]))
47
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Training]))
48
48
  self.assertTrue(a.has(bc.mixin.Training))
49
- self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Training]))
49
+ self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Training]))
50
50
  self.assertFalse(a.is_a(bc.mixin.Batching))
51
51
  self.assertFalse(a.has(bc.mixin.Batching))
52
52
 
53
53
  def test_Batching(self):
54
54
  a = bc.mixin.Batching()
55
55
  self.assertTrue(a.is_a(bc.mixin.Batching))
56
- self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching]))
56
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
57
57
  self.assertTrue(a.has(bc.mixin.Batching))
58
- self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Batching]))
58
+ self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Batching]))
59
59
 
60
60
  self.assertFalse(a.is_a(bc.mixin.Training))
61
61
  self.assertFalse(a.has(bc.mixin.Training))
@@ -63,9 +63,9 @@ class TestMode(unittest.TestCase):
63
63
  def test_Mode(self):
64
64
  a = bc.mixin.Mode()
65
65
  self.assertTrue(a.is_a(bc.mixin.Mode))
66
- self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Mode]))
66
+ self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Mode]))
67
67
  self.assertTrue(a.has(bc.mixin.Mode))
68
- self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Mode]))
68
+ self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Mode]))
69
69
 
70
70
  self.assertFalse(a.is_a(bc.mixin.Training))
71
71
  self.assertFalse(a.has(bc.mixin.Training))