brainstate 0.1.7__py2.py3-none-any.whl → 0.1.8__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/compile/_make_jaxpr.py +8 -8
- brainstate/compile/_make_jaxpr_test.py +19 -9
- brainstate/nn/_delay.py +129 -97
- brainstate/nn/_delay_test.py +184 -0
- brainstate/nn/_dynamics.py +38 -56
- brainstate/nn/_elementwise.py +1 -2
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module_test.py +0 -168
- brainstate/nn/_synapse.py +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.8.dist-info}/METADATA +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.8.dist-info}/RECORD +15 -14
- {brainstate-0.1.7.dist-info → brainstate-0.1.8.dist-info}/LICENSE +0 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.8.dist-info}/WHEEL +0 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.8.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -352,11 +352,13 @@ class StatefulFunction(PrettyObject):
|
|
352
352
|
cache_key = default_cache_key
|
353
353
|
return self.get_state_trace(cache_key).get_write_states()
|
354
354
|
|
355
|
-
def
|
355
|
+
def _check_input_ouput(self, x):
|
356
356
|
if isinstance(x, State):
|
357
|
-
|
358
|
-
|
359
|
-
|
357
|
+
x.raise_error_with_source_info(
|
358
|
+
ValueError(
|
359
|
+
'Inputs/outputs for brainstate transformations cannot be an instance of State. '
|
360
|
+
f'But we got {x}'
|
361
|
+
)
|
360
362
|
)
|
361
363
|
|
362
364
|
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
@@ -477,9 +479,7 @@ class StatefulFunction(PrettyObject):
|
|
477
479
|
|
478
480
|
# State instance as functional returns is not allowed.
|
479
481
|
# Checking whether the states are returned.
|
480
|
-
|
481
|
-
if isinstance(leaf, State):
|
482
|
-
leaf.raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
|
482
|
+
jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
|
483
483
|
return out, state_values
|
484
484
|
|
485
485
|
def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
|
@@ -501,7 +501,7 @@ class StatefulFunction(PrettyObject):
|
|
501
501
|
cache_key = self.get_arg_cache_key(*args, **kwargs)
|
502
502
|
|
503
503
|
# check input types
|
504
|
-
jax.tree.map(self.
|
504
|
+
jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
|
505
505
|
|
506
506
|
if cache_key not in self._cached_state_trace:
|
507
507
|
try:
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from __future__ import annotations
|
17
16
|
|
18
17
|
import unittest
|
19
18
|
|
@@ -132,15 +131,26 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
132
131
|
print(jaxpr)
|
133
132
|
print(states)
|
134
133
|
|
134
|
+
def test_state_in(self):
|
135
|
+
def f(a):
|
136
|
+
return a.value
|
135
137
|
|
136
|
-
|
137
|
-
|
138
|
+
with pytest.raises(ValueError):
|
139
|
+
brainstate.compile.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
|
138
140
|
|
139
|
-
|
141
|
+
def test_state_out(self):
|
142
|
+
def f(a):
|
143
|
+
return brainstate.State(a)
|
140
144
|
|
141
|
-
|
142
|
-
|
143
|
-
return a
|
145
|
+
with pytest.raises(ValueError):
|
146
|
+
brainstate.compile.StatefulFunction(f).make_jaxpr(1.)
|
144
147
|
|
145
|
-
|
146
|
-
|
148
|
+
def test_return_states(self):
|
149
|
+
a = brainstate.State(jnp.ones(3))
|
150
|
+
|
151
|
+
@brainstate.compile.jit
|
152
|
+
def f():
|
153
|
+
return a
|
154
|
+
|
155
|
+
with pytest.raises(ValueError):
|
156
|
+
f()
|
brainstate/nn/_delay.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
import math
|
17
16
|
import numbers
|
18
17
|
from functools import partial
|
19
18
|
from typing import Optional, Dict, Callable, Union, Sequence
|
@@ -41,34 +40,27 @@ _INTERP_LINEAR = 'linear_interp'
|
|
41
40
|
_INTERP_ROUND = 'round'
|
42
41
|
|
43
42
|
|
44
|
-
def _get_delay(delay_time
|
45
|
-
|
46
|
-
if
|
47
|
-
return 0
|
48
|
-
|
49
|
-
|
50
|
-
if delay_step == 0:
|
51
|
-
return 0., 0
|
52
|
-
with jax.ensure_compile_time_eval():
|
53
|
-
delay_time = delay_step * environ.get_dt()
|
54
|
-
else:
|
55
|
-
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
|
56
|
-
# assert isinstance(delay_time, (int, float))
|
57
|
-
with jax.ensure_compile_time_eval():
|
58
|
-
delay_step = delay_time / environ.get_dt()
|
59
|
-
delay_step = math.ceil(float(delay_step))
|
43
|
+
def _get_delay(delay_time):
|
44
|
+
with jax.ensure_compile_time_eval():
|
45
|
+
if delay_time is None:
|
46
|
+
return 0. * environ.get_dt(), 0
|
47
|
+
delay_step = delay_time / environ.get_dt()
|
48
|
+
delay_step = jnp.ceil(delay_step).astype(environ.ditype())
|
60
49
|
return delay_time, delay_step
|
61
50
|
|
62
51
|
|
63
52
|
class DelayAccess(Node):
|
64
53
|
"""
|
65
|
-
|
54
|
+
Accessor node for a registered entry in a Delay instance.
|
55
|
+
|
56
|
+
This node holds a reference to a Delay and a named entry that was
|
57
|
+
registered on that Delay. It is used by graphs to query delayed
|
58
|
+
values by delegating to the underlying Delay instance.
|
66
59
|
|
67
60
|
Args:
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
delay_entry: The delay entry.
|
61
|
+
delay: The delay instance.
|
62
|
+
time: The delay time.
|
63
|
+
delay_entry: The delay entry.
|
72
64
|
"""
|
73
65
|
|
74
66
|
__module__ = 'brainstate.nn'
|
@@ -78,17 +70,15 @@ class DelayAccess(Node):
|
|
78
70
|
delay: 'Delay',
|
79
71
|
time: Union[None, int, float],
|
80
72
|
delay_entry: str,
|
81
|
-
*indices,
|
82
73
|
):
|
83
74
|
super().__init__()
|
84
75
|
self.refs = {'delay': delay}
|
85
76
|
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
86
77
|
self._delay_entry = delay_entry
|
87
|
-
delay.register_entry(self._delay_entry, time)
|
88
|
-
self.indices = indices
|
78
|
+
self.delay_info = delay.register_entry(self._delay_entry, time)
|
89
79
|
|
90
80
|
def update(self):
|
91
|
-
return self.refs['delay'].at(self._delay_entry
|
81
|
+
return self.refs['delay'].at(self._delay_entry)
|
92
82
|
|
93
83
|
|
94
84
|
class Delay(Module):
|
@@ -141,17 +131,21 @@ class Delay(Module):
|
|
141
131
|
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
142
132
|
|
143
133
|
# delay method
|
144
|
-
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
|
145
|
-
|
134
|
+
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
|
135
|
+
f'Un-supported delay method {delay_method}. '
|
136
|
+
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
|
137
|
+
)
|
146
138
|
self.delay_method = delay_method
|
147
139
|
|
148
140
|
# interp method
|
149
|
-
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
|
150
|
-
|
141
|
+
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
|
142
|
+
f'Un-supported interpolation method {interp_method}. '
|
143
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
|
144
|
+
)
|
151
145
|
self.interp_method = interp_method
|
152
146
|
|
153
147
|
# delay length and time
|
154
|
-
self.max_time, delay_length = _get_delay(time
|
148
|
+
self.max_time, delay_length = _get_delay(time)
|
155
149
|
self.max_length = delay_length + 1
|
156
150
|
|
157
151
|
super().__init__()
|
@@ -169,7 +163,10 @@ class Delay(Module):
|
|
169
163
|
# other info
|
170
164
|
if entries is not None:
|
171
165
|
for entry, delay_time in entries.items():
|
172
|
-
|
166
|
+
if isinstance(delay_time, (tuple, list)):
|
167
|
+
self.register_entry(entry, *delay_time)
|
168
|
+
else:
|
169
|
+
self.register_entry(entry, delay_time)
|
173
170
|
|
174
171
|
self.take_aware_unit = take_aware_unit
|
175
172
|
self._unit = None
|
@@ -205,73 +202,107 @@ class Delay(Module):
|
|
205
202
|
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
206
203
|
self.history.value = jax.tree.map(fun, self.target_info)
|
207
204
|
|
208
|
-
def register_delay(
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
if
|
214
|
-
|
215
|
-
delay_time = delay_time.item()
|
205
|
+
def register_delay(self, *delay_time):
|
206
|
+
"""
|
207
|
+
Register delay times and update the maximum delay configuration.
|
208
|
+
|
209
|
+
This method processes one or more delay times, validates their format and consistency,
|
210
|
+
and updates the delay buffer size if necessary. It handles both scalar and vector
|
211
|
+
delay times, ensuring all vector delays have the same size.
|
216
212
|
|
217
|
-
|
213
|
+
Args:
|
214
|
+
*delay_time: Variable number of delay time arguments. The first argument should be
|
215
|
+
the primary delay time (float, int, or array-like). Additional arguments are
|
216
|
+
treated as indices or secondary delay parameters. All delay times should be
|
217
|
+
non-negative numbers or arrays of the same size.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
|
221
|
+
containing (delay_step, *delay_time[1:]) where delay_step is the computed
|
222
|
+
delay step in integer time units, and the remaining elements are the
|
223
|
+
additional delay parameters passed in.
|
224
|
+
|
225
|
+
Raises:
|
226
|
+
AssertionError: If no delay time is provided (empty delay_time).
|
227
|
+
ValueError: If delay times have inconsistent sizes when using vector delays,
|
228
|
+
or if delay times are not scalar or 1D arrays.
|
229
|
+
|
230
|
+
Note:
|
231
|
+
- The method updates self.max_time and self.max_length if the new delay
|
232
|
+
requires a larger buffer size.
|
233
|
+
- Delay steps are computed using the current environment time step (dt).
|
234
|
+
- All delay indices (delay_time[1:]) must be integers.
|
235
|
+
- Vector delays must all have the same size as the first delay time.
|
236
|
+
|
237
|
+
Example:
|
238
|
+
>>> delay_obj.register_delay(5.0) # Register 5ms delay
|
239
|
+
>>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
|
240
|
+
"""
|
241
|
+
assert len(delay_time) >= 1, 'You should provide at least one delay time.'
|
242
|
+
delay_size = u.math.size(delay_time[0])
|
243
|
+
for dt in delay_time[1:]:
|
244
|
+
assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
|
245
|
+
for dt in delay_time:
|
246
|
+
if u.math.ndim(dt) == 0:
|
247
|
+
pass
|
248
|
+
elif u.math.ndim(dt) == 1:
|
249
|
+
if u.math.size(dt) != delay_size:
|
250
|
+
raise ValueError(
|
251
|
+
f'The delay time should be a scalar or a vector with the same size. '
|
252
|
+
f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
|
253
|
+
)
|
254
|
+
else:
|
255
|
+
raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
|
256
|
+
if delay_time[0] is None:
|
257
|
+
return None
|
258
|
+
time, delay_step = _get_delay(delay_time[0])
|
259
|
+
max_delay_step = jnp.max(delay_step)
|
260
|
+
self.max_time = u.math.max(time)
|
218
261
|
|
219
262
|
# delay variable
|
220
|
-
if self.max_length <=
|
221
|
-
self.max_length =
|
222
|
-
|
223
|
-
return self
|
263
|
+
if self.max_length <= max_delay_step + 1:
|
264
|
+
self.max_length = max_delay_step + 1
|
265
|
+
return delay_step, *delay_time[1:]
|
224
266
|
|
225
|
-
def register_entry(
|
226
|
-
self,
|
227
|
-
entry: str,
|
228
|
-
delay_time: Optional[Union[int, float]] = None,
|
229
|
-
delay_step: Optional[int] = None,
|
230
|
-
) -> 'Delay':
|
267
|
+
def register_entry(self, entry: str, *delay_time) -> 'Delay':
|
231
268
|
"""
|
232
269
|
Register an entry to access the delay data.
|
233
270
|
|
234
271
|
Args:
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
Returns:
|
240
|
-
Return the self.
|
272
|
+
entry: str. The entry to access the delay data.
|
273
|
+
delay_time: The delay time of the entry, the first element is the delay time,
|
274
|
+
the second and later element is the index.
|
241
275
|
"""
|
242
276
|
if entry in self._registered_entries:
|
243
|
-
raise KeyError(
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
277
|
+
raise KeyError(
|
278
|
+
f'Entry {entry} has been registered. '
|
279
|
+
f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
|
280
|
+
f'The new delay for the key {entry} is {delay_time}. '
|
281
|
+
f'You can use another key. '
|
282
|
+
)
|
283
|
+
delay_info = self.register_delay(*delay_time)
|
284
|
+
self._registered_entries[entry] = delay_info
|
285
|
+
return delay_info
|
251
286
|
|
252
|
-
|
287
|
+
def access(self, entry: str, delay_time: Sequence) -> DelayAccess:
|
288
|
+
"""
|
289
|
+
Create a DelayAccess object for a specific delay entry and delay time.
|
253
290
|
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
self.max_time = delay_time
|
258
|
-
self._registered_entries[entry] = delay_step
|
259
|
-
return self
|
291
|
+
Args:
|
292
|
+
entry (str): The name of the delay entry to access.
|
293
|
+
delay_time (Sequence): The delay time or parameters associated with the entry.
|
260
294
|
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
) -> DelayAccess:
|
266
|
-
return DelayAccess(self, time, delay_entry=entry)
|
295
|
+
Returns:
|
296
|
+
DelayAccess: An object that provides access to the delay data for the specified entry and time.
|
297
|
+
"""
|
298
|
+
return DelayAccess(self, delay_time, delay_entry=entry)
|
267
299
|
|
268
|
-
def at(self, entry: str
|
300
|
+
def at(self, entry: str) -> ArrayLike:
|
269
301
|
"""
|
270
302
|
Get the data at the given entry.
|
271
303
|
|
272
304
|
Args:
|
273
305
|
entry: str. The entry to access the data.
|
274
|
-
*indices: The slicing indices. Not include the slice at the batch dimension.
|
275
306
|
|
276
307
|
Returns:
|
277
308
|
The data.
|
@@ -282,8 +313,8 @@ class Delay(Module):
|
|
282
313
|
raise KeyError(f'Does not find delay entry "{entry}".')
|
283
314
|
delay_step = self._registered_entries[entry]
|
284
315
|
if delay_step is None:
|
285
|
-
delay_step = 0
|
286
|
-
return self.retrieve_at_step(delay_step
|
316
|
+
delay_step = (0,)
|
317
|
+
return self.retrieve_at_step(*delay_step)
|
287
318
|
|
288
319
|
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
|
289
320
|
"""
|
@@ -306,8 +337,10 @@ class Delay(Module):
|
|
306
337
|
|
307
338
|
if environ.get(environ.JIT_ERROR_CHECK, False):
|
308
339
|
def _check_delay(delay_len):
|
309
|
-
raise ValueError(
|
310
|
-
|
340
|
+
raise ValueError(
|
341
|
+
f'The request delay length should be less than the '
|
342
|
+
f'maximum delay {self.max_length - 1}. But we got {delay_len}'
|
343
|
+
)
|
311
344
|
|
312
345
|
jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
|
313
346
|
|
@@ -363,13 +396,17 @@ class Delay(Module):
|
|
363
396
|
|
364
397
|
if environ.get(environ.JIT_ERROR_CHECK, False):
|
365
398
|
def _check_delay(t_now, t_delay):
|
366
|
-
raise ValueError(
|
367
|
-
|
368
|
-
|
399
|
+
raise ValueError(
|
400
|
+
f'The request delay time should be within '
|
401
|
+
f'[{t_now - self.max_time - dt}, {t_now}], '
|
402
|
+
f'but we got {t_delay}'
|
403
|
+
)
|
369
404
|
|
370
405
|
jit_error_if(
|
371
|
-
jnp.logical_or(
|
372
|
-
|
406
|
+
jnp.logical_or(
|
407
|
+
delay_time > current_time,
|
408
|
+
delay_time < current_time - self.max_time - dt
|
409
|
+
),
|
373
410
|
_check_delay,
|
374
411
|
current_time,
|
375
412
|
delay_time
|
@@ -385,10 +422,7 @@ class Delay(Module):
|
|
385
422
|
return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
|
386
423
|
|
387
424
|
elif self.interp_method == _INTERP_ROUND: # "round" interpolation
|
388
|
-
return self.retrieve_at_step(
|
389
|
-
jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
|
390
|
-
*indices
|
391
|
-
)
|
425
|
+
return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
|
392
426
|
|
393
427
|
else: # raise error
|
394
428
|
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
|
@@ -429,8 +463,6 @@ class Delay(Module):
|
|
429
463
|
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
430
464
|
|
431
465
|
|
432
|
-
|
433
|
-
|
434
466
|
class StateWithDelay(Delay):
|
435
467
|
"""
|
436
468
|
A ``State`` type that defines the state in a differential equation.
|
@@ -440,8 +472,8 @@ class StateWithDelay(Delay):
|
|
440
472
|
|
441
473
|
state: State # state
|
442
474
|
|
443
|
-
def __init__(self, target: Node, item: str):
|
444
|
-
super().__init__(None)
|
475
|
+
def __init__(self, target: Node, item: str, init: Callable = None):
|
476
|
+
super().__init__(None, init=init)
|
445
477
|
|
446
478
|
self._target = target
|
447
479
|
self._target_term = item
|
@@ -0,0 +1,184 @@
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
brainstate.environ.set(dt=0.1)
|
24
|
+
|
25
|
+
|
26
|
+
class TestDelay(unittest.TestCase):
|
27
|
+
def test_delay1(self):
|
28
|
+
a = brainstate.State(brainstate.random.random(10, 20))
|
29
|
+
delay = brainstate.nn.Delay(a.value)
|
30
|
+
delay.register_entry('a', 1.)
|
31
|
+
delay.register_entry('b', 2.)
|
32
|
+
delay.register_entry('c', None)
|
33
|
+
|
34
|
+
delay.init_state()
|
35
|
+
with self.assertRaises(KeyError):
|
36
|
+
delay.register_entry('c', 10.)
|
37
|
+
|
38
|
+
def test_rotation_delay(self):
|
39
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
40
|
+
t0 = 0.
|
41
|
+
t1, n1 = 1., 10
|
42
|
+
t2, n2 = 2., 20
|
43
|
+
|
44
|
+
rotation_delay.register_entry('a', t0)
|
45
|
+
rotation_delay.register_entry('b', t1)
|
46
|
+
rotation_delay.register_entry('c2', 1.9)
|
47
|
+
rotation_delay.register_entry('c', t2)
|
48
|
+
|
49
|
+
rotation_delay.init_state()
|
50
|
+
|
51
|
+
print()
|
52
|
+
# print(rotation_delay)
|
53
|
+
# print(rotation_delay.max_length)
|
54
|
+
|
55
|
+
for i in range(100):
|
56
|
+
brainstate.environ.set(i=i)
|
57
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
58
|
+
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
59
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
60
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
61
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
|
+
|
63
|
+
def test_concat_delay(self):
|
64
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
65
|
+
t0 = 0.
|
66
|
+
t1, n1 = 1., 10
|
67
|
+
t2, n2 = 2., 20
|
68
|
+
|
69
|
+
rotation_delay.register_entry('a', t0)
|
70
|
+
rotation_delay.register_entry('b', t1)
|
71
|
+
rotation_delay.register_entry('c', t2)
|
72
|
+
|
73
|
+
rotation_delay.init_state()
|
74
|
+
|
75
|
+
print()
|
76
|
+
for i in range(100):
|
77
|
+
brainstate.environ.set(i=i)
|
78
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
79
|
+
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
81
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
82
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
83
|
+
# brainstate.util.clear_buffer_memory()
|
84
|
+
|
85
|
+
def test_jit_erro(self):
|
86
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
|
+
rotation_delay.init_state()
|
88
|
+
|
89
|
+
with brainstate.environ.context(i=0, t=0, jit_error_check=True):
|
90
|
+
rotation_delay.retrieve_at_time(-2.0)
|
91
|
+
with self.assertRaises(Exception):
|
92
|
+
rotation_delay.retrieve_at_time(-2.1)
|
93
|
+
rotation_delay.retrieve_at_time(-2.01)
|
94
|
+
with self.assertRaises(Exception):
|
95
|
+
rotation_delay.retrieve_at_time(-2.09)
|
96
|
+
with self.assertRaises(Exception):
|
97
|
+
rotation_delay.retrieve_at_time(0.1)
|
98
|
+
with self.assertRaises(Exception):
|
99
|
+
rotation_delay.retrieve_at_time(0.01)
|
100
|
+
|
101
|
+
def test_round_interp(self):
|
102
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
|
+
for delay_method in ['rotation', 'concat']:
|
104
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
105
|
+
interp_method='round')
|
106
|
+
t0, n1 = 0.01, 0
|
107
|
+
t1, n1 = 1.04, 10
|
108
|
+
t2, n2 = 1.06, 11
|
109
|
+
rotation_delay.init_state()
|
110
|
+
|
111
|
+
@brainstate.compile.jit
|
112
|
+
def retrieve(td, i):
|
113
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
114
|
+
return rotation_delay.retrieve_at_time(td)
|
115
|
+
|
116
|
+
print()
|
117
|
+
for i in range(100):
|
118
|
+
t = i * brainstate.environ.get_dt()
|
119
|
+
with brainstate.environ.context(i=i, t=t):
|
120
|
+
rotation_delay.update(jnp.ones(shape) * i)
|
121
|
+
print(i,
|
122
|
+
retrieve(t - t0, i),
|
123
|
+
retrieve(t - t1, i),
|
124
|
+
retrieve(t - t2, i))
|
125
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
126
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
127
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
128
|
+
|
129
|
+
def test_linear_interp(self):
|
130
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
131
|
+
for delay_method in ['rotation', 'concat']:
|
132
|
+
print(shape, delay_method)
|
133
|
+
|
134
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
135
|
+
interp_method='linear_interp')
|
136
|
+
t0, n0 = 0.01, 0.1
|
137
|
+
t1, n1 = 1.04, 10.4
|
138
|
+
t2, n2 = 1.06, 10.6
|
139
|
+
rotation_delay.init_state()
|
140
|
+
|
141
|
+
@brainstate.compile.jit
|
142
|
+
def retrieve(td, i):
|
143
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
144
|
+
return rotation_delay.retrieve_at_time(td)
|
145
|
+
|
146
|
+
print()
|
147
|
+
for i in range(100):
|
148
|
+
t = i * brainstate.environ.get_dt()
|
149
|
+
with brainstate.environ.context(i=i, t=t):
|
150
|
+
rotation_delay.update(jnp.ones(shape) * i)
|
151
|
+
print(i,
|
152
|
+
retrieve(t - t0, i),
|
153
|
+
retrieve(t - t1, i),
|
154
|
+
retrieve(t - t2, i))
|
155
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
156
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
157
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
158
|
+
|
159
|
+
def test_rotation_and_concat_delay(self):
|
160
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
161
|
+
concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
162
|
+
t0 = 0.
|
163
|
+
t1, n1 = 1., 10
|
164
|
+
t2, n2 = 2., 20
|
165
|
+
|
166
|
+
rotation_delay.register_entry('a', t0)
|
167
|
+
rotation_delay.register_entry('b', t1)
|
168
|
+
rotation_delay.register_entry('c', t2)
|
169
|
+
concat_delay.register_entry('a', t0)
|
170
|
+
concat_delay.register_entry('b', t1)
|
171
|
+
concat_delay.register_entry('c', t2)
|
172
|
+
|
173
|
+
rotation_delay.init_state()
|
174
|
+
concat_delay.init_state()
|
175
|
+
|
176
|
+
print()
|
177
|
+
for i in range(100):
|
178
|
+
brainstate.environ.set(i=i)
|
179
|
+
new = jnp.ones((1,)) * i
|
180
|
+
rotation_delay.update(new)
|
181
|
+
concat_delay.update(new)
|
182
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
183
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
184
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
brainstate/nn/_dynamics.py
CHANGED
@@ -33,9 +33,8 @@ For handling the delays:
|
|
33
33
|
|
34
34
|
"""
|
35
35
|
|
36
|
-
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
36
|
+
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
|
37
37
|
|
38
|
-
import brainunit as u
|
39
38
|
import jax
|
40
39
|
import numpy as np
|
41
40
|
|
@@ -43,7 +42,7 @@ from brainstate import environ
|
|
43
42
|
from brainstate._state import State
|
44
43
|
from brainstate.graph import Node
|
45
44
|
from brainstate.mixin import ParamDescriber
|
46
|
-
from brainstate.typing import Size, ArrayLike
|
45
|
+
from brainstate.typing import Size, ArrayLike
|
47
46
|
from ._delay import StateWithDelay, Delay
|
48
47
|
from ._module import Module
|
49
48
|
|
@@ -794,10 +793,7 @@ class Dynamics(Module):
|
|
794
793
|
"""
|
795
794
|
return Prefetch(self, item)
|
796
795
|
|
797
|
-
def align_pre(
|
798
|
-
self,
|
799
|
-
dyn: Union[ParamDescriber[T], T]
|
800
|
-
) -> T:
|
796
|
+
def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
|
801
797
|
"""
|
802
798
|
Registers a dynamics module to execute after this module.
|
803
799
|
|
@@ -844,11 +840,7 @@ class Dynamics(Module):
|
|
844
840
|
else:
|
845
841
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
846
842
|
|
847
|
-
def prefetch_delay(
|
848
|
-
self,
|
849
|
-
state: str,
|
850
|
-
delay: Optional[ArrayLike] = None
|
851
|
-
) -> 'PrefetchDelayAt':
|
843
|
+
def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
|
852
844
|
"""
|
853
845
|
Create a reference to a delayed state or variable in the module.
|
854
846
|
|
@@ -858,19 +850,17 @@ class Dynamics(Module):
|
|
858
850
|
|
859
851
|
Args:
|
860
852
|
state (str): The name of the state or variable to reference.
|
861
|
-
|
862
|
-
typically in time units (e.g., milliseconds).
|
853
|
+
delay_time (ArrayLike): The amount of time to delay the variable access,
|
854
|
+
typically in time units (e.g., milliseconds).
|
855
|
+
init (Callable, optional): An optional initialization function to provide
|
856
|
+
a default value if the delayed state is not yet available.
|
863
857
|
|
864
858
|
Returns:
|
865
859
|
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
866
860
|
"""
|
867
|
-
return self
|
861
|
+
return PrefetchDelayAt(self, state, delay_time, init=init)
|
868
862
|
|
869
|
-
def output_delay(
|
870
|
-
self,
|
871
|
-
delay: Optional[ArrayLike] = None,
|
872
|
-
variable_like: PyTree = None
|
873
|
-
) -> 'OutputDelayAt':
|
863
|
+
def output_delay(self, *delay_time) -> 'OutputDelayAt':
|
874
864
|
"""
|
875
865
|
Create a reference to the delayed output of the module.
|
876
866
|
|
@@ -881,12 +871,11 @@ class Dynamics(Module):
|
|
881
871
|
Args:
|
882
872
|
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
883
873
|
typically in time units (e.g., milliseconds). Defaults to None.
|
884
|
-
variable_like:
|
885
874
|
|
886
875
|
Returns:
|
887
876
|
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
888
877
|
"""
|
889
|
-
return OutputDelayAt(self,
|
878
|
+
return OutputDelayAt(self, delay_time)
|
890
879
|
|
891
880
|
|
892
881
|
class Prefetch(Node):
|
@@ -1024,7 +1013,7 @@ class PrefetchDelay(Node):
|
|
1024
1013
|
self.module = module
|
1025
1014
|
self.item = item
|
1026
1015
|
|
1027
|
-
def at(self,
|
1016
|
+
def at(self, *delay_time):
|
1028
1017
|
"""
|
1029
1018
|
Specifies the delay time for accessing the variable.
|
1030
1019
|
|
@@ -1039,7 +1028,7 @@ class PrefetchDelay(Node):
|
|
1039
1028
|
PrefetchDelayAt
|
1040
1029
|
An object that provides access to the variable at the specified delay time.
|
1041
1030
|
"""
|
1042
|
-
return PrefetchDelayAt(self.module, self.item,
|
1031
|
+
return PrefetchDelayAt(self.module, self.item, delay_time)
|
1043
1032
|
|
1044
1033
|
|
1045
1034
|
class PrefetchDelayAt(Node):
|
@@ -1075,7 +1064,8 @@ class PrefetchDelayAt(Node):
|
|
1075
1064
|
self,
|
1076
1065
|
module: Dynamics,
|
1077
1066
|
item: str,
|
1078
|
-
|
1067
|
+
delay_time: Tuple,
|
1068
|
+
init: Callable = None
|
1079
1069
|
):
|
1080
1070
|
"""
|
1081
1071
|
Initialize a PrefetchDelayAt object.
|
@@ -1086,24 +1076,27 @@ class PrefetchDelayAt(Node):
|
|
1086
1076
|
The dynamics module that contains the referenced state or variable.
|
1087
1077
|
item : str
|
1088
1078
|
The name of the state or variable to access with delay.
|
1089
|
-
|
1090
|
-
The amount of time to delay access by, typically in time units.
|
1079
|
+
delay_time : Tuple
|
1080
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1091
1081
|
"""
|
1092
1082
|
super().__init__()
|
1093
|
-
assert isinstance(module, Dynamics), ''
|
1083
|
+
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1094
1084
|
self.module = module
|
1095
1085
|
self.item = item
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
# register the delay
|
1086
|
+
if not isinstance(delay_time, (tuple, list)):
|
1087
|
+
delay_time = (delay_time,)
|
1088
|
+
self.delay_time = delay_time
|
1089
|
+
if len(delay_time) > 0:
|
1102
1090
|
key = _get_prefetch_delay_key(item)
|
1103
1091
|
if not module._has_after_update(key):
|
1104
|
-
module._add_after_update(
|
1092
|
+
module._add_after_update(
|
1093
|
+
key,
|
1094
|
+
not_receive_update_output(
|
1095
|
+
StateWithDelay(module, item, init=init)
|
1096
|
+
)
|
1097
|
+
)
|
1105
1098
|
self.state_delay: StateWithDelay = module._get_after_update(key)
|
1106
|
-
self.state_delay.register_delay(
|
1099
|
+
self.delay_info = self.state_delay.register_delay(*delay_time)
|
1107
1100
|
|
1108
1101
|
def __call__(self, *args, **kwargs):
|
1109
1102
|
"""
|
@@ -1114,10 +1107,10 @@ class PrefetchDelayAt(Node):
|
|
1114
1107
|
Any
|
1115
1108
|
The value of the state or variable at the specified delay time.
|
1116
1109
|
"""
|
1117
|
-
if self.
|
1110
|
+
if len(self.delay_time) == 0:
|
1118
1111
|
return _get_prefetch_item(self).value
|
1119
1112
|
else:
|
1120
|
-
return self.state_delay.retrieve_at_step(self.
|
1113
|
+
return self.state_delay.retrieve_at_step(*self.delay_info)
|
1121
1114
|
|
1122
1115
|
|
1123
1116
|
class OutputDelayAt(Node):
|
@@ -1150,32 +1143,20 @@ class OutputDelayAt(Node):
|
|
1150
1143
|
def __init__(
|
1151
1144
|
self,
|
1152
1145
|
module: Dynamics,
|
1153
|
-
|
1154
|
-
variable_like: Optional[PyTree] = None,
|
1146
|
+
delay_time: Tuple,
|
1155
1147
|
):
|
1156
1148
|
super().__init__()
|
1157
1149
|
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1158
1150
|
self.module = module
|
1159
|
-
dt = environ.get_dt()
|
1160
|
-
if time is None:
|
1161
|
-
time = u.math.zeros_like(dt)
|
1162
|
-
self.time = time
|
1163
|
-
self.step = u.math.asarray(time / dt, dtype=environ.ditype())
|
1164
|
-
|
1165
|
-
# register the delay
|
1166
1151
|
key = _get_output_delay_key()
|
1167
1152
|
if not module._has_after_update(key):
|
1168
|
-
delay = Delay(
|
1169
|
-
jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
|
1170
|
-
time,
|
1171
|
-
take_aware_unit=True
|
1172
|
-
)
|
1153
|
+
delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
|
1173
1154
|
module._add_after_update(key, receive_update_output(delay))
|
1174
1155
|
self.out_delay: Delay = module._get_after_update(key)
|
1175
|
-
self.out_delay.register_delay(
|
1156
|
+
self.delay_info = self.out_delay.register_delay(*delay_time)
|
1176
1157
|
|
1177
1158
|
def __call__(self, *args, **kwargs):
|
1178
|
-
return self.out_delay.retrieve_at_step(self.
|
1159
|
+
return self.out_delay.retrieve_at_step(*self.delay_info)
|
1179
1160
|
|
1180
1161
|
|
1181
1162
|
def _get_prefetch_delay_key(item) -> str:
|
@@ -1242,8 +1223,9 @@ def maybe_init_prefetch(target, *args, **kwargs):
|
|
1242
1223
|
_get_prefetch_item_delay(target)
|
1243
1224
|
|
1244
1225
|
elif isinstance(target, PrefetchDelayAt):
|
1245
|
-
|
1246
|
-
delay
|
1226
|
+
pass
|
1227
|
+
# delay = _get_prefetch_item_delay(target)
|
1228
|
+
# delay.register_delay(*target.delay_time)
|
1247
1229
|
|
1248
1230
|
|
1249
1231
|
class DynamicsGroup(Module):
|
brainstate/nn/_elementwise.py
CHANGED
@@ -19,9 +19,8 @@ from typing import Optional
|
|
19
19
|
|
20
20
|
import brainunit as u
|
21
21
|
import jax.numpy as jnp
|
22
|
-
import jax.typing
|
23
22
|
|
24
|
-
from brainstate import
|
23
|
+
from brainstate import functional as F
|
25
24
|
from brainstate._state import ParamState
|
26
25
|
from brainstate.typing import ArrayLike
|
27
26
|
from ._module import ElementWiseBlock
|
brainstate/nn/_linear_mv.py
CHANGED
@@ -15,11 +15,11 @@
|
|
15
15
|
|
16
16
|
from typing import Union, Callable, Optional
|
17
17
|
|
18
|
+
import brainevent
|
18
19
|
import brainunit as u
|
19
20
|
import jax
|
20
21
|
|
21
22
|
from brainstate import init
|
22
|
-
import brainevent
|
23
23
|
from brainstate._state import ParamState
|
24
24
|
from brainstate.typing import Size, ArrayLike
|
25
25
|
from ._module import Module
|
brainstate/nn/_module_test.py
CHANGED
@@ -15,172 +15,9 @@
|
|
15
15
|
|
16
16
|
import unittest
|
17
17
|
|
18
|
-
import jax.numpy as jnp
|
19
|
-
|
20
18
|
import brainstate
|
21
19
|
|
22
20
|
|
23
|
-
class TestDelay(unittest.TestCase):
|
24
|
-
def test_delay1(self):
|
25
|
-
a = brainstate.State(brainstate.random.random(10, 20))
|
26
|
-
delay = brainstate.nn.Delay(a.value)
|
27
|
-
delay.register_entry('a', 1.)
|
28
|
-
delay.register_entry('b', 2.)
|
29
|
-
delay.register_entry('c', None)
|
30
|
-
|
31
|
-
delay.init_state()
|
32
|
-
with self.assertRaises(KeyError):
|
33
|
-
delay.register_entry('c', 10.)
|
34
|
-
|
35
|
-
def test_rotation_delay(self):
|
36
|
-
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
37
|
-
t0 = 0.
|
38
|
-
t1, n1 = 1., 10
|
39
|
-
t2, n2 = 2., 20
|
40
|
-
|
41
|
-
rotation_delay.register_entry('a', t0)
|
42
|
-
rotation_delay.register_entry('b', t1)
|
43
|
-
rotation_delay.register_entry('c2', 1.9)
|
44
|
-
rotation_delay.register_entry('c', t2)
|
45
|
-
|
46
|
-
rotation_delay.init_state()
|
47
|
-
|
48
|
-
print()
|
49
|
-
# print(rotation_delay)
|
50
|
-
# print(rotation_delay.max_length)
|
51
|
-
|
52
|
-
for i in range(100):
|
53
|
-
brainstate.environ.set(i=i)
|
54
|
-
rotation_delay.update(jnp.ones((1,)) * i)
|
55
|
-
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
56
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
57
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
58
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
59
|
-
|
60
|
-
def test_concat_delay(self):
|
61
|
-
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
62
|
-
t0 = 0.
|
63
|
-
t1, n1 = 1., 10
|
64
|
-
t2, n2 = 2., 20
|
65
|
-
|
66
|
-
rotation_delay.register_entry('a', t0)
|
67
|
-
rotation_delay.register_entry('b', t1)
|
68
|
-
rotation_delay.register_entry('c', t2)
|
69
|
-
|
70
|
-
rotation_delay.init_state()
|
71
|
-
|
72
|
-
print()
|
73
|
-
for i in range(100):
|
74
|
-
brainstate.environ.set(i=i)
|
75
|
-
rotation_delay.update(jnp.ones((1,)) * i)
|
76
|
-
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
77
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
78
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
79
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
80
|
-
# brainstate.util.clear_buffer_memory()
|
81
|
-
|
82
|
-
def test_jit_erro(self):
|
83
|
-
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
84
|
-
rotation_delay.init_state()
|
85
|
-
|
86
|
-
with brainstate.environ.context(i=0, t=0, jit_error_check=True):
|
87
|
-
rotation_delay.retrieve_at_time(-2.0)
|
88
|
-
with self.assertRaises(Exception):
|
89
|
-
rotation_delay.retrieve_at_time(-2.1)
|
90
|
-
rotation_delay.retrieve_at_time(-2.01)
|
91
|
-
with self.assertRaises(Exception):
|
92
|
-
rotation_delay.retrieve_at_time(-2.09)
|
93
|
-
with self.assertRaises(Exception):
|
94
|
-
rotation_delay.retrieve_at_time(0.1)
|
95
|
-
with self.assertRaises(Exception):
|
96
|
-
rotation_delay.retrieve_at_time(0.01)
|
97
|
-
|
98
|
-
def test_round_interp(self):
|
99
|
-
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
100
|
-
for delay_method in ['rotation', 'concat']:
|
101
|
-
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
102
|
-
interp_method='round')
|
103
|
-
t0, n1 = 0.01, 0
|
104
|
-
t1, n1 = 1.04, 10
|
105
|
-
t2, n2 = 1.06, 11
|
106
|
-
rotation_delay.init_state()
|
107
|
-
|
108
|
-
@brainstate.compile.jit
|
109
|
-
def retrieve(td, i):
|
110
|
-
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
111
|
-
return rotation_delay.retrieve_at_time(td)
|
112
|
-
|
113
|
-
print()
|
114
|
-
for i in range(100):
|
115
|
-
t = i * brainstate.environ.get_dt()
|
116
|
-
with brainstate.environ.context(i=i, t=t):
|
117
|
-
rotation_delay.update(jnp.ones(shape) * i)
|
118
|
-
print(i,
|
119
|
-
retrieve(t - t0, i),
|
120
|
-
retrieve(t - t1, i),
|
121
|
-
retrieve(t - t2, i))
|
122
|
-
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
123
|
-
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
124
|
-
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
125
|
-
|
126
|
-
def test_linear_interp(self):
|
127
|
-
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
128
|
-
for delay_method in ['rotation', 'concat']:
|
129
|
-
print(shape, delay_method)
|
130
|
-
|
131
|
-
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
132
|
-
interp_method='linear_interp')
|
133
|
-
t0, n0 = 0.01, 0.1
|
134
|
-
t1, n1 = 1.04, 10.4
|
135
|
-
t2, n2 = 1.06, 10.6
|
136
|
-
rotation_delay.init_state()
|
137
|
-
|
138
|
-
@brainstate.compile.jit
|
139
|
-
def retrieve(td, i):
|
140
|
-
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
141
|
-
return rotation_delay.retrieve_at_time(td)
|
142
|
-
|
143
|
-
print()
|
144
|
-
for i in range(100):
|
145
|
-
t = i * brainstate.environ.get_dt()
|
146
|
-
with brainstate.environ.context(i=i, t=t):
|
147
|
-
rotation_delay.update(jnp.ones(shape) * i)
|
148
|
-
print(i,
|
149
|
-
retrieve(t - t0, i),
|
150
|
-
retrieve(t - t1, i),
|
151
|
-
retrieve(t - t2, i))
|
152
|
-
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
153
|
-
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
154
|
-
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
155
|
-
|
156
|
-
def test_rotation_and_concat_delay(self):
|
157
|
-
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
158
|
-
concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
159
|
-
t0 = 0.
|
160
|
-
t1, n1 = 1., 10
|
161
|
-
t2, n2 = 2., 20
|
162
|
-
|
163
|
-
rotation_delay.register_entry('a', t0)
|
164
|
-
rotation_delay.register_entry('b', t1)
|
165
|
-
rotation_delay.register_entry('c', t2)
|
166
|
-
concat_delay.register_entry('a', t0)
|
167
|
-
concat_delay.register_entry('b', t1)
|
168
|
-
concat_delay.register_entry('c', t2)
|
169
|
-
|
170
|
-
rotation_delay.init_state()
|
171
|
-
concat_delay.init_state()
|
172
|
-
|
173
|
-
print()
|
174
|
-
for i in range(100):
|
175
|
-
brainstate.environ.set(i=i)
|
176
|
-
new = jnp.ones((1,)) * i
|
177
|
-
rotation_delay.update(new)
|
178
|
-
concat_delay.update(new)
|
179
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
180
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
181
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
182
|
-
|
183
|
-
|
184
21
|
class TestModule(unittest.TestCase):
|
185
22
|
def test_states(self):
|
186
23
|
class A(brainstate.nn.Module):
|
@@ -201,8 +38,3 @@ class TestModule(unittest.TestCase):
|
|
201
38
|
print(b.states())
|
202
39
|
print(b.states(level=0))
|
203
40
|
print(b.states(level=0))
|
204
|
-
|
205
|
-
|
206
|
-
if __name__ == '__main__':
|
207
|
-
with brainstate.environ.context(dt=0.1):
|
208
|
-
unittest.main()
|
brainstate/nn/_synapse.py
CHANGED
@@ -23,7 +23,7 @@ import brainunit as u
|
|
23
23
|
from brainstate import init, environ
|
24
24
|
from brainstate._state import ShortTermState, HiddenState
|
25
25
|
from brainstate.mixin import AlignPost
|
26
|
-
from brainstate.typing import ArrayLike, Size
|
26
|
+
from brainstate.typing import ArrayLike, Size
|
27
27
|
from ._dynamics import Dynamics
|
28
28
|
from ._exp_euler import exp_euler_step
|
29
29
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=bM6vJtv8DWVcxOUQ0GPpkCzmRkdTREdTutkSVv0OziQ,1496
|
2
2
|
brainstate/_compatible_import.py,sha256=LUSZlA0APWozxM8Kf9pZrM2YbwY7X3jzVVHZInaBL7Y,4630
|
3
3
|
brainstate/_state.py,sha256=PLzoYx13jIgnzyBxnktLTVPCFl6seG__aNIHS9A5Nms,60770
|
4
4
|
brainstate/_state_test.py,sha256=b6uvZdVRyC4n6-fYzmHNry1b-gJ6zE_kRSxGinqiHaw,1638
|
@@ -31,8 +31,8 @@ brainstate/compile/_loop_collect_return.py,sha256=qOYGoD2TMZd2Px6y241GWNwSjOLyaK
|
|
31
31
|
brainstate/compile/_loop_collect_return_test.py,sha256=pEJdcOthEM17q5kXYhgR6JfzzqibijD_O1VPXVe_Ml4,1808
|
32
32
|
brainstate/compile/_loop_no_collection.py,sha256=tPkSxee41VexWEALpN1uuT78BDrX3uT1FHv8WZtov4c,7549
|
33
33
|
brainstate/compile/_loop_no_collection_test.py,sha256=ivavF59xep_g9bV1SSdXd5E1f6nhc7EUBfcbgHpxbfg,1419
|
34
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
35
|
-
brainstate/compile/_make_jaxpr_test.py,sha256=
|
34
|
+
brainstate/compile/_make_jaxpr.py,sha256=BOJa3H1Fa0Esh3Al3CDaS351w45vk8xPDc9R_nk8jh8,37367
|
35
|
+
brainstate/compile/_make_jaxpr_test.py,sha256=_aRwYt932dXlcntFNYlm5asRRE97oGp4F8hrgNO63SY,5164
|
36
36
|
brainstate/compile/_progress_bar.py,sha256=AhAyyI_ckzgaj0PSj1ep1hq8rGzQLSyC2aVCYaT-e-o,7502
|
37
37
|
brainstate/compile/_unvmap.py,sha256=9S42MeTmFJa8nfBI_AjEfrAdUsDmM3KFat59O8zXIEw,4120
|
38
38
|
brainstate/compile/_util.py,sha256=QD7lvS4Zb3P2HPNIm6mG8Rl_Lk2tQj0znOCPcH95XnI,6275
|
@@ -60,12 +60,13 @@ brainstate/nn/_collective_ops_test.py,sha256=bwq0DApcsk0_2xpxMl0_e2cGKT63g5rSngp
|
|
60
60
|
brainstate/nn/_common.py,sha256=qHAOID_eeKiPUXk_ION65sbYXyF-ddH5w5BayvH8Thg,6431
|
61
61
|
brainstate/nn/_conv.py,sha256=Zk-yj34n6CkjntcM9xpMGLTxKNfWdIWsTsoGbtdL0yU,18448
|
62
62
|
brainstate/nn/_conv_test.py,sha256=2lcUTG7twkyhuyKwuBux-NgU8NU_W4Cp1-G8EyDJ_uk,8862
|
63
|
-
brainstate/nn/_delay.py,sha256=
|
63
|
+
brainstate/nn/_delay.py,sha256=_FFKQ09-NHLyYPcfuvoHzpL2zc9Twu_p3k33VDbgxKo,19523
|
64
|
+
brainstate/nn/_delay_test.py,sha256=O5eY-LTuzxLi6hjsC-lC12mhYe8xL36cWzvEwnG7h4g,8187
|
64
65
|
brainstate/nn/_dropout.py,sha256=Dq3hQrOBT6gODlDbcoag6zLGXTU_p_2MhnmVsV57Hds,17783
|
65
66
|
brainstate/nn/_dropout_test.py,sha256=L46PvC2OA7EnS4MsRhh_YnvKheHYNafOsKM8uzux_zo,4446
|
66
|
-
brainstate/nn/_dynamics.py,sha256=
|
67
|
+
brainstate/nn/_dynamics.py,sha256=QVCQqdglR1p-vMFzXvvEWfK1ltg4PJhUxSEEbbFP95s,48236
|
67
68
|
brainstate/nn/_dynamics_test.py,sha256=w7AV57LdhbBNYprdFpKq8MFSCbXKVkGgp_NbL3ANX3I,2769
|
68
|
-
brainstate/nn/_elementwise.py,sha256=
|
69
|
+
brainstate/nn/_elementwise.py,sha256=_9O4S-gBZiTmsx5lJ3WdQVLcACBmXTwPTY8pnUoK7KY,33460
|
69
70
|
brainstate/nn/_elementwise_test.py,sha256=_dd9eX2ZJ7p24ahuoapCaRTZ0g1boufXMyqHFx1d4WY,5688
|
70
71
|
brainstate/nn/_embedding.py,sha256=SaAJbgXmuJ8XlCOX9ob4yvmgh9Fk627wMguRzJMJ1H8,2138
|
71
72
|
brainstate/nn/_exp_euler.py,sha256=WTpZm-XQmsdMLNazY7wIu8eeO6pK0kRzt2lJnhEgMIk,3293
|
@@ -74,12 +75,12 @@ brainstate/nn/_fixedprob.py,sha256=KGXohiU0wZnFIQDuwiRUTFsbsr8R0p8zgi5UZDuv1Bk,1
|
|
74
75
|
brainstate/nn/_fixedprob_test.py,sha256=qbRBh-MpMtEOsg492gFu2w9-FOP9z_bXapm-Q0gLLYM,3929
|
75
76
|
brainstate/nn/_inputs.py,sha256=hMPkx9qDBpJWPshZXLF4H1QiYK1-46wntHUIlG7cT7c,20603
|
76
77
|
brainstate/nn/_linear.py,sha256=FnPxATdT66DecjTW0tUTyL6clQmN_cG8kPC3qDWUE6A,14500
|
77
|
-
brainstate/nn/_linear_mv.py,sha256=
|
78
|
+
brainstate/nn/_linear_mv.py,sha256=DyJ0OOHcvuJrBvBhV0CbmVgw-6BCg_ms-Rso9eB5e1E,2635
|
78
79
|
brainstate/nn/_linear_mv_test.py,sha256=ZCM1Zy6mImQfCfdZOGnTwkiLLPXK5yalv1Ts9sWZuPA,3864
|
79
80
|
brainstate/nn/_linear_test.py,sha256=eIS-VCR3QmXB_byO1Uexg65Pv48CBRUA_Je-UGrFVTY,2925
|
80
81
|
brainstate/nn/_ltp.py,sha256=_najNUyfaFYcOlUTm7ThJopInbos3kwJyrm-QUfI-hc,861
|
81
82
|
brainstate/nn/_module.py,sha256=jlFaoltT2F8a7cxsyu6fXp7mfIkEsk-JD3G1Xh3Ay8I,12783
|
82
|
-
brainstate/nn/_module_test.py,sha256=
|
83
|
+
brainstate/nn/_module_test.py,sha256=v4ZHwjvrEaqHzCgyR8MKIB-iNhdtaCLOc55EORXTjPg,1451
|
83
84
|
brainstate/nn/_neuron.py,sha256=2walTScvL034LS53pArDASXz6z26SSPbmCvchWWjkUU,27441
|
84
85
|
brainstate/nn/_neuron_test.py,sha256=QF8pixUqA5Oj7MrNi2NR8VAnfGpAvNpwV2mBc3e_pTY,6393
|
85
86
|
brainstate/nn/_normalizations.py,sha256=YSC1W7JaexoZ8tKcy5B-dm-_x8GuPzS_6XBfkaKpdXM,37464
|
@@ -92,7 +93,7 @@ brainstate/nn/_rate_rnns_test.py,sha256=__hhx7e6LX_1mDLLQyIi4TNCaFAWnOVSTIgwHNjz
|
|
92
93
|
brainstate/nn/_readout.py,sha256=OJjSba5Wr7dtUXqYhAv1D7BUGOI-lAmg6urxPBrZe3c,7116
|
93
94
|
brainstate/nn/_readout_test.py,sha256=L2T0-SkiACxkY_I5Pbnbmy0Zw3tbpV3l5xVzAw42f2g,2136
|
94
95
|
brainstate/nn/_stp.py,sha256=-ahDEqSp8bQsU_nUK4jks8fjMYKgIbO0v7zpyGVuXtA,8645
|
95
|
-
brainstate/nn/_synapse.py,sha256=
|
96
|
+
brainstate/nn/_synapse.py,sha256=I7npO9rg6bvzCiCNLiTqEOxUjGaDVlkOgmkT9M1V4xY,19860
|
96
97
|
brainstate/nn/_synapse_test.py,sha256=xmCWFxZUIM2YtmW5otKnADGCCK__4JpXmSYcZ3wzlQM,4994
|
97
98
|
brainstate/nn/_synaptic_projection.py,sha256=UFgzsMB1VZ9ieumwmaTIC7irLrZ4pmwfiuOYomkwXG4,17917
|
98
99
|
brainstate/nn/_synouts.py,sha256=jWQP1-qXFpdYgyUSJNFD7_bk4_-67ok36br-OzbcSXY,4524
|
@@ -124,8 +125,8 @@ brainstate/util/pretty_repr.py,sha256=7Xp7IFNUeP7cGlpvwwJyBslbQVnXEqC1I6neV1Jx1S
|
|
124
125
|
brainstate/util/pretty_table.py,sha256=uJVaamFGQ4nKP8TkEGPWXHpzjMecDo2q1Ah6XtRjdPY,108117
|
125
126
|
brainstate/util/scaling.py,sha256=U6DM-afPrLejiGqo1Nla7z4YbTBVicctsBEweurr_mk,7524
|
126
127
|
brainstate/util/struct.py,sha256=2Y_wuDFQ6ldl_H4_w0IjzAtkbHooVgdsVbnT7Z6_Efc,17528
|
127
|
-
brainstate-0.1.
|
128
|
-
brainstate-0.1.
|
129
|
-
brainstate-0.1.
|
130
|
-
brainstate-0.1.
|
131
|
-
brainstate-0.1.
|
128
|
+
brainstate-0.1.8.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
129
|
+
brainstate-0.1.8.dist-info/METADATA,sha256=85Nudry-TozOxctIhW8zx7DUgOxPVmML3XCq7YqYDFs,4122
|
130
|
+
brainstate-0.1.8.dist-info/WHEEL,sha256=AHX6tWk3qWuce7vKLrj7lnulVHEdWoltgauo8bgCXgU,109
|
131
|
+
brainstate-0.1.8.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
132
|
+
brainstate-0.1.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|