brainstate 0.1.0.post20241129__py2.py3-none-any.whl → 0.1.0.post20241210__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +15 -1
- brainstate/augment/_mapping_test.py +84 -0
- brainstate/compile/_conditions.py +5 -7
- brainstate/compile/_jit.py +3 -3
- brainstate/compile/_loop_collect_return.py +10 -7
- brainstate/compile/_loop_no_collection.py +4 -5
- brainstate/compile/_make_jaxpr.py +30 -25
- brainstate/compile/_progress_bar.py +20 -1
- brainstate/functional/_activations.py +4 -12
- brainstate/graph/_graph_operation.py +4 -1
- brainstate/nn/_collective_ops.py +18 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -1
- brainstate/nn/_elementwise/_dropout.py +27 -18
- brainstate/nn/_interaction/_normalizations.py +35 -25
- brainstate/util/_tracers.py +0 -7
- {brainstate-0.1.0.post20241129.dist-info → brainstate-0.1.0.post20241210.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20241129.dist-info → brainstate-0.1.0.post20241210.dist-info}/RECORD +20 -22
- {brainstate-0.1.0.post20241129.dist-info → brainstate-0.1.0.post20241210.dist-info}/top_level.txt +0 -1
- benchmark/COBA_2005.py +0 -125
- benchmark/CUBA_2005.py +0 -149
- {brainstate-0.1.0.post20241129.dist-info → brainstate-0.1.0.post20241210.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241129.dist-info → brainstate-0.1.0.post20241210.dist-info}/WHEEL +0 -0
brainstate/_state.py
CHANGED
@@ -99,6 +99,9 @@ def catch_new_states(tag: str = None) -> List:
|
|
99
99
|
|
100
100
|
|
101
101
|
class Catcher:
|
102
|
+
"""
|
103
|
+
The catcher to catch the new states.
|
104
|
+
"""
|
102
105
|
def __init__(self, tag: str):
|
103
106
|
self.tag = tag
|
104
107
|
self.state_ids = set()
|
@@ -231,6 +234,7 @@ class State(Generic[A], PrettyRepr):
|
|
231
234
|
# avoid using self._setattr to avoid the check
|
232
235
|
vars(self).update(metadata)
|
233
236
|
|
237
|
+
# record the state initialization
|
234
238
|
record_state_init(self)
|
235
239
|
|
236
240
|
if not TYPE_CHECKING:
|
@@ -290,7 +294,6 @@ class State(Generic[A], PrettyRepr):
|
|
290
294
|
v: The value.
|
291
295
|
"""
|
292
296
|
self.write_value(v)
|
293
|
-
self._been_writen = True
|
294
297
|
|
295
298
|
def write_value(self, v) -> None:
|
296
299
|
# value checking
|
@@ -301,6 +304,8 @@ class State(Generic[A], PrettyRepr):
|
|
301
304
|
record_state_value_write(self)
|
302
305
|
# set the value
|
303
306
|
self._value = v
|
307
|
+
# set flag
|
308
|
+
self._been_writen = True
|
304
309
|
|
305
310
|
def restore_value(self, v) -> None:
|
306
311
|
"""
|
@@ -511,6 +516,15 @@ class LongTermState(State):
|
|
511
516
|
__module__ = 'brainstate'
|
512
517
|
|
513
518
|
|
519
|
+
class BatchState(LongTermState):
|
520
|
+
"""
|
521
|
+
The batch state, which is used to store the batch data in the program.
|
522
|
+
"""
|
523
|
+
|
524
|
+
__module__ = 'brainstate'
|
525
|
+
|
526
|
+
|
527
|
+
|
514
528
|
class HiddenState(ShortTermState):
|
515
529
|
"""
|
516
530
|
The hidden state, which is used to store the hidden data in a dynamic model.
|
@@ -204,7 +204,91 @@ class TestVmap(unittest.TestCase):
|
|
204
204
|
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
|
205
205
|
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
206
206
|
|
207
|
+
write_state_ids = [id(st) for st in trace.get_write_states()]
|
208
|
+
read_state_ids = [id(st) for st in trace.get_read_states()]
|
209
|
+
|
210
|
+
assert id(foo.a) in read_state_ids
|
211
|
+
assert id(foo.b) in write_state_ids
|
212
|
+
|
213
|
+
print(trace.get_write_states())
|
214
|
+
print(trace.get_read_states())
|
215
|
+
|
216
|
+
|
217
|
+
|
218
|
+
def test_vmap_jit(self):
|
219
|
+
class Foo(bst.nn.Module):
|
220
|
+
def __init__(self):
|
221
|
+
super().__init__()
|
222
|
+
self.a = bst.ParamState(jnp.arange(4))
|
223
|
+
self.b = bst.ShortTermState(jnp.arange(4))
|
224
|
+
|
225
|
+
def __call__(self):
|
226
|
+
self.b.value = self.a.value * self.b.value
|
227
|
+
|
228
|
+
@bst.augment.vmap
|
229
|
+
def mul(foo):
|
230
|
+
foo()
|
231
|
+
|
232
|
+
@bst.compile.jit
|
233
|
+
def mul_jit(inp):
|
234
|
+
mul(foo)
|
235
|
+
foo.a.value += inp
|
236
|
+
|
237
|
+
foo = Foo()
|
238
|
+
with bst.StateTraceStack() as trace:
|
239
|
+
mul_jit(1.)
|
240
|
+
|
241
|
+
print(foo.a.value)
|
242
|
+
print(foo.b.value)
|
243
|
+
self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
244
|
+
self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
245
|
+
|
246
|
+
write_state_ids = [id(st) for st in trace.get_write_states()]
|
247
|
+
read_state_ids = [id(st) for st in trace.get_read_states()]
|
248
|
+
|
249
|
+
assert id(foo.a) in write_state_ids
|
250
|
+
assert id(foo.b) in write_state_ids
|
251
|
+
|
207
252
|
print(trace.get_write_states())
|
208
253
|
print(trace.get_read_states())
|
209
254
|
|
210
255
|
|
256
|
+
def test_vmap_grad(self):
|
257
|
+
class Foo(bst.nn.Module):
|
258
|
+
def __init__(self):
|
259
|
+
super().__init__()
|
260
|
+
self.a = bst.ParamState(jnp.arange(4.))
|
261
|
+
self.b = bst.ShortTermState(jnp.arange(4.))
|
262
|
+
|
263
|
+
def __call__(self):
|
264
|
+
self.b.value = self.a.value * self.b.value
|
265
|
+
|
266
|
+
@bst.augment.vmap
|
267
|
+
def mul(foo):
|
268
|
+
foo()
|
269
|
+
|
270
|
+
def loss():
|
271
|
+
mul(foo)
|
272
|
+
return jnp.sum(foo.b.value)
|
273
|
+
|
274
|
+
foo = Foo()
|
275
|
+
with bst.StateTraceStack() as trace:
|
276
|
+
grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
|
277
|
+
print(grads)
|
278
|
+
print(loss)
|
279
|
+
|
280
|
+
# print(foo.a.value)
|
281
|
+
# print(foo.b.value)
|
282
|
+
# self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
|
283
|
+
# self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
|
284
|
+
#
|
285
|
+
# write_state_ids = [id(st) for st in trace.get_write_states()]
|
286
|
+
# read_state_ids = [id(st) for st in trace.get_read_states()]
|
287
|
+
#
|
288
|
+
# assert id(foo.a) in write_state_ids
|
289
|
+
# assert id(foo.b) in write_state_ids
|
290
|
+
#
|
291
|
+
# print(trace.get_write_states())
|
292
|
+
# print(trace.get_read_states())
|
293
|
+
|
294
|
+
|
@@ -94,9 +94,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
|
94
94
|
return false_fun(*operands)
|
95
95
|
|
96
96
|
# evaluate jaxpr
|
97
|
-
|
98
|
-
|
99
|
-
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
97
|
+
stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
|
98
|
+
stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
|
100
99
|
|
101
100
|
# state trace and state values
|
102
101
|
state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
|
@@ -175,10 +174,9 @@ def switch(index, branches: Sequence[Callable], *operands):
|
|
175
174
|
return branches[int(index)](*operands)
|
176
175
|
|
177
176
|
# evaluate jaxpr
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
wrapped_branch.make_jaxpr(*operands)
|
177
|
+
wrapped_branches = [StatefulFunction(branch) for branch in branches]
|
178
|
+
for wrapped_branch in wrapped_branches:
|
179
|
+
wrapped_branch.make_jaxpr(*operands)
|
182
180
|
|
183
181
|
# wrap the functions
|
184
182
|
state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
|
brainstate/compile/_jit.py
CHANGED
@@ -83,9 +83,9 @@ def _get_jitted_fun(
|
|
83
83
|
return fun.fun(*args, **params)
|
84
84
|
|
85
85
|
# compile the function and get the state trace
|
86
|
-
|
87
|
-
|
88
|
-
|
86
|
+
state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
|
87
|
+
read_state_vals = state_trace.get_read_state_values(True)
|
88
|
+
|
89
89
|
# call the jitted function
|
90
90
|
write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
|
91
91
|
# write the state values back to the states
|
@@ -202,16 +202,19 @@ def scan(
|
|
202
202
|
# ------------------------------ #
|
203
203
|
xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
|
204
204
|
x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
205
|
+
stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
|
206
|
+
state_trace = stateful_fun.get_state_trace()
|
207
|
+
all_writen_state_vals = state_trace.get_write_state_values(True)
|
208
|
+
all_read_state_vals = state_trace.get_read_state_values(True)
|
209
|
+
wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
|
211
210
|
|
212
211
|
# scan
|
213
212
|
init = (all_writen_state_vals, init)
|
214
|
-
(all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
|
213
|
+
(all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
|
214
|
+
init,
|
215
|
+
xs,
|
216
|
+
length=length,
|
217
|
+
reverse=reverse,
|
215
218
|
unroll=unroll)
|
216
219
|
# assign the written state values and restore the read state values
|
217
220
|
write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
|
@@ -103,11 +103,10 @@ def while_loop(
|
|
103
103
|
pass
|
104
104
|
|
105
105
|
# evaluate jaxpr
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
raise ValueError("while_loop: cond_fun should not have any write states.")
|
106
|
+
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
107
|
+
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
108
|
+
if len(stateful_cond.get_write_states()) != 0:
|
109
|
+
raise ValueError("while_loop: cond_fun should not have any write states.")
|
111
110
|
|
112
111
|
# state trace and state values
|
113
112
|
state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
|
@@ -72,7 +72,6 @@ from jax.util import wraps
|
|
72
72
|
from brainstate._state import State, StateTraceStack
|
73
73
|
from brainstate._utils import set_module_as
|
74
74
|
from brainstate.typing import PyTree
|
75
|
-
from brainstate.util._tracers import new_jax_trace
|
76
75
|
|
77
76
|
AxisName = Hashable
|
78
77
|
|
@@ -112,28 +111,27 @@ def _new_arg_fn(frame, trace, aval):
|
|
112
111
|
return tracer
|
113
112
|
|
114
113
|
|
115
|
-
def
|
116
|
-
|
117
|
-
frame
|
114
|
+
def _new_jax_trace():
|
115
|
+
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
116
|
+
frame = main.jaxpr_stack[-1]
|
117
|
+
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
118
|
+
return frame, trace
|
119
|
+
|
120
|
+
|
121
|
+
def _init_state_trace_stack() -> StateTraceStack:
|
118
122
|
state_trace: StateTraceStack = StateTraceStack()
|
119
|
-
# Set the function to transform the new argument to a tracer
|
120
|
-
state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
|
121
|
-
return state_trace
|
122
123
|
|
124
|
+
if jax.__version_info__ < (0, 4, 36):
|
125
|
+
# Should be within the calling of ``jax.make_jaxpr()``
|
126
|
+
frame, trace = _new_jax_trace()
|
127
|
+
# Set the function to transform the new argument to a tracer
|
128
|
+
state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
|
129
|
+
return state_trace
|
123
130
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
# Args:
|
129
|
-
# x: The input.
|
130
|
-
#
|
131
|
-
# Returns:
|
132
|
-
# The abstractified input.
|
133
|
-
# """
|
134
|
-
# if isinstance(x, pe.DynamicJaxprTracer):
|
135
|
-
# return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
|
136
|
-
# return shaped_abstractify(x)
|
131
|
+
else:
|
132
|
+
trace = jax.core.trace_ctx.trace
|
133
|
+
state_trace.set_new_arg(trace.new_arg)
|
134
|
+
return state_trace
|
137
135
|
|
138
136
|
|
139
137
|
class StatefulFunction(object):
|
@@ -383,12 +381,15 @@ class StatefulFunction(object):
|
|
383
381
|
A tuple of the states that are read and written by the function and the output of the function.
|
384
382
|
"""
|
385
383
|
# state trace
|
386
|
-
state_trace =
|
384
|
+
state_trace = _init_state_trace_stack()
|
387
385
|
self._cached_state_trace[cache_key] = state_trace
|
388
386
|
with state_trace:
|
389
387
|
out = self.fun(*args, **kwargs)
|
390
|
-
state_values =
|
391
|
-
|
388
|
+
state_values = (
|
389
|
+
state_trace.get_write_state_values(True)
|
390
|
+
if return_only_write else
|
391
|
+
state_trace.get_state_values()
|
392
|
+
)
|
392
393
|
state_trace.recovery_original_values()
|
393
394
|
|
394
395
|
# State instance as functional returns is not allowed.
|
@@ -419,17 +420,21 @@ class StatefulFunction(object):
|
|
419
420
|
try:
|
420
421
|
# jaxpr
|
421
422
|
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
422
|
-
functools.partial(
|
423
|
+
functools.partial(
|
424
|
+
self._wrapped_fun_to_eval,
|
425
|
+
cache_key,
|
426
|
+
return_only_write=return_only_write
|
427
|
+
),
|
423
428
|
static_argnums=self.static_argnums,
|
424
429
|
axis_env=self.axis_env,
|
425
430
|
return_shape=True,
|
426
431
|
abstracted_axes=self.abstracted_axes
|
427
432
|
)(*args, **kwargs)
|
428
|
-
|
429
433
|
# returns
|
430
434
|
self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
431
435
|
self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
|
432
436
|
self._cached_jaxpr[cache_key] = jaxpr
|
437
|
+
|
433
438
|
except Exception as e:
|
434
439
|
try:
|
435
440
|
self._cached_state_trace.pop(cache_key)
|
@@ -93,14 +93,32 @@ class ProgressBarRunner(object):
|
|
93
93
|
self.tqdm_bars[0].update(self.remainder)
|
94
94
|
self.tqdm_bars[0].close()
|
95
95
|
|
96
|
+
def _tqdm(self, is_init, is_print, is_final):
|
97
|
+
if is_init:
|
98
|
+
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
99
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
100
|
+
if is_print:
|
101
|
+
self.tqdm_bars[0].update(self.print_freq)
|
102
|
+
if is_final:
|
103
|
+
if self.remainder > 0:
|
104
|
+
self.tqdm_bars[0].update(self.remainder)
|
105
|
+
self.tqdm_bars[0].close()
|
106
|
+
|
96
107
|
def __call__(self, iter_num, *args, **kwargs):
|
108
|
+
# jax.debug.callback(
|
109
|
+
# self._tqdm,
|
110
|
+
# iter_num == 0,
|
111
|
+
# (iter_num + 1) % self.print_freq == 0,
|
112
|
+
# iter_num == self.n - 1
|
113
|
+
# )
|
114
|
+
|
97
115
|
_ = jax.lax.cond(
|
98
116
|
iter_num == 0,
|
99
117
|
lambda: jax.debug.callback(self._define_tqdm),
|
100
118
|
lambda: None,
|
101
119
|
)
|
102
120
|
_ = jax.lax.cond(
|
103
|
-
|
121
|
+
iter_num % self.print_freq == (self.print_freq - 1),
|
104
122
|
lambda: jax.debug.callback(self._update_tqdm),
|
105
123
|
lambda: None,
|
106
124
|
)
|
@@ -109,3 +127,4 @@ class ProgressBarRunner(object):
|
|
109
127
|
lambda: jax.debug.callback(self._close_tqdm),
|
110
128
|
lambda: None,
|
111
129
|
)
|
130
|
+
|
@@ -588,8 +588,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
|
|
588
588
|
|
589
589
|
def log_softmax(x: ArrayLike,
|
590
590
|
axis: int | tuple[int, ...] | None = -1,
|
591
|
-
where: ArrayLike | None = None,
|
592
|
-
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
591
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
593
592
|
r"""Log-Softmax function.
|
594
593
|
|
595
594
|
Computes the logarithm of the :code:`softmax` function, which rescales
|
@@ -604,8 +603,6 @@ def log_softmax(x: ArrayLike,
|
|
604
603
|
axis: the axis or axes along which the :code:`log_softmax` should be
|
605
604
|
computed. Either an integer or a tuple of integers.
|
606
605
|
where: Elements to include in the :code:`log_softmax`.
|
607
|
-
initial: The minimum value used to shift the input array. Must be present
|
608
|
-
when :code:`where` is not None.
|
609
606
|
|
610
607
|
Returns:
|
611
608
|
An array.
|
@@ -613,15 +610,12 @@ def log_softmax(x: ArrayLike,
|
|
613
610
|
See also:
|
614
611
|
:func:`softmax`
|
615
612
|
"""
|
616
|
-
|
617
|
-
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
618
|
-
return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
|
613
|
+
return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where)
|
619
614
|
|
620
615
|
|
621
616
|
def softmax(x: ArrayLike,
|
622
617
|
axis: int | tuple[int, ...] | None = -1,
|
623
|
-
where: ArrayLike | None = None,
|
624
|
-
initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
618
|
+
where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
|
625
619
|
r"""Softmax function.
|
626
620
|
|
627
621
|
Computes the function which rescales elements to the range :math:`[0, 1]`
|
@@ -645,9 +639,7 @@ def softmax(x: ArrayLike,
|
|
645
639
|
See also:
|
646
640
|
:func:`log_softmax`
|
647
641
|
"""
|
648
|
-
|
649
|
-
initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
|
650
|
-
return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
|
642
|
+
return _keep_unit(jax.nn.softmax, x, axis=axis, where=where)
|
651
643
|
|
652
644
|
|
653
645
|
def standardize(x: ArrayLike,
|
@@ -608,7 +608,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
|
|
608
608
|
if isinstance(value, TreefyState):
|
609
609
|
variable.update_from_ref(value)
|
610
610
|
elif isinstance(value, State):
|
611
|
-
|
611
|
+
if value._been_writen:
|
612
|
+
variable.write_value(value.value)
|
613
|
+
else:
|
614
|
+
variable.restore_value(value.value)
|
612
615
|
else:
|
613
616
|
raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
|
614
617
|
else: # if it doesn't, create a new variable
|
brainstate/nn/_collective_ops.py
CHANGED
@@ -20,8 +20,10 @@ from typing import Dict, Callable, TypeVar
|
|
20
20
|
|
21
21
|
import jax
|
22
22
|
|
23
|
+
from brainstate._state import catch_new_states
|
23
24
|
from brainstate._utils import set_module_as
|
24
25
|
from brainstate.graph import nodes
|
26
|
+
from brainstate.util._filter import Filter
|
25
27
|
from ._module import Module
|
26
28
|
|
27
29
|
# the maximum order
|
@@ -74,16 +76,29 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
|
|
74
76
|
|
75
77
|
|
76
78
|
@set_module_as('brainstate.nn')
|
77
|
-
def init_all_states(
|
79
|
+
def init_all_states(
|
80
|
+
target: T,
|
81
|
+
*args,
|
82
|
+
exclude: Filter = None,
|
83
|
+
**kwargs
|
84
|
+
) -> T:
|
78
85
|
"""
|
79
86
|
Collectively initialize states of all children nodes in the given target.
|
80
87
|
|
81
88
|
Args:
|
82
89
|
target: The target Module.
|
90
|
+
exclude: The filter to exclude some nodes.
|
91
|
+
tag: The tag for the new states.
|
92
|
+
args: The positional arguments for the initialization, which will be passed to the `init_state` method
|
93
|
+
of each node.
|
94
|
+
kwargs: The keyword arguments for the initialization, which will be passed to the `init_state` method
|
95
|
+
of each node.
|
83
96
|
|
84
97
|
Returns:
|
85
98
|
The target Module.
|
86
99
|
"""
|
100
|
+
|
101
|
+
# node that has `call_order` decorated
|
87
102
|
nodes_with_order = []
|
88
103
|
|
89
104
|
nodes_ = nodes(target).filter(Module)
|
@@ -97,7 +112,7 @@ def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
|
|
97
112
|
else:
|
98
113
|
node.init_state(*args, **kwargs)
|
99
114
|
|
100
|
-
# reset the node's states
|
115
|
+
# reset the node's states with `call_order`
|
101
116
|
for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
|
102
117
|
node.init_state(*args, **kwargs)
|
103
118
|
|
@@ -115,6 +130,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module:
|
|
115
130
|
Returns:
|
116
131
|
The target Module.
|
117
132
|
"""
|
133
|
+
|
118
134
|
nodes_with_order = []
|
119
135
|
|
120
136
|
# reset node whose `init_state` has no `call_order`
|
@@ -112,7 +112,7 @@ class STP(Synapse):
|
|
112
112
|
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
|
113
113
|
|
114
114
|
def update(self, pre_spike):
|
115
|
-
u = exp_euler_step(lambda u:
|
115
|
+
u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
|
116
116
|
x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
|
117
117
|
|
118
118
|
# --- original code:
|
@@ -17,7 +17,7 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
from functools import partial
|
20
|
-
from typing import Optional
|
20
|
+
from typing import Optional, Sequence
|
21
21
|
|
22
22
|
import brainunit as u
|
23
23
|
import jax.numpy as jnp
|
@@ -29,7 +29,6 @@ from brainstate.typing import Size
|
|
29
29
|
|
30
30
|
__all__ = [
|
31
31
|
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
|
32
|
-
'AlphaDropout', 'FeatureAlphaDropout',
|
33
32
|
]
|
34
33
|
|
35
34
|
|
@@ -47,9 +46,9 @@ class Dropout(ElementWiseBlock):
|
|
47
46
|
research 15.1 (2014): 1929-1958.
|
48
47
|
|
49
48
|
Args:
|
50
|
-
|
51
|
-
|
52
|
-
|
49
|
+
prob: Probability to keep element of the tensor.
|
50
|
+
broadcast_dims: dimensions that will share the same dropout mask.
|
51
|
+
name: str. The name of the dynamic system.
|
53
52
|
|
54
53
|
"""
|
55
54
|
__module__ = 'brainstate.nn'
|
@@ -57,20 +56,28 @@ class Dropout(ElementWiseBlock):
|
|
57
56
|
def __init__(
|
58
57
|
self,
|
59
58
|
prob: float = 0.5,
|
59
|
+
broadcast_dims: Sequence[int] = (),
|
60
60
|
name: Optional[str] = None
|
61
61
|
) -> None:
|
62
62
|
super().__init__(name=name)
|
63
63
|
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
64
64
|
self.prob = prob
|
65
|
+
self.broadcast_dims = broadcast_dims
|
65
66
|
|
66
67
|
def __call__(self, x):
|
67
68
|
dtype = u.math.get_dtype(x)
|
68
69
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
69
70
|
if fit_phase and self.prob < 1.:
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
71
|
+
broadcast_shape = list(x.shape)
|
72
|
+
for dim in self.broadcast_dims:
|
73
|
+
broadcast_shape[dim] = 1
|
74
|
+
keep_mask = random.bernoulli(self.prob, broadcast_shape)
|
75
|
+
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
|
76
|
+
return jnp.where(
|
77
|
+
keep_mask,
|
78
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
79
|
+
jnp.asarray(0., dtype=dtype)
|
80
|
+
)
|
74
81
|
else:
|
75
82
|
return x
|
76
83
|
|
@@ -93,7 +100,6 @@ class _DropoutNd(ElementWiseBlock):
|
|
93
100
|
self.channel_axis = channel_axis
|
94
101
|
|
95
102
|
def __call__(self, x):
|
96
|
-
|
97
103
|
# check input shape
|
98
104
|
inp_dim = u.math.ndim(x)
|
99
105
|
if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
|
@@ -114,10 +120,13 @@ class _DropoutNd(ElementWiseBlock):
|
|
114
120
|
# generate mask
|
115
121
|
if fit_phase and self.prob < 1.:
|
116
122
|
dtype = u.math.get_dtype(x)
|
117
|
-
keep_mask =
|
118
|
-
|
119
|
-
|
120
|
-
|
123
|
+
keep_mask = random.bernoulli(self.prob, mask_shape)
|
124
|
+
keep_mask = jnp.broadcast_to(keep_mask, x.shape)
|
125
|
+
return jnp.where(
|
126
|
+
keep_mask,
|
127
|
+
jnp.asarray(x / self.prob, dtype=dtype),
|
128
|
+
jnp.asarray(0., dtype=dtype)
|
129
|
+
)
|
121
130
|
else:
|
122
131
|
return x
|
123
132
|
|
@@ -296,8 +305,8 @@ class AlphaDropout(_DropoutNd):
|
|
296
305
|
"""
|
297
306
|
__module__ = 'brainstate.nn'
|
298
307
|
|
299
|
-
def
|
300
|
-
|
308
|
+
def update(self, *args, **kwargs):
|
309
|
+
raise NotImplementedError("AlphaDropout is not supported in the current version.")
|
301
310
|
|
302
311
|
|
303
312
|
class FeatureAlphaDropout(_DropoutNd):
|
@@ -344,8 +353,8 @@ class FeatureAlphaDropout(_DropoutNd):
|
|
344
353
|
"""
|
345
354
|
__module__ = 'brainstate.nn'
|
346
355
|
|
347
|
-
def
|
348
|
-
|
356
|
+
def update(self, *args, **kwargs):
|
357
|
+
raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
|
349
358
|
|
350
359
|
|
351
360
|
class DropoutFixed(ElementWiseBlock):
|
@@ -23,7 +23,7 @@ import jax
|
|
23
23
|
import jax.numpy as jnp
|
24
24
|
|
25
25
|
from brainstate import environ, init
|
26
|
-
from brainstate._state import
|
26
|
+
from brainstate._state import ParamState, BatchState
|
27
27
|
from brainstate.nn._module import Module
|
28
28
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
29
29
|
|
@@ -91,6 +91,18 @@ def _abs_sq(x):
|
|
91
91
|
return jax.lax.square(x)
|
92
92
|
|
93
93
|
|
94
|
+
class NormalizationParamState(ParamState):
|
95
|
+
# This is a dummy class to be used as a compatibility
|
96
|
+
# usage of `ETraceParam` for the layers in "brainetrace"
|
97
|
+
def execute(self, x):
|
98
|
+
param = self.value
|
99
|
+
if 'scale' in param:
|
100
|
+
x = x * param['scale']
|
101
|
+
if 'bias' in param:
|
102
|
+
x = x + param['bias']
|
103
|
+
return x
|
104
|
+
|
105
|
+
|
94
106
|
def _compute_stats(
|
95
107
|
x: ArrayLike,
|
96
108
|
axes: Sequence[int],
|
@@ -150,12 +162,17 @@ def _compute_stats(
|
|
150
162
|
# In the distributed case we stack multiple arrays to speed comms.
|
151
163
|
if len(xs) > 1:
|
152
164
|
reduced_mus = jax.lax.pmean(
|
153
|
-
jnp.stack(mus, axis=0),
|
165
|
+
jnp.stack(mus, axis=0),
|
166
|
+
axis_name,
|
154
167
|
axis_index_groups=axis_index_groups,
|
155
168
|
)
|
156
169
|
return tuple(reduced_mus[i] for i in range(len(xs)))
|
157
170
|
else:
|
158
|
-
return jax.lax.pmean(
|
171
|
+
return jax.lax.pmean(
|
172
|
+
mus[0],
|
173
|
+
axis_name,
|
174
|
+
axis_index_groups=axis_index_groups
|
175
|
+
)
|
159
176
|
|
160
177
|
if use_mean:
|
161
178
|
if use_fast_variance:
|
@@ -176,7 +193,7 @@ def _normalize(
|
|
176
193
|
x: ArrayLike,
|
177
194
|
mean: Optional[ArrayLike],
|
178
195
|
var: Optional[ArrayLike],
|
179
|
-
weights: Optional[
|
196
|
+
weights: Optional[NormalizationParamState],
|
180
197
|
reduction_axes: Axes,
|
181
198
|
feature_axes: Axes,
|
182
199
|
dtype: DTypeLike,
|
@@ -212,10 +229,9 @@ def _normalize(
|
|
212
229
|
y = x - mean
|
213
230
|
mul = jax.lax.rsqrt(var + epsilon)
|
214
231
|
y = y * mul
|
215
|
-
args = []
|
216
232
|
if weights is not None:
|
217
|
-
y
|
218
|
-
dtype = canonicalize_dtype(x, *
|
233
|
+
y = weights.execute(y)
|
234
|
+
dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
|
219
235
|
else:
|
220
236
|
assert var is None, 'mean and val must be both None or not None.'
|
221
237
|
assert weights is None, 'scale and bias are not supported without mean and val'
|
@@ -223,17 +239,6 @@ def _normalize(
|
|
223
239
|
return jnp.asarray(y, dtype)
|
224
240
|
|
225
241
|
|
226
|
-
def _scale_operation(x: jax.Array, param: Dict):
|
227
|
-
args = []
|
228
|
-
if 'scale' in param:
|
229
|
-
x = x * param['scale']
|
230
|
-
args.append(param['scale'])
|
231
|
-
if 'bias' in param:
|
232
|
-
x = x + param['bias']
|
233
|
-
args.append(param['bias'])
|
234
|
-
return x, args
|
235
|
-
|
236
|
-
|
237
242
|
class _BatchNorm(Module):
|
238
243
|
__module__ = 'brainstate.nn'
|
239
244
|
num_spatial_dims: int
|
@@ -254,6 +259,8 @@ class _BatchNorm(Module):
|
|
254
259
|
use_fast_variance: bool = True,
|
255
260
|
name: Optional[str] = None,
|
256
261
|
dtype: Any = None,
|
262
|
+
param_type: type = NormalizationParamState,
|
263
|
+
mean_type: type = BatchState,
|
257
264
|
):
|
258
265
|
super().__init__(name=name)
|
259
266
|
|
@@ -279,8 +286,8 @@ class _BatchNorm(Module):
|
|
279
286
|
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
280
287
|
for i, ax in enumerate(self.in_size)])
|
281
288
|
if self.track_running_stats:
|
282
|
-
self.running_mean =
|
283
|
-
self.running_var =
|
289
|
+
self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
|
290
|
+
self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
|
284
291
|
else:
|
285
292
|
self.running_mean = None
|
286
293
|
self.running_var = None
|
@@ -290,7 +297,7 @@ class _BatchNorm(Module):
|
|
290
297
|
assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
|
291
298
|
bias = init.param(self.bias_initializer, feature_shape)
|
292
299
|
scale = init.param(self.scale_initializer, feature_shape)
|
293
|
-
self.weight =
|
300
|
+
self.weight = param_type(dict(bias=bias, scale=scale))
|
294
301
|
else:
|
295
302
|
self.weight = None
|
296
303
|
|
@@ -531,6 +538,7 @@ class LayerNorm(Module):
|
|
531
538
|
axis_index_groups: Any = None,
|
532
539
|
use_fast_variance: bool = True,
|
533
540
|
dtype: Optional[jax.typing.DTypeLike] = None,
|
541
|
+
param_type: type = NormalizationParamState,
|
534
542
|
):
|
535
543
|
super().__init__()
|
536
544
|
|
@@ -554,7 +562,7 @@ class LayerNorm(Module):
|
|
554
562
|
if use_bias:
|
555
563
|
weights['bias'] = init.param(bias_init, feature_shape)
|
556
564
|
if len(weights):
|
557
|
-
self.weight =
|
565
|
+
self.weight = param_type(weights)
|
558
566
|
else:
|
559
567
|
self.weight = None
|
560
568
|
|
@@ -654,6 +662,7 @@ class RMSNorm(Module):
|
|
654
662
|
axis_name: Optional[str] = None,
|
655
663
|
axis_index_groups: Any = None,
|
656
664
|
use_fast_variance: bool = True,
|
665
|
+
param_type: type = NormalizationParamState,
|
657
666
|
):
|
658
667
|
super().__init__()
|
659
668
|
|
@@ -663,7 +672,7 @@ class RMSNorm(Module):
|
|
663
672
|
# parameters about axis
|
664
673
|
feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
|
665
674
|
self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
|
666
|
-
self.reduction_axes = (reduction_axes,
|
675
|
+
self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
|
667
676
|
self.axis_name = axis_name
|
668
677
|
self.axis_index_groups = axis_index_groups
|
669
678
|
|
@@ -671,7 +680,7 @@ class RMSNorm(Module):
|
|
671
680
|
feature_shape = tuple([(ax if i in self.feature_axes else 1)
|
672
681
|
for i, ax in enumerate(self.in_size)])
|
673
682
|
if use_scale:
|
674
|
-
self.scale =
|
683
|
+
self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
|
675
684
|
else:
|
676
685
|
self.scale = None
|
677
686
|
|
@@ -795,6 +804,7 @@ class GroupNorm(Module):
|
|
795
804
|
axis_name: Optional[str] = None,
|
796
805
|
axis_index_groups: Any = None,
|
797
806
|
use_fast_variance: bool = True,
|
807
|
+
param_type: type = NormalizationParamState,
|
798
808
|
):
|
799
809
|
super().__init__()
|
800
810
|
|
@@ -848,7 +858,7 @@ class GroupNorm(Module):
|
|
848
858
|
if use_bias:
|
849
859
|
weights['bias'] = init.param(bias_init, feature_shape)
|
850
860
|
if len(weights):
|
851
|
-
self.weight =
|
861
|
+
self.weight = param_type(weights)
|
852
862
|
else:
|
853
863
|
self.weight = None
|
854
864
|
|
brainstate/util/_tracers.py
CHANGED
@@ -16,7 +16,6 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import jax
|
18
18
|
import jax.core
|
19
|
-
from jax.interpreters import partial_eval as pe
|
20
19
|
|
21
20
|
from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr
|
22
21
|
|
@@ -25,12 +24,6 @@ __all__ = [
|
|
25
24
|
]
|
26
25
|
|
27
26
|
|
28
|
-
def new_jax_trace():
|
29
|
-
main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
|
30
|
-
frame = main.jaxpr_stack[-1]
|
31
|
-
trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
|
32
|
-
return frame, trace
|
33
|
-
|
34
27
|
|
35
28
|
def current_jax_trace():
|
36
29
|
"""Returns the Jax tracing state."""
|
{brainstate-0.1.0.post20241129.dist-info → brainstate-0.1.0.post20241210.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20241210
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,7 +1,5 @@
|
|
1
|
-
benchmark/COBA_2005.py,sha256=Q8PsZ0lxu14jsF3bCtlZW35iQB8S2_oFEUYQzK2hPiA,5561
|
2
|
-
benchmark/CUBA_2005.py,sha256=_W94yOMh2ueqblU4ItEPeTLwHF0_lbEWlVNEBy0Tix0,6222
|
3
1
|
brainstate/__init__.py,sha256=r7C3eLTg8LEusoH6PGgBFFt4ZgbketYLoLA0lQhUCsE,2098
|
4
|
-
brainstate/_state.py,sha256=
|
2
|
+
brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
|
5
3
|
brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
|
6
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
7
5
|
brainstate/environ.py,sha256=G6r_rqfbofRbjFFalRu_DHaL7ruFTeLRXBQDXM6P-tQ,17477
|
@@ -17,24 +15,24 @@ brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9
|
|
17
15
|
brainstate/augment/_eval_shape.py,sha256=dGlRVHOAZ9LSRZsFi1erxgEWHrnhBO3Kq3WW11-Hvng,3819
|
18
16
|
brainstate/augment/_eval_shape_test.py,sha256=1nnxbU7hPRbZPQWNWbQ518pw-H7FGDKKnQpZGBY9uRI,1390
|
19
17
|
brainstate/augment/_mapping.py,sha256=cpxzVGCEYnP5jPqrowYoPXciw_-QR2F3wggrRj1OCPc,21850
|
20
|
-
brainstate/augment/_mapping_test.py,sha256=
|
18
|
+
brainstate/augment/_mapping_test.py,sha256=TEAecjZmTSDCfxARgrzcDJ2dW1Yz_sCITmFiA9FGrhk,9455
|
21
19
|
brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
|
22
20
|
brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
|
23
21
|
brainstate/compile/_ad_checkpoint.py,sha256=5zJ1ENeTU4FzRY_uNpr85NhKfuicMMjcIbhu6-bSM4k,9451
|
24
22
|
brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
|
25
|
-
brainstate/compile/_conditions.py,sha256=
|
23
|
+
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
26
24
|
brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
|
27
25
|
brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
|
28
26
|
brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
|
29
|
-
brainstate/compile/_jit.py,sha256=
|
27
|
+
brainstate/compile/_jit.py,sha256=3mQ-RUFz35wceZyKE_MoR58OBL0RK_i6sHm4rWYzMLs,13698
|
30
28
|
brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
|
31
|
-
brainstate/compile/_loop_collect_return.py,sha256=
|
29
|
+
brainstate/compile/_loop_collect_return.py,sha256=DybSBixeuxleKJV6n9FgVBDsUTmexzS0IdgWYRqp5cU,22940
|
32
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
33
|
-
brainstate/compile/_loop_no_collection.py,sha256=
|
31
|
+
brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
|
34
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
35
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
|
36
34
|
brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
|
37
|
-
brainstate/compile/_progress_bar.py,sha256=
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=H544Oh10SiF5ccrKHM9ay7ZHigYIhNhSQGEKbDxRJgg,4485
|
38
36
|
brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
|
39
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
40
38
|
brainstate/event/__init__.py,sha256=wOBkq7kDg90M8Y9FuoXRlSEuu1ZzbIhCJ1dHeLqN6_Q,1194
|
@@ -52,7 +50,7 @@ brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,104
|
|
52
50
|
brainstate/event/_xla_custom_op.py,sha256=QB4jz_fUEPF-efJCVKAxwx8U79AqdcKoEg2QrGwot8I,10864
|
53
51
|
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
54
52
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
55
|
-
brainstate/functional/_activations.py,sha256=
|
53
|
+
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
56
54
|
brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
|
57
55
|
brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
|
58
56
|
brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
|
@@ -63,7 +61,7 @@ brainstate/graph/_graph_context_test.py,sha256=IYpjqbXwSFF65XL0ZbdPeC1jYyEHLpQVr
|
|
63
61
|
brainstate/graph/_graph_convert.py,sha256=llSREtGQrIggkD0wmxUbYKuSveLW4ihDZME6Ab-mRTQ,9147
|
64
62
|
brainstate/graph/_graph_node.py,sha256=BTuVlGgA2b82zNudjsN88QXuxfDcMvU2-kB64AkdQnY,8993
|
65
63
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
66
|
-
brainstate/graph/_graph_operation.py,sha256=
|
64
|
+
brainstate/graph/_graph_operation.py,sha256=PupZeFWBR-OHbhdJcoqlvy2YqoIS9Ze4q0tz8HRy4f4,64166
|
67
65
|
brainstate/graph/_graph_operation_test.py,sha256=ADyyuMk2xidEkkFNpGvUbvEtRmUj-tqOI4cF3eRuakM,24678
|
68
66
|
brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
|
69
67
|
brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
|
@@ -73,7 +71,7 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
|
|
73
71
|
brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
|
74
72
|
brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
|
75
73
|
brainstate/nn/__init__.py,sha256=rxURT8J1XfBn3Vh3Dx_WzVADWn9zVriIty5KZEG-x6o,1622
|
76
|
-
brainstate/nn/_collective_ops.py,sha256=
|
74
|
+
brainstate/nn/_collective_ops.py,sha256=sSjIIs1MvZA30XFFmK7iL1D_sCeh7hFd3PanCH6kgZo,6779
|
77
75
|
brainstate/nn/_exp_euler.py,sha256=yjkfSllFxGWKEAlHo5AzBizzkFj6FEVDKmFV6E2g214,3521
|
78
76
|
brainstate/nn/_exp_euler_test.py,sha256=clwRD8QR71k1jn6NrACMDEUcFMh0J9RTosoPnlYWUkw,1242
|
79
77
|
brainstate/nn/_module.py,sha256=HDLPvLfB7jat2VT3gBu0MxA7vfzK7xgowemitHX8Cgo,10835
|
@@ -82,7 +80,7 @@ brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,1473
|
|
82
80
|
brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
|
83
81
|
brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2K8k5FAcf3Pa5N8,10927
|
84
82
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
85
|
-
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=
|
83
|
+
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=IHy6IsGjWpKZ8NLq4X7PaRwx3tpO2HRZNppCWM2fe4I,11862
|
86
84
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
87
85
|
brainstate/nn/_dyn_impl/_inputs.py,sha256=6eZKnkmrM0Gog2fpSKjSnwnQvhbFYhG4q9Vuo-GH2LI,5050
|
88
86
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
@@ -98,7 +96,7 @@ brainstate/nn/_dynamics/_state_delay.py,sha256=nZYGmVKmQvAQu-W4YOUFH1gnr-ZS3rg_G
|
|
98
96
|
brainstate/nn/_dynamics/_synouts.py,sha256=9TGAc-nVa50th7KKn4oKLbro-4W4rwxYvp-eu7ksAIE,4491
|
99
97
|
brainstate/nn/_dynamics/_synouts_test.py,sha256=V_jDswRN4VvEXD-2yJO3VA1TALgX0HK6oPBQiUntOWc,2266
|
100
98
|
brainstate/nn/_elementwise/__init__.py,sha256=PK8oq1K_EG2941AiUyLxCWoRdWvMO3yt8ZJbw3Lkhu8,935
|
101
|
-
brainstate/nn/_elementwise/_dropout.py,sha256=
|
99
|
+
brainstate/nn/_elementwise/_dropout.py,sha256=0Ebo-2y1VswvBqZ7sCA0SEUm37y49EUsef8oiSFpYGk,17759
|
102
100
|
brainstate/nn/_elementwise/_dropout_test.py,sha256=Qn7xqZOyZMPCGF6tFjTiPId0yELOXjSsW5-hgihP3fE,4383
|
103
101
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
104
102
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
@@ -108,7 +106,7 @@ brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34
|
|
108
106
|
brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
|
109
107
|
brainstate/nn/_interaction/_linear.py,sha256=bjiWGJCe81ugQQOykwjWlLW5uhe0CHWwkPA20a4n5YQ,21340
|
110
108
|
brainstate/nn/_interaction/_linear_test.py,sha256=KlvFZA0rpyaspf4LT4K7u-RR5jCEB_q1WReqAw9sFcU,1274
|
111
|
-
brainstate/nn/_interaction/_normalizations.py,sha256=
|
109
|
+
brainstate/nn/_interaction/_normalizations.py,sha256=7YDzkmO_iqd70fH_wawb60Bu8eGOdvZq23emP-b68Hc,37440
|
112
110
|
brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
|
113
111
|
brainstate/nn/_interaction/_poolings.py,sha256=LpwuyeNBVCaVFW7zWc7E-vvlYqx54h46Br5XT6zd_94,47020
|
114
112
|
brainstate/nn/_interaction/_poolings_test.py,sha256=wmd5PngZ3E9tNyF3s0xk-DoDR5yFqpTi9A6nbNoIqn4,7429
|
@@ -136,10 +134,10 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
136
134
|
brainstate/util/_pretty_repr.py,sha256=NYEBCo2iz9Potx-IR7uZZzt2aLQW_94vH79fGusiC2A,5737
|
137
135
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
138
136
|
brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
|
139
|
-
brainstate/util/_tracers.py,sha256
|
137
|
+
brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
|
140
138
|
brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
|
141
|
-
brainstate-0.1.0.
|
142
|
-
brainstate-0.1.0.
|
143
|
-
brainstate-0.1.0.
|
144
|
-
brainstate-0.1.0.
|
145
|
-
brainstate-0.1.0.
|
139
|
+
brainstate-0.1.0.post20241210.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
140
|
+
brainstate-0.1.0.post20241210.dist-info/METADATA,sha256=E6AATarjpwssXflLfA-OCkxFxqZxqJNxHZteO6UWMhw,3401
|
141
|
+
brainstate-0.1.0.post20241210.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
142
|
+
brainstate-0.1.0.post20241210.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
143
|
+
brainstate-0.1.0.post20241210.dist-info/RECORD,,
|
benchmark/COBA_2005.py
DELETED
@@ -1,125 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
#
|
17
|
-
# Implementation of the paper:
|
18
|
-
#
|
19
|
-
# - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
|
20
|
-
# Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
|
21
|
-
#
|
22
|
-
# which is based on the balanced network proposed by:
|
23
|
-
#
|
24
|
-
# - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
|
25
|
-
#
|
26
|
-
import os
|
27
|
-
import sys
|
28
|
-
|
29
|
-
sys.path.append('../')
|
30
|
-
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
|
31
|
-
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
32
|
-
|
33
|
-
|
34
|
-
import jax
|
35
|
-
import brainunit as u
|
36
|
-
import time
|
37
|
-
import brainstate as bst
|
38
|
-
|
39
|
-
|
40
|
-
class EINet(bst.nn.DynamicsGroup):
|
41
|
-
def __init__(self, scale):
|
42
|
-
super().__init__()
|
43
|
-
self.n_exc = int(3200 * scale)
|
44
|
-
self.n_inh = int(800 * scale)
|
45
|
-
self.num = self.n_exc + self.n_inh
|
46
|
-
self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
47
|
-
tau=20. * u.ms, tau_ref=5. * u.ms,
|
48
|
-
V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
|
49
|
-
self.E = bst.nn.AlignPostProj(
|
50
|
-
comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
|
51
|
-
syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
|
52
|
-
out=bst.nn.COBA.desc(E=0. * u.mV),
|
53
|
-
post=self.N
|
54
|
-
)
|
55
|
-
self.I = bst.nn.AlignPostProj(
|
56
|
-
comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
|
57
|
-
syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
|
58
|
-
out=bst.nn.COBA.desc(E=-80. * u.mV),
|
59
|
-
post=self.N
|
60
|
-
)
|
61
|
-
|
62
|
-
def init_state(self, *args, **kwargs):
|
63
|
-
self.rate = bst.ShortTermState(u.math.zeros(self.num))
|
64
|
-
|
65
|
-
def update(self, t, inp):
|
66
|
-
with bst.environ.context(t=t):
|
67
|
-
spk = self.N.get_spike() != 0.
|
68
|
-
self.E(spk[:self.n_exc])
|
69
|
-
self.I(spk[self.n_exc:])
|
70
|
-
self.N(inp)
|
71
|
-
self.rate.value += self.N.get_spike()
|
72
|
-
|
73
|
-
|
74
|
-
@bst.compile.jit(static_argnums=0)
|
75
|
-
def run(scale: float):
|
76
|
-
# network
|
77
|
-
net = EINet(scale)
|
78
|
-
bst.nn.init_all_states(net)
|
79
|
-
|
80
|
-
duration = 1e4 * u.ms
|
81
|
-
# simulation
|
82
|
-
with bst.environ.context(dt=0.1 * u.ms):
|
83
|
-
times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
|
84
|
-
bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
|
85
|
-
|
86
|
-
return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
|
87
|
-
|
88
|
-
|
89
|
-
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
|
90
|
-
jax.block_until_ready(run(s))
|
91
|
-
|
92
|
-
t0 = time.time()
|
93
|
-
n, rate = jax.block_until_ready(run(s))
|
94
|
-
t1 = time.time()
|
95
|
-
print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
|
96
|
-
|
97
|
-
|
98
|
-
# A6000 NVIDIA GPU
|
99
|
-
|
100
|
-
# scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
|
101
|
-
# scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
|
102
|
-
# scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
|
103
|
-
# scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
|
104
|
-
# scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
|
105
|
-
# scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
|
106
|
-
# scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
|
107
|
-
# scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
|
108
|
-
# scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
|
109
|
-
# scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
|
110
|
-
# scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
|
111
|
-
|
112
|
-
|
113
|
-
# AMD Ryzen 7 7840HS
|
114
|
-
|
115
|
-
# scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
|
116
|
-
# scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
|
117
|
-
# scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
|
118
|
-
# scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
|
119
|
-
# scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
|
120
|
-
# scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
|
121
|
-
# scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
|
122
|
-
# scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
|
123
|
-
# scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
|
124
|
-
# scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
|
125
|
-
# scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
|
benchmark/CUBA_2005.py
DELETED
@@ -1,149 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
#
|
17
|
-
# Implementation of the paper:
|
18
|
-
#
|
19
|
-
# - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
|
20
|
-
# Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
|
21
|
-
#
|
22
|
-
# which is based on the balanced network proposed by:
|
23
|
-
#
|
24
|
-
# - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
|
25
|
-
#
|
26
|
-
|
27
|
-
import os
|
28
|
-
import sys
|
29
|
-
|
30
|
-
sys.path.append('../')
|
31
|
-
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
|
32
|
-
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
|
33
|
-
|
34
|
-
|
35
|
-
import jax
|
36
|
-
import time
|
37
|
-
|
38
|
-
import brainunit as u
|
39
|
-
|
40
|
-
import brainstate as bst
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
class FixedProb(bst.nn.Module):
|
45
|
-
def __init__(self, n_pre, n_post, prob, weight):
|
46
|
-
super().__init__()
|
47
|
-
self.prob = prob
|
48
|
-
self.weight = weight
|
49
|
-
self.n_pre = n_pre
|
50
|
-
self.n_post = n_post
|
51
|
-
|
52
|
-
self.mask = bst.random.rand(n_pre, n_post) < prob
|
53
|
-
|
54
|
-
def update(self, x):
|
55
|
-
return (x @ self.mask) * self.weight
|
56
|
-
|
57
|
-
|
58
|
-
class EINet(bst.nn.DynamicsGroup):
|
59
|
-
def __init__(self, scale=1.0):
|
60
|
-
super().__init__()
|
61
|
-
self.n_exc = int(3200 * scale)
|
62
|
-
self.n_inh = int(800 * scale)
|
63
|
-
self.num = self.n_exc + self.n_inh
|
64
|
-
self.N = bst.nn.LIFRef(
|
65
|
-
self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
66
|
-
tau=20. * u.ms, tau_ref=5. * u.ms,
|
67
|
-
V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
|
68
|
-
)
|
69
|
-
self.E = bst.nn.AlignPostProj(
|
70
|
-
comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
|
71
|
-
# comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
|
72
|
-
syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
|
73
|
-
out=bst.nn.CUBA.desc(scale=u.volt),
|
74
|
-
post=self.N
|
75
|
-
)
|
76
|
-
self.I = bst.nn.AlignPostProj(
|
77
|
-
comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
|
78
|
-
# comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
|
79
|
-
syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
|
80
|
-
out=bst.nn.CUBA.desc(scale=u.volt),
|
81
|
-
post=self.N
|
82
|
-
)
|
83
|
-
|
84
|
-
def init_state(self, *args, **kwargs):
|
85
|
-
self.rate = bst.ShortTermState(u.math.zeros(self.num))
|
86
|
-
|
87
|
-
def update(self, t, inp):
|
88
|
-
with bst.environ.context(t=t):
|
89
|
-
spk = self.N.get_spike()
|
90
|
-
self.E(spk[:self.n_exc])
|
91
|
-
self.I(spk[self.n_exc:])
|
92
|
-
self.N(inp)
|
93
|
-
self.rate.value += self.N.get_spike()
|
94
|
-
|
95
|
-
|
96
|
-
@bst.compile.jit(static_argnums=0)
|
97
|
-
def run(scale: float):
|
98
|
-
# network
|
99
|
-
net = EINet(scale)
|
100
|
-
bst.nn.init_all_states(net)
|
101
|
-
|
102
|
-
duration = 1e4 * u.ms
|
103
|
-
# simulation
|
104
|
-
with bst.environ.context(dt=0.1 * u.ms):
|
105
|
-
times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
|
106
|
-
bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
|
107
|
-
# pbar=bst.compile.ProgressBar(100)
|
108
|
-
)
|
109
|
-
|
110
|
-
return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
|
111
|
-
|
112
|
-
|
113
|
-
for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
|
114
|
-
jax.block_until_ready(run(s))
|
115
|
-
|
116
|
-
t0 = time.time()
|
117
|
-
n, rate = jax.block_until_ready(run(s))
|
118
|
-
t1 = time.time()
|
119
|
-
print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
|
120
|
-
|
121
|
-
|
122
|
-
# A6000 NVIDIA GPU
|
123
|
-
|
124
|
-
# scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
|
125
|
-
# scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
|
126
|
-
# scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
|
127
|
-
# scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
|
128
|
-
# scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
|
129
|
-
# scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
|
130
|
-
# scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
|
131
|
-
# scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
|
132
|
-
# scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
|
133
|
-
# scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
|
134
|
-
# scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
|
135
|
-
|
136
|
-
|
137
|
-
# AMD Ryzen 7 7840HS
|
138
|
-
|
139
|
-
# scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
|
140
|
-
# scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
|
141
|
-
# scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
|
142
|
-
# scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
|
143
|
-
# scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
|
144
|
-
# scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
|
145
|
-
# scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
|
146
|
-
# scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
|
147
|
-
# scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
|
148
|
-
# scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
|
149
|
-
# scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz
|
File without changes
|
File without changes
|