brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -24,8 +24,8 @@ import numpy as np
|
|
24
24
|
|
25
25
|
from brainstate import environ
|
26
26
|
from brainstate._state import ShortTermState, State
|
27
|
-
from brainstate.compile import jit_error_if
|
28
27
|
from brainstate.graph import Node
|
28
|
+
from brainstate.transform import jit_error_if
|
29
29
|
from brainstate.typing import ArrayLike, PyTree
|
30
30
|
from ._collective_ops import call_order
|
31
31
|
from ._module import Module
|
@@ -59,8 +59,8 @@ class DelayAccess(Node):
|
|
59
59
|
|
60
60
|
Args:
|
61
61
|
delay: The delay instance.
|
62
|
-
time: The delay time.
|
63
|
-
|
62
|
+
*time: The delay time.
|
63
|
+
entry: The delay entry.
|
64
64
|
"""
|
65
65
|
|
66
66
|
__module__ = 'brainstate.nn'
|
@@ -68,22 +68,22 @@ class DelayAccess(Node):
|
|
68
68
|
def __init__(
|
69
69
|
self,
|
70
70
|
delay: 'Delay',
|
71
|
-
time
|
72
|
-
|
71
|
+
*time,
|
72
|
+
entry: str,
|
73
73
|
):
|
74
74
|
super().__init__()
|
75
|
-
self.
|
75
|
+
self.delay = delay
|
76
76
|
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
77
|
-
self._delay_entry =
|
78
|
-
self.delay_info = delay.register_entry(self._delay_entry, time)
|
77
|
+
self._delay_entry = entry
|
78
|
+
self.delay_info = delay.register_entry(self._delay_entry, *time)
|
79
79
|
|
80
80
|
def update(self):
|
81
|
-
return self.
|
81
|
+
return self.delay.at(self._delay_entry)
|
82
82
|
|
83
83
|
|
84
84
|
class Delay(Module):
|
85
85
|
"""
|
86
|
-
|
86
|
+
Delay variable for storing short-term history data.
|
87
87
|
|
88
88
|
The data in this delay variable is arranged as::
|
89
89
|
|
@@ -240,21 +240,8 @@ class Delay(Module):
|
|
240
240
|
>>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
|
241
241
|
"""
|
242
242
|
assert len(delay_time) >= 1, 'You should provide at least one delay time.'
|
243
|
-
delay_size = u.math.size(delay_time[0])
|
244
243
|
for dt in delay_time[1:]:
|
245
244
|
assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
|
246
|
-
# delay_size = u.math.size(delay_time[0])
|
247
|
-
# for dt in delay_time:
|
248
|
-
# if u.math.ndim(dt) == 0:
|
249
|
-
# pass
|
250
|
-
# elif u.math.ndim(dt) == 1:
|
251
|
-
# if u.math.size(dt) != delay_size:
|
252
|
-
# raise ValueError(
|
253
|
-
# f'The delay time should be a scalar or a vector with the same size. '
|
254
|
-
# f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
|
255
|
-
# )
|
256
|
-
# else:
|
257
|
-
# raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
|
258
245
|
if delay_time[0] is None:
|
259
246
|
return None
|
260
247
|
with jax.ensure_compile_time_eval():
|
@@ -287,7 +274,7 @@ class Delay(Module):
|
|
287
274
|
self._registered_entries[entry] = delay_info
|
288
275
|
return delay_info
|
289
276
|
|
290
|
-
def access(self, entry: str, delay_time
|
277
|
+
def access(self, entry: str, *delay_time) -> DelayAccess:
|
291
278
|
"""
|
292
279
|
Create a DelayAccess object for a specific delay entry and delay time.
|
293
280
|
|
@@ -298,7 +285,7 @@ class Delay(Module):
|
|
298
285
|
Returns:
|
299
286
|
DelayAccess: An object that provides access to the delay data for the specified entry and time.
|
300
287
|
"""
|
301
|
-
return DelayAccess(self, delay_time,
|
288
|
+
return DelayAccess(self, delay_time, entry=entry)
|
302
289
|
|
303
290
|
def at(self, entry: str) -> ArrayLike:
|
304
291
|
"""
|
@@ -472,15 +459,94 @@ class Delay(Module):
|
|
472
459
|
|
473
460
|
class StateWithDelay(Delay):
|
474
461
|
"""
|
475
|
-
|
462
|
+
Delayed history buffer bound to a module state.
|
463
|
+
|
464
|
+
StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
|
465
|
+
concrete :py:class:`~brainstate._state.State` living on a target module
|
466
|
+
(for example a membrane potential ``V`` on a neuron). It automatically
|
467
|
+
maintains a rolling history of that state and exposes convenient helpers to
|
468
|
+
retrieve the value at a given delay either by step or by time.
|
469
|
+
|
470
|
+
In normal usage you rarely instantiate this class directly. It is created
|
471
|
+
implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
|
472
|
+
|
473
|
+
- ``module.prefetch('V').delay.at(5.0 * u.ms)``
|
474
|
+
- ``module.prefetch_delay('V', 5.0 * u.ms)``
|
475
|
+
|
476
|
+
Both will construct a StateWithDelay bound to ``module.V`` under the hood
|
477
|
+
and register the requested delay, so you can retrieve the delayed value
|
478
|
+
inside your update rules.
|
479
|
+
|
480
|
+
Parameters
|
481
|
+
----------
|
482
|
+
target : :py:class:`~brainstate.graph.Node`
|
483
|
+
The module object that owns the state to track.
|
484
|
+
item : str
|
485
|
+
The attribute name of the target state on ``target`` (must be a
|
486
|
+
:py:class:`~brainstate._state.State`).
|
487
|
+
init : Callable, optional
|
488
|
+
Optional initializer used to fill the history buffer before ``t0``
|
489
|
+
when delays request values from the past that hasn't been simulated yet.
|
490
|
+
The callable receives ``(shape, dtype)`` and must return an array.
|
491
|
+
If not provided, zeros are used. You may also pass a scalar/array
|
492
|
+
literal via the underlying Delay API when constructing manually.
|
493
|
+
delay_method : {"rotation", "concat"}, default "rotation"
|
494
|
+
Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
|
495
|
+
"rotation" keeps a ring buffer; "concat" shifts by concatenation.
|
496
|
+
|
497
|
+
Attributes
|
498
|
+
----------
|
499
|
+
state : :py:class:`~brainstate._state.State`
|
500
|
+
The concrete state object being tracked.
|
501
|
+
history : :py:class:`~brainstate._state.ShortTermState`
|
502
|
+
Rolling time axis buffer with shape ``[length, *state.shape]``.
|
503
|
+
max_time : float
|
504
|
+
Maximum time span currently supported by the buffer.
|
505
|
+
max_length : int
|
506
|
+
Buffer length in steps (``ceil(max_time/dt)+1``).
|
507
|
+
|
508
|
+
Notes
|
509
|
+
-----
|
510
|
+
- This class inherits all retrieval utilities from :py:class:`~.Delay`:
|
511
|
+
use :py:meth:`retrieve_at_step` when you know the integer delay steps,
|
512
|
+
or :py:meth:`retrieve_at_time` for continuous-time queries with optional
|
513
|
+
linear/round interpolation.
|
514
|
+
- It is registered as an "after-update" hook on the owning Dynamics so the
|
515
|
+
buffer is updated automatically after each simulation step.
|
516
|
+
|
517
|
+
Examples
|
518
|
+
--------
|
519
|
+
Access a neuron's membrane potential 5 ms in the past:
|
520
|
+
|
521
|
+
>>> import brainunit as u
|
522
|
+
>>> import brainstate as brainstate
|
523
|
+
>>> lif = brainstate.nn.LIF(100)
|
524
|
+
>>> # Create a delayed accessor to V(t-5ms)
|
525
|
+
>>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
|
526
|
+
>>> # Inside another module's update you can read the delayed value
|
527
|
+
>>> v_t_minus_5ms = v_delay()
|
528
|
+
|
529
|
+
Register multiple delay taps and index-specific delays:
|
530
|
+
|
531
|
+
>>> # Under the hood, a StateWithDelay is created and you can register
|
532
|
+
>>> # additional taps (in steps or time) via its Delay interface
|
533
|
+
>>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
|
534
|
+
>>> # Direct access to buffer by steps (advanced)
|
535
|
+
>>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
|
476
536
|
"""
|
477
537
|
|
478
538
|
__module__ = 'brainstate.nn'
|
479
539
|
|
480
540
|
state: State # state
|
481
541
|
|
482
|
-
def __init__(
|
483
|
-
|
542
|
+
def __init__(
|
543
|
+
self,
|
544
|
+
target: Node,
|
545
|
+
item: str,
|
546
|
+
init: Callable = None,
|
547
|
+
delay_method: Optional[str] = _DELAY_ROTATE,
|
548
|
+
):
|
549
|
+
super().__init__(None, init=init, delay_method=delay_method)
|
484
550
|
|
485
551
|
self._target = target
|
486
552
|
self._target_term = item
|
brainstate/nn/_delay_test.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2025
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,10 +20,14 @@ import jax.numpy as jnp
|
|
20
20
|
|
21
21
|
import brainstate
|
22
22
|
|
23
|
-
brainstate.environ.set(dt=0.1)
|
24
|
-
|
25
23
|
|
26
24
|
class TestDelay(unittest.TestCase):
|
25
|
+
def setUp(self):
|
26
|
+
brainstate.environ.set(dt=0.1)
|
27
|
+
|
28
|
+
def tearDown(self):
|
29
|
+
brainstate.environ.pop('dt')
|
30
|
+
|
27
31
|
def test_delay1(self):
|
28
32
|
a = brainstate.State(brainstate.random.random(10, 20))
|
29
33
|
delay = brainstate.nn.Delay(a.value)
|
@@ -61,26 +65,27 @@ class TestDelay(unittest.TestCase):
|
|
61
65
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
66
|
|
63
67
|
def test_concat_delay(self):
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
+
with brainstate.environ.context(dt=0.1) as env:
|
69
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
70
|
+
t0 = 0.
|
71
|
+
t1, n1 = 1., 10
|
72
|
+
t2, n2 = 2., 20
|
68
73
|
|
69
|
-
|
70
|
-
|
71
|
-
|
74
|
+
rotation_delay.register_entry('a', t0)
|
75
|
+
rotation_delay.register_entry('b', t1)
|
76
|
+
rotation_delay.register_entry('c', t2)
|
72
77
|
|
73
|
-
|
78
|
+
rotation_delay.init_state()
|
74
79
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
80
|
+
print()
|
81
|
+
for i in range(100):
|
82
|
+
brainstate.environ.set(i=i)
|
83
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
84
|
+
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
85
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
86
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
87
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
88
|
+
# brainstate.util.clear_buffer_memory()
|
84
89
|
|
85
90
|
def test_jit_erro(self):
|
86
91
|
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|