brainstate 0.1.0.post20241129__py2.py3-none-any.whl → 0.1.0.post20241210__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_state.py CHANGED
@@ -99,6 +99,9 @@ def catch_new_states(tag: str = None) -> List:
99
99
 
100
100
 
101
101
  class Catcher:
102
+ """
103
+ The catcher to catch the new states.
104
+ """
102
105
  def __init__(self, tag: str):
103
106
  self.tag = tag
104
107
  self.state_ids = set()
@@ -231,6 +234,7 @@ class State(Generic[A], PrettyRepr):
231
234
  # avoid using self._setattr to avoid the check
232
235
  vars(self).update(metadata)
233
236
 
237
+ # record the state initialization
234
238
  record_state_init(self)
235
239
 
236
240
  if not TYPE_CHECKING:
@@ -290,7 +294,6 @@ class State(Generic[A], PrettyRepr):
290
294
  v: The value.
291
295
  """
292
296
  self.write_value(v)
293
- self._been_writen = True
294
297
 
295
298
  def write_value(self, v) -> None:
296
299
  # value checking
@@ -301,6 +304,8 @@ class State(Generic[A], PrettyRepr):
301
304
  record_state_value_write(self)
302
305
  # set the value
303
306
  self._value = v
307
+ # set flag
308
+ self._been_writen = True
304
309
 
305
310
  def restore_value(self, v) -> None:
306
311
  """
@@ -511,6 +516,15 @@ class LongTermState(State):
511
516
  __module__ = 'brainstate'
512
517
 
513
518
 
519
+ class BatchState(LongTermState):
520
+ """
521
+ The batch state, which is used to store the batch data in the program.
522
+ """
523
+
524
+ __module__ = 'brainstate'
525
+
526
+
527
+
514
528
  class HiddenState(ShortTermState):
515
529
  """
516
530
  The hidden state, which is used to store the hidden data in a dynamic model.
@@ -204,7 +204,91 @@ class TestVmap(unittest.TestCase):
204
204
  self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
205
205
  self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
206
206
 
207
+ write_state_ids = [id(st) for st in trace.get_write_states()]
208
+ read_state_ids = [id(st) for st in trace.get_read_states()]
209
+
210
+ assert id(foo.a) in read_state_ids
211
+ assert id(foo.b) in write_state_ids
212
+
213
+ print(trace.get_write_states())
214
+ print(trace.get_read_states())
215
+
216
+
217
+
218
+ def test_vmap_jit(self):
219
+ class Foo(bst.nn.Module):
220
+ def __init__(self):
221
+ super().__init__()
222
+ self.a = bst.ParamState(jnp.arange(4))
223
+ self.b = bst.ShortTermState(jnp.arange(4))
224
+
225
+ def __call__(self):
226
+ self.b.value = self.a.value * self.b.value
227
+
228
+ @bst.augment.vmap
229
+ def mul(foo):
230
+ foo()
231
+
232
+ @bst.compile.jit
233
+ def mul_jit(inp):
234
+ mul(foo)
235
+ foo.a.value += inp
236
+
237
+ foo = Foo()
238
+ with bst.StateTraceStack() as trace:
239
+ mul_jit(1.)
240
+
241
+ print(foo.a.value)
242
+ print(foo.b.value)
243
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
244
+ self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
245
+
246
+ write_state_ids = [id(st) for st in trace.get_write_states()]
247
+ read_state_ids = [id(st) for st in trace.get_read_states()]
248
+
249
+ assert id(foo.a) in write_state_ids
250
+ assert id(foo.b) in write_state_ids
251
+
207
252
  print(trace.get_write_states())
208
253
  print(trace.get_read_states())
209
254
 
210
255
 
256
+ def test_vmap_grad(self):
257
+ class Foo(bst.nn.Module):
258
+ def __init__(self):
259
+ super().__init__()
260
+ self.a = bst.ParamState(jnp.arange(4.))
261
+ self.b = bst.ShortTermState(jnp.arange(4.))
262
+
263
+ def __call__(self):
264
+ self.b.value = self.a.value * self.b.value
265
+
266
+ @bst.augment.vmap
267
+ def mul(foo):
268
+ foo()
269
+
270
+ def loss():
271
+ mul(foo)
272
+ return jnp.sum(foo.b.value)
273
+
274
+ foo = Foo()
275
+ with bst.StateTraceStack() as trace:
276
+ grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
277
+ print(grads)
278
+ print(loss)
279
+
280
+ # print(foo.a.value)
281
+ # print(foo.b.value)
282
+ # self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
283
+ # self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
284
+ #
285
+ # write_state_ids = [id(st) for st in trace.get_write_states()]
286
+ # read_state_ids = [id(st) for st in trace.get_read_states()]
287
+ #
288
+ # assert id(foo.a) in write_state_ids
289
+ # assert id(foo.b) in write_state_ids
290
+ #
291
+ # print(trace.get_write_states())
292
+ # print(trace.get_read_states())
293
+
294
+
@@ -94,9 +94,8 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
94
94
  return false_fun(*operands)
95
95
 
96
96
  # evaluate jaxpr
97
- with jax.ensure_compile_time_eval():
98
- stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
99
- stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
97
+ stateful_true = StatefulFunction(true_fun).make_jaxpr(*operands)
98
+ stateful_false = StatefulFunction(false_fun).make_jaxpr(*operands)
100
99
 
101
100
  # state trace and state values
102
101
  state_trace = stateful_true.get_state_trace() + stateful_false.get_state_trace()
@@ -175,10 +174,9 @@ def switch(index, branches: Sequence[Callable], *operands):
175
174
  return branches[int(index)](*operands)
176
175
 
177
176
  # evaluate jaxpr
178
- with jax.ensure_compile_time_eval():
179
- wrapped_branches = [StatefulFunction(branch) for branch in branches]
180
- for wrapped_branch in wrapped_branches:
181
- wrapped_branch.make_jaxpr(*operands)
177
+ wrapped_branches = [StatefulFunction(branch) for branch in branches]
178
+ for wrapped_branch in wrapped_branches:
179
+ wrapped_branch.make_jaxpr(*operands)
182
180
 
183
181
  # wrap the functions
184
182
  state_trace = wrapped_branches[0].get_state_trace() + wrapped_branches[1].get_state_trace()
@@ -83,9 +83,9 @@ def _get_jitted_fun(
83
83
  return fun.fun(*args, **params)
84
84
 
85
85
  # compile the function and get the state trace
86
- with jax.ensure_compile_time_eval():
87
- state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
88
- read_state_vals = state_trace.get_read_state_values(True)
86
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
87
+ read_state_vals = state_trace.get_read_state_values(True)
88
+
89
89
  # call the jitted function
90
90
  write_state_vals, outs = jit_fun(state_trace.get_state_values(), *args, **params)
91
91
  # write the state values back to the states
@@ -202,16 +202,19 @@ def scan(
202
202
  # ------------------------------ #
203
203
  xs_avals = [jax.core.raise_to_shaped(jax.core.get_aval(x)) for x in xs_flat]
204
204
  x_avals = [jax.core.mapped_aval(length, 0, aval) for aval in xs_avals]
205
- with jax.ensure_compile_time_eval():
206
- stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
207
- state_trace = stateful_fun.get_state_trace()
208
- all_writen_state_vals = state_trace.get_write_state_values(True)
209
- all_read_state_vals = state_trace.get_read_state_values(True)
210
- wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
205
+ stateful_fun = StatefulFunction(f).make_jaxpr(init, xs_tree.unflatten(x_avals))
206
+ state_trace = stateful_fun.get_state_trace()
207
+ all_writen_state_vals = state_trace.get_write_state_values(True)
208
+ all_read_state_vals = state_trace.get_read_state_values(True)
209
+ wrapped_f = wrap_single_fun(stateful_fun, state_trace.been_writen, all_read_state_vals)
211
210
 
212
211
  # scan
213
212
  init = (all_writen_state_vals, init)
214
- (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f, init, xs, length=length, reverse=reverse,
213
+ (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
214
+ init,
215
+ xs,
216
+ length=length,
217
+ reverse=reverse,
215
218
  unroll=unroll)
216
219
  # assign the written state values and restore the read state values
217
220
  write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
@@ -103,11 +103,10 @@ def while_loop(
103
103
  pass
104
104
 
105
105
  # evaluate jaxpr
106
- with jax.ensure_compile_time_eval():
107
- stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
108
- stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
109
- if len(stateful_cond.get_write_states()) != 0:
110
- raise ValueError("while_loop: cond_fun should not have any write states.")
106
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
107
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
108
+ if len(stateful_cond.get_write_states()) != 0:
109
+ raise ValueError("while_loop: cond_fun should not have any write states.")
111
110
 
112
111
  # state trace and state values
113
112
  state_trace = stateful_cond.get_state_trace() + stateful_body.get_state_trace()
@@ -72,7 +72,6 @@ from jax.util import wraps
72
72
  from brainstate._state import State, StateTraceStack
73
73
  from brainstate._utils import set_module_as
74
74
  from brainstate.typing import PyTree
75
- from brainstate.util._tracers import new_jax_trace
76
75
 
77
76
  AxisName = Hashable
78
77
 
@@ -112,28 +111,27 @@ def _new_arg_fn(frame, trace, aval):
112
111
  return tracer
113
112
 
114
113
 
115
- def _init_state_trace() -> StateTraceStack:
116
- # Should be within the calling of ``jax.make_jaxpr()``
117
- frame, trace = new_jax_trace()
114
+ def _new_jax_trace():
115
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
116
+ frame = main.jaxpr_stack[-1]
117
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
118
+ return frame, trace
119
+
120
+
121
+ def _init_state_trace_stack() -> StateTraceStack:
118
122
  state_trace: StateTraceStack = StateTraceStack()
119
- # Set the function to transform the new argument to a tracer
120
- state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
121
- return state_trace
122
123
 
124
+ if jax.__version_info__ < (0, 4, 36):
125
+ # Should be within the calling of ``jax.make_jaxpr()``
126
+ frame, trace = _new_jax_trace()
127
+ # Set the function to transform the new argument to a tracer
128
+ state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
129
+ return state_trace
123
130
 
124
- # def wrapped_abstractify(x: Any) -> Any:
125
- # """
126
- # Abstractify the input.
127
- #
128
- # Args:
129
- # x: The input.
130
- #
131
- # Returns:
132
- # The abstractified input.
133
- # """
134
- # if isinstance(x, pe.DynamicJaxprTracer):
135
- # return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
136
- # return shaped_abstractify(x)
131
+ else:
132
+ trace = jax.core.trace_ctx.trace
133
+ state_trace.set_new_arg(trace.new_arg)
134
+ return state_trace
137
135
 
