brainstate 0.1.7__py2.py3-none-any.whl → 0.1.8__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.7"
20
+ __version__ = "0.1.8"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -352,11 +352,13 @@ class StatefulFunction(PrettyObject):
352
352
  cache_key = default_cache_key
353
353
  return self.get_state_trace(cache_key).get_write_states()
354
354
 
355
- def _check_input(self, x):
355
+ def _check_input_ouput(self, x):
356
356
  if isinstance(x, State):
357
- raise ValueError(
358
- 'Inputs for brainstate transformations cannot be an instance of State. '
359
- f'But we got {x}'
357
+ x.raise_error_with_source_info(
358
+ ValueError(
359
+ 'Inputs/outputs for brainstate transformations cannot be an instance of State. '
360
+ f'But we got {x}'
361
+ )
360
362
  )
361
363
 
362
364
  def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
@@ -477,9 +479,7 @@ class StatefulFunction(PrettyObject):
477
479
 
478
480
  # State instance as functional returns is not allowed.
479
481
  # Checking whether the states are returned.
480
- for leaf in jax.tree.leaves(out):
481
- if isinstance(leaf, State):
482
- leaf.raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
482
+ jax.tree.map(self._check_input_ouput, out, is_leaf=lambda x: isinstance(x, State))
483
483
  return out, state_values
484
484
 
485
485
  def make_jaxpr(self, *args, return_only_write: bool = False, **kwargs):
@@ -501,7 +501,7 @@ class StatefulFunction(PrettyObject):
501
501
  cache_key = self.get_arg_cache_key(*args, **kwargs)
502
502
 
503
503
  # check input types
504
- jax.tree.map(self._check_input, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
504
+ jax.tree.map(self._check_input_ouput, (args, kwargs), is_leaf=lambda x: isinstance(x, State))
505
505
 
506
506
  if cache_key not in self._cached_state_trace:
507
507
  try:
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
17
16
 
18
17
  import unittest
19
18
 
@@ -132,15 +131,26 @@ class TestMakeJaxpr(unittest.TestCase):
132
131
  print(jaxpr)
133
132
  print(states)
134
133
 
134
+ def test_state_in(self):
135
+ def f(a):
136
+ return a.value
135
137
 
136
- def test_return_states():
137
- import jax.numpy
138
+ with pytest.raises(ValueError):
139
+ brainstate.compile.StatefulFunction(f).make_jaxpr(brainstate.State(1.))
138
140
 
139
- a = brainstate.State(jax.numpy.ones(3))
141
+ def test_state_out(self):
142
+ def f(a):
143
+ return brainstate.State(a)
140
144
 
141
- @brainstate.compile.jit
142
- def f():
143
- return a
145
+ with pytest.raises(ValueError):
146
+ brainstate.compile.StatefulFunction(f).make_jaxpr(1.)
144
147
 
145
- with pytest.raises(ValueError):
146
- f()
148
+ def test_return_states(self):
149
+ a = brainstate.State(jnp.ones(3))
150
+
151
+ @brainstate.compile.jit
152
+ def f():
153
+ return a
154
+
155
+ with pytest.raises(ValueError):
156
+ f()
brainstate/nn/_delay.py CHANGED
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import math
17
16
  import numbers
18
17
  from functools import partial
19
18
  from typing import Optional, Dict, Callable, Union, Sequence
@@ -41,34 +40,27 @@ _INTERP_LINEAR = 'linear_interp'
41
40
  _INTERP_ROUND = 'round'
42
41
 
43
42
 
44
- def _get_delay(delay_time, delay_step):
45
- if delay_time is None:
46
- if delay_step is None:
47
- return 0., 0
48
- else:
49
- assert isinstance(delay_step, int), '"delay_step" should be an integer.'
50
- if delay_step == 0:
51
- return 0., 0
52
- with jax.ensure_compile_time_eval():
53
- delay_time = delay_step * environ.get_dt()
54
- else:
55
- assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
56
- # assert isinstance(delay_time, (int, float))
57
- with jax.ensure_compile_time_eval():
58
- delay_step = delay_time / environ.get_dt()
59
- delay_step = math.ceil(float(delay_step))
43
+ def _get_delay(delay_time):
44
+ with jax.ensure_compile_time_eval():
45
+ if delay_time is None:
46
+ return 0. * environ.get_dt(), 0
47
+ delay_step = delay_time / environ.get_dt()
48
+ delay_step = jnp.ceil(delay_step).astype(environ.ditype())
60
49
  return delay_time, delay_step
61
50
 
62
51
 
63
52
  class DelayAccess(Node):
64
53
  """
65
- The delay access class.
54
+ Accessor node for a registered entry in a Delay instance.
55
+
56
+ This node holds a reference to a Delay and a named entry that was
57
+ registered on that Delay. It is used by graphs to query delayed
58
+ values by delegating to the underlying Delay instance.
66
59
 
67
60
  Args:
68
- delay: The delay instance.
69
- time: The delay time.
70
- indices: The indices of the delay data.
71
- delay_entry: The delay entry.
61
+ delay: The delay instance.
62
+ time: The delay time.
63
+ delay_entry: The delay entry.
72
64
  """
73
65
 
74
66
  __module__ = 'brainstate.nn'
@@ -78,17 +70,15 @@ class DelayAccess(Node):
78
70
  delay: 'Delay',
79
71
  time: Union[None, int, float],
80
72
  delay_entry: str,
81
- *indices,
82
73
  ):
83
74
  super().__init__()
84
75
  self.refs = {'delay': delay}
85
76
  assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
86
77
  self._delay_entry = delay_entry
87
- delay.register_entry(self._delay_entry, time)
88
- self.indices = indices
78
+ self.delay_info = delay.register_entry(self._delay_entry, time)
89
79
 
90
80
  def update(self):
91
- return self.refs['delay'].at(self._delay_entry, *self.indices)
81
+ return self.refs['delay'].at(self._delay_entry)
92
82
 
93
83
 
94
84
  class Delay(Module):
@@ -141,17 +131,21 @@ class Delay(Module):
141
131
  self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
142
132
 
143
133
  # delay method
144
- assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
145
- f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
134
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
135
+ f'Un-supported delay method {delay_method}. '
136
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
137
+ )
146
138
  self.delay_method = delay_method
