brainstate 0.0.1.post20240623__py2.py3-none-any.whl → 0.0.1.post20240708__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +6 -11
- brainstate/_module_test.py +1 -1
- brainstate/_random_for_unit.py +48 -0
- brainstate/_state.py +12 -6
- brainstate/init/_generic.py +97 -32
- brainstate/init/_random_inits.py +17 -7
- brainstate/init/_regular_inits.py +8 -7
- brainstate/mixin.py +3 -3
- brainstate/mixin_test.py +9 -9
- brainstate/nn/_projection/_align_post.py +11 -11
- brainstate/nn/_projection/_align_pre.py +3 -3
- brainstate/random.py +66 -36
- brainstate/transform/_jit_error.py +71 -49
- brainstate/transform/_jit_error_test.py +55 -0
- brainstate/transform/_make_jaxpr.py +10 -5
- brainstate/typing.py +2 -0
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.post20240708.dist-info}/METADATA +1 -1
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.post20240708.dist-info}/RECORD +21 -19
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.post20240708.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.post20240708.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.post20240623.dist-info → brainstate-0.0.1.post20240708.dist-info}/top_level.txt +0 -0
brainstate/_module.py
CHANGED
@@ -59,7 +59,7 @@ import numpy as np
|
|
59
59
|
from . import environ
|
60
60
|
from ._state import State, StateDictManager, visible_state_dict
|
61
61
|
from ._utils import set_module_as
|
62
|
-
from .mixin import Mixin, Mode, DelayedInit,
|
62
|
+
from .mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
|
63
63
|
from .transform import jit_error
|
64
64
|
from .util import unique_name, DictManager, get_unique_name
|
65
65
|
|
@@ -809,7 +809,6 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
|
|
809
809
|
keep_size: bool = False,
|
810
810
|
name: Optional[str] = None,
|
811
811
|
mode: Optional[Mode] = None,
|
812
|
-
method: str = 'exp_auto'
|
813
812
|
):
|
814
813
|
# size
|
815
814
|
if isinstance(size, (list, tuple)):
|
@@ -831,9 +830,6 @@ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
|
|
831
830
|
# number of neurons
|
832
831
|
self.num = np.prod(size)
|
833
832
|
|
834
|
-
# integration method
|
835
|
-
self.method = method
|
836
|
-
|
837
833
|
# -- Attribute for "InputProjMixIn" -- #
|
838
834
|
# each instance of "SupportInputProj" should have
|
839
835
|
# "_current_inputs" and "_delta_inputs" attributes
|
@@ -1216,7 +1212,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1216
1212
|
if environ.get(environ.JIT_ERROR_CHECK, False):
|
1217
1213
|
def _check_delay(delay_len):
|
1218
1214
|
raise ValueError(f'The request delay length should be less than the '
|
1219
|
-
f'maximum delay {self.max_length}. But we got {delay_len}')
|
1215
|
+
f'maximum delay {self.max_length - 1}. But we got {delay_len}')
|
1220
1216
|
|
1221
1217
|
jit_error(delay_step >= self.max_length, _check_delay, delay_step)
|
1222
1218
|
|
@@ -1264,8 +1260,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1264
1260
|
dt = environ.get_dt()
|
1265
1261
|
|
1266
1262
|
if environ.get(environ.JIT_ERROR_CHECK, False):
|
1267
|
-
def _check_delay(
|
1268
|
-
t_now, t_delay = args
|
1263
|
+
def _check_delay(t_now, t_delay):
|
1269
1264
|
raise ValueError(f'The request delay time should be within '
|
1270
1265
|
f'[{t_now - self.max_time - dt}, {t_now}], '
|
1271
1266
|
f'but we got {t_delay}')
|
@@ -1273,7 +1268,7 @@ class Delay(ExtendedUpdateWithBA, DelayedInit):
|
|
1273
1268
|
jit_error(jnp.logical_or(delay_time > current_time,
|
1274
1269
|
delay_time < current_time - self.max_time - dt),
|
1275
1270
|
_check_delay,
|
1276
|
-
|
1271
|
+
current_time, delay_time)
|
1277
1272
|
|
1278
1273
|
diff = current_time - delay_time
|
1279
1274
|
float_time_step = diff / dt
|
@@ -1415,7 +1410,7 @@ class DelayAccess(Module):
|
|
1415
1410
|
return self.refs['delay'].at(self._delay_entry, *self.indices)
|
1416
1411
|
|
1417
1412
|
|
1418
|
-
def register_delay_of_target(target:
|
1413
|
+
def register_delay_of_target(target: JointTypes[ExtendedUpdateWithBA, UpdateReturn]):
|
1419
1414
|
"""Register delay class for the given target.
|
1420
1415
|
|
1421
1416
|
Args:
|
@@ -1425,7 +1420,7 @@ def register_delay_of_target(target: AllOfTypes[ExtendedUpdateWithBA, UpdateRetu
|
|
1425
1420
|
The delay registered for the given target.
|
1426
1421
|
"""
|
1427
1422
|
if not target.has_after_update(delay_identifier):
|
1428
|
-
assert isinstance(target,
|
1423
|
+
assert isinstance(target, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
|
1429
1424
|
target.add_after_update(delay_identifier, Delay(target.update_return_info()))
|
1430
1425
|
delay_cls = target.get_after_update(delay_identifier)
|
1431
1426
|
return delay_cls
|
brainstate/_module_test.py
CHANGED
@@ -86,7 +86,7 @@ class TestDelay(unittest.TestCase):
|
|
86
86
|
rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
87
|
rotation_delay.init_state()
|
88
88
|
|
89
|
-
with bst.environ.context(i=0, t=0):
|
89
|
+
with bst.environ.context(i=0, t=0, jit_error_check=True):
|
90
90
|
rotation_delay.retrieve_at_time(-2.0)
|
91
91
|
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
92
92
|
rotation_delay.retrieve_at_time(-2.1)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
import brainunit as bu
|
16
|
+
import jax
|
17
|
+
import jax.random as jr
|
18
|
+
|
19
|
+
from .typing import ArrayLike, Size, DTypeLike
|
20
|
+
|
21
|
+
|
22
|
+
def uniform_for_unit(
|
23
|
+
key,
|
24
|
+
shape: Size = (),
|
25
|
+
dtype: DTypeLike = float,
|
26
|
+
minval: ArrayLike = 0.,
|
27
|
+
maxval: ArrayLike = 1.
|
28
|
+
) -> jax.Array | bu.Quantity:
|
29
|
+
if isinstance(minval, bu.Quantity) and isinstance(maxval, bu.Quantity):
|
30
|
+
return bu.Quantity(jr.uniform(key, shape, dtype, minval.value, maxval.value), dim=minval.dim)
|
31
|
+
elif isinstance(minval, bu.Quantity):
|
32
|
+
assert minval.is_unitless, f'minval must be unitless when maxval is not a Quantity, got {minval}'
|
33
|
+
minval = minval.value
|
34
|
+
elif isinstance(maxval, bu.Quantity):
|
35
|
+
assert maxval.is_unitless, f'maxval must be unitless when minval is not a Quantity, got {maxval}'
|
36
|
+
maxval = maxval.value
|
37
|
+
return jr.uniform(key, shape, dtype, minval, maxval)
|
38
|
+
|
39
|
+
|
40
|
+
def permutation_for_unit(
|
41
|
+
key,
|
42
|
+
x: int | ArrayLike,
|
43
|
+
axis: int = 0,
|
44
|
+
independent: bool = False
|
45
|
+
) -> jax.Array | bu.Quantity:
|
46
|
+
if isinstance(x, bu.Quantity):
|
47
|
+
return bu.Quantity(jr.permutation(key, x.value, axis, independent=independent), dim=x.dim)
|
48
|
+
return jr.permutation(key, x, axis, independent=independent)
|
brainstate/_state.py
CHANGED
@@ -29,7 +29,9 @@ max_int = np.iinfo(np.int32)
|
|
29
29
|
|
30
30
|
__all__ = [
|
31
31
|
'State', 'ShortTermState', 'LongTermState', 'ParamState',
|
32
|
-
'StateDictManager',
|
32
|
+
'StateDictManager',
|
33
|
+
'StateTrace',
|
34
|
+
'visible_state_dict',
|
33
35
|
'check_state_value_tree',
|
34
36
|
]
|
35
37
|
|
@@ -141,7 +143,7 @@ class State(object):
|
|
141
143
|
"""
|
142
144
|
# value checking
|
143
145
|
v = v.value if isinstance(v, State) else v
|
144
|
-
self.
|
146
|
+
self._check_value_tree(v)
|
145
147
|
# write the value by the stack (>= level)
|
146
148
|
trace: StateTrace
|
147
149
|
for trace in thread_local_stack.stack[self._level:]:
|
@@ -149,9 +151,9 @@ class State(object):
|
|
149
151
|
# set the value
|
150
152
|
self._value = v
|
151
153
|
|
152
|
-
def
|
154
|
+
def _check_value_tree(self, v):
|
153
155
|
if self._check_tree or _global_context_to_check_state_tree[-1]:
|
154
|
-
in_tree = jax.
|
156
|
+
in_tree = jax.tree.structure(v)
|
155
157
|
if in_tree != self._tree:
|
156
158
|
self._raise_error_with_source_info(
|
157
159
|
ValueError(f'The given value {in_tree} does not '
|
@@ -370,12 +372,13 @@ class StateTrace(object):
|
|
370
372
|
self.types[index] = 'write'
|
371
373
|
self._written_ids.add(id_)
|
372
374
|
|
373
|
-
def collect_values(self, *categories: str) -> Tuple:
|
375
|
+
def collect_values(self, *categories: str, check_val_tree: bool = False) -> Tuple:
|
374
376
|
"""
|
375
377
|
Collect the values by the given categories.
|
376
378
|
|
377
379
|
Args:
|
378
380
|
*categories: The categories.
|
381
|
+
check_val_tree: Whether to check the tree structure of the value.
|
379
382
|
|
380
383
|
Returns:
|
381
384
|
results: The values.
|
@@ -383,7 +386,10 @@ class StateTrace(object):
|
|
383
386
|
results = []
|
384
387
|
for st, ty in zip(self.states, self.types):
|
385
388
|
if ty in categories:
|
386
|
-
|
389
|
+
val = st.value
|
390
|
+
if check_val_tree:
|
391
|
+
st._check_value_tree(val)
|
392
|
+
results.append(val)
|
387
393
|
return tuple(results)
|
388
394
|
|
389
395
|
def recovery_original_values(self) -> None:
|
brainstate/init/_generic.py
CHANGED
@@ -15,15 +15,15 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
import numbers
|
19
18
|
from typing import Union, Callable, Optional, Sequence
|
20
19
|
|
20
|
+
import brainunit as bu
|
21
21
|
import jax
|
22
|
-
import jax.numpy as jnp
|
23
22
|
import numpy as np
|
24
23
|
|
25
24
|
from brainstate._state import State
|
26
25
|
from ._base import to_size
|
26
|
+
from ..typing import ArrayLike
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'param',
|
@@ -33,11 +33,57 @@ __all__ = [
|
|
33
33
|
|
34
34
|
|
35
35
|
def _is_scalar(x):
|
36
|
-
return
|
36
|
+
return bu.math.isscalar(x)
|
37
|
+
|
38
|
+
|
39
|
+
def are_shapes_broadcastable(shape1, shape2):
|
40
|
+
"""
|
41
|
+
Check if two shapes are broadcastable.
|
42
|
+
|
43
|
+
Parameters:
|
44
|
+
- shape1: Tuple[int], the shape of the first array.
|
45
|
+
- shape2: Tuple[int], the shape of the second array.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
- bool: True if shapes are broadcastable, False otherwise.
|
49
|
+
"""
|
50
|
+
# Reverse the shapes to compare from the last dimension
|
51
|
+
shape1_reversed = shape1[::-1]
|
52
|
+
shape2_reversed = shape2[::-1]
|
53
|
+
|
54
|
+
# Iterate over the dimensions of the shorter shape
|
55
|
+
for dim1, dim2 in zip(shape1_reversed, shape2_reversed):
|
56
|
+
# Check if the dimensions are not equal and neither is 1
|
57
|
+
if dim1 != dim2 and 1 not in (dim1, dim2):
|
58
|
+
return False
|
59
|
+
|
60
|
+
# If all dimensions are compatible, the shapes are broadcastable
|
61
|
+
return True
|
62
|
+
|
63
|
+
|
64
|
+
def _expand_params_to_match_sizes(params, sizes):
|
65
|
+
"""
|
66
|
+
Expand the dimensions of params to match the dimensions of sizes.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
- params: jax.Array or np.ndarray, the parameter array to be expanded.
|
70
|
+
- sizes: tuple[int] or list[int], the target shape dimensions.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
- Expanded params with dimensions matching sizes.
|
74
|
+
"""
|
75
|
+
params_dim = params.ndim
|
76
|
+
sizes_dim = len(sizes)
|
77
|
+
dim_diff = sizes_dim - params_dim
|
78
|
+
|
79
|
+
# Add new axes to params if it has fewer dimensions than sizes
|
80
|
+
for _ in range(dim_diff):
|
81
|
+
params = bu.math.expand_dims(params, axis=0) # Add new axis at the last dimension
|
82
|
+
return params
|
37
83
|
|
38
84
|
|
39
85
|
def param(
|
40
|
-
|
86
|
+
parameter: Union[Callable, ArrayLike],
|
41
87
|
sizes: Union[int, Sequence[int]],
|
42
88
|
batch_size: Optional[int] = None,
|
43
89
|
allow_none: bool = True,
|
@@ -47,7 +93,7 @@ def param(
|
|
47
93
|
|
48
94
|
Parameters
|
49
95
|
----------
|
50
|
-
|
96
|
+
parameter: callable, ArrayLike, State
|
51
97
|
The initialization of the parameter.
|
52
98
|
- If it is None, the created parameter will be None.
|
53
99
|
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
|
@@ -71,42 +117,55 @@ def param(
|
|
71
117
|
--------
|
72
118
|
noise, state
|
73
119
|
"""
|
74
|
-
if
|
120
|
+
# Check if the parameter is None
|
121
|
+
if parameter is None:
|
75
122
|
if allow_none:
|
76
123
|
return None
|
77
124
|
else:
|
78
125
|
raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
|
79
126
|
f'Callable function, but we got None. ')
|
80
|
-
sizes = list(to_size(sizes))
|
81
|
-
if allow_scalar and _is_scalar(param):
|
82
|
-
return param
|
83
127
|
|
84
|
-
if
|
85
|
-
|
128
|
+
# Check if the parameter is a scalar value
|
129
|
+
if allow_scalar and _is_scalar(parameter):
|
130
|
+
return parameter
|
86
131
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
param = jnp.repeat(jnp.expand_dims(param, axis=0), batch_size, axis=0)
|
93
|
-
elif isinstance(param, State):
|
94
|
-
param = param
|
132
|
+
# Convert sizes to a tuple
|
133
|
+
sizes = tuple(to_size(sizes))
|
134
|
+
|
135
|
+
# Check if the parameter is a callable function
|
136
|
+
if callable(parameter):
|
95
137
|
if batch_size is not None:
|
96
|
-
|
138
|
+
sizes = (batch_size,) + sizes
|
139
|
+
return parameter(sizes)
|
140
|
+
elif isinstance(parameter, (np.ndarray, jax.Array, bu.Quantity, State)):
|
141
|
+
parameter = parameter
|
97
142
|
else:
|
98
|
-
raise ValueError(f'Unknown parameter type: {type(
|
143
|
+
raise ValueError(f'Unknown parameter type: {type(parameter)}')
|
144
|
+
|
145
|
+
# Check if the shape of the parameter matches the given size
|
146
|
+
if not are_shapes_broadcastable(parameter.shape, sizes):
|
147
|
+
raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
|
99
148
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
149
|
+
# Expand the parameter to match the given batch size
|
150
|
+
param_value = parameter.value if isinstance(parameter, State) else parameter
|
151
|
+
if batch_size is not None:
|
152
|
+
if param_value.ndim <= len(sizes):
|
153
|
+
# add a new axis to the params so that it matches the dimensionality of the given shape ``sizes``
|
154
|
+
param_value = _expand_params_to_match_sizes(param_value, sizes)
|
155
|
+
param_value = bu.math.repeat(
|
156
|
+
bu.math.expand_dims(param_value, axis=0),
|
157
|
+
batch_size,
|
158
|
+
axis=0
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
if param_value.shape[0] != batch_size:
|
162
|
+
raise ValueError(f'The batch size of the parameter {param_value.shape[0]} '
|
163
|
+
f'does not match with the given batch size {batch_size}')
|
164
|
+
return type(parameter)(param_value) if isinstance(parameter, State) else param_value
|
106
165
|
|
107
166
|
|
108
167
|
def state(
|
109
|
-
init: Union[Callable,
|
168
|
+
init: Union[Callable, jax.typing.ArrayLike],
|
110
169
|
sizes: Union[int, Sequence[int]] = None,
|
111
170
|
batch_size: Optional[int] = None,
|
112
171
|
):
|
@@ -124,18 +183,24 @@ def state(
|
|
124
183
|
|
125
184
|
else:
|
126
185
|
if sizes is not None:
|
127
|
-
if
|
128
|
-
raise ValueError(f'The shape of "data" {
|
186
|
+
if bu.math.shape(init) != sizes:
|
187
|
+
raise ValueError(f'The shape of "data" {bu.math.shape(init)} does not match with "var_shape" {sizes}')
|
129
188
|
if isinstance(batch_size, int):
|
130
189
|
batch_size = batch_size
|
131
|
-
data = State(
|
190
|
+
data = State(
|
191
|
+
bu.math.repeat(
|
192
|
+
bu.math.expand_dims(init, axis=0),
|
193
|
+
batch_size,
|
194
|
+
axis=0
|
195
|
+
)
|
196
|
+
)
|
132
197
|
else:
|
133
198
|
data = State(init)
|
134
199
|
return data
|
135
200
|
|
136
201
|
|
137
202
|
def noise(
|
138
|
-
noises: Optional[Union[
|
203
|
+
noises: Optional[Union[ArrayLike, Callable]],
|
139
204
|
size: Union[int, Sequence[int]],
|
140
205
|
num_vars: int = 1,
|
141
206
|
noise_idx: int = 0,
|
brainstate/init/_random_inits.py
CHANGED
@@ -17,11 +17,13 @@
|
|
17
17
|
|
18
18
|
import math
|
19
19
|
|
20
|
+
import brainunit as bu
|
20
21
|
import jax.numpy as jnp
|
21
22
|
import numpy as np
|
22
23
|
|
23
24
|
from brainstate import environ, random
|
24
25
|
from ._base import Initializer, to_size
|
26
|
+
from ..typing import ArrayLike
|
25
27
|
|
26
28
|
__all__ = [
|
27
29
|
'Normal',
|
@@ -260,7 +262,7 @@ class Uniform(Initializer):
|
|
260
262
|
class VarianceScaling(Initializer):
|
261
263
|
def __init__(
|
262
264
|
self,
|
263
|
-
scale:
|
265
|
+
scale: ArrayLike,
|
264
266
|
mode: str,
|
265
267
|
distribution: str,
|
266
268
|
in_axis: int = -2,
|
@@ -287,7 +289,9 @@ class VarianceScaling(Initializer):
|
|
287
289
|
denominator = (fan_in + fan_out) / 2
|
288
290
|
else:
|
289
291
|
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
|
290
|
-
|
292
|
+
scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
|
293
|
+
dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
|
294
|
+
variance = (scale / denominator).astype(self.dtype)
|
291
295
|
if self.distribution == "truncated_normal":
|
292
296
|
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(self.dtype)
|
293
297
|
res = random.truncated_normal(-2, 2, shape, dtype=self.dtype) * stddev
|
@@ -298,7 +302,7 @@ class VarianceScaling(Initializer):
|
|
298
302
|
jnp.sqrt(3 * variance).astype(self.dtype))
|
299
303
|
else:
|
300
304
|
raise ValueError("invalid distribution for variance scaling initializer")
|
301
|
-
return res
|
305
|
+
return res if dim == bu.DIMENSIONLESS else res * dim
|
302
306
|
|
303
307
|
def __repr__(self):
|
304
308
|
name = self.__class__.__name__
|
@@ -425,7 +429,7 @@ class Orthogonal(Initializer):
|
|
425
429
|
|
426
430
|
def __init__(
|
427
431
|
self,
|
428
|
-
scale:
|
432
|
+
scale: ArrayLike = 1.,
|
429
433
|
axis: int = -1,
|
430
434
|
dtype=None
|
431
435
|
):
|
@@ -440,6 +444,9 @@ class Orthogonal(Initializer):
|
|
440
444
|
n_cols = np.prod(shape) // n_rows
|
441
445
|
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
|
442
446
|
norm_dst = random.normal(size=matrix_shape, dtype=self.dtype)
|
447
|
+
|
448
|
+
scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
|
449
|
+
dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
|
443
450
|
q_mat, r_mat = jnp.linalg.qr(norm_dst)
|
444
451
|
# Enforce Q is uniformly distributed
|
445
452
|
q_mat *= jnp.sign(jnp.diag(r_mat))
|
@@ -447,7 +454,8 @@ class Orthogonal(Initializer):
|
|
447
454
|
q_mat = q_mat.T
|
448
455
|
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
|
449
456
|
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
|
450
|
-
|
457
|
+
r = jnp.asarray(scale, dtype=self.dtype) * q_mat
|
458
|
+
return r if dim == bu.DIMENSIONLESS else r * dim
|
451
459
|
|
452
460
|
def __repr__(self):
|
453
461
|
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
@@ -472,7 +480,9 @@ class DeltaOrthogonal(Initializer):
|
|
472
480
|
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
|
473
481
|
if shape[-1] < shape[-2]:
|
474
482
|
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
|
475
|
-
|
483
|
+
scale = self.scale.value if isinstance(self.scale, bu.Quantity) else self.scale
|
484
|
+
dim = self.scale.dim if isinstance(self.scale, bu.Quantity) else bu.DIMENSIONLESS
|
485
|
+
ortho_matrix = Orthogonal(scale=scale, axis=self.axis, dtype=self.dtype)(*shape[-2:])
|
476
486
|
W = jnp.zeros(shape, dtype=self.dtype)
|
477
487
|
if len(shape) == 3:
|
478
488
|
k = shape[0]
|
@@ -483,7 +493,7 @@ class DeltaOrthogonal(Initializer):
|
|
483
493
|
else:
|
484
494
|
k1, k2, k3 = shape[:3]
|
485
495
|
W = W.at[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2].set(ortho_matrix)
|
486
|
-
return W
|
496
|
+
return W if dim == bu.DIMENSIONLESS else W * dim
|
487
497
|
|
488
498
|
def __repr__(self):
|
489
499
|
return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, dtype={self.dtype})'
|
@@ -15,7 +15,8 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
|
18
|
+
|
19
|
+
import brainunit as bu
|
19
20
|
|
20
21
|
from brainstate import environ
|
21
22
|
from ._base import Initializer, to_size
|
@@ -39,7 +40,7 @@ class ZeroInit(Initializer):
|
|
39
40
|
|
40
41
|
def __call__(self, shape):
|
41
42
|
shape = to_size(shape)
|
42
|
-
return
|
43
|
+
return bu.math.zeros(shape, dtype=self.dtype)
|
43
44
|
|
44
45
|
def __repr__(self):
|
45
46
|
return f"{self.__class__.__name__}(dtype={self.dtype})"
|
@@ -59,11 +60,11 @@ class Constant(Initializer):
|
|
59
60
|
def __init__(self, value=1., dtype=None):
|
60
61
|
super(Constant, self).__init__()
|
61
62
|
self.dtype = dtype or environ.dftype()
|
62
|
-
self.value =
|
63
|
+
self.value = bu.math.asarray(value, dtype=self.dtype)
|
63
64
|
|
64
65
|
def __call__(self, shape):
|
65
66
|
shape = to_size(shape)
|
66
|
-
return
|
67
|
+
return bu.math.full(shape, self.value, dtype=self.dtype)
|
67
68
|
|
68
69
|
def __repr__(self):
|
69
70
|
return f'{self.__class__.__name__}(value={self.value}, dtype={self.dtype})'
|
@@ -94,15 +95,15 @@ class Identity(Initializer):
|
|
94
95
|
def __init__(self, value=1., dtype=None):
|
95
96
|
super(Identity, self).__init__()
|
96
97
|
self.dtype = dtype or environ.dftype()
|
97
|
-
self.value =
|
98
|
+
self.value = bu.math.asarray(value, dtype=self.dtype)
|
98
99
|
|
99
100
|
def __call__(self, shape):
|
100
101
|
shape = to_size(shape)
|
101
102
|
if isinstance(shape, (tuple, list)):
|
102
103
|
if len(shape) > 2:
|
103
104
|
raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.')
|
104
|
-
r =
|
105
|
-
r =
|
105
|
+
r = bu.math.eye(shape, dtype=self.dtype)
|
106
|
+
r = bu.math.fill_diagonal(r, self.value)
|
106
107
|
return r
|
107
108
|
|
108
109
|
def __repr__(self):
|
brainstate/mixin.py
CHANGED
@@ -32,7 +32,7 @@ __all__ = [
|
|
32
32
|
'UpdateReturn',
|
33
33
|
|
34
34
|
# types
|
35
|
-
'
|
35
|
+
'JointTypes',
|
36
36
|
'OneOfTypes',
|
37
37
|
|
38
38
|
# behavior modes
|
@@ -206,7 +206,7 @@ class _JointGenericAlias(_UnionGenericAlias, _root=True):
|
|
206
206
|
|
207
207
|
|
208
208
|
@_SpecialForm
|
209
|
-
def
|
209
|
+
def JointTypes(self, parameters):
|
210
210
|
"""All of types; AllOfTypes[X, Y] means both X and Y.
|
211
211
|
|
212
212
|
To define a union, use e.g. Union[int, str].
|
@@ -341,7 +341,7 @@ class JointMode(Mode):
|
|
341
341
|
"""
|
342
342
|
Check whether the mode is exactly the desired mode.
|
343
343
|
"""
|
344
|
-
return
|
344
|
+
return JointTypes[tuple(self.types)] == cls
|
345
345
|
|
346
346
|
def __getattr__(self, item):
|
347
347
|
"""
|
brainstate/mixin_test.py
CHANGED
@@ -23,7 +23,7 @@ class TestMixin(unittest.TestCase):
|
|
23
23
|
self.assertTrue(bc.mixin.Mixin)
|
24
24
|
self.assertTrue(bc.mixin.DelayedInit)
|
25
25
|
self.assertTrue(bc.mixin.DelayedInitializer)
|
26
|
-
self.assertTrue(bc.mixin.
|
26
|
+
self.assertTrue(bc.mixin.JointTypes)
|
27
27
|
self.assertTrue(bc.mixin.OneOfTypes)
|
28
28
|
self.assertTrue(bc.mixin.Mode)
|
29
29
|
self.assertTrue(bc.mixin.Batching)
|
@@ -33,29 +33,29 @@ class TestMixin(unittest.TestCase):
|
|
33
33
|
class TestMode(unittest.TestCase):
|
34
34
|
def test_JointMode(self):
|
35
35
|
a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
|
36
|
-
self.assertTrue(a.is_a(bc.mixin.
|
36
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching, bc.mixin.Training]))
|
37
37
|
self.assertTrue(a.has(bc.mixin.Batching))
|
38
38
|
self.assertTrue(a.has(bc.mixin.Training))
|
39
39
|
b = bc.mixin.JointMode(bc.mixin.Batching())
|
40
|
-
self.assertTrue(b.is_a(bc.mixin.
|
40
|
+
self.assertTrue(b.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
|
41
41
|
self.assertTrue(b.is_a(bc.mixin.Batching))
|
42
42
|
self.assertTrue(b.has(bc.mixin.Batching))
|
43
43
|
|
44
44
|
def test_Training(self):
|
45
45
|
a = bc.mixin.Training()
|
46
46
|
self.assertTrue(a.is_a(bc.mixin.Training))
|
47
|
-
self.assertTrue(a.is_a(bc.mixin.
|
47
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Training]))
|
48
48
|
self.assertTrue(a.has(bc.mixin.Training))
|
49
|
-
self.assertTrue(a.has(bc.mixin.
|
49
|
+
self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Training]))
|
50
50
|
self.assertFalse(a.is_a(bc.mixin.Batching))
|
51
51
|
self.assertFalse(a.has(bc.mixin.Batching))
|
52
52
|
|
53
53
|
def test_Batching(self):
|
54
54
|
a = bc.mixin.Batching()
|
55
55
|
self.assertTrue(a.is_a(bc.mixin.Batching))
|
56
|
-
self.assertTrue(a.is_a(bc.mixin.
|
56
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
|
57
57
|
self.assertTrue(a.has(bc.mixin.Batching))
|
58
|
-
self.assertTrue(a.has(bc.mixin.
|
58
|
+
self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Batching]))
|
59
59
|
|
60
60
|
self.assertFalse(a.is_a(bc.mixin.Training))
|
61
61
|
self.assertFalse(a.has(bc.mixin.Training))
|
@@ -63,9 +63,9 @@ class TestMode(unittest.TestCase):
|
|
63
63
|
def test_Mode(self):
|
64
64
|
a = bc.mixin.Mode()
|
65
65
|
self.assertTrue(a.is_a(bc.mixin.Mode))
|
66
|
-
self.assertTrue(a.is_a(bc.mixin.
|
66
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Mode]))
|
67
67
|
self.assertTrue(a.has(bc.mixin.Mode))
|
68
|
-
self.assertTrue(a.has(bc.mixin.
|
68
|
+
self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Mode]))
|
69
69
|
|
70
70
|
self.assertFalse(a.is_a(bc.mixin.Training))
|
71
71
|
self.assertFalse(a.has(bc.mixin.Training))
|