brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -24,8 +24,8 @@ import numpy as np
|
|
24
24
|
|
25
25
|
from brainstate import environ
|
26
26
|
from brainstate._state import ShortTermState, State
|
27
|
-
from brainstate.compile import jit_error_if
|
28
27
|
from brainstate.graph import Node
|
28
|
+
from brainstate.transform import jit_error_if
|
29
29
|
from brainstate.typing import ArrayLike, PyTree
|
30
30
|
from ._collective_ops import call_order
|
31
31
|
from ._module import Module
|
@@ -59,8 +59,8 @@ class DelayAccess(Node):
|
|
59
59
|
|
60
60
|
Args:
|
61
61
|
delay: The delay instance.
|
62
|
-
time: The delay time.
|
63
|
-
|
62
|
+
*time: The delay time.
|
63
|
+
entry: The delay entry.
|
64
64
|
"""
|
65
65
|
|
66
66
|
__module__ = 'brainstate.nn'
|
@@ -68,22 +68,22 @@ class DelayAccess(Node):
|
|
68
68
|
def __init__(
|
69
69
|
self,
|
70
70
|
delay: 'Delay',
|
71
|
-
time
|
72
|
-
|
71
|
+
*time,
|
72
|
+
entry: str,
|
73
73
|
):
|
74
74
|
super().__init__()
|
75
|
-
self.
|
75
|
+
self.delay = delay
|
76
76
|
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
77
|
-
self._delay_entry =
|
78
|
-
self.delay_info = delay.register_entry(self._delay_entry, time)
|
77
|
+
self._delay_entry = entry
|
78
|
+
self.delay_info = delay.register_entry(self._delay_entry, *time)
|
79
79
|
|
80
80
|
def update(self):
|
81
|
-
return self.
|
81
|
+
return self.delay.at(self._delay_entry)
|
82
82
|
|
83
83
|
|
84
84
|
class Delay(Module):
|
85
85
|
"""
|
86
|
-
|
86
|
+
Delay variable for storing short-term history data.
|
87
87
|
|
88
88
|
The data in this delay variable is arranged as::
|
89
89
|
|
@@ -240,21 +240,8 @@ class Delay(Module):
|
|
240
240
|
>>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
|
241
241
|
"""
|
242
242
|
assert len(delay_time) >= 1, 'You should provide at least one delay time.'
|
243
|
-
delay_size = u.math.size(delay_time[0])
|
244
243
|
for dt in delay_time[1:]:
|
245
244
|
assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
|
246
|
-
# delay_size = u.math.size(delay_time[0])
|
247
|
-
# for dt in delay_time:
|
248
|
-
# if u.math.ndim(dt) == 0:
|
249
|
-
# pass
|
250
|
-
# elif u.math.ndim(dt) == 1:
|
251
|
-
# if u.math.size(dt) != delay_size:
|
252
|
-
# raise ValueError(
|
253
|
-
# f'The delay time should be a scalar or a vector with the same size. '
|
254
|
-
# f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
|
255
|
-
# )
|
256
|
-
# else:
|
257
|
-
# raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
|
258
245
|
if delay_time[0] is None:
|
259
246
|
return None
|
260
247
|
with jax.ensure_compile_time_eval():
|
@@ -287,7 +274,7 @@ class Delay(Module):
|
|
287
274
|
self._registered_entries[entry] = delay_info
|
288
275
|
return delay_info
|
289
276
|
|
290
|
-
def access(self, entry: str, delay_time
|
277
|
+
def access(self, entry: str, *delay_time) -> DelayAccess:
|
291
278
|
"""
|
292
279
|
Create a DelayAccess object for a specific delay entry and delay time.
|
293
280
|
|
@@ -298,7 +285,7 @@ class Delay(Module):
|
|
298
285
|
Returns:
|
299
286
|
DelayAccess: An object that provides access to the delay data for the specified entry and time.
|
300
287
|
"""
|
301
|
-
return DelayAccess(self, delay_time,
|
288
|
+
return DelayAccess(self, delay_time, entry=entry)
|
302
289
|
|
303
290
|
def at(self, entry: str) -> ArrayLike:
|
304
291
|
"""
|
@@ -532,8 +519,8 @@ class StateWithDelay(Delay):
|
|
532
519
|
Access a neuron's membrane potential 5 ms in the past:
|
533
520
|
|
534
521
|
>>> import brainunit as u
|
535
|
-
>>> import brainstate as
|
536
|
-
>>> lif =
|
522
|
+
>>> import brainstate as brainstate
|
523
|
+
>>> lif = brainstate.nn.LIF(100)
|
537
524
|
>>> # Create a delayed accessor to V(t-5ms)
|
538
525
|
>>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
|
539
526
|
>>> # Inside another module's update you can read the delayed value
|
brainstate/nn/_delay_test.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2025
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,10 +20,14 @@ import jax.numpy as jnp
|
|
20
20
|
|
21
21
|
import brainstate
|
22
22
|
|
23
|
-
brainstate.environ.set(dt=0.1)
|
24
|
-
|
25
23
|
|
26
24
|
class TestDelay(unittest.TestCase):
|
25
|
+
def setUp(self):
|
26
|
+
brainstate.environ.set(dt=0.1)
|
27
|
+
|
28
|
+
def tearDown(self):
|
29
|
+
brainstate.environ.pop('dt')
|
30
|
+
|
27
31
|
def test_delay1(self):
|
28
32
|
a = brainstate.State(brainstate.random.random(10, 20))
|
29
33
|
delay = brainstate.nn.Delay(a.value)
|
@@ -61,26 +65,27 @@ class TestDelay(unittest.TestCase):
|
|
61
65
|
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
66
|
|
63
67
|
def test_concat_delay(self):
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
+
with brainstate.environ.context(dt=0.1) as env:
|
69
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
70
|
+
t0 = 0.
|
71
|
+
t1, n1 = 1., 10
|
72
|
+
t2, n2 = 2., 20
|
68
73
|
|
69
|
-
|
70
|
-
|
71
|
-
|
74
|
+
rotation_delay.register_entry('a', t0)
|
75
|
+
rotation_delay.register_entry('b', t1)
|
76
|
+
rotation_delay.register_entry('c', t2)
|
72
77
|
|
73
|
-
|
78
|
+
rotation_delay.init_state()
|
74
79
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
80
|
+
print()
|
81
|
+
for i in range(100):
|
82
|
+
brainstate.environ.set(i=i)
|
83
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
84
|
+
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
85
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
86
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
87
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
88
|
+
# brainstate.util.clear_buffer_memory()
|
84
89
|
|
85
90
|
def test_jit_erro(self):
|
86
91
|
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|