147
139
 
148
140
  # interp method
149
- assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
150
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
141
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
142
+ f'Un-supported interpolation method {interp_method}. '
143
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
144
+ )
151
145
  self.interp_method = interp_method
152
146
 
153
147
  # delay length and time
154
- self.max_time, delay_length = _get_delay(time, None)
148
+ self.max_time, delay_length = _get_delay(time)
155
149
  self.max_length = delay_length + 1
156
150
 
157
151
  super().__init__()
@@ -169,7 +163,10 @@ class Delay(Module):
169
163
  # other info
170
164
  if entries is not None:
171
165
  for entry, delay_time in entries.items():
172
- self.register_entry(entry, delay_time)
166
+ if isinstance(delay_time, (tuple, list)):
167
+ self.register_entry(entry, *delay_time)
168
+ else:
169
+ self.register_entry(entry, delay_time)
173
170
 
174
171
  self.take_aware_unit = take_aware_unit
175
172
  self._unit = None
@@ -205,73 +202,107 @@ class Delay(Module):
205
202
  fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
206
203
  self.history.value = jax.tree.map(fun, self.target_info)
207
204
 
208
- def register_delay(
209
- self,
210
- delay_time: Optional[Union[int, float]] = None,
211
- delay_step: Optional[int] = None,
212
- ):
213
- if isinstance(delay_time, (np.ndarray, jax.Array)):
214
- assert delay_time.size == 1 and delay_time.ndim == 0
215
- delay_time = delay_time.item()
205
+ def register_delay(self, *delay_time):
206
+ """
207
+ Register delay times and update the maximum delay configuration.
208
+
209
+ This method processes one or more delay times, validates their format and consistency,
210
+ and updates the delay buffer size if necessary. It handles both scalar and vector
211
+ delay times, ensuring all vector delays have the same size.
216
212
 
217
- _, delay_step = _get_delay(delay_time, delay_step)
213
+ Args:
214
+ *delay_time: Variable number of delay time arguments. The first argument should be
215
+ the primary delay time (float, int, or array-like). Additional arguments are
216
+ treated as indices or secondary delay parameters. All delay times should be
217
+ non-negative numbers or arrays of the same size.
218
+
219
+ Returns:
220
+ tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
221
+ containing (delay_step, *delay_time[1:]) where delay_step is the computed
222
+ delay step in integer time units, and the remaining elements are the
223
+ additional delay parameters passed in.
224
+
225
+ Raises:
226
+ AssertionError: If no delay time is provided (empty delay_time).
227
+ ValueError: If delay times have inconsistent sizes when using vector delays,
228
+ or if delay times are not scalar or 1D arrays.
229
+
230
+ Note:
231
+ - The method updates self.max_time and self.max_length if the new delay
232
+ requires a larger buffer size.
233
+ - Delay steps are computed using the current environment time step (dt).
234
+ - All delay indices (delay_time[1:]) must be integers.
235
+ - Vector delays must all have the same size as the first delay time.
236
+
237
+ Example:
238
+ >>> delay_obj.register_delay(5.0) # Register 5ms delay
239
+ >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
240
+ """
241
+ assert len(delay_time) >= 1, 'You should provide at least one delay time.'
242
+ delay_size = u.math.size(delay_time[0])
243
+ for dt in delay_time[1:]:
244
+ assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
245
+ for dt in delay_time:
246
+ if u.math.ndim(dt) == 0:
247
+ pass
248
+ elif u.math.ndim(dt) == 1:
249
+ if u.math.size(dt) != delay_size:
250
+ raise ValueError(
251
+ f'The delay time should be a scalar or a vector with the same size. '
252
+ f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
253
+ )
254
+ else:
255
+ raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
256
+ if delay_time[0] is None:
257
+ return None
258
+ time, delay_step = _get_delay(delay_time[0])
259
+ max_delay_step = jnp.max(delay_step)
260
+ self.max_time = u.math.max(time)
218
261
 
219
262
  # delay variable
220
- if self.max_length <= delay_step + 1:
221
- self.max_length = delay_step + 1
222
- self.max_time = delay_time
223
- return self
263
+ if self.max_length <= max_delay_step + 1:
264
+ self.max_length = max_delay_step + 1
265
+ return delay_step, *delay_time[1:]
224
266
 
225
- def register_entry(
226
- self,
227
- entry: str,
228
- delay_time: Optional[Union[int, float]] = None,
229
- delay_step: Optional[int] = None,
230
- ) -> 'Delay':
267
+ def register_entry(self, entry: str, *delay_time) -> 'Delay':
231
268
  """
232
269
  Register an entry to access the delay data.
233
270
 
234
271
  Args:
235
- entry: str. The entry to access the delay data.
236
- delay_time: The delay time of the entry (can be a float).
237
- delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
238
-
239
- Returns:
240
- Return the self.
272
+ entry: str. The entry to access the delay data.
273
+ delay_time: The delay time of the entry, the first element is the delay time,
274
+ the second and later element is the index.
241
275
  """
242
276
  if entry in self._registered_entries:
243
- raise KeyError(f'Entry {entry} has been registered. '
244
- f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
245
- f'The new delay for the key {entry} is {delay_time}. '
246
- f'You can use another key. ')
247
-
248
- if isinstance(delay_time, (np.ndarray, jax.Array)):
249
- assert delay_time.size == 1 and delay_time.ndim == 0
250
- delay_time = delay_time.item()
277
+ raise KeyError(
278
+ f'Entry {entry} has been registered. '
279
+ f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
280
+ f'The new delay for the key {entry} is {delay_time}. '
281
+ f'You can use another key. '
282
+ )
283
+ delay_info = self.register_delay(*delay_time)
284
+ self._registered_entries[entry] = delay_info
285
+ return delay_info
251
286
 
