brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/augment/__init__.py +10 -20
  3. brainstate/compile/__init__.py +18 -37
  4. brainstate/compile/_make_jaxpr.py +9 -2
  5. brainstate/compile/_make_jaxpr_test.py +10 -6
  6. brainstate/compile/_progress_bar.py +49 -6
  7. brainstate/compile/_unvmap.py +3 -3
  8. brainstate/graph/__init__.py +12 -12
  9. brainstate/nn/_dyn_impl/_inputs.py +4 -2
  10. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  11. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
  12. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
  13. brainstate/event/__init__.py +0 -27
  14. brainstate/event/_csr.py +0 -1149
  15. brainstate/event/_csr_benchmark.py +0 -14
  16. brainstate/event/_csr_mv.py +0 -303
  17. brainstate/event/_csr_test.py +0 -277
  18. brainstate/event/_fixedprob_mv.py +0 -730
  19. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  20. brainstate/event/_fixedprob_mv_test.py +0 -132
  21. brainstate/event/_linear_mv.py +0 -359
  22. brainstate/event/_linear_mv_benckmark.py +0 -82
  23. brainstate/event/_linear_mv_test.py +0 -117
  24. brainstate/event/_misc.py +0 -34
  25. brainstate/event/_xla_custom_op.py +0 -317
  26. brainstate/event/_xla_custom_op_test.py +0 -55
  27. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -22,7 +22,6 @@ __version__ = "0.1.0"
22
22
  from . import augment
23
23
  from . import compile
24
24
  from . import environ
25
- from . import event
26
25
  from . import functional
27
26
  from . import graph
28
27
  from . import init
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
39
38
 
