brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -2
- brainstate/augment/__init__.py +10 -20
- brainstate/compile/__init__.py +18 -37
- brainstate/compile/_make_jaxpr.py +9 -2
- brainstate/compile/_make_jaxpr_test.py +10 -6
- brainstate/compile/_progress_bar.py +49 -6
- brainstate/compile/_unvmap.py +3 -3
- brainstate/graph/__init__.py +12 -12
- brainstate/nn/_dyn_impl/_inputs.py +4 -2
- brainstate/nn/_elementwise/_dropout_test.py +1 -1
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
- brainstate/event/__init__.py +0 -27
- brainstate/event/_csr.py +0 -1149
- brainstate/event/_csr_benchmark.py +0 -14
- brainstate/event/_csr_mv.py +0 -303
- brainstate/event/_csr_test.py +0 -277
- brainstate/event/_fixedprob_mv.py +0 -730
- brainstate/event/_fixedprob_mv_benchmark.py +0 -128
- brainstate/event/_fixedprob_mv_test.py +0 -132
- brainstate/event/_linear_mv.py +0 -359
- brainstate/event/_linear_mv_benckmark.py +0 -82
- brainstate/event/_linear_mv_test.py +0 -117
- brainstate/event/_misc.py +0 -34
- brainstate/event/_xla_custom_op.py +0 -317
- brainstate/event/_xla_custom_op_test.py +0 -55
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
@@ -22,7 +22,6 @@ __version__ = "0.1.0"
|
|
22
22
|
from . import augment
|
23
23
|
from . import compile
|
24
24
|
from . import environ
|
25
|
-
from . import event
|
26
25
|
from . import functional
|
27
26
|
from . import graph
|
28
27
|
from . import init
|
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
|
|
39
38
|
|
40
39
|
__all__ = (
|
41
40
|
[
|
42
|
-
'augment', 'compile', 'environ', '
|
41
|
+
'augment', 'compile', 'environ', 'functional',
|
43
42
|
'graph', 'init', 'mixin', 'nn', 'optim', 'random',
|
44
43
|
'surrogate', 'typing', 'util',
|
45
44
|
# deprecated
|
brainstate/augment/__init__.py
CHANGED
@@ -17,24 +17,14 @@
|
|
17
17
|
This module includes transformations for augmenting the functionalities of JAX code.
|
18
18
|
"""
|
19
19
|
|
20
|
-
from ._autograd import
|
21
|
-
from .
|
22
|
-
from .
|
23
|
-
from .
|
24
|
-
from ._mapping import *
|
25
|
-
from ._mapping import __all__ as _mapping_all
|
26
|
-
from ._random import *
|
27
|
-
from ._random import __all__ as _random_all
|
20
|
+
from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
|
21
|
+
from ._eval_shape import abstract_init
|
22
|
+
from ._mapping import vmap, pmap, map
|
23
|
+
from ._random import restore_rngs
|
28
24
|
|
29
|
-
__all__ =
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
del (
|
36
|
-
_eval_shape_all,
|
37
|
-
_autograd_all,
|
38
|
-
_mapping_all,
|
39
|
-
_random_all
|
40
|
-
)
|
25
|
+
__all__ = [
|
26
|
+
'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
|
27
|
+
'abstract_init',
|
28
|
+
'vmap', 'pmap', 'map',
|
29
|
+
'restore_rngs',
|
30
|
+
]
|
brainstate/compile/__init__.py
CHANGED
@@ -17,41 +17,22 @@
|
|
17
17
|
This module contains the functions for the compilation of JAX code.
|
18
18
|
"""
|
19
19
|
|
20
|
-
from ._ad_checkpoint import
|
21
|
-
from .
|
22
|
-
from .
|
23
|
-
from .
|
24
|
-
from .
|
25
|
-
from .
|
26
|
-
from .
|
27
|
-
from .
|
28
|
-
from ._loop_collect_return import *
|
29
|
-
from ._loop_collect_return import __all__ as _loops_collection
|
30
|
-
from ._loop_no_collection import *
|
31
|
-
from ._loop_no_collection import __all__ as _loops_no_collection
|
32
|
-
from ._make_jaxpr import *
|
33
|
-
from ._make_jaxpr import __all__ as _make_jaxpr_all
|
34
|
-
from ._progress_bar import *
|
35
|
-
from ._progress_bar import __all__ as _progress_bar_all
|
20
|
+
from ._ad_checkpoint import checkpoint, remat
|
21
|
+
from ._conditions import cond, switch, ifelse
|
22
|
+
from ._error_if import jit_error_if
|
23
|
+
from ._jit import jit
|
24
|
+
from ._loop_collect_return import scan, checkpointed_scan, for_loop, checkpointed_for_loop
|
25
|
+
from ._loop_no_collection import while_loop, bounded_while_loop
|
26
|
+
from ._make_jaxpr import StatefulFunction, make_jaxpr
|
27
|
+
from ._progress_bar import ProgressBar
|
36
28
|
|
37
|
-
__all__ =
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
del (
|
49
|
-
_jit_error_all,
|
50
|
-
_conditions_all,
|
51
|
-
_loops_collection,
|
52
|
-
_make_jaxpr_all,
|
53
|
-
_jit_all,
|
54
|
-
_progress_bar_all,
|
55
|
-
_loops_no_collection,
|
56
|
-
_ad_checkpoint_all
|
57
|
-
)
|
29
|
+
__all__ = [
|
30
|
+
'checkpoint', 'remat',
|
31
|
+
'cond', 'switch', 'ifelse',
|
32
|
+
'jit_error_if',
|
33
|
+
'jit',
|
34
|
+
'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
|
35
|
+
'while_loop', 'bounded_while_loop',
|
36
|
+
'StatefulFunction', 'make_jaxpr',
|
37
|
+
'ProgressBar',
|
38
|
+
]
|
@@ -73,6 +73,13 @@ from brainstate._state import State, StateTraceStack
|
|
73
73
|
from brainstate._utils import set_module_as
|
74
74
|
from brainstate.typing import PyTree
|
75
75
|
|
76
|
+
|
77
|
+
if jax.__version_info__ < (0, 4, 38):
|
78
|
+
from jax.core import ClosedJaxpr
|
79
|
+
else:
|
80
|
+
from jax.extend.core import ClosedJaxpr
|
81
|
+
|
82
|
+
|
76
83
|
AxisName = Hashable
|
77
84
|
|
78
85
|
__all__ = [
|
@@ -660,7 +667,7 @@ def _make_jaxpr(
|
|
660
667
|
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
661
668
|
return_shape: bool = False,
|
662
669
|
abstracted_axes: Any | None = None,
|
663
|
-
) -> Callable[..., (
|
670
|
+
) -> Callable[..., (ClosedJaxpr | tuple[ClosedJaxpr, Any])]:
|
664
671
|
"""Creates a function that produces its jaxpr given example args.
|
665
672
|
|
666
673
|
Args:
|
@@ -723,7 +730,7 @@ def _make_jaxpr(
|
|
723
730
|
def _abstractify(args, kwargs):
|
724
731
|
flat_args, in_tree = jax.tree.flatten((args, kwargs))
|
725
732
|
if abstracted_axes is None:
|
726
|
-
return map(
|
733
|
+
return map(shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
|
727
734
|
else:
|
728
735
|
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
729
736
|
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
@@ -18,12 +18,16 @@ from __future__ import annotations
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import jax
|
21
|
-
import jax.extend as je
|
22
21
|
import jax.numpy as jnp
|
23
22
|
import pytest
|
24
23
|
|
25
24
|
import brainstate as bst
|
26
25
|
|
26
|
+
if jax.__version_info__ < (0, 4, 38):
|
27
|
+
from jax.core import jaxpr_as_fun
|
28
|
+
else:
|
29
|
+
from jax.extend.core import jaxpr_as_fun
|
30
|
+
|
27
31
|
|
28
32
|
class TestMakeJaxpr(unittest.TestCase):
|
29
33
|
def test_compar_jax_make_jaxpr(self):
|
@@ -85,7 +89,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
85
89
|
print(jaxpr)
|
86
90
|
jaxpr, _ = bst.compile.make_jaxpr(f3)(jnp.zeros(1))
|
87
91
|
print(jaxpr)
|
88
|
-
self.assertTrue(jnp.allclose(
|
92
|
+
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
89
93
|
f3(jnp.zeros(1))))
|
90
94
|
|
91
95
|
def test_compar_jax_make_jaxpr2(self):
|
@@ -103,10 +107,10 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
103
107
|
print()
|
104
108
|
print(jaxpr)
|
105
109
|
print(states)
|
106
|
-
print(
|
110
|
+
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
|
107
111
|
jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
|
108
112
|
print(jaxpr)
|
109
|
-
print(
|
113
|
+
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
110
114
|
|
111
115
|
def test_compar_jax_make_jaxpr3(self):
|
112
116
|
def fa(x):
|
@@ -116,10 +120,10 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
116
120
|
print()
|
117
121
|
print(jaxpr)
|
118
122
|
print(states)
|
119
|
-
# print(
|
123
|
+
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
120
124
|
jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
|
121
125
|
print(jaxpr)
|
122
|
-
# print(
|
126
|
+
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
123
127
|
|
124
128
|
|
125
129
|
def test_return_states():
|
@@ -35,6 +35,47 @@ Output = Any
|
|
35
35
|
class ProgressBar(object):
|
36
36
|
"""
|
37
37
|
A progress bar for tracking the progress of a jitted for-loop computation.
|
38
|
+
|
39
|
+
It can be used in :py:func:`for_loop`, :py:func:`checkpointed_for_loop`, :py:func:`scan`,
|
40
|
+
and :py:func:`checkpointed_scan` functions. Or any other jitted function that uses
|
41
|
+
a for-loop.
|
42
|
+
|
43
|
+
The message displayed in the progress bar can be customized by the following two methods:
|
44
|
+
|
45
|
+
1. By passing a string to the `desc` argument. For example:
|
46
|
+
|
47
|
+
.. code-block:: python
|
48
|
+
|
49
|
+
ProgressBar(desc="Running 1000 iterations")
|
50
|
+
|
51
|
+
2. By passing a tuple with a string and a callable function to the `desc` argument. The callable
|
52
|
+
function should take a dictionary as input and return a dictionary. The returned dictionary
|
53
|
+
will be used to format the string. For example:
|
54
|
+
|
55
|
+
.. code-block:: python
|
56
|
+
|
57
|
+
a = bst.State(1.)
|
58
|
+
def loop_fn(x):
|
59
|
+
a.value = x.value + 1.
|
60
|
+
return jnp.sum(x ** 2)
|
61
|
+
|
62
|
+
pbar = ProgressBar(desc=("Running {i} iterations, loss = {loss}",
|
63
|
+
lambda i_carray_y: {"i": i_carray_y["i"], "loss": i_carray_y["y"]}))
|
64
|
+
|
65
|
+
bst.compile.for_loop(loop_fn, xs, pbar=pbar)
|
66
|
+
|
67
|
+
In this example, ``"i"`` denotes the iteration number and ``"loss"`` is computed from the output,
|
68
|
+
the ``"carry"`` is the dynamic state in the loop, for example ``a.value`` in this case.
|
69
|
+
|
70
|
+
|
71
|
+
Args:
|
72
|
+
freq: The frequency at which to print the progress bar. If not specified, the progress
|
73
|
+
bar will be printed every 5% of the total iterations.
|
74
|
+
count: The number of times to print the progress bar. If not specified, the progress
|
75
|
+
bar will be printed every 5% of the total iterations.
|
76
|
+
desc: A description of the progress bar. If not specified, a default message will be
|
77
|
+
displayed.
|
78
|
+
kwargs: Additional keyword arguments to pass to the progress bar.
|
38
79
|
"""
|
39
80
|
__module__ = "brainstate.compile"
|
40
81
|
|
@@ -42,7 +83,7 @@ class ProgressBar(object):
|
|
42
83
|
self,
|
43
84
|
freq: Optional[int] = None,
|
44
85
|
count: Optional[int] = None,
|
45
|
-
desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
|
86
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]] | str] = None,
|
46
87
|
**kwargs
|
47
88
|
):
|
48
89
|
# print rate
|
@@ -62,9 +103,12 @@ class ProgressBar(object):
|
|
62
103
|
|
63
104
|
# description
|
64
105
|
if desc is not None:
|
65
|
-
|
66
|
-
|
67
|
-
|
106
|
+
if isinstance(desc, str):
|
107
|
+
pass
|
108
|
+
else:
|
109
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
110
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
111
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
68
112
|
self.desc = desc
|
69
113
|
|
70
114
|
# check if tqdm is installed
|
@@ -136,8 +180,7 @@ class ProgressBarRunner(object):
|
|
136
180
|
self.tqdm_bars[0].close()
|
137
181
|
|
138
182
|
def __call__(self, iter_num, **kwargs):
|
139
|
-
data = dict(i=iter_num, **kwargs)
|
140
|
-
data = dict() if isinstance(self.message, str) else self.message[1](data)
|
183
|
+
data = dict() if isinstance(self.message, str) else self.message[1](dict(i=iter_num, **kwargs))
|
141
184
|
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
142
185
|
|
143
186
|
_ = jax.lax.cond(
|
brainstate/compile/_unvmap.py
CHANGED
@@ -78,7 +78,7 @@ mlir.register_lowering(
|
|
78
78
|
|
79
79
|
# unvmap_any
|
80
80
|
|
81
|
-
unvmap_any_p =
|
81
|
+
unvmap_any_p = Primitive("unvmap_any")
|
82
82
|
|
83
83
|
|
84
84
|
def unvmap_any(x):
|
@@ -109,7 +109,7 @@ mlir.register_lowering(
|
|
109
109
|
|
110
110
|
# unvmap_max
|
111
111
|
|
112
|
-
unvmap_max_p =
|
112
|
+
unvmap_max_p = Primitive("unvmap_max")
|
113
113
|
|
114
114
|
|
115
115
|
def unvmap_max(x):
|
@@ -156,7 +156,7 @@ def _without_vmap_batch(x, batch_axes):
|
|
156
156
|
return _without_vmap(x), batching.not_mapped
|
157
157
|
|
158
158
|
|
159
|
-
_no_vmap_prim =
|
159
|
+
_no_vmap_prim = Primitive('no_vmap')
|
160
160
|
_no_vmap_prim.def_impl(_without_vmap_imp)
|
161
161
|
_no_vmap_prim.def_abstract_eval(_without_vmap_abs)
|
162
162
|
batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
|
brainstate/graph/__init__.py
CHANGED
@@ -14,16 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from ._graph_node import
|
18
|
-
from .
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
__all__ = (
|
23
|
-
_graph_node__all__ +
|
24
|
-
_graph_operation__all__
|
25
|
-
)
|
26
|
-
del (
|
27
|
-
_graph_node__all__,
|
28
|
-
_graph_operation__all__
|
17
|
+
from ._graph_node import Node, Dict, List, Sequential
|
18
|
+
from ._graph_operation import (
|
19
|
+
pop_states, nodes, states, treefy_states, update_states, flatten, unflatten,
|
20
|
+
treefy_split, treefy_merge, iter_leaf, iter_node, clone, graphdef,
|
21
|
+
call, RefMap, GraphDef, NodeRef, NodeDef
|
29
22
|
)
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'Node', 'Dict', 'List', 'Sequential',
|
26
|
+
'pop_states', 'nodes', 'states', 'treefy_states', 'update_states', 'flatten', 'unflatten',
|
27
|
+
'treefy_split', 'treefy_merge', 'iter_leaf', 'iter_node', 'clone', 'graphdef',
|
28
|
+
'call', 'RefMap', 'GraphDef', 'NodeRef', 'NodeDef',
|
29
|
+
]
|
@@ -224,7 +224,8 @@ def poisson_input(
|
|
224
224
|
weight = maybe_state(weight)
|
225
225
|
|
226
226
|
assert isinstance(target, State), 'The target must be a State.'
|
227
|
-
p =
|
227
|
+
p = freq * environ.get_dt()
|
228
|
+
p = p.to_decimal() if isinstance(p, u.Quantity) else p
|
228
229
|
a = num_input * p
|
229
230
|
b = num_input * (1 - p)
|
230
231
|
tar_val = target.value
|
@@ -291,7 +292,8 @@ def poisson_input(
|
|
291
292
|
|
292
293
|
# update target variable
|
293
294
|
target.value = jax.tree.map(
|
294
|
-
lambda x: x * weight,
|
295
|
+
lambda tar, x: tar + x * weight,
|
296
|
+
target.value,
|
295
297
|
inp,
|
296
298
|
is_leaf=u.math.is_quantity
|
297
299
|
)
|
{brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250127
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,4 +1,4 @@
|
|
1
|
-
brainstate/__init__.py,sha256=
|
1
|
+
brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
|
2
2
|
brainstate/_state.py,sha256=GZ46liHZSHbAHQEuELvOeoJ27P9xiZDz06G2AASjAjA,29142
|
3
3
|
brainstate/_state_test.py,sha256=rJUFRSXEqrrl4qANRewY9mnDlzSbtHwBIGeZ0ku-8Dg,1650
|
4
4
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
@@ -9,7 +9,7 @@ brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
|
9
9
|
brainstate/surrogate.py,sha256=t4SzVwUVMAPtC-O1vFbuE9F4265wgAv7ud77ufIJsuk,48464
|
10
10
|
brainstate/transform.py,sha256=cxbymTlJ6uHvJWEEYXzFUkAySs_TbUTHakt0NQgWJ3s,808
|
11
11
|
brainstate/typing.py,sha256=Qh-LBzm6oG4rSXv4V5qB8SNYcoOR7bASoK_iQxnlafk,10467
|
12
|
-
brainstate/augment/__init__.py,sha256=
|
12
|
+
brainstate/augment/__init__.py,sha256=zGPq1eTB_56GRCNC9TiPLKTw07PA2O0OCi7bgjYIrY4,1193
|
13
13
|
brainstate/augment/_autograd.py,sha256=o9ivoEY7BmtdM1XmzdMmeRXpj6Tvn5xNB8LSGp2HKC8,25238
|
14
14
|
brainstate/augment/_autograd_test.py,sha256=S2eEgrwTzdSi3u2nKE3u37WSThosLwx1WCP9ptJAGKo,44060
|
15
15
|
brainstate/augment/_eval_shape.py,sha256=ObCgsZ704kLduB1dbjJZh5nVQYEkLR5ebK74V5NV42k,3892
|
@@ -17,7 +17,7 @@ brainstate/augment/_eval_shape_test.py,sha256=LFOJx7CWltmRLXdGY175UebLwtEMz2CzJ_
|
|
17
17
|
brainstate/augment/_mapping.py,sha256=nU6Y7fSnYXyQSILXU2QT-O73Fm3pnwOmgUoDaHqjve8,21544
|
18
18
|
brainstate/augment/_mapping_test.py,sha256=_KFhE3CXItwpbZ1gJfrDu3yUtX0YbfPUuHJG_G_BXEs,8963
|
19
19
|
brainstate/augment/_random.py,sha256=rkB4w4BkKsz9p8lTk31kVHvlVPJSvtGk8REn936KI_4,3071
|
20
|
-
brainstate/compile/__init__.py,sha256=
|
20
|
+
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
21
21
|
brainstate/compile/_ad_checkpoint.py,sha256=K6I4vnznDsqC9cUeCnez9UdV9r_toGA3zHezoHLA6mI,9377
|
22
22
|
brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
|
23
23
|
brainstate/compile/_conditions.py,sha256=gApsHKGQrf1QBjoKXDVL7VsoeJ2zFtSc-hFz9nbYcF0,10113
|
@@ -30,32 +30,18 @@ brainstate/compile/_loop_collect_return.py,sha256=TrKBZhtQecTtuiVz_HOeyepde-znzj
|
|
30
30
|
brainstate/compile/_loop_collect_return_test.py,sha256=bA-_11E8A_0jR5umEO3e409y7bb5QYDTgSL-SBaX7kQ,1802
|
31
31
|
brainstate/compile/_loop_no_collection.py,sha256=qto2__Zt2PJntkjB9AXEgraGLvNUJS483BhCXjJyqv0,7495
|
32
32
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
34
|
-
brainstate/compile/_make_jaxpr_test.py,sha256=
|
35
|
-
brainstate/compile/_progress_bar.py,sha256=
|
36
|
-
brainstate/compile/_unvmap.py,sha256=
|
33
|
+
brainstate/compile/_make_jaxpr.py,sha256=J4oWoPBwG-fdJvNhBEtNgmo3rXrIWCoajELhaIumgPU,33309
|
34
|
+
brainstate/compile/_make_jaxpr_test.py,sha256=3gwdiutn_PJyiweu3oPEXumxEVHKaE2xDGvkwZy2GEo,4367
|
35
|
+
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
36
|
+
brainstate/compile/_unvmap.py,sha256=CJA6D9lUcBfvdLrpFVvC2AdTJqe9uY0Ht6PltQJyr4U,4228
|
37
37
|
brainstate/compile/_util.py,sha256=aCvkTV--g4NsqcodTdBAISt4EwgezCbKzNUV58n-Q_Y,6304
|
38
|
-
brainstate/event/__init__.py,sha256=gSEem-1oTHgy99Mjm3uumTXVd93tLVl0c4dUgRpoifk,895
|
39
|
-
brainstate/event/_csr.py,sha256=PYKw8CGNgQ24MxQDoeBZTrPuC7Z-GetXQld9KiTbNYw,40063
|
40
|
-
brainstate/event/_csr_benchmark.py,sha256=xrj2DSWzw0pUHAE1jRBeSRhMW7ogXvDHEdeaZGioNE4,702
|
41
|
-
brainstate/event/_csr_mv.py,sha256=HStHvK3KyEMfLsIUslZjgbdU6OsD1yKGrzQOzBXG36M,10266
|
42
|
-
brainstate/event/_csr_test.py,sha256=_iXwUFq90GU7npVOUnlI4NA27RJ8zyCZBxe7NDH803o,9533
|
43
|
-
brainstate/event/_fixedprob_mv.py,sha256=nR3lhd87t1Vge435QHnFuDp-UBbWoW0Qk1kbsjRHQyc,25541
|
44
|
-
brainstate/event/_fixedprob_mv_benchmark.py,sha256=_F_8fH5MNMJZHeSqnq9DYMI9OgYr6JIxBKjbsgeWRv4,4720
|
45
|
-
brainstate/event/_fixedprob_mv_test.py,sha256=pVEarvGbqTjnAbxgMVRTAhkyYbvDnlyCJdeOdDD927w,4283
|
46
|
-
brainstate/event/_linear_mv.py,sha256=O5qbY31GNV1qEDrZ5kvPbA8Ae-bY5JpUgGtqDFNAeV0,11794
|
47
|
-
brainstate/event/_linear_mv_benckmark.py,sha256=hu0WqYMIa3jMoH7Fq9dgxcBjjXGFhghPx9vztyCo1KY,2411
|
48
|
-
brainstate/event/_linear_mv_test.py,sha256=V9w41ZP2vu95CyCdCkm-j9Eftqs2kqmeBY809N1-syY,3736
|
49
|
-
brainstate/event/_misc.py,sha256=8IpPooXjF2m0-tuo3pGHqThq2yLSNmYziy_zdurZ3NI,1040
|
50
|
-
brainstate/event/_xla_custom_op.py,sha256=wF_nKgLUv1IGd8OY89MYqIvyZITl8UcrVysJWFugJxY,11093
|
51
|
-
brainstate/event/_xla_custom_op_test.py,sha256=rnkGMleXzLfJj4y5QqwfBvCCLTAHe_uabwBDniY-URM,1745
|
52
38
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
53
39
|
brainstate/functional/_activations.py,sha256=S0Ok7sq5FTbmJWSejpOCHo1jpKX0gYOLy_TO2IUXM8s,21726
|
54
40
|
brainstate/functional/_activations_test.py,sha256=T___RlSrIfXwlkw8dg5A9EZMTZGDzv3a2evUwq_nYFg,13034
|
55
41
|
brainstate/functional/_normalization.py,sha256=i2EV7hSsqcNdcYRX2wAxjq8doHwyN9eNJTGTaPt03xE,2605
|
56
42
|
brainstate/functional/_others.py,sha256=_u_Ys-LiLzDAP4zJggVwaVvirgoS3jvhXMREoS6JOkM,1737
|
57
43
|
brainstate/functional/_spikes.py,sha256=QY-2ayJkgkGELcq-bftPEaf_hJptVf_SP3fY36QvlZc,2678
|
58
|
-
brainstate/graph/__init__.py,sha256=
|
44
|
+
brainstate/graph/__init__.py,sha256=noo4TjBg6iEhjjwk0sAGUhR7Ge-z8Vnc2rLYUvnqttw,1295
|
59
45
|
brainstate/graph/_graph_node.py,sha256=swAokZLKswSTaq2WEhyLIs38sy_67C6maHI6T3e1hvY,8339
|
60
46
|
brainstate/graph/_graph_node_test.py,sha256=BFGfdzZFDHI0XK7hHotSVWKt3em1taGvn8FHF9NCXx8,2702
|
61
47
|
brainstate/graph/_graph_operation.py,sha256=cIwGo3ICgtce2fmdn917r81evMFjJIKeW9doaQK4DD8,64111
|
@@ -79,7 +65,7 @@ brainstate/nn/_dyn_impl/_dynamics_neuron.py,sha256=cTbIn41EPYG0h3ICzKBXxpgB6wwA2
|
|
79
65
|
brainstate/nn/_dyn_impl/_dynamics_neuron_test.py,sha256=Tfzrzu7udGrLJGnqItiLWe5WT0dgduvYOgzGCnaPJQg,6317
|
80
66
|
brainstate/nn/_dyn_impl/_dynamics_synapse.py,sha256=MsbPyaiDyjungyzuK2b3exRGaMpZgmsmmNHNLjgxQKw,15269
|
81
67
|
brainstate/nn/_dyn_impl/_dynamics_synapse_test.py,sha256=t5i-HV0ii9sUNzWTEv04o26QVtQ-mCdMJcFq2MD755A,4981
|
82
|
-
brainstate/nn/_dyn_impl/_inputs.py,sha256=
|
68
|
+
brainstate/nn/_dyn_impl/_inputs.py,sha256=QOUpAb2YJOE78uAvIS8Ep6MFcQHV-V6uRwmYvk5p9bk,11385
|
83
69
|
brainstate/nn/_dyn_impl/_projection_alignpost.py,sha256=PNC1Tzx_SF2DHAHeJCufXzO_Q4qLoBpWABI45B3GRuc,876
|
84
70
|
brainstate/nn/_dyn_impl/_rate_rnns.py,sha256=dz_yT_6hJVhKulcjIARbGtmMzZqISws96CtBc6o5GOo,14768
|
85
71
|
brainstate/nn/_dyn_impl/_rate_rnns_test.py,sha256=gNgtr-a4ZiU1XF9wFG1HiJ9fLosfWchVR9Zn1x39xt4,2452
|
@@ -94,7 +80,7 @@ brainstate/nn/_dynamics/_synouts.py,sha256=9TGAc-nVa50th7KKn4oKLbro-4W4rwxYvp-eu
|
|
94
80
|
brainstate/nn/_dynamics/_synouts_test.py,sha256=V_jDswRN4VvEXD-2yJO3VA1TALgX0HK6oPBQiUntOWc,2266
|
95
81
|
brainstate/nn/_elementwise/__init__.py,sha256=PK8oq1K_EG2941AiUyLxCWoRdWvMO3yt8ZJbw3Lkhu8,935
|
96
82
|
brainstate/nn/_elementwise/_dropout.py,sha256=0Ebo-2y1VswvBqZ7sCA0SEUm37y49EUsef8oiSFpYGk,17759
|
97
|
-
brainstate/nn/_elementwise/_dropout_test.py,sha256=
|
83
|
+
brainstate/nn/_elementwise/_dropout_test.py,sha256=k6aB5v8RYMoV5w8UV9UNSFhaQTV7woS6jx3SNESuCRs,4383
|
98
84
|
brainstate/nn/_elementwise/_elementwise.py,sha256=om-KpwDTk5yFG5KBYXXHquRLV7s28_FJjk-omvyMyvQ,33342
|
99
85
|
brainstate/nn/_elementwise/_elementwise_test.py,sha256=SZI9jB39sZ5SO1dpWGW-PhodthwN0GU9FY1nqf2fWcs,5341
|
100
86
|
brainstate/nn/_interaction/__init__.py,sha256=TTY_SeNrdx4VnUSw6vdyl02OHdS9Qs15cWBp6kjsyNQ,1289
|
@@ -131,8 +117,8 @@ brainstate/util/_others.py,sha256=jsPZwP-v_5HRV-LB5F0NUsiqr04y8bmGIsu_JMyVcbQ,14
|
|
131
117
|
brainstate/util/_pretty_repr.py,sha256=bDpU4gbkS4B8cXBkiN8kBQNmruxiJzDRF-eIqzyeYnM,5716
|
132
118
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
133
119
|
brainstate/util/_struct.py,sha256=KMMHcshOM20gYhSahNzWLxsTt-Rt3AeX3Uz26-rP9vI,17619
|
134
|
-
brainstate-0.1.0.
|
135
|
-
brainstate-0.1.0.
|
136
|
-
brainstate-0.1.0.
|
137
|
-
brainstate-0.1.0.
|
138
|
-
brainstate-0.1.0.
|
120
|
+
brainstate-0.1.0.post20250127.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
121
|
+
brainstate-0.1.0.post20250127.dist-info/METADATA,sha256=-j8gDuJN37nbSkijFC9TtG5UbflMZZWK-sftY7fkeps,3585
|
122
|
+
brainstate-0.1.0.post20250127.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
123
|
+
brainstate-0.1.0.post20250127.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
124
|
+
brainstate-0.1.0.post20250127.dist-info/RECORD,,
|
brainstate/event/__init__.py
DELETED
@@ -1,27 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from ._csr import *
|
18
|
-
from ._fixedprob_mv import *
|
19
|
-
from ._linear_mv import *
|
20
|
-
from ._xla_custom_op import *
|
21
|
-
|
22
|
-
__all__ = [
|
23
|
-
'FixedProb',
|
24
|
-
'XLACustomOp',
|
25
|
-
'CSR',
|
26
|
-
'CSC',
|
27
|
-
]
|