252
- _, delay_step = _get_delay(delay_time, delay_step)
287
+ def access(self, entry: str, delay_time: Sequence) -> DelayAccess:
288
+ """
289
+ Create a DelayAccess object for a specific delay entry and delay time.
253
290
 
254
- # delay variable
255
- if self.max_length <= delay_step + 1:
256
- self.max_length = delay_step + 1
257
- self.max_time = delay_time
258
- self._registered_entries[entry] = delay_step
259
- return self
291
+ Args:
292
+ entry (str): The name of the delay entry to access.
293
+ delay_time (Sequence): The delay time or parameters associated with the entry.
260
294
 
261
- def access(
262
- self,
263
- entry: str = None,
264
- time: Sequence = None,
265
- ) -> DelayAccess:
266
- return DelayAccess(self, time, delay_entry=entry)
295
+ Returns:
296
+ DelayAccess: An object that provides access to the delay data for the specified entry and time.
297
+ """
298
+ return DelayAccess(self, delay_time, delay_entry=entry)
267
299
 
268
- def at(self, entry: str, *indices) -> ArrayLike:
300
+ def at(self, entry: str) -> ArrayLike:
269
301
  """
270
302
  Get the data at the given entry.
271
303
 
272
304
  Args:
273
305
  entry: str. The entry to access the data.
274
- *indices: The slicing indices. Not include the slice at the batch dimension.
275
306
 
276
307
  Returns:
277
308
  The data.
@@ -282,8 +313,8 @@ class Delay(Module):
282
313
  raise KeyError(f'Does not find delay entry "{entry}".')
283
314
  delay_step = self._registered_entries[entry]
284
315
  if delay_step is None:
285
- delay_step = 0
286
- return self.retrieve_at_step(delay_step, *indices)
316
+ delay_step = (0,)
317
+ return self.retrieve_at_step(*delay_step)
287
318
 
288
319
  def retrieve_at_step(self, delay_step, *indices) -> PyTree:
289
320
  """
@@ -306,8 +337,10 @@ class Delay(Module):
306
337
 
307
338
  if environ.get(environ.JIT_ERROR_CHECK, False):
308
339
  def _check_delay(delay_len):
309
- raise ValueError(f'The request delay length should be less than the '
310
- f'maximum delay {self.max_length - 1}. But we got {delay_len}')
340
+ raise ValueError(
341
+ f'The request delay length should be less than the '
342
+ f'maximum delay {self.max_length - 1}. But we got {delay_len}'
343
+ )
311
344
 
312
345
  jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
313
346
 
@@ -363,13 +396,17 @@ class Delay(Module):
363
396
 
364
397
  if environ.get(environ.JIT_ERROR_CHECK, False):
365
398
  def _check_delay(t_now, t_delay):
366
- raise ValueError(f'The request delay time should be within '
367
- f'[{t_now - self.max_time - dt}, {t_now}], '
368
- f'but we got {t_delay}')
399
+ raise ValueError(
400
+ f'The request delay time should be within '
401
+ f'[{t_now - self.max_time - dt}, {t_now}], '
402
+ f'but we got {t_delay}'
403
+ )
369
404
 