40
39
  __all__ = (
41
40
  [
42
- 'augment', 'compile', 'environ', 'event', 'functional',
41
+ 'augment', 'compile', 'environ', 'functional',
43
42
  'graph', 'init', 'mixin', 'nn', 'optim', 'random',
44
43
  'surrogate', 'typing', 'util',
45
44
  # deprecated
@@ -17,24 +17,14 @@
17
17
  This module includes transformations for augmenting the functionalities of JAX code.
18
18
  """
19
19
 
20
- from ._autograd import *
21
- from ._autograd import __all__ as _autograd_all
22
- from ._eval_shape import *
23
- from ._eval_shape import __all__ as _eval_shape_all
24
- from ._mapping import *
25
- from ._mapping import __all__ as _mapping_all
26
- from ._random import *
27
- from ._random import __all__ as _random_all
20
+ from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
21
+ from ._eval_shape import abstract_init
22
+ from ._mapping import vmap, pmap, map
23
+ from ._random import restore_rngs
28
24
 
29
- __all__ = (
30
- _eval_shape_all
31
- + _autograd_all
32
- + _mapping_all
33
- + _random_all
34
- )
35
- del (
36
- _eval_shape_all,
37
- _autograd_all,
38
- _mapping_all,
39
- _random_all
40
- )
25
+ __all__ = [
26
+ 'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
27
+ 'abstract_init',
28
+ 'vmap', 'pmap', 'map',
29
+ 'restore_rngs',
30
+ ]
@@ -17,41 +17,22 @@
17
17
  This module contains the functions for the compilation of JAX code.
18
18
  """
19
19
 
20
- from ._ad_checkpoint import *
21
- from ._ad_checkpoint import __all__ as _ad_checkpoint_all
22
- from ._conditions import *
23
- from ._conditions import __all__ as _conditions_all
24
- from ._error_if import *
25
- from ._error_if import __all__ as _jit_error_all
26
- from ._jit import *
27
- from ._jit import __all__ as _jit_all
28
- from ._loop_collect_return import *
29
- from ._loop_collect_return import __all__ as _loops_collection
30
- from ._loop_no_collection import *
31
- from ._loop_no_collection import __all__ as _loops_no_collection
32
- from ._make_jaxpr import *
33
- from ._make_jaxpr import __all__ as _make_jaxpr_all
34
- from ._progress_bar import *
35
- from ._progress_bar import __all__ as _progress_bar_all
20
+ from ._ad_checkpoint import checkpoint, remat
21
+ from ._conditions import cond, switch, ifelse
22
+ from ._error_if import jit_error_if
23
+ from ._jit import jit
24
+ from ._loop_collect_return import scan, checkpointed_scan, for_loop, checkpointed_for_loop
25
+ from ._loop_no_collection import while_loop, bounded_while_loop
26
+ from ._make_jaxpr import StatefulFunction, make_jaxpr
27
+ from ._progress_bar import ProgressBar
36
28
 
37
- __all__ = (
38
- _jit_error_all
39
- + _conditions_all
40
- + _make_jaxpr_all
41
- + _jit_all
42
- + _progress_bar_all
43
- + _loops_collection
44
- + _loops_no_collection
45
- + _ad_checkpoint_all
46
- )
47
-
48
- del (
49
- _jit_error_all,
50
- _conditions_all,
51
- _loops_collection,
52
- _make_jaxpr_all,
53
- _jit_all,
54
- _progress_bar_all,
55
- _loops_no_collection,
56
- _ad_checkpoint_all
57
- )
29
+ __all__ = [
30
+ 'checkpoint', 'remat',
31
+ 'cond', 'switch', 'ifelse',
32
+ 'jit_error_if',
33
+ 'jit',
34
+ 'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
35
+ 'while_loop', 'bounded_while_loop',
36
+ 'StatefulFunction', 'make_jaxpr',
37
+ 'ProgressBar',
38
+ ]
@@ -73,6 +73,13 @@ from brainstate._state import State, StateTraceStack
73
73
  from brainstate._utils import set_module_as
74
74
  from brainstate.typing import PyTree
75
75
 
76
+
77
+ if jax.__version_info__ < (0, 4, 38):
78
+ from jax.core import ClosedJaxpr
79
+ else:
80
+ from jax.extend.core import ClosedJaxpr
81
+
82
+
76
83
  AxisName = Hashable
77
84
 
78
85
  __all__ = [
@@ -660,7 +667,7 @@ def _make_jaxpr(
660
667
  axis_env: Sequence[tuple[AxisName, int]] | None = None,
661
668
  return_shape: bool = False,
662
669
  abstracted_axes: Any | None = None,
663
- ) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
670
+ ) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
664
671
  """Creates a function that produces its jaxpr given example args.
665
672
 
666
673
  Args:
@@ -723,7 +730,7 @@ def _make_jaxpr(
723
730
  def _abstractify(args, kwargs):
724
731
  flat_args, in_tree = jax.tree.flatten((args, kwargs))
725
732
  if abstracted_axes is None:
726
- return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
733
+ return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
727
734
  else:
728
735
  axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
729
736
  in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
@@ -18,12 +18,16 @@ from __future__ import annotations
18
18
  import unittest
19
19
 
20
20
  import jax
21
- import jax.extend as je
22
21
  import jax.numpy as jnp
23
22
  import pytest
24
23
 
25
24
  import brainstate as bst
26
25
 
26
+ if jax.__version_info__ < (0, 4, 38):
27
+ from jax.core import jaxpr_as_fun
28
+ else:
29
+ from jax.extend.core import jaxpr_as_fun
30
+
27
31
 
28
32
  class TestMakeJaxpr(unittest.TestCase):
29
33
  def test_compar_jax_make_jaxpr(self):
@@ -85,7 +89,7 @@ class TestMakeJaxpr(unittest.TestCase):
85
89
  print(jaxpr)
86
90
  jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
87
91
  print(jaxpr)
88
- self.assertTrue(jnp.allclose(je.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
92
+ self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
93
  f3(jnp.zeros(1))))
90
94
 
91
95
  def test_compar_jax_make_jaxpr2(self):
@@ -103,10 +107,10 @@ class TestMakeJaxpr(unittest.TestCase):
103
107
  print()
104
108
  print(jaxpr)
105
109
  print(states)
106
- print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
110
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
107
111
  jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
108
112
  print(jaxpr)
109
- print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
113
+ print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
110
114
 
111
115
  def test_compar_jax_make_jaxpr3(self):
112
116
  def fa(x):
@@ -116,10 +120,10 @@ class TestMakeJaxpr(unittest.TestCase):
116
120
  print()
117
121
  print(jaxpr)
118
122
  print(states)
119
- # print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
120
124
  jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
121
125
  print(jaxpr)
122
- # print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
126
+ # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
127
 
124
128
 
125
129
  def test_return_states():
@@ -35,6 +35,47 @@ Output = Any
35
35
  class ProgressBar(object):
36
36
  """
37
37
  A progress bar for tracking the progress of a jitted for-loop computation.
38
+
39
+ It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
40
+ and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
41
+ a for-loop.
42
+
43
+ The message displayed in the progress bar can be customized by the following two methods:
44
+
45
+ 1. By passing a string to the `desc` argument. For example:
46
+
47
+ .. code-block:: python
48
+
49
+ ProgressBar(desc="Running 1000 iterations")
50
+
51
+ 2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
52
+ function should take a dictionary as input and return a dictionary. The returned dictionary
53
+ will be used to format the string. For example:
54
+
55
+ .. code-block:: python
56
+
57
+ a = bst.State(1.)
58
+ def loop_fn(x):
59
+ a.value = x.value + 1.
60
+ return jnp.sum(x ** 2)
61
+
62
+ pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
63
+ lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
64
+
65
+ bst.compile.for_loop(loop_fn, xs, pbar=pbar)
66
+
67
+ In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
68
+ the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
69
+
70
+
71
+ Args:
72
+ freq: The frequency at which to print the progress bar. If not specified, the progress
73
+ bar will be printed every 5% of the total iterations.
74
+ count: The number of times to print the progress bar. If not specified, the progress
75
+ bar will be printed every 5% of the total iterations.
76
+ desc: A description of the progress bar. If not specified, a default message will be
77
+ displayed.
78
+ kwargs: Additional keyword arguments to pass to the progress bar.
38
79
  """
39
80
  __module__ = "brainstate.compile"
40
81
 
@@ -42,7 +83,7 @@ class ProgressBar(object):
42
83
  self,
43
84
  freq: Optional[int] = None,
44
85
  count: Optional[int] = None,
45
- desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
86
+ desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
46
87
  **kwargs
47
88
  ):
48
89
  # print rate
@@ -62,9 +103,12 @@ class ProgressBar(object):
62
103
 
63
104
  # description
64
105
  if desc is not None:
65
- assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
66
- assert isinstance(desc[0], str), 'Description should be a string.'
67
- assert callable(desc[1]), 'Description should be a callable.'
106
+ if isinstance(desc, str):
107
+ pass
108
+ else:
109
+ assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
110
+ assert isinstance(desc[0], str), 'Description should be a string.'
111
+ assert callable(desc[1]), 'Description should be a callable.'
68
112
  self.desc = desc
69
113
 
70
114
  # check if tqdm is installed
@@ -136,8 +180,7 @@ class ProgressBarRunner(object):
136
180
  self.tqdm_bars[0].close()
137
181
 
138
182
  def __call__(self, iter_num, **kwargs):
139
- data = dict(i=iter_num, **kwargs)
140
- data = dict() if isinstance(self.message, str) else self.message[1](data)
183
+ data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
141
184
  assert isinstance(data, dict), 'Description function should return a dictionary.'
142
185
 
143
186
  _ = jax.lax.cond(
@@ -78,7 +78,7 @@ mlir.register_lowering(
78
78
 
79
79
  # unvmap_any
80
80
 
81
- unvmap_any_p = jax.core.Primitive("unvmap_any")
81
+ unvmap_any_p = Primitive("unvmap_any")
82
82
 
83
83
 
84
84
  def unvmap_any(x):
@@ -109,7 +109,7 @@ mlir.register_lowering(
109
109
 
110
110
  # unvmap_max
111
111
 
112
- unvmap_max_p = jax.core.Primitive("unvmap_max")
112
+ unvmap_max_p = Primitive("unvmap_max")
113
113
 
114
114
 
115
115
  def unvmap_max(x):
@@ -156,7 +156,7 @@ def _without_vmap_batch(x, batch_axes):
156
156
  return _without_vmap(x), batching.not_mapped
157
157
 
158
158
 
159
- _no_vmap_prim = jax.core.Primitive('no_vmap')
159
+ _no_vmap_prim = Primitive('no_vmap')
160
160
  _no_vmap_prim.def_impl(_without_vmap_imp)
161
161
  _no_vmap_prim.def_abstract_eval(_without_vmap_abs)
162
162
  batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
@@ -14,16 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from ._graph_node import *
18
- from ._graph_node import __all__ as _graph_node__all__
19
- from ._graph_operation import *
20
- from ._graph_operation import __all__ as _graph_operation__all__
21
-
22
- __all__ = (
23
- _graph_node__all__ +
24
- _graph_operation__all__
25
- )
26
- del (
27
- _graph_node__all__,
28
- _graph_operation__all__
17
+ from ._graph_node import Node, Dict, List, Sequential
18
+ from ._graph_operation import (
19
+ pop_states, nodes, states, treefy_states, update_states, flatten, unflatten,
20
+ treefy_split, treefy_merge, iter_leaf, iter_node, clone, graphdef,
21
+ call, RefMap, GraphDef, NodeRef, NodeDef
29
22
  )
23
+
24
+ __all__ = [
25
+ 'Node', 'Dict', 'List', 'Sequential',
26
+ 'pop_states', 'nodes', 'states', 'treefy_states', 'update_states', 'flatten', 'unflatten',
27
+ 'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef',
28
+ 'call', 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef',
29
+ ]
@@ -224,7 +224,8 @@ def poisson_input(
224
224
  weight = maybe_state(weight)
225
225
 
226
226
  assert isinstance(target, State), 'The target must be a State.'
227
- p = (freq * environ.get_dt()).to_decimal()
227
+ p = freq * environ.get_dt()
228
+ p = p.to_decimal() if isinstance(p, u.Quantity) else p
228
229
  a = num_input * p
229
230
  b = num_input * (1 - p)
230
231
  tar_val = target.value
@@ -291,7 +292,8 @@ def poisson_input(
291
292
 
292
293
  # update target variable
293
294
  target.value = jax.tree.map(
294
- lambda x: x * weight,
295
+ lambda tar, x: tar + x * weight,
296
+ target.value,
295
297
  inp,
296
298
  is_leaf=u.math.is_quantity
297
299
  )
@@ -28,7 +28,7 @@ class TestDropout(unittest.TestCase):
28
28
  dropout_layer = bst.nn.Dropout(0.5)
29
29
 
30
30
  # Input data
31
- input_data = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
31
+ input_data = np.arange(20)
32
32
 
33
33
  with bst.environ.context(fit=True):
34
34
  # Apply dropout
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250120
3
+ Version: 0.1.0.post20250127
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,4 +1,4 @@
1
- brainstate/__init__.py,sha256=A-QKdOvSalsCMxgk80Iz6_xMiUin6con6JaONHfciSY,1526
1
+ brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
2
2
  brainstate/_state.py,sha256=GZ46liHZSHbAHQEuELvOeoJ27P9xiZDz06G2AASjAjA,29142
3
3
  brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
4
4
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
@@ -9,7 +9,7 @@ brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
9
9
  brainstate/surrogate.py,sha256=t4SzVwUVMAPtC-O1vFbuE9F4265wgAv7ud77ufIJsuk,48464
10
10
  brainstate/transform.py,sha256=cxbymTlJ6uHvJWEEYXzFUkAySs_TbUTHakt0NQgWJ3s,808
11
11
  brainstate/typing.py,sha256=Qh-LBzm6oG4rSXv4V5qB8SNYcoOR7bASoK_iQxnlafk,10467
12
- brainstate/augment/__init__.py,sha256=BtXIBel7GbttmfBX6grxOxl0IiOJxLEa7qCGAXumamE,1286
12
+ brainstate/augment/__init__.py,sha256=zGPq1eTB_56GRCNC9TiPLKTw07PA2O0OCi7bgjYIrY4,1193
13
13
  brainstate/augment/_autograd.py,sha256=o9ivoEY7BmtdM1XmzdMmeRXpj6Tvn5xNB8LSGp2HKC8,25238
14
14
  brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9ptJAGKo,44060
15
15
  brainstate/augment/_eval_shape.py,sha256=ObCgsZ704kLduB1dbjJZh5nVQYEkLR5ebK74V5NV42k,3892
@@ -17,7 +17,7 @@ brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_
17
17
  brainstate/augment/_mapping.py,sha256=nU6Y7fSnYXyQSILXU2QT-O73Fm3pnwOmgUoDaHqjve8,21544
18
18
  brainstate/augment/_mapping_test.py,sha256=_KFhE3CXItwpbZ1gJfrDu3yUtX0YbfPUuHJG_G_BXEs,8963
19
19
  brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
20
- brainstate/compile/__init__.py,sha256=qZZIYoyEl51IFkFu-Hb-bP3PAEHo94HlTDf57P2ze08,1858
20
+ brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
21
21
  brainstate/compile/_ad_checkpoint.py,sha256=K6I4vnznDsqC9cUeCnez9UdV9r_toGA3zHezoHLA6mI,9377
22
22
  brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
23
23
  brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
@@ -30,32 +30,18 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
30
30
  brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
31
31
  brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
32
32
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=DQf_80w3p0wi2Gb9P6_tLMJ0Oadgyr_jWkVjus0MSjw,33205
34
- brainstate/compile/_make_jaxpr_test.py,sha256=3XaX8LUuG6UjolcD83qDVo5odf8FCDppdr9Q6V0NBs4,4303
35
- brainstate/compile/_progress_bar.py,sha256=0oVlZ4kW_ZMciJjOR_ebj3PNe_XkCMkoQpv-HUUdoF0,5554
36
- brainstate/compile/_unvmap.py,sha256=EY4rbqCzzPOiaRwpWTiyBwb5dVkYFnacHhBZUZObxPI,4255
33
+ brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
34
+ brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
35
+ brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
36
+ brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
37
37
  brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
38
- brainstate/event/__init__.py,sha256=gSEem-1oTHgy99Mjm3uumTXVd93tLVl0c4dUgRpoifk,895
39
- brainstate/event/_csr.py,sha256=PYKw8CGNgQ24MxQDoeBZTrPuC7Z-GetXQld9KiTbNYw,40063
40
- brainstate/event/_csr_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
41
- brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,10266
42
- brainstate/event/_csr_test.py,sha256=_iXwUFq90GU7npVOUnlI4NA27RJ8zyCZBxe7NDH803o,9533
43
- brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
44
- brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
45
- brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
46
- brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
47
- brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
48
- brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
49
- brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
50
- brainstate/event/_xla_custom_op.py,sha256=wF_nKgLUv1IGd8OY89MYqIvyZITl8UcrVysJWFugJxY,11093
51
- brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
52
38
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
53
39
  brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
54
40
  brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
55
41
  brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
56
42
  brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
57
43
  brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
58
- brainstate/graph/__init__.py,sha256=fyvQMlAUY3QYTzvDzz5TDoWS2XQwZ6P3ic6BtysZyHM,1026
44
+ brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
59
45
  brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
60
46
  brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
61
47
  brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
@@ -79,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
79
65
  brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
80
66
  brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
81
67
  brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
82
- brainstate/nn/_dyn_impl/_inputs.py,sha256=UNoGxKIKXwPnhelljDowqAWlV6ds7aBBkEbvdy2oDI4,11302
68
+ brainstate/nn/_dyn_impl/_inputs.py,sha256=QOUpAb2YJOE78uAvIS8Ep6MFcQHV-V6uRwmYvk5p9bk,11385
83
69
  brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
84
70
  brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
85
71
  brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
@@ -94,7 +80,7 @@ brainstate/nn/_dynamics/_synouts.py,sha256=9TGAc-nVa50th7KKn4oKLbro-4W4rwxYvp-eu
94
80
  brainstate/nn/_dynamics/_synouts_test.py,sha256=V_jDswRN4VvEXD-2yJO3VA1TALgX0HK6oPBQiUntOWc,2266
95
81
  brainstate/nn/_elementwise/__init__.py,sha256=PK8oq1K_EG2941AiUyLxCWoRdWvMO3yt8ZJbw3Lkhu8,935
96
82
  brainstate/nn/_elementwise/_dropout.py,sha256=0Ebo-2y1VswvBqZ7sCA0SEUm37y49EUsef8oiSFpYGk,17759
97
- brainstate/nn/_elementwise/_dropout_test.py,sha256=ZzNvjFf46NpKWGBIcT6O0lKOBGpxOStOAIGM4cE8LfE,4405
83
+ brainstate/nn/_elementwise/_dropout_test.py,sha256=k6aB5v8RYMoV5w8UV9UNSFhaQTV7woS6jx3SNESuCRs,4383
98
84
  brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
99
85
  brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
100
86
  brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
@@ -131,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
131
117
  brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
132
118
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
133
119
  brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
134
- brainstate-0.1.0.post20250120.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
135
- brainstate-0.1.0.post20250120.dist-info/METADATA,sha256=vUyr4XjiyAW68waFKMray9EEFHTqjqRp5GlqAG8LsKY,3585
136
- brainstate-0.1.0.post20250120.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
137
- brainstate-0.1.0.post20250120.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
138
- brainstate-0.1.0.post20250120.dist-info/RECORD,,
120
+ brainstate-0.1.0.post20250127.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
121
+ brainstate-0.1.0.post20250127.dist-info/METADATA,sha256=-j8gDuJN37nbSkijFC9TtG5UbflMZZWK-sftY7fkeps,3585
122
+ brainstate-0.1.0.post20250127.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
123
+ brainstate-0.1.0.post20250127.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
124
+ brainstate-0.1.0.post20250127.dist-info/RECORD,,
@@ -1,27 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from ._csr import *
18
- from ._fixedprob_mv import *
19
- from ._linear_mv import *
20
- from ._xla_custom_op import *
21
-
22
- __all__ = [
23
- 'FixedProb',
24
- 'XLACustomOp',
25
- 'CSR',
26
- 'CSC',
27
- ]