brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +6 -3
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -3
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_xla_custom_op.py +7 -3
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -16,34 +16,59 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import copy
|
19
|
-
|
19
|
+
import importlib.util
|
20
|
+
from typing import Optional, Callable, Any, Tuple, Dict
|
20
21
|
|
21
22
|
import jax
|
22
23
|
|
23
|
-
|
24
|
-
from tqdm.auto import tqdm
|
25
|
-
except (ImportError, ModuleNotFoundError):
|
26
|
-
tqdm = None
|
24
|
+
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
27
25
|
|
28
26
|
__all__ = [
|
29
27
|
'ProgressBar',
|
30
28
|
]
|
31
29
|
|
30
|
+
Index = int
|
31
|
+
Carray = Any
|
32
|
+
Output = Any
|
33
|
+
|
32
34
|
|
33
35
|
class ProgressBar(object):
|
36
|
+
"""
|
37
|
+
A progress bar for tracking the progress of a jitted for-loop computation.
|
38
|
+
"""
|
34
39
|
__module__ = "brainstate.compile"
|
35
40
|
|
36
|
-
def __init__(
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
freq: Optional[int] = None,
|
44
|
+
count: Optional[int] = None,
|
45
|
+
desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
|
46
|
+
**kwargs
|
47
|
+
):
|
48
|
+
# print rate
|
37
49
|
self.print_freq = freq
|
38
50
|
if isinstance(freq, int):
|
39
51
|
assert freq > 0, "Print rate should be > 0."
|
52
|
+
|
53
|
+
# print count
|
40
54
|
self.print_count = count
|
41
55
|
if self.print_freq is not None and self.print_count is not None:
|
42
56
|
raise ValueError("Cannot specify both count and freq.")
|
57
|
+
|
58
|
+
# other parameters
|
43
59
|
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
44
60
|
kwargs.pop(kwarg, None)
|
45
61
|
self.kwargs = kwargs
|
46
|
-
|
62
|
+
|
63
|
+
# description
|
64
|
+
if desc is not None:
|
65
|
+
assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
|
66
|
+
assert isinstance(desc[0], str), 'Description should be a string.'
|
67
|
+
assert callable(desc[1]), 'Description should be a callable.'
|
68
|
+
self.desc = desc
|
69
|
+
|
70
|
+
# check if tqdm is installed
|
71
|
+
if not tqdm_installed:
|
47
72
|
raise ImportError("tqdm is not installed.")
|
48
73
|
|
49
74
|
def init(self, n: int):
|
@@ -67,15 +92,22 @@ class ProgressBar(object):
|
|
67
92
|
raise ValueError("Print rate should be less than the "
|
68
93
|
f"number of steps {n}, got {freq}")
|
69
94
|
remainder = n % freq
|
70
|
-
|
71
|
-
message =
|
72
|
-
return ProgressBarRunner(n,
|
95
|
+
|
96
|
+
message = f"Running for {n:,} iterations" if self.desc is None else self.desc
|
97
|
+
return ProgressBarRunner(n, freq, remainder, message, **kwargs)
|
73
98
|
|
74
99
|
|
75
100
|
class ProgressBarRunner(object):
|
76
101
|
__module__ = "brainstate.compile"
|
77
102
|
|
78
|
-
def __init__(
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
n: int,
|
106
|
+
print_freq: int,
|
107
|
+
remainder: int,
|
108
|
+
message: str | Tuple[str, Callable[[Dict], Dict]],
|
109
|
+
**kwargs
|
110
|
+
):
|
79
111
|
self.tqdm_bars = {}
|
80
112
|
self.kwargs = kwargs
|
81
113
|
self.n = n
|
@@ -83,50 +115,46 @@ class ProgressBarRunner(object):
|
|
83
115
|
self.remainder = remainder
|
84
116
|
self.message = message
|
85
117
|
|
86
|
-
def _define_tqdm(self):
|
118
|
+
def _define_tqdm(self, x: dict):
|
119
|
+
from tqdm.auto import tqdm
|
87
120
|
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
88
|
-
|
121
|
+
if isinstance(self.message, str):
|
122
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
123
|
+
else:
|
124
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
89
125
|
|
90
|
-
def _update_tqdm(self):
|
126
|
+
def _update_tqdm(self, x: dict):
|
91
127
|
self.tqdm_bars[0].update(self.print_freq)
|
128
|
+
if not isinstance(self.message, str):
|
129
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
92
130
|
|
93
|
-
def _close_tqdm(self):
|
131
|
+
def _close_tqdm(self, x: dict):
|
94
132
|
if self.remainder > 0:
|
95
133
|
self.tqdm_bars[0].update(self.remainder)
|
134
|
+
if not isinstance(self.message, str):
|
135
|
+
self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
|
96
136
|
self.tqdm_bars[0].close()
|
97
137
|
|
98
|
-
def
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
if is_print:
|
103
|
-
self.tqdm_bars[0].update(self.print_freq)
|
104
|
-
if is_final:
|
105
|
-
if self.remainder > 0:
|
106
|
-
self.tqdm_bars[0].update(self.remainder)
|
107
|
-
self.tqdm_bars[0].close()
|
108
|
-
|
109
|
-
def __call__(self, iter_num, *args, **kwargs):
|
110
|
-
# jax.debug.callback(
|
111
|
-
# self._tqdm,
|
112
|
-
# iter_num == 0,
|
113
|
-
# (iter_num + 1) % self.print_freq == 0,
|
114
|
-
# iter_num == self.n - 1
|
115
|
-
# )
|
138
|
+
def __call__(self, iter_num, **kwargs):
|
139
|
+
data = dict(i=iter_num, **kwargs)
|
140
|
+
data = dict() if isinstance(self.message, str) else self.message[1](data)
|
141
|
+
assert isinstance(data, dict), 'Description function should return a dictionary.'
|
116
142
|
|
117
143
|
_ = jax.lax.cond(
|
118
144
|
iter_num == 0,
|
119
|
-
lambda: jax.debug.callback(self._define_tqdm, ordered=True),
|
120
|
-
lambda: None,
|
145
|
+
lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
|
146
|
+
lambda x: None,
|
147
|
+
data
|
121
148
|
)
|
122
149
|
_ = jax.lax.cond(
|
123
150
|
iter_num % self.print_freq == (self.print_freq - 1),
|
124
|
-
lambda: jax.debug.callback(self._update_tqdm, ordered=True),
|
125
|
-
lambda: None,
|
151
|
+
lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
|
152
|
+
lambda x: None,
|
153
|
+
data
|
126
154
|
)
|
127
155
|
_ = jax.lax.cond(
|
128
156
|
iter_num == self.n - 1,
|
129
|
-
lambda: jax.debug.callback(self._close_tqdm, ordered=True),
|
130
|
-
lambda: None,
|
157
|
+
lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
|
158
|
+
lambda x: None,
|
159
|
+
data
|
131
160
|
)
|
132
|
-
|
brainstate/compile/_unvmap.py
CHANGED
@@ -16,13 +16,16 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
import jax
|
18
18
|
import jax.core
|
19
|
-
import jax.extend as je
|
20
19
|
import jax.interpreters.batching as batching
|
21
20
|
import jax.interpreters.mlir as mlir
|
22
21
|
import jax.numpy as jnp
|
23
|
-
|
24
22
|
from brainstate._utils import set_module_as
|
25
23
|
|
24
|
+
if jax.__version_info__ < (0, 4, 38):
|
25
|
+
from jax.core import Primitive
|
26
|
+
else:
|
27
|
+
from jax.extend.core import Primitive
|
28
|
+
|
26
29
|
__all__ = [
|
27
30
|
"unvmap",
|
28
31
|
]
|
@@ -44,7 +47,7 @@ def unvmap(x, op: str = 'any'):
|
|
44
47
|
|
45
48
|
# unvmap_all
|
46
49
|
|
47
|
-
unvmap_all_p =
|
50
|
+
unvmap_all_p = Primitive("unvmap_all")
|
48
51
|
|
49
52
|
|
50
53
|
def unvmap_all(x):
|
brainstate/event/__init__.py
CHANGED
brainstate/event/_csr.py
CHANGED
@@ -73,7 +73,7 @@ class CSR(u.sparse.SparseMatrix):
|
|
73
73
|
return u.sparse.csr_todense(self)
|
74
74
|
|
75
75
|
def transpose(self, axes=None):
|
76
|
-
assert axes is None
|
76
|
+
assert axes is None, "transpose does not support axes argument."
|
77
77
|
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
78
78
|
|
79
79
|
def __abs__(self):
|
@@ -103,6 +103,7 @@ class CSR(u.sparse.SparseMatrix):
|
|
103
103
|
(op(self.data, other), self.indices, self.indptr),
|
104
104
|
shape=self.shape
|
105
105
|
)
|
106
|
+
|
106
107
|
elif other.ndim == 2 and other.shape == self.shape:
|
107
108
|
rows, cols = csr_to_coo(self.indices, self.indptr)
|
108
109
|
other = other[rows, cols]
|
@@ -112,6 +113,7 @@ class CSR(u.sparse.SparseMatrix):
|
|
112
113
|
self.indptr),
|
113
114
|
shape=self.shape
|
114
115
|
)
|
116
|
+
|
115
117
|
else:
|
116
118
|
raise NotImplementedError(f"mul with object of shape {other.shape}")
|
117
119
|
|
@@ -184,10 +186,12 @@ class CSR(u.sparse.SparseMatrix):
|
|
184
186
|
return self._binary_rop(other, operator.mod)
|
185
187
|
|
186
188
|
def __matmul__(self, other):
|
189
|
+
# csr @ other
|
187
190
|
if isinstance(other, JAXSparse):
|
188
191
|
raise NotImplementedError("matmul between two sparse objects.")
|
189
192
|
other = u.math.asarray(other)
|
190
|
-
data
|
193
|
+
data = self.data
|
194
|
+
# data, other = u.math.promote_dtypes(self.data, other)
|
191
195
|
if other.ndim == 1:
|
192
196
|
return _csr_matvec(
|
193
197
|
data,
|
@@ -208,10 +212,12 @@ class CSR(u.sparse.SparseMatrix):
|
|
208
212
|
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
209
213
|
|
210
214
|
def __rmatmul__(self, other):
|
215
|
+
# other @ csr
|
211
216
|
if isinstance(other, JAXSparse):
|
212
217
|
raise NotImplementedError("matmul between two sparse objects.")
|
213
218
|
other = u.math.asarray(other)
|
214
|
-
data
|
219
|
+
data = self.data
|
220
|
+
# data, other = u.math.promote_dtypes(self.data, other)
|
215
221
|
if other.ndim == 1:
|
216
222
|
return _csr_matvec(
|
217
223
|
data,
|
@@ -566,7 +572,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
566
572
|
if weight_info.size == 1:
|
567
573
|
if transpose:
|
568
574
|
if spike_info.dtype == jnp.bool_:
|
569
|
-
@numba.njit
|
575
|
+
@numba.njit(fastmath=True)
|
570
576
|
def mv(weights, indices, indptr, v, posts):
|
571
577
|
posts[:] = 0.
|
572
578
|
w = weights[()]
|
@@ -576,7 +582,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
576
582
|
posts[indices[j]] += w
|
577
583
|
|
578
584
|
elif float_as_event:
|
579
|
-
@numba.njit
|
585
|
+
@numba.njit(fastmath=True)
|
580
586
|
def mv(weights, indices, indptr, v, posts):
|
581
587
|
posts[:] = 0.
|
582
588
|
w = weights[()]
|
@@ -586,7 +592,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
586
592
|
posts[indices[j]] += w
|
587
593
|
|
588
594
|
else:
|
589
|
-
@numba.njit
|
595
|
+
@numba.njit(fastmath=True)
|
590
596
|
def mv(weights, indices, indptr, v, posts):
|
591
597
|
posts[:] = 0.
|
592
598
|
w = weights[()]
|
@@ -599,7 +605,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
599
605
|
|
600
606
|
else:
|
601
607
|
if spike_info.dtype == jnp.bool_:
|
602
|
-
@numba.njit
|
608
|
+
@numba.njit(fastmath=True)
|
603
609
|
def mv(weights, indices, indptr, v, posts):
|
604
610
|
w = weights[()]
|
605
611
|
for i in range(indptr.shape[0] - 1):
|
@@ -610,7 +616,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
610
616
|
posts[i] = r
|
611
617
|
|
612
618
|
elif float_as_event:
|
613
|
-
@numba.njit
|
619
|
+
@numba.njit(fastmath=True)
|
614
620
|
def mv(weights, indices, indptr, v, posts):
|
615
621
|
w = weights[()]
|
616
622
|
for i in range(indptr.shape[0] - 1):
|
@@ -621,7 +627,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
621
627
|
posts[i] = r
|
622
628
|
|
623
629
|
else:
|
624
|
-
@numba.njit
|
630
|
+
@numba.njit(fastmath=True)
|
625
631
|
def mv(weights, indices, indptr, v, posts):
|
626
632
|
w = weights[()]
|
627
633
|
for i in range(indptr.shape[0] - 1):
|
@@ -635,7 +641,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
635
641
|
else:
|
636
642
|
if transpose:
|
637
643
|
if spike_info.dtype == jnp.bool_:
|
638
|
-
@numba.njit
|
644
|
+
@numba.njit(fastmath=True)
|
639
645
|
def mv(weights, indices, indptr, v, posts):
|
640
646
|
posts[:] = 0.
|
641
647
|
for i in range(v.shape[0]):
|
@@ -644,7 +650,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
644
650
|
posts[indices[j]] += weights[j]
|
645
651
|
|
646
652
|
elif float_as_event:
|
647
|
-
@numba.njit
|
653
|
+
@numba.njit(fastmath=True)
|
648
654
|
def mv(weights, indices, indptr, v, posts):
|
649
655
|
posts[:] = 0.
|
650
656
|
for i in range(v.shape[0]):
|
@@ -653,7 +659,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
653
659
|
posts[indices[j]] += weights[j]
|
654
660
|
|
655
661
|
else:
|
656
|
-
@numba.njit
|
662
|
+
@numba.njit(fastmath=True)
|
657
663
|
def mv(weights, indices, indptr, v, posts):
|
658
664
|
posts[:] = 0.
|
659
665
|
for i in range(v.shape[0]):
|
@@ -664,7 +670,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
664
670
|
|
665
671
|
else:
|
666
672
|
if spike_info.dtype == jnp.bool_:
|
667
|
-
@numba.njit
|
673
|
+
@numba.njit(fastmath=True)
|
668
674
|
def mv(weights, indices, indptr, v, posts):
|
669
675
|
for i in range(indptr.shape[0] - 1):
|
670
676
|
r = 0.
|
@@ -674,7 +680,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
674
680
|
posts[i] = r
|
675
681
|
|
676
682
|
elif float_as_event:
|
677
|
-
@numba.njit
|
683
|
+
@numba.njit(fastmath=True)
|
678
684
|
def mv(weights, indices, indptr, v, posts):
|
679
685
|
for i in range(indptr.shape[0] - 1):
|
680
686
|
r = 0.
|
@@ -684,7 +690,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
684
690
|
posts[i] = r
|
685
691
|
|
686
692
|
else:
|
687
|
-
@numba.njit
|
693
|
+
@numba.njit(fastmath=True)
|
688
694
|
def mv(weights, indices, indptr, v, posts):
|
689
695
|
for i in range(indptr.shape[0] - 1):
|
690
696
|
r = 0.
|
@@ -795,7 +801,31 @@ def event_csrmv_transpose_rule(
|
|
795
801
|
|
796
802
|
def event_csrmv_batching(args, axes, **kwargs):
|
797
803
|
if tuple(axes) == (None, None, None, 0):
|
798
|
-
|
804
|
+
assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
|
805
|
+
r = event_csrmm_p_call(
|
806
|
+
args[0],
|
807
|
+
args[1],
|
808
|
+
args[2],
|
809
|
+
args[3].T,
|
810
|
+
shape=kwargs['shape'],
|
811
|
+
transpose=kwargs['transpose'],
|
812
|
+
float_as_event=kwargs['float_as_event']
|
813
|
+
)
|
814
|
+
return r, [1]
|
815
|
+
|
816
|
+
elif tuple(axes) == (None, None, None, 1):
|
817
|
+
assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
|
818
|
+
r = event_csrmm_p_call(
|
819
|
+
args[0],
|
820
|
+
args[1],
|
821
|
+
args[2],
|
822
|
+
args[3],
|
823
|
+
shape=kwargs['shape'],
|
824
|
+
transpose=kwargs['transpose'],
|
825
|
+
float_as_event=kwargs['float_as_event']
|
826
|
+
)
|
827
|
+
return r, [1]
|
828
|
+
|
799
829
|
else:
|
800
830
|
raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
|
801
831
|
|
@@ -852,17 +882,228 @@ def event_csrmv_p_call(
|
|
852
882
|
|
853
883
|
def event_csrmm_batching(args, axes, **kwargs):
|
854
884
|
if tuple(axes) == (None, None, None, 0):
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
885
|
+
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
886
|
+
batch_size, m, n = args[3].shape
|
887
|
+
B = jnp.transpose(args[3], (1, 0, 2)).reshape(m, batch_size * n)
|
888
|
+
r = event_csrmm_p_call(
|
889
|
+
args[0],
|
890
|
+
args[1],
|
891
|
+
args[2],
|
892
|
+
B,
|
893
|
+
shape=kwargs['shape'],
|
894
|
+
transpose=kwargs['transpose'],
|
895
|
+
float_as_event=kwargs['float_as_event']
|
896
|
+
)
|
897
|
+
r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
|
898
|
+
return [r], [1]
|
899
|
+
|
900
|
+
elif tuple(axes) == (None, None, None, 1):
|
901
|
+
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
902
|
+
m, batch_size, n = args[3].shape
|
903
|
+
B = args[3].reshape(m, batch_size * n)
|
904
|
+
r = event_csrmm_p_call(
|
905
|
+
args[0],
|
906
|
+
args[1],
|
907
|
+
args[2],
|
908
|
+
B,
|
909
|
+
shape=kwargs['shape'],
|
910
|
+
transpose=kwargs['transpose'],
|
911
|
+
float_as_event=kwargs['float_as_event']
|
912
|
+
)
|
913
|
+
r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
|
914
|
+
return [r], [1]
|
915
|
+
|
916
|
+
elif tuple(axes) == (None, None, None, 2):
|
917
|
+
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
918
|
+
m, n, batch_size = args[3].shape
|
919
|
+
B = args[3].reshape(m, batch_size * n)
|
920
|
+
r = event_csrmm_p_call(
|
921
|
+
args[0],
|
922
|
+
args[1],
|
923
|
+
args[2],
|
924
|
+
B,
|
925
|
+
shape=kwargs['shape'],
|
926
|
+
transpose=kwargs['transpose'],
|
927
|
+
float_as_event=kwargs['float_as_event']
|
928
|
+
)
|
929
|
+
r = jnp.reshape(r[0], [r[0].shape[0], n, batch_size])
|
930
|
+
return [r], [2]
|
931
|
+
|
859
932
|
else:
|
860
933
|
raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
|
861
934
|
|
862
935
|
|
936
|
+
def event_csrmm_cpu_kernel_generator(
|
937
|
+
float_as_event: bool,
|
938
|
+
weight_info: jax.ShapeDtypeStruct,
|
939
|
+
spike_info: jax.ShapeDtypeStruct,
|
940
|
+
transpose: bool,
|
941
|
+
**kwargs
|
942
|
+
) -> Kernel:
|
943
|
+
import numba # pylint: disable=import-outside-toplevel
|
944
|
+
|
945
|
+
if weight_info.size == 1:
|
946
|
+
if transpose:
|
947
|
+
# csr.T @ B
|
948
|
+
|
949
|
+
if spike_info.dtype == jnp.bool_:
|
950
|
+
@numba.njit(fastmath=True, parallel=False)
|
951
|
+
def mv(weights, indices, indptr, B, posts):
|
952
|
+
posts[:] = 0.
|
953
|
+
w = weights[()]
|
954
|
+
for k in numba.prange(B.shape[1]):
|
955
|
+
for i in range(B.shape[0]):
|
956
|
+
if B[i, k]:
|
957
|
+
for j in range(indptr[i], indptr[i + 1]):
|
958
|
+
posts[indices[j], k] += w
|
959
|
+
|
960
|
+
elif float_as_event:
|
961
|
+
@numba.njit(fastmath=True, parallel=False)
|
962
|
+
def mv(weights, indices, indptr, B, posts):
|
963
|
+
posts[:] = 0.
|
964
|
+
B = B != 0.
|
965
|
+
w = weights[()]
|
966
|
+
for k in numba.prange(B.shape[1]):
|
967
|
+
for i in range(B.shape[0]):
|
968
|
+
if B[i, k]:
|
969
|
+
for j in range(indptr[i], indptr[i + 1]):
|
970
|
+
posts[indices[j], k] += w
|
971
|
+
|
972
|
+
else:
|
973
|
+
@numba.njit(fastmath=True, parallel=False)
|
974
|
+
def mv(weights, indices, indptr, B, posts):
|
975
|
+
posts[:] = 0.
|
976
|
+
w = weights[()]
|
977
|
+
for k in numba.prange(B.shape[1]):
|
978
|
+
for i in range(B.shape[0]):
|
979
|
+
sp = B[i, k]
|
980
|
+
if sp != 0.:
|
981
|
+
wsp = w * sp
|
982
|
+
for j in range(indptr[i], indptr[i + 1]):
|
983
|
+
posts[indices[j], k] += wsp
|
984
|
+
|
985
|
+
else:
|
986
|
+
# csr @ B
|
987
|
+
if spike_info.dtype == jnp.bool_:
|
988
|
+
@numba.njit(fastmath=True)
|
989
|
+
def mv(weights, indices, indptr, B, posts):
|
990
|
+
w = weights[()]
|
991
|
+
for i in range(indptr.shape[0] - 1):
|
992
|
+
r = np.zeros(B.shape[1], dtype=weights.dtype)
|
993
|
+
for j in range(indptr[i], indptr[i + 1]):
|
994
|
+
index = indices[j]
|
995
|
+
for k in range(B.shape[1]):
|
996
|
+
if B[index, k]:
|
997
|
+
r[k] += w
|
998
|
+
posts[i] = r
|
999
|
+
|
1000
|
+
elif float_as_event:
|
1001
|
+
@numba.njit(fastmath=True)
|
1002
|
+
def mv(weights, indices, indptr, B, posts):
|
1003
|
+
w = weights[()]
|
1004
|
+
B = B != 0.
|
1005
|
+
for i in range(indptr.shape[0] - 1):
|
1006
|
+
r = np.zeros(B.shape[1], dtype=weights.dtype)
|
1007
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1008
|
+
index = indices[j]
|
1009
|
+
for k in range(B.shape[1]):
|
1010
|
+
if B[index, k]:
|
1011
|
+
r[k] += w
|
1012
|
+
posts[i] = r
|
1013
|
+
|
1014
|
+
else:
|
1015
|
+
@numba.njit(fastmath=True)
|
1016
|
+
def mv(weights, indices, indptr, B, posts):
|
1017
|
+
w = weights[()]
|
1018
|
+
for i in range(indptr.shape[0] - 1):
|
1019
|
+
for k in range(B.shape[1]):
|
1020
|
+
r = 0.
|
1021
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1022
|
+
c = B[indices[j], k]
|
1023
|
+
if c != 0.:
|
1024
|
+
r += w * c
|
1025
|
+
posts[i, k] = r
|
1026
|
+
|
1027
|
+
else:
|
1028
|
+
if transpose:
|
1029
|
+
# csr.T @ B
|
1030
|
+
|
1031
|
+
if spike_info.dtype == jnp.bool_:
|
1032
|
+
@numba.njit(fastmath=True, parallel=False)
|
1033
|
+
def mv(weights, indices, indptr, B, posts):
|
1034
|
+
posts[:] = 0.
|
1035
|
+
for k in numba.prange(B.shape[1]):
|
1036
|
+
for i in range(B.shape[0]):
|
1037
|
+
if B[i, k]:
|
1038
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1039
|
+
posts[indices[j], k] += weights[j]
|
1040
|
+
|
1041
|
+
elif float_as_event:
|
1042
|
+
@numba.njit(fastmath=True, parallel=False)
|
1043
|
+
def mv(weights, indices, indptr, B, posts):
|
1044
|
+
posts[:] = 0.
|
1045
|
+
B = B != 0.
|
1046
|
+
for k in numba.prange(B.shape[1]):
|
1047
|
+
for i in range(B.shape[0]):
|
1048
|
+
if B[i, k]:
|
1049
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1050
|
+
posts[indices[j], k] += weights[j]
|
1051
|
+
|
1052
|
+
else:
|
1053
|
+
@numba.njit(fastmath=True, parallel=False)
|
1054
|
+
def mv(weights, indices, indptr, B, posts):
|
1055
|
+
posts[:] = 0.
|
1056
|
+
for k in numba.prange(B.shape[1]):
|
1057
|
+
for i in range(B.shape[0]):
|
1058
|
+
sp = B[i, k]
|
1059
|
+
if sp != 0.:
|
1060
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1061
|
+
posts[indices[j], k] += weights[j] * sp
|
1062
|
+
|
1063
|
+
else:
|
1064
|
+
# csr @ B
|
1065
|
+
|
1066
|
+
if spike_info.dtype == jnp.bool_:
|
1067
|
+
@numba.njit(fastmath=True)
|
1068
|
+
def mv(weights, indices, indptr, B, posts):
|
1069
|
+
for i in range(indptr.shape[0] - 1):
|
1070
|
+
for k in range(B.shape[1]):
|
1071
|
+
r = 0.
|
1072
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1073
|
+
if B[indices[j], k]:
|
1074
|
+
r += weights[j]
|
1075
|
+
posts[i, k] = r
|
1076
|
+
|
1077
|
+
elif float_as_event:
|
1078
|
+
@numba.njit(fastmath=True)
|
1079
|
+
def mv(weights, indices, indptr, B, posts):
|
1080
|
+
B = B != 0.
|
1081
|
+
for i in range(indptr.shape[0] - 1):
|
1082
|
+
for k in range(B.shape[1]):
|
1083
|
+
r = 0.
|
1084
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1085
|
+
if B[indices[j], k]:
|
1086
|
+
r += weights[j]
|
1087
|
+
posts[i, k] = r
|
1088
|
+
|
1089
|
+
else:
|
1090
|
+
@numba.njit(fastmath=True)
|
1091
|
+
def mv(weights, indices, indptr, B, posts):
|
1092
|
+
for i in range(indptr.shape[0] - 1):
|
1093
|
+
for k in range(B.shape[1]):
|
1094
|
+
r = 0.
|
1095
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1096
|
+
c = B[indices[j], k]
|
1097
|
+
if c != 0.:
|
1098
|
+
r += weights[j] * c
|
1099
|
+
posts[i, k] = r
|
1100
|
+
|
1101
|
+
return mv
|
1102
|
+
|
1103
|
+
|
863
1104
|
event_csrmm_p = XLACustomOp(
|
864
1105
|
'event_csrmm',
|
865
|
-
cpu_kernel_or_generator=
|
1106
|
+
cpu_kernel_or_generator=event_csrmm_cpu_kernel_generator,
|
866
1107
|
)
|
867
1108
|
event_csrmm_p.def_batching_rule(event_csrmm_batching)
|
868
1109
|
|
@@ -884,11 +1125,13 @@ def event_csrmm_p_call(
|
|
884
1125
|
indptr,
|
885
1126
|
B,
|
886
1127
|
outs=[
|
887
|
-
jax.ShapeDtypeStruct([shape[
|
1128
|
+
jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype)
|
888
1129
|
if transpose else
|
889
|
-
jax.ShapeDtypeStruct([shape[
|
1130
|
+
jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype),
|
890
1131
|
],
|
891
1132
|
# block_size=block_size,
|
1133
|
+
shape=shape,
|
1134
|
+
transpose=transpose,
|
892
1135
|
float_as_event=float_as_event,
|
893
1136
|
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
894
1137
|
spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),
|