370
405
  jit_error_if(
371
- jnp.logical_or(delay_time > current_time,
372
- delay_time < current_time - self.max_time - dt),
406
+ jnp.logical_or(
407
+ delay_time > current_time,
408
+ delay_time < current_time - self.max_time - dt
409
+ ),
373
410
  _check_delay,
374
411
  current_time,
375
412
  delay_time
@@ -385,10 +422,7 @@ class Delay(Module):
385
422
  return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
386
423
 
387
424
  elif self.interp_method == _INTERP_ROUND: # "round" interpolation
388
- return self.retrieve_at_step(
389
- jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
390
- *indices
391
- )
425
+ return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
392
426
 
393
427
  else: # raise error
394
428
  raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
@@ -429,8 +463,6 @@ class Delay(Module):
429
463
  raise ValueError(f'Unknown updating method "{self.delay_method}"')
430
464
 
431
465
 
432
-
433
-
434
466
  class StateWithDelay(Delay):
435
467
  """
436
468
  A ``State`` type that defines the state in a differential equation.
@@ -440,8 +472,8 @@ class StateWithDelay(Delay):
440
472
 
441
473
  state: State # state
442
474
 
443
- def __init__(self, target: Node, item: str):
444
- super().__init__(None)
475
+ def __init__(self, target: Node, item: str, init: Callable = None):
476
+ super().__init__(None, init=init)
445
477
 
446
478
  self._target = target
447
479
  self._target_term = item
@@ -0,0 +1,184 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate
22
+
23
+ brainstate.environ.set(dt=0.1)
24
+
25
+
26
+ class TestDelay(unittest.TestCase):
27
+ def test_delay1(self):
28
+ a = brainstate.State(brainstate.random.random(10, 20))
29
+ delay = brainstate.nn.Delay(a.value)
30
+ delay.register_entry('a', 1.)
31
+ delay.register_entry('b', 2.)
32
+ delay.register_entry('c', None)
33
+
34
+ delay.init_state()
35
+ with self.assertRaises(KeyError):
36
+ delay.register_entry('c', 10.)
37
+
38
+ def test_rotation_delay(self):
39
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
40
+ t0 = 0.
41
+ t1, n1 = 1., 10
42
+ t2, n2 = 2., 20
43
+
44
+ rotation_delay.register_entry('a', t0)
45
+ rotation_delay.register_entry('b', t1)
46
+ rotation_delay.register_entry('c2', 1.9)
47
+ rotation_delay.register_entry('c', t2)
48
+
49
+ rotation_delay.init_state()
50
+
51
+ print()
52
+ # print(rotation_delay)
53
+ # print(rotation_delay.max_length)
54
+
55
+ for i in range(100):
56
+ brainstate.environ.set(i=i)
57
+ rotation_delay.update(jnp.ones((1,)) * i)
58
+ # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
59
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
60
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
61
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
62
+
63
+ def test_concat_delay(self):
64
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
65
+ t0 = 0.
66
+ t1, n1 = 1., 10
67
+ t2, n2 = 2., 20
68
+
69
+ rotation_delay.register_entry('a', t0)
70
+ rotation_delay.register_entry('b', t1)
71
+ rotation_delay.register_entry('c', t2)
72
+
73
+ rotation_delay.init_state()
74
+
75
+ print()
76
+ for i in range(100):
77
+ brainstate.environ.set(i=i)
78
+ rotation_delay.update(jnp.ones((1,)) * i)
79
+ print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
80
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
81
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
82
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
83
+ # brainstate.util.clear_buffer_memory()
84
+
85
+ def test_jit_erro(self):
86
+ rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
87
+ rotation_delay.init_state()
88
+
89
+ with brainstate.environ.context(i=0, t=0, jit_error_check=True):
90
+ rotation_delay.retrieve_at_time(-2.0)
91
+ with self.assertRaises(Exception):
92
+ rotation_delay.retrieve_at_time(-2.1)
93
+ rotation_delay.retrieve_at_time(-2.01)
94
+ with self.assertRaises(Exception):
95
+ rotation_delay.retrieve_at_time(-2.09)
96
+ with self.assertRaises(Exception):
97
+ rotation_delay.retrieve_at_time(0.1)
98
+ with self.assertRaises(Exception):
99
+ rotation_delay.retrieve_at_time(0.01)
100
+
101
+ def test_round_interp(self):
102
+ for shape in [(1,), (1, 1), (1, 1, 1)]:
103
+ for delay_method in ['rotation', 'concat']:
104
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
105
+ interp_method='round')
106
+ t0, n1 = 0.01, 0
107
+ t1, n1 = 1.04, 10
108
+ t2, n2 = 1.06, 11
109
+ rotation_delay.init_state()
110
+
111
+ @brainstate.compile.jit
112
+ def retrieve(td, i):
113
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
114
+ return rotation_delay.retrieve_at_time(td)
115
+
116
+ print()
117
+ for i in range(100):
118
+ t = i * brainstate.environ.get_dt()
119
+ with brainstate.environ.context(i=i, t=t):
120
+ rotation_delay.update(jnp.ones(shape) * i)
121
+ print(i,
122
+ retrieve(t - t0, i),
123
+ retrieve(t - t1, i),
124
+ retrieve(t - t2, i))
125
+ self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
126
+ self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
127
+ self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
128
+
129
+ def test_linear_interp(self):
130
+ for shape in [(1,), (1, 1), (1, 1, 1)]:
131
+ for delay_method in ['rotation', 'concat']:
132
+ print(shape, delay_method)
133
+
134
+ rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
135
+ interp_method='linear_interp')
136
+ t0, n0 = 0.01, 0.1
137
+ t1, n1 = 1.04, 10.4
138
+ t2, n2 = 1.06, 10.6
139
+ rotation_delay.init_state()
140
+
141
+ @brainstate.compile.jit
142
+ def retrieve(td, i):
143
+ with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
144
+ return rotation_delay.retrieve_at_time(td)
145
+
146
+ print()
147
+ for i in range(100):
148
+ t = i * brainstate.environ.get_dt()
149
+ with brainstate.environ.context(i=i, t=t):
150
+ rotation_delay.update(jnp.ones(shape) * i)
151
+ print(i,
152
+ retrieve(t - t0, i),
153
+ retrieve(t - t1, i),
154
+ retrieve(t - t2, i))
155
+ self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
156
+ self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
157
+ self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
158
+
159
+ def test_rotation_and_concat_delay(self):
160
+ rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
161
+ concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
162
+ t0 = 0.
163
+ t1, n1 = 1., 10
164
+ t2, n2 = 2., 20
165
+
166
+ rotation_delay.register_entry('a', t0)
167
+ rotation_delay.register_entry('b', t1)
168
+ rotation_delay.register_entry('c', t2)
169
+ concat_delay.register_entry('a', t0)
170
+ concat_delay.register_entry('b', t1)
171
+ concat_delay.register_entry('c', t2)
172
+
173
+ rotation_delay.init_state()
174
+ concat_delay.init_state()
175
+
176
+ print()
177
+ for i in range(100):
178
+ brainstate.environ.set(i=i)
179
+ new = jnp.ones((1,)) * i
180
+ rotation_delay.update(new)
181
+ concat_delay.update(new)
182
+ self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
183
+ self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
184
+ self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
@@ -33,9 +33,8 @@ For handling the delays:
33
33
 
34
34
  """
35
35
 
36
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
36
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
37
37
 
38
- import brainunit as u
39
38
  import jax
40
39
  import numpy as np
41
40
 
@@ -43,7 +42,7 @@ from brainstate import environ
43
42
  from brainstate._state import State
44
43
  from brainstate.graph import Node
45
44
  from brainstate.mixin import ParamDescriber
46
- from brainstate.typing import Size, ArrayLike, PyTree
45
+ from brainstate.typing import Size, ArrayLike
47
46
  from ._delay import StateWithDelay, Delay
48
47
  from ._module import Module
49
48
 
@@ -794,10 +793,7 @@ class Dynamics(Module):
794
793
  """
795
794
  return Prefetch(self, item)
796
795
 
797
- def align_pre(
798
- self,
799
- dyn: Union[ParamDescriber[T], T]
800
- ) -> T:
796
+ def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
801
797
  """
802
798
  Registers a dynamics module to execute after this module.
803
799
 
@@ -844,11 +840,7 @@ class Dynamics(Module):
844
840
  else:
845
841
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
846
842
 
847
- def prefetch_delay(
848
- self,
849
- state: str,
850
- delay: Optional[ArrayLike] = None
851
- ) -> 'PrefetchDelayAt':
843
+ def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
852
844
  """
853
845
  Create a reference to a delayed state or variable in the module.
854
846
 
@@ -858,19 +850,17 @@ class Dynamics(Module):
858
850
 
859
851
  Args:
860
852
  state (str): The name of the state or variable to reference.
861
- delay (Optional[ArrayLike]): The amount of time to delay the variable access,
862
- typically in time units (e.g., milliseconds). Defaults to None.
853
+ delay_time (ArrayLike): The amount of time to delay the variable access,
854
+ typically in time units (e.g., milliseconds).
855
+ init (Callable, optional): An optional initialization function to provide
856
+ a default value if the delayed state is not yet available.
863
857
 
864
858
  Returns:
865
859
  PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
866
860
  """
867
- return self.prefetch(state).delay.at(delay)
861
+ return PrefetchDelayAt(self, state, delay_time, init=init)
868
862
 
869
- def output_delay(
870
- self,
871
- delay: Optional[ArrayLike] = None,
872
- variable_like: PyTree = None
873
- ) -> 'OutputDelayAt':
863
+ def output_delay(self, *delay_time) -> 'OutputDelayAt':
874
864
  """
875
865
  Create a reference to the delayed output of the module.
876
866
 
@@ -881,12 +871,11 @@ class Dynamics(Module):
881
871
  Args:
882
872
  delay (Optional[ArrayLike]): The amount of time to delay the output access,
883
873
  typically in time units (e.g., milliseconds). Defaults to None.
884
- variable_like:
885
874
 
886
875
  Returns:
887
876
  OutputDelayAt: An object that provides access to the module's output at the specified delay time.
888
877
  """
889
- return OutputDelayAt(self, delay)
878
+ return OutputDelayAt(self, delay_time)
890
879
 
891
880
 
892
881
  class Prefetch(Node):
@@ -1024,7 +1013,7 @@ class PrefetchDelay(Node):
1024
1013
  self.module = module
1025
1014
  self.item = item
1026
1015
 
1027
- def at(self, time: ArrayLike):
1016
+ def at(self, *delay_time):
1028
1017
  """
1029
1018
  Specifies the delay time for accessing the variable.
1030
1019
 
@@ -1039,7 +1028,7 @@ class PrefetchDelay(Node):
1039
1028
  PrefetchDelayAt
1040
1029
  An object that provides access to the variable at the specified delay time.
1041
1030
  """
1042
- return PrefetchDelayAt(self.module, self.item, time)
1031
+ return PrefetchDelayAt(self.module, self.item, delay_time)
1043
1032
 
1044
1033
 
1045
1034
  class PrefetchDelayAt(Node):
@@ -1075,7 +1064,8 @@ class PrefetchDelayAt(Node):
1075
1064
  self,
1076
1065
  module: Dynamics,
1077
1066
  item: str,
1078
- time: ArrayLike = None,
1067
+ delay_time: Tuple,
1068
+ init: Callable = None
1079
1069
  ):