138
136
 
139
137
  class StatefulFunction(object):
@@ -383,12 +381,15 @@ class StatefulFunction(object):
383
381
  A tuple of the states that are read and written by the function and the output of the function.
384
382
  """
385
383
  # state trace
386
- state_trace = _init_state_trace()
384
+ state_trace = _init_state_trace_stack()
387
385
  self._cached_state_trace[cache_key] = state_trace
388
386
  with state_trace:
389
387
  out = self.fun(*args, **kwargs)
390
- state_values = state_trace.get_write_state_values(
391
- True) if return_only_write else state_trace.get_state_values()
388
+ state_values = (
389
+ state_trace.get_write_state_values(True)
390
+ if return_only_write else
391
+ state_trace.get_state_values()
392
+ )
392
393
  state_trace.recovery_original_values()
393
394
 
394
395
  # State instance as functional returns is not allowed.
@@ -419,17 +420,21 @@ class StatefulFunction(object):
419
420
  try:
420
421
  # jaxpr
421
422
  jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
422
- functools.partial(self._wrapped_fun_to_eval, cache_key, return_only_write=return_only_write),
423
+ functools.partial(
424
+ self._wrapped_fun_to_eval,
425
+ cache_key,
426
+ return_only_write=return_only_write
427
+ ),
423
428
  static_argnums=self.static_argnums,
424
429
  axis_env=self.axis_env,
425
430
  return_shape=True,
426
431
  abstracted_axes=self.abstracted_axes
427
432
  )(*args, **kwargs)
428
-
429
433
  # returns
430
434
  self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
431
435
  self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
432
436
  self._cached_jaxpr[cache_key] = jaxpr
437
+
433
438
  except Exception as e:
434
439
  try:
435
440
  self._cached_state_trace.pop(cache_key)
@@ -93,14 +93,32 @@ class ProgressBarRunner(object):
93
93
  self.tqdm_bars[0].update(self.remainder)
94
94
  self.tqdm_bars[0].close()
95
95
 
96
+ def _tqdm(self, is_init, is_print, is_final):
97
+ if is_init:
98
+ self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
99
+ self.tqdm_bars[0].set_description(self.message, refresh=False)
100
+ if is_print:
101
+ self.tqdm_bars[0].update(self.print_freq)
102
+ if is_final:
103
+ if self.remainder > 0:
104
+ self.tqdm_bars[0].update(self.remainder)
105
+ self.tqdm_bars[0].close()
106
+
96
107
  def __call__(self, iter_num, *args, **kwargs):
108
+ # jax.debug.callback(
109
+ # self._tqdm,
110
+ # iter_num == 0,
111
+ # (iter_num + 1) % self.print_freq == 0,
112
+ # iter_num == self.n - 1
113
+ # )
114
+
97
115
  _ = jax.lax.cond(
98
116
  iter_num == 0,
99
117
  lambda: jax.debug.callback(self._define_tqdm),
100
118
  lambda: None,
101
119
  )
102
120
  _ = jax.lax.cond(
103
- (iter_num + 1) % self.print_freq == 0,
121
+ iter_num % self.print_freq == (self.print_freq - 1),
104
122
  lambda: jax.debug.callback(self._update_tqdm),
105
123
  lambda: None,
106
124
  )
@@ -109,3 +127,4 @@ class ProgressBarRunner(object):
109
127
  lambda: jax.debug.callback(self._close_tqdm),
110
128
  lambda: None,
111
129
  )
130
+
@@ -588,8 +588,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
588
588
 
589
589
  def log_softmax(x: ArrayLike,
590
590
  axis: int | tuple[int, ...] | None = -1,
591
- where: ArrayLike | None = None,
592
- initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
591
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
593
592
  r"""Log-Softmax function.
594
593
 
595
594
  Computes the logarithm of the :code:`softmax` function, which rescales