1080
1070
  """
1081
1071
  Initialize a PrefetchDelayAt object.
@@ -1086,24 +1076,27 @@ class PrefetchDelayAt(Node):
1086
1076
  The dynamics module that contains the referenced state or variable.
1087
1077
  item : str
1088
1078
  The name of the state or variable to access with delay.
1089
- time : ArrayLike
1090
- The amount of time to delay access by, typically in time units.
1079
+ delay_time : Tuple
1080
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1091
1081
  """
1092
1082
  super().__init__()
1093
- assert isinstance(module, Dynamics), ''
1083
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1094
1084
  self.module = module
1095
1085
  self.item = item
1096
- self.time = time
1097
-
1098
- if time is not None:
1099
- self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1100
-
1101
- # register the delay
1086
+ if not isinstance(delay_time, (tuple, list)):
1087
+ delay_time = (delay_time,)
1088
+ self.delay_time = delay_time
1089
+ if len(delay_time) > 0:
1102
1090
  key = _get_prefetch_delay_key(item)
1103
1091
  if not module._has_after_update(key):
1104
- module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1092
+ module._add_after_update(
1093
+ key,
1094
+ not_receive_update_output(
1095
+ StateWithDelay(module, item, init=init)
1096
+ )
1097
+ )
1105
1098
  self.state_delay: StateWithDelay = module._get_after_update(key)
1106
- self.state_delay.register_delay(time)
1099
+ self.delay_info = self.state_delay.register_delay(*delay_time)
1107
1100
 
1108
1101
  def __call__(self, *args, **kwargs):
1109
1102
  """
@@ -1114,10 +1107,10 @@ class PrefetchDelayAt(Node):
1114
1107
  Any
1115
1108
  The value of the state or variable at the specified delay time.
1116
1109
  """
1117
- if self.time is None:
1110
+ if len(self.delay_time) == 0:
1118
1111
  return _get_prefetch_item(self).value
1119
1112
  else:
1120
- return self.state_delay.retrieve_at_step(self.step)
1113
+ return self.state_delay.retrieve_at_step(*self.delay_info)
1121
1114
 
1122
1115
 
1123
1116
  class OutputDelayAt(Node):
@@ -1150,32 +1143,20 @@ class OutputDelayAt(Node):
1150
1143
  def __init__(
1151
1144
  self,
1152
1145
  module: Dynamics,
1153
- time: Optional[ArrayLike] = None,
1154
- variable_like: Optional[PyTree] = None,
1146
+ delay_time: Tuple,
1155
1147
  ):
1156
1148
  super().__init__()
1157
1149
  assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1158
1150
  self.module = module
1159
- dt = environ.get_dt()
1160
- if time is None:
1161
- time = u.math.zeros_like(dt)
1162
- self.time = time
1163
- self.step = u.math.asarray(time / dt, dtype=environ.ditype())
1164
-
1165
- # register the delay
1166
1151
  key = _get_output_delay_key()
1167
1152
  if not module._has_after_update(key):
1168
- delay = Delay(
1169
- jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
1170
- time,
1171
- take_aware_unit=True
1172
- )
1153
+ delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1173
1154
  module._add_after_update(key, receive_update_output(delay))
1174
1155
  self.out_delay: Delay = module._get_after_update(key)
1175
- self.out_delay.register_delay(time)
1156
+ self.delay_info = self.out_delay.register_delay(*delay_time)
1176
1157
 
1177
1158
  def __call__(self, *args, **kwargs):
1178
- return self.out_delay.retrieve_at_step(self.step)
1159
+ return self.out_delay.retrieve_at_step(*self.delay_info)
1179
1160
 
1180
1161
 
1181
1162
  def _get_prefetch_delay_key(item) -> str:
@@ -1242,8 +1223,9 @@ def maybe_init_prefetch(target, *args, **kwargs):
1242
1223
  _get_prefetch_item_delay(target)
1243
1224
 
1244
1225
  elif isinstance(target, PrefetchDelayAt):
1245
- delay = _get_prefetch_item_delay(target)
1246
- delay.register_delay(target.time)
1226
+ pass
1227
+ # delay = _get_prefetch_item_delay(target)
1228
+ # delay.register_delay(*target.delay_time)
1247
1229
 
1248
1230
 
1249
1231
  class DynamicsGroup(Module):
@@ -19,9 +19,8 @@ from typing import Optional
19
19
 
20
20
  import brainunit as u
21
21
  import jax.numpy as jnp
22
- import jax.typing
23
22
 
24
- from brainstate import random, functional as F
23
+ from brainstate import functional as F
25
24
  from brainstate._state import ParamState
26
25
  from brainstate.typing import ArrayLike
27
26
  from ._module import ElementWiseBlock
@@ -15,11 +15,11 @@
15
15
 
16
16
  from typing import Union, Callable, Optional
17
17
 
18
+ import brainevent
18
19
  import brainunit as u
19
20
  import jax
20
21
 
21
22
  from brainstate import init
22
- import brainevent
23
23
  from brainstate._state import ParamState
24
24
  from brainstate.typing import Size, ArrayLike
25
25
  from ._module import Module
@@ -15,172 +15,9 @@
15
15
 
16
16
  import unittest
17
17
 
18
- import jax.numpy as jnp
19
-
20
18
  import brainstate
21
19
 
22
20
 
23
- class TestDelay(unittest.TestCase):
24
- def test_delay1(self):
25
- a = brainstate.State(brainstate.random.random(10, 20))
26
- delay = brainstate.nn.Delay(a.value)
27
- delay.register_entry('a', 1.)
28
- delay.register_entry('b', 2.)
29
- delay.register_entry('c', None)
30
-
31
- delay.init_state()
32
- with self.assertRaises(KeyError):
33
- delay.register_entry('c', 10.)
34
-
35
- def test_rotation_delay(self):
36
- rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
37
- t0 = 0.
38
- t1, n1 = 1., 10
39
- t2, n2 = 2., 20
40
-
41
- rotation_delay.register_entry('a', t0)
42
- rotation_delay.register_entry('b', t1)
43
- rotation_delay.register_entry('c2', 1.9)
44
- rotation_delay.register_entry('c', t2)
45
-
46
- rotation_delay.init_state()
47
-
48
- print()
49
- # print(rotation_delay)
50
- # print(rotation_delay.max_length)
51
-
52
- for i in range(100):
53
- brainstate.environ.set(i=i)
54
- rotation_delay.update(jnp.ones((1,)) * i)
55
- # print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
56
- self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
57
- self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
58
- self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
59
-
60
- def test_concat_delay(self):
61
- rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
62
- t0 = 0.
63
- t1, n1 = 1., 10
64
- t2, n2 = 2., 20
65
-
66
- rotation_delay.register_entry('a', t0)
67
- rotation_delay.register_entry('b', t1)
68
- rotation_delay.register_entry('c', t2)
69
-
70
- rotation_delay.init_state()
71
-
72
- print()
73
- for i in range(100):
74
- brainstate.environ.set(i=i)
75
- rotation_delay.update(jnp.ones((1,)) * i)
76
- print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
77
- self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
78
- self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
79
- self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
80
- # brainstate.util.clear_buffer_memory()
81
-
82
- def test_jit_erro(self):
83
- rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
84
- rotation_delay.init_state()
85
-
86
- with brainstate.environ.context(i=0, t=0, jit_error_check=True):
87
- rotation_delay.retrieve_at_time(-2.0)
88
- with self.assertRaises(Exception):
89
- rotation_delay.retrieve_at_time(-2.1)
90
- rotation_delay.retrieve_at_time(-2.01)
91
- with self.assertRaises(Exception):
92
- rotation_delay.retrieve_at_time(-2.09)
93
- with self.assertRaises(Exception):
94
- rotation_delay.retrieve_at_time(0.1)
95
- with self.assertRaises(Exception):
96
- rotation_delay.retrieve_at_time(0.01)
97
-
98
- def test_round_interp(self):
99
- for shape in [(1,), (1, 1), (1, 1, 1)]:
100
- for delay_method in ['rotation', 'concat']:
101
- rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
102
- interp_method='round')
103
- t0, n1 = 0.01, 0
104
- t1, n1 = 1.04, 10
105
- t2, n2 = 1.06, 11
106
- rotation_delay.init_state()
107
-
108
- @brainstate.compile.jit
109
- def retrieve(td, i):
110
- with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
111
- return rotation_delay.retrieve_at_time(td)
112
-
113
- print()
114
- for i in range(100):
115
- t = i * brainstate.environ.get_dt()
116
- with brainstate.environ.context(i=i, t=t):
117
- rotation_delay.update(jnp.ones(shape) * i)
118
- print(i,
119
- retrieve(t - t0, i),
120
- retrieve(t - t1, i),
121
- retrieve(t - t2, i))
122
- self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
123
- self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
124
- self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
125
-
126
- def test_linear_interp(self):
127
- for shape in [(1,), (1, 1), (1, 1, 1)]:
128
- for delay_method in ['rotation', 'concat']:
129
- print(shape, delay_method)
130
-
131
- rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
132
- interp_method='linear_interp')
133
- t0, n0 = 0.01, 0.1
134
- t1, n1 = 1.04, 10.4
135
- t2, n2 = 1.06, 10.6
136
- rotation_delay.init_state()
137
-
138
- @brainstate.compile.jit
139
- def retrieve(td, i):
140
- with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
141
- return rotation_delay.retrieve_at_time(td)
142
-
143
- print()
144
- for i in range(100):
145
- t = i * brainstate.environ.get_dt()
146
- with brainstate.environ.context(i=i, t=t):
147
- rotation_delay.update(jnp.ones(shape) * i)
148
- print(i,
149
- retrieve(t - t0, i),
150
- retrieve(t - t1, i),
151
- retrieve(t - t2, i))
152
- self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
153
- self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
154
- self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
155
-
156
- def test_rotation_and_concat_delay(self):
157
- rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
158
- concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
159
- t0 = 0.
160
- t1, n1 = 1., 10
161
- t2, n2 = 2., 20
162
-
163
- rotation_delay.register_entry('a', t0)
164
- rotation_delay.register_entry('b', t1)
165
- rotation_delay.register_entry('c', t2)
166
- concat_delay.register_entry('a', t0)
167
- concat_delay.register_entry('b', t1)
168
- concat_delay.register_entry('c', t2)
169
-
170
- rotation_delay.init_state()
171
- concat_delay.init_state()
172
-
173
- print()
174
- for i in range(100):
175
- brainstate.environ.set(i=i)
176
- new = jnp.ones((1,)) * i
177
- rotation_delay.update(new)
178
- concat_delay.update(new)
179
- self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
180
- self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
181
- self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
182
-
183
-
184
21
  class TestModule(unittest.TestCase):
185
22
  def test_states(self):
186
23
  class A(brainstate.nn.Module):