@@ -604,8 +603,6 @@ def log_softmax(x: ArrayLike,
604
603
  axis: the axis or axes along which the :code:`log_softmax` should be
605
604
  computed. Either an integer or a tuple of integers.
606
605
  where: Elements to include in the :code:`log_softmax`.
607
- initial: The minimum value used to shift the input array. Must be present
608
- when :code:`where` is not None.
609
606
 
610
607
  Returns:
611
608
  An array.
@@ -613,15 +610,12 @@ def log_softmax(x: ArrayLike,
613
610
  See also:
614
611
  :func:`softmax`
615
612
  """
616
- if initial is not None:
617
- initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
618
- return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
613
+ return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where)
619
614
 
620
615
 
621
616
  def softmax(x: ArrayLike,
622
617
  axis: int | tuple[int, ...] | None = -1,
623
- where: ArrayLike | None = None,
624
- initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
618
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
625
619
  r"""Softmax function.
626
620
 
627
621
  Computes the function which rescales elements to the range :math:`[0, 1]`
@@ -645,9 +639,7 @@ def softmax(x: ArrayLike,
645
639
  See also:
646
640
  :func:`log_softmax`
647
641
  """
648
- if initial is not None:
649
- initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
650
- return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
642
+ return _keep_unit(jax.nn.softmax, x, axis=axis, where=where)
651
643
 
652
644
 
653
645
  def standardize(x: ArrayLike,
@@ -608,7 +608,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
608
608
  if isinstance(value, TreefyState):
609
609
  variable.update_from_ref(value)
610
610
  elif isinstance(value, State):
611
- variable.restore_value(value.value)
611
+ if value._been_writen:
612
+ variable.write_value(value.value)
613
+ else:
614
+ variable.restore_value(value.value)
612
615
  else:
613
616
  raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
614
617
  else: # if it doesn't, create a new variable
@@ -20,8 +20,10 @@ from typing import Dict, Callable, TypeVar
20
20
 
21
21
  import jax
22
22
 
23
+ from brainstate._state import catch_new_states
23
24
  from brainstate._utils import set_module_as
24
25
  from brainstate.graph import nodes
26
+ from brainstate.util._filter import Filter
25
27
  from ._module import Module
26
28
 
27
29
  # the maximum order
@@ -74,16 +76,29 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
74
76
 
75
77
 
76
78
  @set_module_as('brainstate.nn')
77
- def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
79
+ def init_all_states(
80
+ target: T,
81
+ *args,
82
+ exclude: Filter = None,
83
+ **kwargs
84
+ ) -> T:
78
85
  """
79
86
  Collectively initialize states of all children nodes in the given target.
80
87
 
81
88
  Args:
82
89
  target: The target Module.
90
+ exclude: The filter to exclude some nodes.
91
+ tag: The tag for the new states.
92
+ args: The positional arguments for the initialization, which will be passed to the `init_state` method
93
+ of each node.
94
+ kwargs: The keyword arguments for the initialization, which will be passed to the `init_state` method
95
+ of each node.
83
96
 
84
97
  Returns:
85
98
  The target Module.
86
99
  """
100
+
101
+ # node that has `call_order` decorated
87
102
  nodes_with_order = []
88
103
 
89
104
  nodes_ = nodes(target).filter(Module)
@@ -97,7 +112,7 @@ def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
97
112
  else:
98
113
  node.init_state(*args, **kwargs)
99
114
 
100
- # reset the node's states
115
+ # reset the node's states with `call_order`
101
116
  for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
102
117
  node.init_state(*args, **kwargs)
103
118
 
@@ -115,6 +130,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module:
115
130
  Returns:
116
131
  The target Module.
117
132
  """
133
+
118
134
  nodes_with_order = []
119
135
 
120
136
  # reset node whose `init_state` has no `call_order`
@@ -112,7 +112,7 @@ class STP(Synapse):
112
112
  self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
113
113
 
114
114
  def update(self, pre_spike):
115
- u = exp_euler_step(lambda u: self.U - u / self.tau_f, self.u.value)
115
+ u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
116
116
  x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
117
117
 
118
118
  # --- original code:
@@ -17,7 +17,7 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  from functools import partial
20
- from typing import Optional
20
+ from typing import Optional, Sequence
21
21
 
22
22
  import brainunit as u
23
23
  import jax.numpy as jnp
@@ -29,7 +29,6 @@ from brainstate.typing import Size
29
29
 
30
30
  __all__ = [
31
31
  'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
32
- 'AlphaDropout', 'FeatureAlphaDropout',
33
32
  ]
34
33
 
35
34
 
@@ -47,9 +46,9 @@ class Dropout(ElementWiseBlock):
47
46
  research 15.1 (2014): 1929-1958.
48
47
 
49
48
  Args:
50
- prob: Probability to keep element of the tensor.
51
- mode: Mode. The computation mode of the object.
52
- name: str. The name of the dynamic system.
49
+ prob: Probability to keep element of the tensor.
50
+ broadcast_dims: dimensions that will share the same dropout mask.
51
+ name: str. The name of the dynamic system.
53
52
 
54
53
  """
55
54
  __module__ = 'brainstate.nn'
@@ -57,20 +56,28 @@ class Dropout(ElementWiseBlock):
57
56
  def __init__(
58
57
  self,
59
58
  prob: float = 0.5,
59
+ broadcast_dims: Sequence[int] = (),
60
60
  name: Optional[str] = None
61
61
  ) -> None:
62
62
  super().__init__(name=name)
63
63
  assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
64
64
  self.prob = prob
65
+ self.broadcast_dims = broadcast_dims
65
66
 
66
67
  def __call__(self, x):
67
68
  dtype = u.math.get_dtype(x)
68
69
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
69
70
  if fit_phase and self.prob < 1.:
70
- keep_mask = random.bernoulli(self.prob, x.shape)
71
- return jnp.where(keep_mask,
72
- jnp.asarray(x / self.prob, dtype=dtype),
73
- jnp.asarray(0., dtype=dtype))
71
+ broadcast_shape = list(x.shape)
72
+ for dim in self.broadcast_dims:
73
+ broadcast_shape[dim] = 1
74
+ keep_mask = random.bernoulli(self.prob, broadcast_shape)
75
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
76
+ return jnp.where(
77
+ keep_mask,
78
+ jnp.asarray(x / self.prob, dtype=dtype),
79
+ jnp.asarray(0., dtype=dtype)
80
+ )
74
81
  else:
75
82
  return x
76
83
 
@@ -93,7 +100,6 @@ class _DropoutNd(ElementWiseBlock):
93
100
  self.channel_axis = channel_axis
94
101
 
95
102
  def __call__(self, x):
96
-
97
103
  # check input shape
98
104
  inp_dim = u.math.ndim(x)
99
105
  if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
@@ -114,10 +120,13 @@ class _DropoutNd(ElementWiseBlock):
114
120
  # generate mask
115
121
  if fit_phase and self.prob < 1.:
116
122
  dtype = u.math.get_dtype(x)
117
- keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
118
- return jnp.where(keep_mask,
119
- jnp.asarray(x / self.prob, dtype=dtype),
120
- jnp.asarray(0., dtype=dtype))
123
+ keep_mask = random.bernoulli(self.prob, mask_shape)
124
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
125
+ return jnp.where(
126
+ keep_mask,
127
+ jnp.asarray(x / self.prob, dtype=dtype),
128
+ jnp.asarray(0., dtype=dtype)
129
+ )
121
130
  else:
122
131
  return x
123
132
 
@@ -296,8 +305,8 @@ class AlphaDropout(_DropoutNd):
296
305
  """
297
306
  __module__ = 'brainstate.nn'
298
307
 
299
- def forward(self, x):
300
- return F.alpha_dropout(x, self.p, self.training)
308
+ def update(self, *args, **kwargs):
309
+ raise NotImplementedError("AlphaDropout is not supported in the current version.")
301
310
 
302
311
 
303
312
  class FeatureAlphaDropout(_DropoutNd):
@@ -344,8 +353,8 @@ class FeatureAlphaDropout(_DropoutNd):
344
353
  """
345
354
  __module__ = 'brainstate.nn'
346
355
 
347
- def forward(self, x):
348
- return F.feature_alpha_dropout(x, self.p, self.training)
356
+ def update(self, *args, **kwargs):
357
+ raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
349
358
 
350
359
 
351
360
  class DropoutFixed(ElementWiseBlock):
@@ -23,7 +23,7 @@ import jax
23
23
  import jax.numpy as jnp
24
24
 
25
25
  from brainstate import environ, init
26
- from brainstate._state import LongTermState, ParamState
26
+ from brainstate._state import ParamState, BatchState
27
27
  from brainstate.nn._module import Module
28
28
  from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
29
29
 
@@ -91,6 +91,18 @@ def _abs_sq(x):
91
91
  return jax.lax.square(x)
92
92
 
93
93
 
94
+ class NormalizationParamState(ParamState):
95
+ # This is a dummy class to be used as a compatibility
96
+ # usage of `ETraceParam` for the layers in "brainetrace"
97
+ def execute(self, x):
98
+ param = self.value
99
+ if 'scale' in param:
100
+ x = x * param['scale']
101
+ if 'bias' in param:
102
+ x = x + param['bias']
103
+ return x
104
+
105
+
94
106
  def _compute_stats(
95
107
  x: ArrayLike,
96
108
  axes: Sequence[int],
@@ -150,12 +162,17 @@ def _compute_stats(
150
162
  # In the distributed case we stack multiple arrays to speed comms.
151
163
  if len(xs) > 1:
152
164
  reduced_mus = jax.lax.pmean(
153
- jnp.stack(mus, axis=0), axis_name,
165
+ jnp.stack(mus, axis=0),
166
+ axis_name,
154
167
  axis_index_groups=axis_index_groups,
155
168
  )
156
169
  return tuple(reduced_mus[i] for i in range(len(xs)))
157
170
  else:
158
- return jax.lax.pmean(mus[0], axis_name, axis_index_groups=axis_index_groups)
171
+ return jax.lax.pmean(
172
+ mus[0],
173
+ axis_name,
174
+ axis_index_groups=axis_index_groups
175
+ )
159
176
 
160
177
  if use_mean:
161
178
  if use_fast_variance:
@@ -176,7 +193,7 @@ def _normalize(
176
193
  x: ArrayLike,
177
194
  mean: Optional[ArrayLike],
178
195
  var: Optional[ArrayLike],
179
- weights: Optional[ParamState],
196
+ weights: Optional[NormalizationParamState],
180
197
  reduction_axes: Axes,
181
198
  feature_axes: Axes,
182
199
  dtype: DTypeLike,
@@ -212,10 +229,9 @@ def _normalize(
212
229
  y = x - mean
213
230
  mul = jax.lax.rsqrt(var + epsilon)
214
231
  y = y * mul
215
- args = []
216
232
  if weights is not None:
217
- y, args = _scale_operation(y, weights.value)
218
- dtype = canonicalize_dtype(x, *args, dtype=dtype)
233
+ y = weights.execute(y)
234
+ dtype = canonicalize_dtype(x, *jax.tree.leaves(weights.value), dtype=dtype)
219
235
  else:
220
236
  assert var is None, 'mean and val must be both None or not None.'
221
237
  assert weights is None, 'scale and bias are not supported without mean and val'
@@ -223,17 +239,6 @@ def _normalize(
223
239
  return jnp.asarray(y, dtype)
224
240
 
225
241
 
226
- def _scale_operation(x: jax.Array, param: Dict):
227
- args = []
228
- if 'scale' in param:
229
- x = x * param['scale']
230
- args.append(param['scale'])
231
- if 'bias' in param:
232
- x = x + param['bias']
233
- args.append(param['bias'])
234
- return x, args
235
-
236
-
237
242
  class _BatchNorm(Module):
238
243
  __module__ = 'brainstate.nn'
239
244
  num_spatial_dims: int
@@ -254,6 +259,8 @@ class _BatchNorm(Module):
254
259
  use_fast_variance: bool = True,
255
260
  name: Optional[str] = None,
256
261
  dtype: Any = None,
262
+ param_type: type = NormalizationParamState,
263
+ mean_type: type = BatchState,
257
264
  ):
258
265
  super().__init__(name=name)
259
266
 
@@ -279,8 +286,8 @@ class _BatchNorm(Module):
279
286
  feature_shape = tuple([(ax if i in self.feature_axes else 1)
280
287
  for i, ax in enumerate(self.in_size)])
281
288
  if self.track_running_stats:
282
- self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
283
- self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
289
+ self.running_mean = mean_type(jnp.zeros(feature_shape, dtype=self.dtype))
290
+ self.running_var = mean_type(jnp.ones(feature_shape, dtype=self.dtype))
284
291
  else:
285
292
  self.running_mean = None
286
293
  self.running_var = None
@@ -290,7 +297,7 @@ class _BatchNorm(Module):
290
297
  assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
291
298
  bias = init.param(self.bias_initializer, feature_shape)
292
299
  scale = init.param(self.scale_initializer, feature_shape)
293
- self.weight = ParamState(dict(bias=bias, scale=scale))
300
+ self.weight = param_type(dict(bias=bias, scale=scale))
294
301
  else:
295
302
  self.weight = None
296
303
 
@@ -531,6 +538,7 @@ class LayerNorm(Module):
531
538
  axis_index_groups: Any = None,
532
539
  use_fast_variance: bool = True,
533
540
  dtype: Optional[jax.typing.DTypeLike] = None,
541
+ param_type: type = NormalizationParamState,
534
542
  ):
535
543
  super().__init__()
536
544
 
@@ -554,7 +562,7 @@ class LayerNorm(Module):
554
562
  if use_bias:
555
563
  weights['bias'] = init.param(bias_init, feature_shape)
556
564
  if len(weights):
557
- self.weight = ParamState(weights)
565
+ self.weight = param_type(weights)
558
566
  else:
559
567
  self.weight = None
560
568
 
@@ -654,6 +662,7 @@ class RMSNorm(Module):
654
662
  axis_name: Optional[str] = None,
655
663
  axis_index_groups: Any = None,
656
664
  use_fast_variance: bool = True,
665
+ param_type: type = NormalizationParamState,
657
666
  ):
658
667
  super().__init__()
659
668
 
@@ -663,7 +672,7 @@ class RMSNorm(Module):
663
672
  # parameters about axis
664
673
  feature_axes = (feature_axes,) if isinstance(feature_axes, int) else feature_axes
665
674
  self.feature_axes = _canonicalize_axes(len(self.in_size), feature_axes)
666
- self.reduction_axes = (reduction_axes, ) if isinstance(reduction_axes, int) else reduction_axes
675
+ self.reduction_axes = (reduction_axes,) if isinstance(reduction_axes, int) else reduction_axes
667
676
  self.axis_name = axis_name
668
677
  self.axis_index_groups = axis_index_groups
669
678
 
@@ -671,7 +680,7 @@ class RMSNorm(Module):
671
680
  feature_shape = tuple([(ax if i in self.feature_axes else 1)
672
681
  for i, ax in enumerate(self.in_size)])
673
682
  if use_scale:
674
- self.scale = ParamState({'scale': init.param(scale_init, feature_shape)})
683
+ self.scale = param_type({'scale': init.param(scale_init, feature_shape)})
675
684
  else:
676
685
  self.scale = None
677
686
 
@@ -795,6 +804,7 @@ class GroupNorm(Module):
795
804
  axis_name: Optional[str] = None,
796
805
  axis_index_groups: Any = None,
797
806
  use_fast_variance: bool = True,
807
+ param_type: type = NormalizationParamState,
798
808
  ):
799
809
  super().__init__()
800
810
 
@@ -848,7 +858,7 @@ class GroupNorm(Module):
848
858
  if use_bias:
849
859
  weights['bias'] = init.param(bias_init, feature_shape)
850
860
  if len(weights):
851
- self.weight = ParamState(weights)
861
+ self.weight = param_type(weights)
852
862
  else:
853
863
  self.weight = None
854
864
 
@@ -16,7 +16,6 @@ from __future__ import annotations
16
16
 
17
17
  import jax
18
18
  import jax.core
19
- from jax.interpreters import partial_eval as pe
20
19
 
21
20
  from ._pretty_repr import PrettyRepr, PrettyType, PrettyAttr
22
21
 
@@ -25,12 +24,6 @@ __all__ = [
25
24
  ]
26
25
 
27
26
 
28
- def new_jax_trace():
29
- main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
30
- frame = main.jaxpr_stack[-1]
31
- trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
32
- return frame, trace
33
-
34
27
 
35
28
  def current_jax_trace():
36
29
  """Returns the Jax tracing state."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20241129
3
+ Version: 0.1.0.post20241210
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,7 +1,5 @@
1
- benchmark/COBA_2005.py,sha256=Q8PsZ0lxu14jsF3bCtlZW35iQB8S2_oFEUYQzK2hPiA,5561
2
- benchmark/CUBA_2005.py,sha256=_W94yOMh2ueqblU4ItEPeTLwHF0_lbEWlVNEBy0Tix0,6222
3
1
  brainstate/__init__.py,sha256=r7C3eLTg8LEusoH6PGgBFFt4ZgbketYLoLA0lQhUCsE,2098
4
- brainstate/_state.py,sha256=0h3B32130Tvv41eeolYTbEGT9FZ9WUZ9yYoGyVnte_c,28808
2
+ brainstate/_state.py,sha256=4aDpLyHGr1VlPXeLSfM3USQG5K4o7orF7IlaBdYrtfE,29098
5
3
  brainstate/_state_test.py,sha256=1boTp1w8DiCFLsPwNtlLrlIqGRpkasAmLid5bv2fgP4,2223
6
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
7
5
  brainstate/environ.py,sha256=G6r_rqfbofRbjFFalRu_DHaL7ruFTeLRXBQDXM6P-tQ,17477
@@ -17,24 +15,24 @@ brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9
17
15
  brainstate/augment/_eval_shape.py,sha256=dGlRVHOAZ9LSRZsFi1erxgEWHrnhBO3Kq3WW11-Hvng,3819
18
16
  brainstate/augment/_eval_shape_test.py,sha256=1nnxbU7hPRbZPQWNWbQ518pw-H7FGDKKnQpZGBY9uRI,1390
19
17
  brainstate/augment/_mapping.py,sha256=cpxzVGCEYnP5jPqrowYoPXciw_-QR2F3wggrRj1OCPc,21850
20
- brainstate/augment/_mapping_test.py,sha256=bwoRm0KVK53gpB_b2dWwJl4BQplH74RCEU1VyE3DN4M,6762
18
+ brainstate/augment/_mapping_test.py,sha256=TEAecjZmTSDCfxARgrzcDJ2dW1Yz_sCITmFiA9FGrhk,9455
21
19
  brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
22
20
  brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
23
21
  brainstate/compile/_ad_checkpoint.py,sha256=5zJ1ENeTU4FzRY_uNpr85NhKfuicMMjcIbhu6-bSM4k,9451
24
22
  brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
25
- brainstate/compile/_conditions.py,sha256=ocz6sDc7Xzabz2GnRsQmS6GDps-WP-OXUd0EZTTlG0k,10217
23
+ brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
26
24
  brainstate/compile/_conditions_test.py,sha256=s9LF6h9LvigvgxUIugTqvgCHBIU8TXS1Ar1OlIxXfrw,8389
27
25
  brainstate/compile/_error_if.py,sha256=TFvhqITKkRO9m30GdlUP4eEjJvLWQUhjkujXO9zvrWs,2689
28
26
  brainstate/compile/_error_if_test.py,sha256=SJmAfosVoGd4vhfFtb1IvjeFVW914bfTccCg6DoLWYk,1992
29
- brainstate/compile/_jit.py,sha256=bfEszNttEtE6npqHBam1_DBlRa39fE6qP6lGaWw2amA,13750
27
+ brainstate/compile/_jit.py,sha256=3mQ-RUFz35wceZyKE_MoR58OBL0RK_i6sHm4rWYzMLs,13698
30
28
  brainstate/compile/_jit_test.py,sha256=zD7kck9SQJGmUDolh9P4luKwQ21fBGje1Z4STTEXIuA,4135
31
- brainstate/compile/_loop_collect_return.py,sha256=1IRhsWFad5khM1DIBLOI-kdC4pJYsT3TGqG0KO6ftjQ,22782
29
+ brainstate/compile/_loop_collect_return.py,sha256=DybSBixeuxleKJV6n9FgVBDsUTmexzS0IdgWYRqp5cU,22940
32
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
33
- brainstate/compile/_loop_no_collection.py,sha256=2rSK20enkBMXPAbsCyb7PCICPNrgaSpl5jfumgWpxA0,7401
31
+ brainstate/compile/_loop_no_collection.py,sha256=0i31gdQ7sI-d6pvnh08ttUUwdAtpx4uoYhGuf_CyL9s,7343
34
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
35
- brainstate/compile/_make_jaxpr.py,sha256=9ZWYgW77M0YcrqmViyXSovuvQDoIa8hmSnhE7pBqDfU,32828
33
+ brainstate/compile/_make_jaxpr.py,sha256=S5O9KUB3bsxoKcfptlV0MRfKA__Ija37WxkakIRL3z0,33010
36
34
  brainstate/compile/_make_jaxpr_test.py,sha256=qJUtkyj50JQ6f4UJbOLhvRdkbNn3NSKibFL9jESdQkA,4279
37
- brainstate/compile/_progress_bar.py,sha256=LML4DjrLSIeGYJWLjqy6BnHSz03fu1gnjf-7kljP384,3824
35
+ brainstate/compile/_progress_bar.py,sha256=H544Oh10SiF5ccrKHM9ay7ZHigYIhNhSQGEKbDxRJgg,4485
38
36
  brainstate/compile/_unvmap.py,sha256=ewbLLNXiI_dBsEBaVzSS0BEXNol22sd9gMzk606lSkM,4139
39
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
40
38
  brainstate/event/__init__.py,sha256=wOBkq7kDg90M8Y9FuoXRlSEuu1ZzbIhCJ1dHeLqN6_Q,1194
@@ -52,7 +50,7 @@ brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,104
52
50
  brainstate/event/_xla_custom_op.py,sha256=QB4jz_fUEPF-efJCVKAxwx8U79AqdcKoEg2QrGwot8I,10864
53
51
  brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
54
52
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
55
- brainstate/functional/_activations.py,sha256=tsCGoV_35IyZ0mM6_4fV9v0-Vj3V9Qm55U8wSpZ_E4o,22180
53
+ brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
56
54
  brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
57
55
  brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
58
56
  brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
@@ -63,7 +61,7 @@ brainstate/graph/_graph_context_test.py,sha256=IYpjqbXwSFF65XL0ZbdPeC1jYyEHLpQVr
63
61
  brainstate/graph/_graph_convert.py,sha256=llSREtGQrIggkD0wmxUbYKuSveLW4ihDZME6Ab-mRTQ,9147
64
62
  brainstate/graph/_graph_node.py,sha256=BTuVlGgA2b82zNudjsN88QXuxfDcMvU2-kB64AkdQnY,8993
65
63
  brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
66
- brainstate/graph/_graph_operation.py,sha256=C_WDUwg_cY7mduW8lb8yRG3trJy9KCaRmMluXf003Ss,64006
64
+ brainstate/graph/_graph_operation.py,sha256=PupZeFWBR-OHbhdJcoqlvy2YqoIS9Ze4q0tz8HRy4f4,64166
67
65
  brainstate/graph/_graph_operation_test.py,sha256=ADyyuMk2xidEkkFNpGvUbvEtRmUj-tqOI4cF3eRuakM,24678
68
66
  brainstate/init/__init__.py,sha256=R1dHgub47o-WJM9QkFLc7x_Q7GsyaKKDtrRHTFPpC5g,1097
69
67
  brainstate/init/_base.py,sha256=B_NLS9aKNrvuj5NAlSgBbQTVev7IRvzcx8vH0J-Gq2w,1671
@@ -73,7 +71,7 @@ brainstate/init/_random_inits_test.py,sha256=lBL2RQdBSZ88Zqz4IMdbHJMvDi7ooZq6caC
73
71
  brainstate/init/_regular_inits.py,sha256=DmVMajugfyYFNUMzgFdDKMvbBu9hMWxkfDd-50uhoLg,3187
74
72
  brainstate/init/_regular_inits_test.py,sha256=tJl4aOkclllJIfKzJTbc0cfYCw2SoBsx8_G123RnqbU,1842
75
73
  brainstate/nn/__init__.py,sha256=rxURT8J1XfBn3Vh3Dx_WzVADWn9zVriIty5KZEG-x6o,1622
76
- brainstate/nn/_collective_ops.py,sha256=BzKKUxeBPtZojHqvOzgPkz9EFaVJUwQWmqR18mF2l38,6233
74
+ brainstate/nn/_collective_ops.py,sha256=sSjIIs1MvZA30XFFmK7iL1D_sCeh7hFd3PanCH6kgZo,6779
77
75
  brainstate/nn/_exp_euler.py,sha256=yjkfSllFxGWKEAlHo5AzBizzkFj6FEVDKmFV6E2g214,3521
78
76
  brainstate/nn/_exp_euler_test.py,sha256=clwRD8QR71k1jn6NrACMDEUcFMh0J9RTosoPnlYWUkw,1242
79
77
  brainstate/nn/_module.py,sha256=HDLPvLfB7jat2VT3gBu0MxA7vfzK7xgowemitHX8Cgo,10835
@@ -82,7 +80,7 @@ brainstate/nn/metrics.py,sha256=iupHjSRTHYY-HmEPBC4tXWrZfF4zh1ek2NwSAA0gnwE,1473
82
80
  brainstate/nn/_dyn_impl/__init__.py,sha256=Oazar7h89dp1WA2Vx4Tj7gCBhxJKH4LAUEABkBEG7vU,1462
83
81
  brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2K8k5FAcf3Pa5N8,10927
84
82
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
85
- brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=BP-ko0FyjWZopuUhAy3Ot3wWRQlGcpumWJpKrQakqok,11869
83
+ brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=IHy6IsGjWpKZ8NLq4X7PaRwx3tpO2HRZNppCWM2fe4I,11862
86
84
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
87
85
  brainstate/nn/_dyn_impl/_inputs.py,sha256=6eZKnkmrM0Gog2fpSKjSnwnQvhbFYhG4q9Vuo-GH2LI,5050
88
86
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
@@ -98,7 +96,7 @@ brainstate/nn/_dynamics/_state_delay.py,sha256=nZYGmVKmQvAQu-W4YOUFH1gnr-ZS3rg_G
98
96
  brainstate/nn/_dynamics/_synouts.py,sha256=9TGAc-nVa50th7KKn4oKLbro-4W4rwxYvp-eu7ksAIE,4491
99
97
  brainstate/nn/_dynamics/_synouts_test.py,sha256=V_jDswRN4VvEXD-2yJO3VA1TALgX0HK6oPBQiUntOWc,2266
100
98
  brainstate/nn/_elementwise/__init__.py,sha256=PK8oq1K_EG2941AiUyLxCWoRdWvMO3yt8ZJbw3Lkhu8,935
101
- brainstate/nn/_elementwise/_dropout.py,sha256=lXy8ki5mnmhU9lacJGesy7mums0qH1pplXD6BGGdw-I,17338
99
+ brainstate/nn/_elementwise/_dropout.py,sha256=0Ebo-2y1VswvBqZ7sCA0SEUm37y49EUsef8oiSFpYGk,17759
102
100
  brainstate/nn/_elementwise/_dropout_test.py,sha256=Qn7xqZOyZMPCGF6tFjTiPId0yELOXjSsW5-hgihP3fE,4383
103
101
  brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
104
102
  brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
@@ -108,7 +106,7 @@ brainstate/nn/_interaction/_conv_test.py,sha256=fHXRFYnDghFiKre63RqMwIE_gbPKdK34
108
106
  brainstate/nn/_interaction/_embedding.py,sha256=iK0I1ExKWFa_QzV9UDGj32Ljsmdr1g_LlAtMcusebxU,2187
109
107
  brainstate/nn/_interaction/_linear.py,sha256=bjiWGJCe81ugQQOykwjWlLW5uhe0CHWwkPA20a4n5YQ,21340
110
108
  brainstate/nn/_interaction/_linear_test.py,sha256=KlvFZA0rpyaspf4LT4K7u-RR5jCEB_q1WReqAw9sFcU,1274
111
- brainstate/nn/_interaction/_normalizations.py,sha256=pdVzuMMmKP49pFbLxZJ6A9m0_nBEMEvOUN4Y1lWnhf0,37004
109
+ brainstate/nn/_interaction/_normalizations.py,sha256=7YDzkmO_iqd70fH_wawb60Bu8eGOdvZq23emP-b68Hc,37440
112
110
  brainstate/nn/_interaction/_normalizations_test.py,sha256=2p1Jf8nA999VYGWbvOZfKYlKk6UmL0vaEB76xkXxkXw,2438
113
111
  brainstate/nn/_interaction/_poolings.py,sha256=LpwuyeNBVCaVFW7zWc7E-vvlYqx54h46Br5XT6zd_94,47020
114
112
  brainstate/nn/_interaction/_poolings_test.py,sha256=wmd5PngZ3E9tNyF3s0xk-DoDR5yFqpTi9A6nbNoIqn4,7429
@@ -136,10 +134,10 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
136
134
  brainstate/util/_pretty_repr.py,sha256=NYEBCo2iz9Potx-IR7uZZzt2aLQW_94vH79fGusiC2A,5737
137
135
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
138
136
  brainstate/util/_struct.py,sha256=0exv0oOiSt1hmx20Y4J2-pCGtCTx13WcAlEYSBkyung,17640
139
- brainstate/util/_tracers.py,sha256=-sX76GJRThdSpDJBejAIzDdBbVhmH6kb-1WoDJVI7V0,2556
137
+ brainstate/util/_tracers.py,sha256=0r5T4nhxMzI79NtqroqitsdMT4YfpgV5RdYJLS5uJ0w,2285
140
138
  brainstate/util/_visualization.py,sha256=n4ZVz10z7VBqA0cKO6vyHwEMprWJgPeEqtITzDMai2Y,1519
141
- brainstate-0.1.0.post20241129.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
142
- brainstate-0.1.0.post20241129.dist-info/METADATA,sha256=s2ks57prKGe9kPi_goPGwYVGr40TrgrEFjtC0W8ff6Q,3401
143
- brainstate-0.1.0.post20241129.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
144
- brainstate-0.1.0.post20241129.dist-info/top_level.txt,sha256=MVkn5SZk0qis32_AGxU2Tg1xADfj2IgCNS25CQD7_ng,21
145
- brainstate-0.1.0.post20241129.dist-info/RECORD,,
139
+ brainstate-0.1.0.post20241210.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
140
+ brainstate-0.1.0.post20241210.dist-info/METADATA,sha256=E6AATarjpwssXflLfA-OCkxFxqZxqJNxHZteO6UWMhw,3401
141
+ brainstate-0.1.0.post20241210.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
142
+ brainstate-0.1.0.post20241210.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
143
+ brainstate-0.1.0.post20241210.dist-info/RECORD,,
benchmark/COBA_2005.py DELETED
@@ -1,125 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- #
17
- # Implementation of the paper:
18
- #
19
- # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
- # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
- #
22
- # which is based on the balanced network proposed by:
23
- #
24
- # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
- #
26
- import os
27
- import sys
28
-
29
- sys.path.append('../')
30
- os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
31
- os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
32
-
33
-
34
- import jax
35
- import brainunit as u
36
- import time
37
- import brainstate as bst
38
-
39
-
40
- class EINet(bst.nn.DynamicsGroup):
41
- def __init__(self, scale):
42
- super().__init__()
43
- self.n_exc = int(3200 * scale)
44
- self.n_inh = int(800 * scale)
45
- self.num = self.n_exc + self.n_inh
46
- self.N = bst.nn.LIFRef(self.num, V_rest=-60. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
47
- tau=20. * u.ms, tau_ref=5. * u.ms,
48
- V_initializer=bst.init.Normal(-55., 2., unit=u.mV))
49
- self.E = bst.nn.AlignPostProj(
50
- comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=0.6 * u.mS),
51
- syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
52
- out=bst.nn.COBA.desc(E=0. * u.mV),
53
- post=self.N
54
- )
55
- self.I = bst.nn.AlignPostProj(
56
- comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=6.7 * u.mS),
57
- syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
58
- out=bst.nn.COBA.desc(E=-80. * u.mV),
59
- post=self.N
60
- )
61
-
62
- def init_state(self, *args, **kwargs):
63
- self.rate = bst.ShortTermState(u.math.zeros(self.num))
64
-
65
- def update(self, t, inp):
66
- with bst.environ.context(t=t):
67
- spk = self.N.get_spike() != 0.
68
- self.E(spk[:self.n_exc])
69
- self.I(spk[self.n_exc:])
70
- self.N(inp)
71
- self.rate.value += self.N.get_spike()
72
-
73
-
74
- @bst.compile.jit(static_argnums=0)
75
- def run(scale: float):
76
- # network
77
- net = EINet(scale)
78
- bst.nn.init_all_states(net)
79
-
80
- duration = 1e4 * u.ms
81
- # simulation
82
- with bst.environ.context(dt=0.1 * u.ms):
83
- times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
84
- bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times)
85
-
86
- return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
87
-
88
-
89
- for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
90
- jax.block_until_ready(run(s))
91
-
92
- t0 = time.time()
93
- n, rate = jax.block_until_ready(run(s))
94
- t1 = time.time()
95
- print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
96
-
97
-
98
- # A6000 NVIDIA GPU
99
-
100
- # scale=1, size=4000, time = 2.659956455230713 s, firing rate = 50.62445068359375 Hz
101
- # scale=2, size=8000, time = 2.7318649291992188 s, firing rate = 50.613040924072266 Hz
102
- # scale=4, size=16000, time = 2.807222604751587 s, firing rate = 50.60573959350586 Hz
103
- # scale=6, size=24000, time = 3.026782512664795 s, firing rate = 50.60918045043945 Hz
104
- # scale=8, size=32000, time = 3.1258811950683594 s, firing rate = 50.607574462890625 Hz
105
- # scale=10, size=40000, time = 3.172346353530884 s, firing rate = 50.60942840576172 Hz
106
- # scale=20, size=80000, time = 3.751189947128296 s, firing rate = 50.612369537353516 Hz
107
- # scale=40, size=160000, time = 5.0217814445495605 s, firing rate = 50.617958068847656 Hz
108
- # scale=60, size=240000, time = 7.002646207809448 s, firing rate = 50.61948776245117 Hz
109
- # scale=80, size=320000, time = 9.384576320648193 s, firing rate = 50.618499755859375 Hz
110
- # scale=100, size=400000, time = 11.69654369354248 s, firing rate = 50.61605453491211 Hz
111
-
112
-
113
- # AMD Ryzen 7 7840HS
114
-
115
- # scale=1, size=4000, time = 4.436027526855469 s, firing rate = 50.6119270324707 Hz
116
- # scale=2, size=8000, time = 8.349745273590088 s, firing rate = 50.612266540527344 Hz
117
- # scale=4, size=16000, time = 16.39163303375244 s, firing rate = 50.61349105834961 Hz
118
- # scale=6, size=24000, time = 15.725558042526245 s, firing rate = 50.6125602722168 Hz
119
- # scale=8, size=32000, time = 21.31995177268982 s, firing rate = 50.61244583129883 Hz
120
- # scale=10, size=40000, time = 27.811061143875122 s, firing rate = 50.61423873901367 Hz
121
- # scale=20, size=80000, time = 45.54235219955444 s, firing rate = 50.61320877075195 Hz
122
- # scale=40, size=160000, time = 82.22228026390076 s, firing rate = 50.61309814453125 Hz
123
- # scale=60, size=240000, time = 125.44037556648254 s, firing rate = 50.613094329833984 Hz
124
- # scale=80, size=320000, time = 171.20458459854126 s, firing rate = 50.613365173339844 Hz
125
- # scale=100, size=400000, time = 215.4547393321991 s, firing rate = 50.6129150390625 Hz
benchmark/CUBA_2005.py DELETED
@@ -1,149 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- #
17
- # Implementation of the paper:
18
- #
19
- # - Brette, R., Rudolph, M., Carnevale, T., Hines, M., Beeman, D., Bower, J. M., et al. (2007),
20
- # Simulation of networks of spiking neurons: a review of tools and strategies., J. Comput. Neurosci., 23, 3, 349–98
21
- #
22
- # which is based on the balanced network proposed by:
23
- #
24
- # - Vogels, T. P. and Abbott, L. F. (2005), Signal propagation and logic gating in networks of integrate-and-fire neurons., J. Neurosci., 25, 46, 10786–95
25
- #
26
-
27
- import os
28
- import sys
29
-
30
- sys.path.append('../')
31
- os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.99'
32
- os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
33
-
34
-
35
- import jax
36
- import time
37
-
38
- import brainunit as u
39
-
40
- import brainstate as bst
41
-
42
-
43
-
44
- class FixedProb(bst.nn.Module):
45
- def __init__(self, n_pre, n_post, prob, weight):
46
- super().__init__()
47
- self.prob = prob
48
- self.weight = weight
49
- self.n_pre = n_pre
50
- self.n_post = n_post
51
-
52
- self.mask = bst.random.rand(n_pre, n_post) < prob
53
-
54
- def update(self, x):
55
- return (x @ self.mask) * self.weight
56
-
57
-
58
- class EINet(bst.nn.DynamicsGroup):
59
- def __init__(self, scale=1.0):
60
- super().__init__()
61
- self.n_exc = int(3200 * scale)
62
- self.n_inh = int(800 * scale)
63
- self.num = self.n_exc + self.n_inh
64
- self.N = bst.nn.LIFRef(
65
- self.num, V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
66
- tau=20. * u.ms, tau_ref=5. * u.ms,
67
- V_initializer=bst.init.Normal(-55., 2., unit=u.mV)
68
- )
69
- self.E = bst.nn.AlignPostProj(
70
- comm=bst.event.FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
71
- # comm=FixedProb(self.n_exc, self.num, prob=80 / self.num, weight=1.62 * u.mS),
72
- syn=bst.nn.Expon.desc(self.num, tau=5. * u.ms),
73
- out=bst.nn.CUBA.desc(scale=u.volt),
74
- post=self.N
75
- )
76
- self.I = bst.nn.AlignPostProj(
77
- comm=bst.event.FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
78
- # comm=FixedProb(self.n_inh, self.num, prob=80 / self.num, weight=-9.0 * u.mS),
79
- syn=bst.nn.Expon.desc(self.num, tau=10. * u.ms),
80
- out=bst.nn.CUBA.desc(scale=u.volt),
81
- post=self.N
82
- )
83
-
84
- def init_state(self, *args, **kwargs):
85
- self.rate = bst.ShortTermState(u.math.zeros(self.num))
86
-
87
- def update(self, t, inp):
88
- with bst.environ.context(t=t):
89
- spk = self.N.get_spike()
90
- self.E(spk[:self.n_exc])
91
- self.I(spk[self.n_exc:])
92
- self.N(inp)
93
- self.rate.value += self.N.get_spike()
94
-
95
-
96
- @bst.compile.jit(static_argnums=0)
97
- def run(scale: float):
98
- # network
99
- net = EINet(scale)
100
- bst.nn.init_all_states(net)
101
-
102
- duration = 1e4 * u.ms
103
- # simulation
104
- with bst.environ.context(dt=0.1 * u.ms):
105
- times = u.math.arange(0. * u.ms, duration, bst.environ.get_dt())
106
- bst.compile.for_loop(lambda t: net.update(t, 20. * u.mA), times,
107
- # pbar=bst.compile.ProgressBar(100)
108
- )
109
-
110
- return net.num, net.rate.value.sum() / net.num / duration.to_decimal(u.second)
111
-
112
-
113
- for s in [1, 2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
114
- jax.block_until_ready(run(s))
115
-
116
- t0 = time.time()
117
- n, rate = jax.block_until_ready(run(s))
118
- t1 = time.time()
119
- print(f'scale={s}, size={n}, time = {t1 - t0} s, firing rate = {rate} Hz')
120
-
121
-
122
- # A6000 NVIDIA GPU
123
-
124
- # scale=1, size=4000, time = 2.6354849338531494 s, firing rate = 24.982027053833008 Hz
125
- # scale=2, size=8000, time = 2.6781561374664307 s, firing rate = 23.719463348388672 Hz
126
- # scale=4, size=16000, time = 2.7448785305023193 s, firing rate = 24.592931747436523 Hz
127
- # scale=6, size=24000, time = 2.8237478733062744 s, firing rate = 24.159996032714844 Hz
128
- # scale=8, size=32000, time = 2.9344418048858643 s, firing rate = 24.956790924072266 Hz
129
- # scale=10, size=40000, time = 3.042517900466919 s, firing rate = 23.644424438476562 Hz
130
- # scale=20, size=80000, time = 3.6727631092071533 s, firing rate = 24.226743698120117 Hz
131
- # scale=40, size=160000, time = 4.857396602630615 s, firing rate = 24.329742431640625 Hz
132
- # scale=60, size=240000, time = 6.812030792236328 s, firing rate = 24.370006561279297 Hz
133
- # scale=80, size=320000, time = 9.227966547012329 s, firing rate = 24.41067886352539 Hz
134
- # scale=100, size=400000, time = 11.405697584152222 s, firing rate = 24.32524871826172 Hz
135
-
136
-
137
- # AMD Ryzen 7 7840HS
138
-
139
- # scale=1, size=4000, time = 1.1661601066589355 s, firing rate = 22.438201904296875 Hz
140
- # scale=2, size=8000, time = 3.3255884647369385 s, firing rate = 23.868364334106445 Hz
141
- # scale=4, size=16000, time = 6.950139999389648 s, firing rate = 24.21693229675293 Hz
142
- # scale=6, size=24000, time = 10.011993169784546 s, firing rate = 24.240270614624023 Hz
143
- # scale=8, size=32000, time = 13.027734518051147 s, firing rate = 24.753198623657227 Hz
144
- # scale=10, size=40000, time = 16.449942350387573 s, firing rate = 24.7176570892334 Hz
145
- # scale=20, size=80000, time = 30.754598140716553 s, firing rate = 24.119956970214844 Hz
146
- # scale=40, size=160000, time = 63.6387836933136 s, firing rate = 24.72784996032715 Hz
147
- # scale=60, size=240000, time = 78.58532166481018 s, firing rate = 24.402742385864258 Hz
148
- # scale=80, size=320000, time = 102.4250214099884 s, firing rate = 24.59092140197754 Hz
149
- # scale=100, size=400000, time = 145.35173273086548 s, firing rate = 24.33751106262207 Hz