@@ -201,8 +38,3 @@ class TestModule(unittest.TestCase):
201
38
  print(b.states())
202
39
  print(b.states(level=0))
203
40
  print(b.states(level=0))
204
-
205
-
206
- if __name__ == '__main__':
207
- with brainstate.environ.context(dt=0.1):
208
- unittest.main()
brainstate/nn/_synapse.py CHANGED
@@ -23,7 +23,7 @@ import brainunit as u
23
23
  from brainstate import init, environ
24
24
  from brainstate._state import ShortTermState, HiddenState
25
25
  from brainstate.mixin import AlignPost
26
- from brainstate.typing import ArrayLike, Size, PyTree
26
+ from brainstate.typing import ArrayLike, Size
27
27
  from ._dynamics import Dynamics
28
28
  from ._exp_euler import exp_euler_step
29
29
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,4 +1,4 @@
1
- brainstate/__init__.py,sha256=l8qmh_ah9kUi2khglrj4O89ERt0tHdVT8WnVQu1lTcY,1496
1
+ brainstate/__init__.py,sha256=bM6vJtv8DWVcxOUQ0GPpkCzmRkdTREdTutkSVv0OziQ,1496
2
2
  brainstate/_compatible_import.py,sha256=LUSZlA0APWozxM8Kf9pZrM2YbwY7X3jzVVHZInaBL7Y,4630
3
3
  brainstate/_state.py,sha256=PLzoYx13jIgnzyBxnktLTVPCFl6seG__aNIHS9A5Nms,60770
4
4
  brainstate/_state_test.py,sha256=b6uvZdVRyC4n6-fYzmHNry1b-gJ6zE_kRSxGinqiHaw,1638
@@ -31,8 +31,8 @@ brainstate/compile/_loop_collect_return.py,sha256=qOYGoD2TMZd2Px6y241GWNwSjOLyaK
31
31
  brainstate/compile/_loop_collect_return_test.py,sha256=pEJdcOthEM17q5kXYhgR6JfzzqibijD_O1VPXVe_Ml4,1808
32
32
  brainstate/compile/_loop_no_collection.py,sha256=tPkSxee41VexWEALpN1uuT78BDrX3uT1FHv8WZtov4c,7549
33
33
  brainstate/compile/_loop_no_collection_test.py,sha256=ivavF59xep_g9bV1SSdXd5E1f6nhc7EUBfcbgHpxbfg,1419
34
- brainstate/compile/_make_jaxpr.py,sha256=adQwsEOxFZQk8yr0SbB3Tjh5j3bBvZCZSh7JVf6ymCA,37387
35
- brainstate/compile/_make_jaxpr_test.py,sha256=xPusEJikMQRfoOmUFrugGcE5UpRm0giHoL_NPomN5rI,4791
34
+ brainstate/compile/_make_jaxpr.py,sha256=BOJa3H1Fa0Esh3Al3CDaS351w45vk8xPDc9R_nk8jh8,37367
35
+ brainstate/compile/_make_jaxpr_test.py,sha256=_aRwYt932dXlcntFNYlm5asRRE97oGp4F8hrgNO63SY,5164
36
36
  brainstate/compile/_progress_bar.py,sha256=AhAyyI_ckzgaj0PSj1ep1hq8rGzQLSyC2aVCYaT-e-o,7502
37
37
  brainstate/compile/_unvmap.py,sha256=9S42MeTmFJa8nfBI_AjEfrAdUsDmM3KFat59O8zXIEw,4120
38
38
  brainstate/compile/_util.py,sha256=QD7lvS4Zb3P2HPNIm6mG8Rl_Lk2tQj0znOCPcH95XnI,6275
@@ -60,12 +60,13 @@ brainstate/nn/_collective_ops_test.py,sha256=bwq0DApcsk0_2xpxMl0_e2cGKT63g5rSngp
60
60
  brainstate/nn/_common.py,sha256=qHAOID_eeKiPUXk_ION65sbYXyF-ddH5w5BayvH8Thg,6431
61
61
  brainstate/nn/_conv.py,sha256=Zk-yj34n6CkjntcM9xpMGLTxKNfWdIWsTsoGbtdL0yU,18448
62
62
  brainstate/nn/_conv_test.py,sha256=2lcUTG7twkyhuyKwuBux-NgU8NU_W4Cp1-G8EyDJ_uk,8862
63
- brainstate/nn/_delay.py,sha256=l36FBgNhfL64tM3VGOsJNTtKr44HjxxtBWMFFCm3Pks,17361
63
+ brainstate/nn/_delay.py,sha256=_FFKQ09-NHLyYPcfuvoHzpL2zc9Twu_p3k33VDbgxKo,19523
64
+ brainstate/nn/_delay_test.py,sha256=O5eY-LTuzxLi6hjsC-lC12mhYe8xL36cWzvEwnG7h4g,8187
64
65
  brainstate/nn/_dropout.py,sha256=Dq3hQrOBT6gODlDbcoag6zLGXTU_p_2MhnmVsV57Hds,17783
65
66
  brainstate/nn/_dropout_test.py,sha256=L46PvC2OA7EnS4MsRhh_YnvKheHYNafOsKM8uzux_zo,4446
66
- brainstate/nn/_dynamics.py,sha256=PTTGDqe-uRvrTZbhNGy1O4hT_wG4vxbOXyfbOx2itaw,48341
67
+ brainstate/nn/_dynamics.py,sha256=QVCQqdglR1p-vMFzXvvEWfK1ltg4PJhUxSEEbbFP95s,48236
67
68
  brainstate/nn/_dynamics_test.py,sha256=w7AV57LdhbBNYprdFpKq8MFSCbXKVkGgp_NbL3ANX3I,2769
68
- brainstate/nn/_elementwise.py,sha256=l3Wk65UDLFc3xFH1jwTXRwKAVmt3dUeG0iF_Al0w1_0,33487
69
+ brainstate/nn/_elementwise.py,sha256=_9O4S-gBZiTmsx5lJ3WdQVLcACBmXTwPTY8pnUoK7KY,33460
69
70
  brainstate/nn/_elementwise_test.py,sha256=_dd9eX2ZJ7p24ahuoapCaRTZ0g1boufXMyqHFx1d4WY,5688
70
71
  brainstate/nn/_embedding.py,sha256=SaAJbgXmuJ8XlCOX9ob4yvmgh9Fk627wMguRzJMJ1H8,2138
71
72
  brainstate/nn/_exp_euler.py,sha256=WTpZm-XQmsdMLNazY7wIu8eeO6pK0kRzt2lJnhEgMIk,3293
@@ -74,12 +75,12 @@ brainstate/nn/_fixedprob.py,sha256=KGXohiU0wZnFIQDuwiRUTFsbsr8R0p8zgi5UZDuv1Bk,1
74
75
  brainstate/nn/_fixedprob_test.py,sha256=qbRBh-MpMtEOsg492gFu2w9-FOP9z_bXapm-Q0gLLYM,3929
75
76
  brainstate/nn/_inputs.py,sha256=hMPkx9qDBpJWPshZXLF4H1QiYK1-46wntHUIlG7cT7c,20603
76
77
  brainstate/nn/_linear.py,sha256=FnPxATdT66DecjTW0tUTyL6clQmN_cG8kPC3qDWUE6A,14500
77
- brainstate/nn/_linear_mv.py,sha256=6hDXx4yPqRSa7uIsW9f9eJuy23dcXN9Mp2_lSvw8BDA,2635
78
+ brainstate/nn/_linear_mv.py,sha256=DyJ0OOHcvuJrBvBhV0CbmVgw-6BCg_ms-Rso9eB5e1E,2635
78
79
  brainstate/nn/_linear_mv_test.py,sha256=ZCM1Zy6mImQfCfdZOGnTwkiLLPXK5yalv1Ts9sWZuPA,3864
79
80
  brainstate/nn/_linear_test.py,sha256=eIS-VCR3QmXB_byO1Uexg65Pv48CBRUA_Je-UGrFVTY,2925
80
81
  brainstate/nn/_ltp.py,sha256=_najNUyfaFYcOlUTm7ThJopInbos3kwJyrm-QUfI-hc,861
81
82
  brainstate/nn/_module.py,sha256=jlFaoltT2F8a7cxsyu6fXp7mfIkEsk-JD3G1Xh3Ay8I,12783
82
- brainstate/nn/_module_test.py,sha256=sjT7t-N4ZZKPN_MujNv8bcT5uZQD-CmU5nKz9KwlBUc,8963
83
+ brainstate/nn/_module_test.py,sha256=v4ZHwjvrEaqHzCgyR8MKIB-iNhdtaCLOc55EORXTjPg,1451
83
84
  brainstate/nn/_neuron.py,sha256=2walTScvL034LS53pArDASXz6z26SSPbmCvchWWjkUU,27441
84
85
  brainstate/nn/_neuron_test.py,sha256=QF8pixUqA5Oj7MrNi2NR8VAnfGpAvNpwV2mBc3e_pTY,6393
85
86
  brainstate/nn/_normalizations.py,sha256=YSC1W7JaexoZ8tKcy5B-dm-_x8GuPzS_6XBfkaKpdXM,37464
@@ -92,7 +93,7 @@ brainstate/nn/_rate_rnns_test.py,sha256=__hhx7e6LX_1mDLLQyIi4TNCaFAWnOVSTIgwHNjz
92
93
  brainstate/nn/_readout.py,sha256=OJjSba5Wr7dtUXqYhAv1D7BUGOI-lAmg6urxPBrZe3c,7116
93
94
  brainstate/nn/_readout_test.py,sha256=L2T0-SkiACxkY_I5Pbnbmy0Zw3tbpV3l5xVzAw42f2g,2136
94
95
  brainstate/nn/_stp.py,sha256=-ahDEqSp8bQsU_nUK4jks8fjMYKgIbO0v7zpyGVuXtA,8645
95
- brainstate/nn/_synapse.py,sha256=E43DxH0PW3pa6rKn9UWkD_PxIdpMAo04Yy41gaR3KPc,19868
96
+ brainstate/nn/_synapse.py,sha256=I7npO9rg6bvzCiCNLiTqEOxUjGaDVlkOgmkT9M1V4xY,19860
96
97
  brainstate/nn/_synapse_test.py,sha256=xmCWFxZUIM2YtmW5otKnADGCCK__4JpXmSYcZ3wzlQM,4994
97
98
  brainstate/nn/_synaptic_projection.py,sha256=UFgzsMB1VZ9ieumwmaTIC7irLrZ4pmwfiuOYomkwXG4,17917
98
99
  brainstate/nn/_synouts.py,sha256=jWQP1-qXFpdYgyUSJNFD7_bk4_-67ok36br-OzbcSXY,4524
@@ -124,8 +125,8 @@ brainstate/util/pretty_repr.py,sha256=7Xp7IFNUeP7cGlpvwwJyBslbQVnXEqC1I6neV1Jx1S
124
125
  brainstate/util/pretty_table.py,sha256=uJVaamFGQ4nKP8TkEGPWXHpzjMecDo2q1Ah6XtRjdPY,108117
125
126
  brainstate/util/scaling.py,sha256=U6DM-afPrLejiGqo1Nla7z4YbTBVicctsBEweurr_mk,7524
126
127
  brainstate/util/struct.py,sha256=2Y_wuDFQ6ldl_H4_w0IjzAtkbHooVgdsVbnT7Z6_Efc,17528
127
- brainstate-0.1.7.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
128
- brainstate-0.1.7.dist-info/METADATA,sha256=PrV7tP9tHHf5d8ZOUoXPa9t5Sm0WBo41thLQVjfgHuA,4122
129
- brainstate-0.1.7.dist-info/WHEEL,sha256=AHX6tWk3qWuce7vKLrj7lnulVHEdWoltgauo8bgCXgU,109
130
- brainstate-0.1.7.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
131
- brainstate-0.1.7.dist-info/RECORD,,
128
+ brainstate-0.1.8.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
129
+ brainstate-0.1.8.dist-info/METADATA,sha256=85Nudry-TozOxctIhW8zx7DUgOxPVmML3XCq7YqYDFs,4122
130
+ brainstate-0.1.8.dist-info/WHEEL,sha256=AHX6tWk3qWuce7vKLrj7lnulVHEdWoltgauo8bgCXgU,109
131
+ brainstate-0.1.8.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
132
+ brainstate-0.1.8.dist-info/RECORD,,