brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -2
- brainstate/augment/__init__.py +10 -20
- brainstate/compile/__init__.py +18 -37
- brainstate/compile/_make_jaxpr.py +9 -2
- brainstate/compile/_make_jaxpr_test.py +10 -6
- brainstate/compile/_progress_bar.py +49 -6
- brainstate/compile/_unvmap.py +3 -3
- brainstate/graph/__init__.py +12 -12
- brainstate/nn/_dyn_impl/_inputs.py +4 -2
- brainstate/nn/_elementwise/_dropout_test.py +1 -1
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
- brainstate/event/__init__.py +0 -27
- brainstate/event/_csr.py +0 -1149
- brainstate/event/_csr_benchmark.py +0 -14
- brainstate/event/_csr_mv.py +0 -303
- brainstate/event/_csr_test.py +0 -277
- brainstate/event/_fixedprob_mv.py +0 -730
- brainstate/event/_fixedprob_mv_benchmark.py +0 -128
- brainstate/event/_fixedprob_mv_test.py +0 -132
- brainstate/event/_linear_mv.py +0 -359
- brainstate/event/_linear_mv_benckmark.py +0 -82
- brainstate/event/_linear_mv_test.py +0 -117
- brainstate/event/_misc.py +0 -34
- brainstate/event/_xla_custom_op.py +0 -317
- brainstate/event/_xla_custom_op_test.py +0 -55
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
brainstate/event/_csr.py
DELETED
@@ -1,1149 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
|
-
import operator
|
20
|
-
from typing import Callable
|
21
|
-
|
22
|
-
import brainunit as u
|
23
|
-
import jax
|
24
|
-
import jax.numpy as jnp
|
25
|
-
import numpy as np
|
26
|
-
from brainunit.sparse._csr import (
|
27
|
-
_csr_matvec as csr_matvec,
|
28
|
-
_csr_matmat as csr_matmat,
|
29
|
-
_csr_to_coo as csr_to_coo
|
30
|
-
)
|
31
|
-
from jax.experimental.sparse import JAXSparse
|
32
|
-
from jax.interpreters import ad
|
33
|
-
|
34
|
-
from brainstate.typing import Shape
|
35
|
-
from ._xla_custom_op import XLACustomOp
|
36
|
-
|
37
|
-
__all__ = [
|
38
|
-
'CSR',
|
39
|
-
'CSC',
|
40
|
-
]
|
41
|
-
|
42
|
-
|
43
|
-
@jax.tree_util.register_pytree_node_class
|
44
|
-
class CSR(u.sparse.SparseMatrix):
|
45
|
-
"""
|
46
|
-
Event-driven and Unit-aware CSR matrix.
|
47
|
-
"""
|
48
|
-
data: jax.Array | u.Quantity
|
49
|
-
indices: jax.Array
|
50
|
-
indptr: jax.Array
|
51
|
-
shape: tuple[int, int]
|
52
|
-
nse = property(lambda self: self.data.size)
|
53
|
-
dtype = property(lambda self: self.data.dtype)
|
54
|
-
_bufs = property(lambda self: (self.data, self.indices, self.indptr))
|
55
|
-
|
56
|
-
def __init__(self, args, *, shape):
|
57
|
-
self.data, self.indices, self.indptr = map(u.math.asarray, args)
|
58
|
-
super().__init__(args, shape=shape)
|
59
|
-
|
60
|
-
@classmethod
|
61
|
-
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
|
62
|
-
if nse is None:
|
63
|
-
nse = (u.get_mantissa(mat) != 0).sum()
|
64
|
-
return u.sparse.csr_fromdense(mat, nse=nse, index_dtype=index_dtype)
|
65
|
-
|
66
|
-
def with_data(self, data: jax.Array | u.Quantity) -> CSR:
|
67
|
-
assert data.shape == self.data.shape
|
68
|
-
assert data.dtype == self.data.dtype
|
69
|
-
assert u.get_unit(data) == u.get_unit(self.data)
|
70
|
-
return CSR((data, self.indices, self.indptr), shape=self.shape)
|
71
|
-
|
72
|
-
def todense(self):
|
73
|
-
return u.sparse.csr_todense(self)
|
74
|
-
|
75
|
-
def transpose(self, axes=None):
|
76
|
-
assert axes is None, "transpose does not support axes argument."
|
77
|
-
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
78
|
-
|
79
|
-
def __abs__(self):
|
80
|
-
return CSR((abs(self.data), self.indices, self.indptr), shape=self.shape)
|
81
|
-
|
82
|
-
def __neg__(self):
|
83
|
-
return CSR((-self.data, self.indices, self.indptr), shape=self.shape)
|
84
|
-
|
85
|
-
def __pos__(self):
|
86
|
-
return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
|
87
|
-
|
88
|
-
def _binary_op(self, other, op):
|
89
|
-
if isinstance(other, CSR):
|
90
|
-
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
|
91
|
-
return CSR(
|
92
|
-
(op(self.data, other.data),
|
93
|
-
self.indices,
|
94
|
-
self.indptr),
|
95
|
-
shape=self.shape
|
96
|
-
)
|
97
|
-
if isinstance(other, JAXSparse):
|
98
|
-
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
|
99
|
-
|
100
|
-
other = u.math.asarray(other)
|
101
|
-
if other.size == 1:
|
102
|
-
return CSR(
|
103
|
-
(op(self.data, other), self.indices, self.indptr),
|
104
|
-
shape=self.shape
|
105
|
-
)
|
106
|
-
|
107
|
-
elif other.ndim == 2 and other.shape == self.shape:
|
108
|
-
rows, cols = csr_to_coo(self.indices, self.indptr)
|
109
|
-
other = other[rows, cols]
|
110
|
-
return CSR(
|
111
|
-
(op(self.data, other),
|
112
|
-
self.indices,
|
113
|
-
self.indptr),
|
114
|
-
shape=self.shape
|
115
|
-
)
|
116
|
-
|
117
|
-
else:
|
118
|
-
raise NotImplementedError(f"mul with object of shape {other.shape}")
|
119
|
-
|
120
|
-
def _binary_rop(self, other, op):
|
121
|
-
if isinstance(other, CSR):
|
122
|
-
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
|
123
|
-
return CSR(
|
124
|
-
(op(other.data, self.data),
|
125
|
-
self.indices,
|
126
|
-
self.indptr),
|
127
|
-
shape=self.shape
|
128
|
-
)
|
129
|
-
if isinstance(other, JAXSparse):
|
130
|
-
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
|
131
|
-
|
132
|
-
other = u.math.asarray(other)
|
133
|
-
if other.size == 1:
|
134
|
-
return CSR(
|
135
|
-
(op(other, self.data),
|
136
|
-
self.indices,
|
137
|
-
self.indptr),
|
138
|
-
shape=self.shape
|
139
|
-
)
|
140
|
-
elif other.ndim == 2 and other.shape == self.shape:
|
141
|
-
rows, cols = csr_to_coo(self.indices, self.indptr)
|
142
|
-
other = other[rows, cols]
|
143
|
-
return CSR(
|
144
|
-
(op(other, self.data),
|
145
|
-
self.indices,
|
146
|
-
self.indptr),
|
147
|
-
shape=self.shape
|
148
|
-
)
|
149
|
-
else:
|
150
|
-
raise NotImplementedError(f"mul with object of shape {other.shape}")
|
151
|
-
|
152
|
-
def __mul__(self, other: jax.Array | u.Quantity) -> CSR:
|
153
|
-
return self._binary_op(other, operator.mul)
|
154
|
-
|
155
|
-
def __rmul__(self, other: jax.Array | u.Quantity) -> CSR:
|
156
|
-
return self._binary_rop(other, operator.mul)
|
157
|
-
|
158
|
-
def __div__(self, other: jax.Array | u.Quantity) -> CSR:
|
159
|
-
return self._binary_op(other, operator.truediv)
|
160
|
-
|
161
|
-
def __rdiv__(self, other: jax.Array | u.Quantity) -> CSR:
|
162
|
-
return self._binary_rop(other, operator.truediv)
|
163
|
-
|
164
|
-
def __truediv__(self, other) -> CSR:
|
165
|
-
return self.__div__(other)
|
166
|
-
|
167
|
-
def __rtruediv__(self, other) -> CSR:
|
168
|
-
return self.__rdiv__(other)
|
169
|
-
|
170
|
-
def __add__(self, other) -> CSR:
|
171
|
-
return self._binary_op(other, operator.add)
|
172
|
-
|
173
|
-
def __radd__(self, other) -> CSR:
|
174
|
-
return self._binary_rop(other, operator.add)
|
175
|
-
|
176
|
-
def __sub__(self, other) -> CSR:
|
177
|
-
return self._binary_op(other, operator.sub)
|
178
|
-
|
179
|
-
def __rsub__(self, other) -> CSR:
|
180
|
-
return self._binary_rop(other, operator.sub)
|
181
|
-
|
182
|
-
def __mod__(self, other) -> CSR:
|
183
|
-
return self._binary_op(other, operator.mod)
|
184
|
-
|
185
|
-
def __rmod__(self, other) -> CSR:
|
186
|
-
return self._binary_rop(other, operator.mod)
|
187
|
-
|
188
|
-
def __matmul__(self, other):
|
189
|
-
# csr @ other
|
190
|
-
if isinstance(other, JAXSparse):
|
191
|
-
raise NotImplementedError("matmul between two sparse objects.")
|
192
|
-
other = u.math.asarray(other)
|
193
|
-
data = self.data
|
194
|
-
# data, other = u.math.promote_dtypes(self.data, other)
|
195
|
-
if other.ndim == 1:
|
196
|
-
return _csr_matvec(
|
197
|
-
data,
|
198
|
-
self.indices,
|
199
|
-
self.indptr,
|
200
|
-
other,
|
201
|
-
shape=self.shape
|
202
|
-
)
|
203
|
-
elif other.ndim == 2:
|
204
|
-
return _csr_matmat(
|
205
|
-
data,
|
206
|
-
self.indices,
|
207
|
-
self.indptr,
|
208
|
-
other,
|
209
|
-
shape=self.shape
|
210
|
-
)
|
211
|
-
else:
|
212
|
-
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
213
|
-
|
214
|
-
def __rmatmul__(self, other):
|
215
|
-
# other @ csr
|
216
|
-
if isinstance(other, JAXSparse):
|
217
|
-
raise NotImplementedError("matmul between two sparse objects.")
|
218
|
-
other = u.math.asarray(other)
|
219
|
-
data = self.data
|
220
|
-
# data, other = u.math.promote_dtypes(self.data, other)
|
221
|
-
if other.ndim == 1:
|
222
|
-
return _csr_matvec(
|
223
|
-
data,
|
224
|
-
self.indices,
|
225
|
-
self.indptr,
|
226
|
-
other,
|
227
|
-
shape=self.shape,
|
228
|
-
transpose=True
|
229
|
-
)
|
230
|
-
elif other.ndim == 2:
|
231
|
-
other = other.T
|
232
|
-
r = _csr_matmat(
|
233
|
-
data,
|
234
|
-
self.indices,
|
235
|
-
self.indptr,
|
236
|
-
other,
|
237
|
-
shape=self.shape,
|
238
|
-
transpose=True
|
239
|
-
)
|
240
|
-
return r.T
|
241
|
-
else:
|
242
|
-
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
243
|
-
|
244
|
-
def tree_flatten(self):
|
245
|
-
return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
|
246
|
-
|
247
|
-
@classmethod
|
248
|
-
def tree_unflatten(cls, aux_data, children):
|
249
|
-
obj = object.__new__(cls)
|
250
|
-
obj.data, = children
|
251
|
-
if aux_data.keys() != {'shape', 'indices', 'indptr'}:
|
252
|
-
raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
|
253
|
-
obj.__dict__.update(**aux_data)
|
254
|
-
return obj
|
255
|
-
|
256
|
-
|
257
|
-
@jax.tree_util.register_pytree_node_class
|
258
|
-
class CSC(u.sparse.SparseMatrix):
|
259
|
-
"""
|
260
|
-
Event-driven and Unit-aware CSC matrix.
|
261
|
-
"""
|
262
|
-
data: jax.Array
|
263
|
-
indices: jax.Array
|
264
|
-
indptr: jax.Array
|
265
|
-
shape: tuple[int, int]
|
266
|
-
nse = property(lambda self: self.data.size)
|
267
|
-
dtype = property(lambda self: self.data.dtype)
|
268
|
-
|
269
|
-
def __init__(self, args, *, shape):
|
270
|
-
self.data, self.indices, self.indptr = map(u.math.asarray, args)
|
271
|
-
super().__init__(args, shape=shape)
|
272
|
-
|
273
|
-
@classmethod
|
274
|
-
def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
|
275
|
-
if nse is None:
|
276
|
-
nse = (u.get_mantissa(mat) != 0).sum()
|
277
|
-
return u.sparse.csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T
|
278
|
-
|
279
|
-
@classmethod
|
280
|
-
def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
|
281
|
-
"""Create an empty CSC instance. Public method is sparse.empty()."""
|
282
|
-
shape = tuple(shape)
|
283
|
-
if len(shape) != 2:
|
284
|
-
raise ValueError(f"CSC must have ndim=2; got {shape=}")
|
285
|
-
data = jnp.empty(0, dtype)
|
286
|
-
indices = jnp.empty(0, index_dtype)
|
287
|
-
indptr = jnp.zeros(shape[1] + 1, index_dtype)
|
288
|
-
return cls((data, indices, indptr), shape=shape)
|
289
|
-
|
290
|
-
@classmethod
|
291
|
-
def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
|
292
|
-
return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T
|
293
|
-
|
294
|
-
def with_data(self, data: jax.Array | u.Quantity) -> CSC:
|
295
|
-
assert data.shape == self.data.shape
|
296
|
-
assert data.dtype == self.data.dtype
|
297
|
-
assert u.get_unit(data) == u.get_unit(self.data)
|
298
|
-
return CSC((data, self.indices, self.indptr), shape=self.shape)
|
299
|
-
|
300
|
-
def todense(self):
|
301
|
-
return u.sparse.csr_todense(self.T).T
|
302
|
-
|
303
|
-
def transpose(self, axes=None):
|
304
|
-
assert axes is None
|
305
|
-
return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
306
|
-
|
307
|
-
def __abs__(self):
|
308
|
-
return CSC((abs(self.data), self.indices, self.indptr), shape=self.shape)
|
309
|
-
|
310
|
-
def __neg__(self):
|
311
|
-
return CSC((-self.data, self.indices, self.indptr), shape=self.shape)
|
312
|
-
|
313
|
-
def __pos__(self):
|
314
|
-
return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
|
315
|
-
|
316
|
-
def _binary_op(self, other, op):
|
317
|
-
if isinstance(other, CSC):
|
318
|
-
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
|
319
|
-
return CSC(
|
320
|
-
(op(self.data, other.data),
|
321
|
-
self.indices,
|
322
|
-
self.indptr),
|
323
|
-
shape=self.shape
|
324
|
-
)
|
325
|
-
if isinstance(other, JAXSparse):
|
326
|
-
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
|
327
|
-
|
328
|
-
other = u.math.asarray(other)
|
329
|
-
if other.size == 1:
|
330
|
-
return CSC(
|
331
|
-
(op(self.data, other),
|
332
|
-
self.indices,
|
333
|
-
self.indptr),
|
334
|
-
shape=self.shape
|
335
|
-
)
|
336
|
-
elif other.ndim == 2 and other.shape == self.shape:
|
337
|
-
cols, rows = csr_to_coo(self.indices, self.indptr)
|
338
|
-
other = other[rows, cols]
|
339
|
-
return CSC(
|
340
|
-
(op(self.data, other),
|
341
|
-
self.indices,
|
342
|
-
self.indptr),
|
343
|
-
shape=self.shape
|
344
|
-
)
|
345
|
-
else:
|
346
|
-
raise NotImplementedError(f"mul with object of shape {other.shape}")
|
347
|
-
|
348
|
-
def _binary_rop(self, other, op):
|
349
|
-
if isinstance(other, CSC):
|
350
|
-
if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
|
351
|
-
return CSC(
|
352
|
-
(op(other.data, self.data),
|
353
|
-
self.indices,
|
354
|
-
self.indptr),
|
355
|
-
shape=self.shape
|
356
|
-
)
|
357
|
-
if isinstance(other, JAXSparse):
|
358
|
-
raise NotImplementedError(f"binary operation {op} between two sparse objects.")
|
359
|
-
|
360
|
-
other = u.math.asarray(other)
|
361
|
-
if other.size == 1:
|
362
|
-
return CSC(
|
363
|
-
(op(other, self.data),
|
364
|
-
self.indices,
|
365
|
-
self.indptr),
|
366
|
-
shape=self.shape
|
367
|
-
)
|
368
|
-
elif other.ndim == 2 and other.shape == self.shape:
|
369
|
-
cols, rows = csr_to_coo(self.indices, self.indptr)
|
370
|
-
other = other[rows, cols]
|
371
|
-
return CSC(
|
372
|
-
(op(other, self.data),
|
373
|
-
self.indices,
|
374
|
-
self.indptr),
|
375
|
-
shape=self.shape
|
376
|
-
)
|
377
|
-
else:
|
378
|
-
raise NotImplementedError(f"mul with object of shape {other.shape}")
|
379
|
-
|
380
|
-
def __mul__(self, other: jax.Array | u.Quantity) -> 'CSC':
|
381
|
-
return self._binary_op(other, operator.mul)
|
382
|
-
|
383
|
-
def __rmul__(self, other: jax.Array | u.Quantity) -> 'CSC':
|
384
|
-
return self._binary_rop(other, operator.mul)
|
385
|
-
|
386
|
-
def __div__(self, other: jax.Array | u.Quantity) -> CSC:
|
387
|
-
return self._binary_op(other, operator.truediv)
|
388
|
-
|
389
|
-
def __rdiv__(self, other: jax.Array | u.Quantity) -> CSC:
|
390
|
-
return self._binary_rop(other, operator.truediv)
|
391
|
-
|
392
|
-
def __truediv__(self, other) -> CSC:
|
393
|
-
return self.__div__(other)
|
394
|
-
|
395
|
-
def __rtruediv__(self, other) -> CSC:
|
396
|
-
return self.__rdiv__(other)
|
397
|
-
|
398
|
-
def __add__(self, other) -> CSC:
|
399
|
-
return self._binary_op(other, operator.add)
|
400
|
-
|
401
|
-
def __radd__(self, other) -> CSC:
|
402
|
-
return self._binary_rop(other, operator.add)
|
403
|
-
|
404
|
-
def __sub__(self, other) -> CSC:
|
405
|
-
return self._binary_op(other, operator.sub)
|
406
|
-
|
407
|
-
def __rsub__(self, other) -> CSC:
|
408
|
-
return self._binary_rop(other, operator.sub)
|
409
|
-
|
410
|
-
def __mod__(self, other) -> CSC:
|
411
|
-
return self._binary_op(other, operator.mod)
|
412
|
-
|
413
|
-
def __rmod__(self, other) -> CSC:
|
414
|
-
return self._binary_rop(other, operator.mod)
|
415
|
-
|
416
|
-
def __matmul__(self, other):
|
417
|
-
if isinstance(other, JAXSparse):
|
418
|
-
raise NotImplementedError("matmul between two sparse objects.")
|
419
|
-
other = u.math.asarray(other)
|
420
|
-
data, other = u.math.promote_dtypes(self.data, other)
|
421
|
-
if other.ndim == 1:
|
422
|
-
return _csr_matvec(
|
423
|
-
data,
|
424
|
-
self.indices,
|
425
|
-
self.indptr,
|
426
|
-
other,
|
427
|
-
shape=self.shape[::-1],
|
428
|
-
transpose=True
|
429
|
-
)
|
430
|
-
elif other.ndim == 2:
|
431
|
-
return _csr_matmat(
|
432
|
-
data,
|
433
|
-
self.indices,
|
434
|
-
self.indptr,
|
435
|
-
other,
|
436
|
-
shape=self.shape[::-1],
|
437
|
-
transpose=True
|
438
|
-
)
|
439
|
-
else:
|
440
|
-
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
441
|
-
|
442
|
-
def __rmatmul__(self, other):
|
443
|
-
if isinstance(other, JAXSparse):
|
444
|
-
raise NotImplementedError("matmul between two sparse objects.")
|
445
|
-
other = u.math.asarray(other)
|
446
|
-
data, other = u.math.promote_dtypes(self.data, other)
|
447
|
-
if other.ndim == 1:
|
448
|
-
return _csr_matvec(
|
449
|
-
data,
|
450
|
-
self.indices,
|
451
|
-
self.indptr,
|
452
|
-
other,
|
453
|
-
shape=self.shape[::-1],
|
454
|
-
transpose=False
|
455
|
-
)
|
456
|
-
elif other.ndim == 2:
|
457
|
-
other = other.T
|
458
|
-
r = _csr_matmat(
|
459
|
-
data,
|
460
|
-
self.indices,
|
461
|
-
self.indptr, other,
|
462
|
-
shape=self.shape[::-1],
|
463
|
-
transpose=False
|
464
|
-
)
|
465
|
-
return r.T
|
466
|
-
else:
|
467
|
-
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
468
|
-
|
469
|
-
def tree_flatten(self):
|
470
|
-
return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
|
471
|
-
|
472
|
-
@classmethod
|
473
|
-
def tree_unflatten(cls, aux_data, children):
|
474
|
-
obj = object.__new__(cls)
|
475
|
-
obj.data, = children
|
476
|
-
if aux_data.keys() != {'shape', 'indices', 'indptr'}:
|
477
|
-
raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
|
478
|
-
obj.__dict__.update(**aux_data)
|
479
|
-
return obj
|
480
|
-
|
481
|
-
|
482
|
-
def _csr_matvec(
|
483
|
-
data: jax.Array | u.Quantity,
|
484
|
-
indices: jax.Array,
|
485
|
-
indptr: jax.Array,
|
486
|
-
v: jax.Array | u.Quantity,
|
487
|
-
*,
|
488
|
-
shape: Shape,
|
489
|
-
transpose: bool = False,
|
490
|
-
float_as_event: bool = True,
|
491
|
-
) -> jax.Array | u.Quantity:
|
492
|
-
"""Product of CSR sparse matrix and a dense vector.
|
493
|
-
|
494
|
-
Args:
|
495
|
-
data : array of shape ``(nse,)``.
|
496
|
-
indices : array of shape ``(nse,)``
|
497
|
-
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
498
|
-
v : array of shape ``(shape[0] if transpose else shape[1],)``
|
499
|
-
and dtype ``data.dtype``
|
500
|
-
shape : length-2 tuple representing the matrix shape
|
501
|
-
transpose : boolean specifying whether to transpose the sparse matrix
|
502
|
-
before computing.
|
503
|
-
|
504
|
-
Returns:
|
505
|
-
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
506
|
-
the matrix vector product.
|
507
|
-
"""
|
508
|
-
data, unitd = u.split_mantissa_unit(data)
|
509
|
-
v, unitv = u.split_mantissa_unit(v)
|
510
|
-
res = event_csrmv_p_call(
|
511
|
-
data, indices, indptr, v,
|
512
|
-
shape=shape,
|
513
|
-
transpose=transpose,
|
514
|
-
float_as_event=float_as_event
|
515
|
-
)[0]
|
516
|
-
return u.maybe_decimal(res * (unitd * unitv))
|
517
|
-
|
518
|
-
|
519
|
-
def _csr_matmat(
|
520
|
-
data: jax.Array | u.Quantity,
|
521
|
-
indices: jax.Array,
|
522
|
-
indptr: jax.Array,
|
523
|
-
B: jax.Array | u.Quantity,
|
524
|
-
*,
|
525
|
-
shape: Shape,
|
526
|
-
transpose: bool = False,
|
527
|
-
float_as_event: bool = True,
|
528
|
-
) -> jax.Array | u.Quantity:
|
529
|
-
"""
|
530
|
-
Product of CSR sparse matrix and a dense matrix.
|
531
|
-
|
532
|
-
Args:
|
533
|
-
data : array of shape ``(nse,)``.
|
534
|
-
indices : array of shape ``(nse,)``
|
535
|
-
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
536
|
-
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
537
|
-
dtype ``data.dtype``
|
538
|
-
shape : length-2 tuple representing the matrix shape
|
539
|
-
transpose : boolean specifying whether to transpose the sparse matrix
|
540
|
-
before computing.
|
541
|
-
|
542
|
-
Returns:
|
543
|
-
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
544
|
-
representing the matrix-matrix product.
|
545
|
-
"""
|
546
|
-
data, unitd = u.split_mantissa_unit(data)
|
547
|
-
B, unitb = u.split_mantissa_unit(B)
|
548
|
-
res = event_csrmm_p_call(
|
549
|
-
data,
|
550
|
-
indices,
|
551
|
-
indptr,
|
552
|
-
B,
|
553
|
-
shape=shape,
|
554
|
-
transpose=transpose,
|
555
|
-
float_as_event=float_as_event,
|
556
|
-
)[0]
|
557
|
-
return u.maybe_decimal(res * (unitd * unitb))
|
558
|
-
|
559
|
-
|
560
|
-
Kernel = Callable
|
561
|
-
|
562
|
-
|
563
|
-
def event_csrmv_cpu_kernel_generator(
|
564
|
-
float_as_event: bool,
|
565
|
-
weight_info: jax.ShapeDtypeStruct,
|
566
|
-
spike_info: jax.ShapeDtypeStruct,
|
567
|
-
transpose: bool,
|
568
|
-
**kwargs
|
569
|
-
) -> Kernel:
|
570
|
-
import numba # pylint: disable=import-outside-toplevel
|
571
|
-
|
572
|
-
if weight_info.size == 1:
|
573
|
-
if transpose:
|
574
|
-
if spike_info.dtype == jnp.bool_:
|
575
|
-
@numba.njit(fastmath=True)
|
576
|
-
def mv(weights, indices, indptr, v, posts):
|
577
|
-
posts[:] = 0.
|
578
|
-
w = weights[()]
|
579
|
-
for i in range(v.shape[0]):
|
580
|
-
if v[i]:
|
581
|
-
for j in range(indptr[i], indptr[i + 1]):
|
582
|
-
posts[indices[j]] += w
|
583
|
-
|
584
|
-
elif float_as_event:
|
585
|
-
@numba.njit(fastmath=True)
|
586
|
-
def mv(weights, indices, indptr, v, posts):
|
587
|
-
posts[:] = 0.
|
588
|
-
w = weights[()]
|
589
|
-
for i in range(v.shape[0]):
|
590
|
-
if v[i] != 0.:
|
591
|
-
for j in range(indptr[i], indptr[i + 1]):
|
592
|
-
posts[indices[j]] += w
|
593
|
-
|
594
|
-
else:
|
595
|
-
@numba.njit(fastmath=True)
|
596
|
-
def mv(weights, indices, indptr, v, posts):
|
597
|
-
posts[:] = 0.
|
598
|
-
w = weights[()]
|
599
|
-
for i in range(v.shape[0]):
|
600
|
-
sp = v[i]
|
601
|
-
if sp != 0.:
|
602
|
-
wsp = w * sp
|
603
|
-
for j in range(indptr[i], indptr[i + 1]):
|
604
|
-
posts[indices[j]] += wsp
|
605
|
-
|
606
|
-
else:
|
607
|
-
if spike_info.dtype == jnp.bool_:
|
608
|
-
@numba.njit(fastmath=True)
|
609
|
-
def mv(weights, indices, indptr, v, posts):
|
610
|
-
w = weights[()]
|
611
|
-
for i in range(indptr.shape[0] - 1):
|
612
|
-
r = 0.
|
613
|
-
for j in range(indptr[i], indptr[i + 1]):
|
614
|
-
if v[indices[j]]:
|
615
|
-
r += w
|
616
|
-
posts[i] = r
|
617
|
-
|
618
|
-
elif float_as_event:
|
619
|
-
@numba.njit(fastmath=True)
|
620
|
-
def mv(weights, indices, indptr, v, posts):
|
621
|
-
w = weights[()]
|
622
|
-
for i in range(indptr.shape[0] - 1):
|
623
|
-
r = 0.
|
624
|
-
for j in range(indptr[i], indptr[i + 1]):
|
625
|
-
if v[indices[j]] != 0.:
|
626
|
-
r += w
|
627
|
-
posts[i] = r
|
628
|
-
|
629
|
-
else:
|
630
|
-
@numba.njit(fastmath=True)
|
631
|
-
def mv(weights, indices, indptr, v, posts):
|
632
|
-
w = weights[()]
|
633
|
-
for i in range(indptr.shape[0] - 1):
|
634
|
-
r = 0.
|
635
|
-
for j in range(indptr[i], indptr[i + 1]):
|
636
|
-
c = v[indices[j]]
|
637
|
-
if c != 0.:
|
638
|
-
r += w * c
|
639
|
-
posts[i] = r
|
640
|
-
|
641
|
-
else:
|
642
|
-
if transpose:
|
643
|
-
if spike_info.dtype == jnp.bool_:
|
644
|
-
@numba.njit(fastmath=True)
|
645
|
-
def mv(weights, indices, indptr, v, posts):
|
646
|
-
posts[:] = 0.
|
647
|
-
for i in range(v.shape[0]):
|
648
|
-
if v[i]:
|
649
|
-
for j in range(indptr[i], indptr[i + 1]):
|
650
|
-
posts[indices[j]] += weights[j]
|
651
|
-
|
652
|
-
elif float_as_event:
|
653
|
-
@numba.njit(fastmath=True)
|
654
|
-
def mv(weights, indices, indptr, v, posts):
|
655
|
-
posts[:] = 0.
|
656
|
-
for i in range(v.shape[0]):
|
657
|
-
if v[i] != 0.:
|
658
|
-
for j in range(indptr[i], indptr[i + 1]):
|
659
|
-
posts[indices[j]] += weights[j]
|
660
|
-
|
661
|
-
else:
|
662
|
-
@numba.njit(fastmath=True)
|
663
|
-
def mv(weights, indices, indptr, v, posts):
|
664
|
-
posts[:] = 0.
|
665
|
-
for i in range(v.shape[0]):
|
666
|
-
sp = v[i]
|
667
|
-
if sp != 0.:
|
668
|
-
for j in range(indptr[i], indptr[i + 1]):
|
669
|
-
posts[indices[j]] += weights[j] * sp
|
670
|
-
|
671
|
-
else:
|
672
|
-
if spike_info.dtype == jnp.bool_:
|
673
|
-
@numba.njit(fastmath=True)
|
674
|
-
def mv(weights, indices, indptr, v, posts):
|
675
|
-
for i in range(indptr.shape[0] - 1):
|
676
|
-
r = 0.
|
677
|
-
for j in range(indptr[i], indptr[i + 1]):
|
678
|
-
if v[indices[j]]:
|
679
|
-
r += weights[j]
|
680
|
-
posts[i] = r
|
681
|
-
|
682
|
-
elif float_as_event:
|
683
|
-
@numba.njit(fastmath=True)
|
684
|
-
def mv(weights, indices, indptr, v, posts):
|
685
|
-
for i in range(indptr.shape[0] - 1):
|
686
|
-
r = 0.
|
687
|
-
for j in range(indptr[i], indptr[i + 1]):
|
688
|
-
if v[indices[j]] != 0.:
|
689
|
-
r += weights[j]
|
690
|
-
posts[i] = r
|
691
|
-
|
692
|
-
else:
|
693
|
-
@numba.njit(fastmath=True)
|
694
|
-
def mv(weights, indices, indptr, v, posts):
|
695
|
-
for i in range(indptr.shape[0] - 1):
|
696
|
-
r = 0.
|
697
|
-
for j in range(indptr[i], indptr[i + 1]):
|
698
|
-
c = v[indices[j]]
|
699
|
-
if c != 0.:
|
700
|
-
r += weights[j] * c
|
701
|
-
posts[i] = r
|
702
|
-
|
703
|
-
return mv
|
704
|
-
|
705
|
-
|
706
|
-
def event_csrmv_jvp_v(
|
707
|
-
v_dot,
|
708
|
-
data,
|
709
|
-
indices,
|
710
|
-
indptr,
|
711
|
-
v,
|
712
|
-
*,
|
713
|
-
shape,
|
714
|
-
transpose,
|
715
|
-
**kwargs
|
716
|
-
):
|
717
|
-
return [
|
718
|
-
csr_matvec(
|
719
|
-
data,
|
720
|
-
indices,
|
721
|
-
indptr,
|
722
|
-
v_dot,
|
723
|
-
shape=shape,
|
724
|
-
transpose=transpose
|
725
|
-
)
|
726
|
-
]
|
727
|
-
|
728
|
-
|
729
|
-
def event_csrmv_jvp_weights(
|
730
|
-
data_dot,
|
731
|
-
data,
|
732
|
-
indices,
|
733
|
-
indptr,
|
734
|
-
v,
|
735
|
-
*,
|
736
|
-
shape,
|
737
|
-
transpose,
|
738
|
-
float_as_event,
|
739
|
-
**kwargs
|
740
|
-
):
|
741
|
-
return event_csrmv_p_call(
|
742
|
-
data_dot,
|
743
|
-
indices,
|
744
|
-
indptr,
|
745
|
-
v,
|
746
|
-
shape=shape,
|
747
|
-
transpose=transpose,
|
748
|
-
float_as_event=float_as_event,
|
749
|
-
)
|
750
|
-
|
751
|
-
|
752
|
-
def event_csrmv_transpose_rule(
|
753
|
-
ct,
|
754
|
-
data,
|
755
|
-
indices,
|
756
|
-
indptr,
|
757
|
-
events,
|
758
|
-
*,
|
759
|
-
shape,
|
760
|
-
float_as_event,
|
761
|
-
transpose,
|
762
|
-
**kwargs
|
763
|
-
):
|
764
|
-
if ad.is_undefined_primal(indices):
|
765
|
-
raise ValueError("Cannot transpose with respect to sparse indices.")
|
766
|
-
|
767
|
-
ct = ct[0]
|
768
|
-
|
769
|
-
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
|
770
|
-
raise ValueError("Cannot transpose with respect to sparse indices.")
|
771
|
-
if ad.is_undefined_primal(events):
|
772
|
-
ct_events = csr_matvec(
|
773
|
-
data,
|
774
|
-
indices,
|
775
|
-
indptr,
|
776
|
-
ct,
|
777
|
-
shape=shape,
|
778
|
-
transpose=not transpose
|
779
|
-
)[0]
|
780
|
-
return data, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events)
|
781
|
-
else:
|
782
|
-
if type(ct[0]) is ad.Zero:
|
783
|
-
ct_values = ad.Zero(data)
|
784
|
-
else:
|
785
|
-
if data.aval.shape[0] == 1: # scalar
|
786
|
-
ct_values = event_csrmv_p_call(
|
787
|
-
jnp.ones(1, dtype=data.dtype),
|
788
|
-
indices,
|
789
|
-
indptr,
|
790
|
-
events,
|
791
|
-
shape=shape,
|
792
|
-
transpose=transpose,
|
793
|
-
float_as_event=float_as_event,
|
794
|
-
)[0]
|
795
|
-
ct_values = jnp.inner(ct, ct_values)
|
796
|
-
else: # heterogeneous values
|
797
|
-
row, col = csr_to_coo(indices, indptr)
|
798
|
-
ct_values = events[row] * ct[col] if transpose else events[col] * ct[row]
|
799
|
-
return ct_values, indices, indptr, events
|
800
|
-
|
801
|
-
|
802
|
-
def event_csrmv_batching(args, axes, **kwargs):
|
803
|
-
if tuple(axes) == (None, None, None, 0):
|
804
|
-
assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
|
805
|
-
r = event_csrmm_p_call(
|
806
|
-
args[0],
|
807
|
-
args[1],
|
808
|
-
args[2],
|
809
|
-
args[3].T,
|
810
|
-
shape=kwargs['shape'],
|
811
|
-
transpose=kwargs['transpose'],
|
812
|
-
float_as_event=kwargs['float_as_event']
|
813
|
-
)
|
814
|
-
return r, [1]
|
815
|
-
|
816
|
-
elif tuple(axes) == (None, None, None, 1):
|
817
|
-
assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
|
818
|
-
r = event_csrmm_p_call(
|
819
|
-
args[0],
|
820
|
-
args[1],
|
821
|
-
args[2],
|
822
|
-
args[3],
|
823
|
-
shape=kwargs['shape'],
|
824
|
-
transpose=kwargs['transpose'],
|
825
|
-
float_as_event=kwargs['float_as_event']
|
826
|
-
)
|
827
|
-
return r, [1]
|
828
|
-
|
829
|
-
else:
|
830
|
-
raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
|
831
|
-
|
832
|
-
|
833
|
-
event_csrmv_p = XLACustomOp(
|
834
|
-
'event_csrmv',
|
835
|
-
cpu_kernel_or_generator=event_csrmv_cpu_kernel_generator,
|
836
|
-
)
|
837
|
-
event_csrmv_p.defjvp(event_csrmv_jvp_weights, None, None, event_csrmv_jvp_v)
|
838
|
-
event_csrmv_p.def_transpose_rule(event_csrmv_transpose_rule)
|
839
|
-
event_csrmv_p.def_batching_rule(event_csrmv_batching)
|
840
|
-
|
841
|
-
|
842
|
-
def event_csrmv_p_call(
|
843
|
-
weights,
|
844
|
-
indices,
|
845
|
-
indptr,
|
846
|
-
v,
|
847
|
-
*,
|
848
|
-
shape: Shape,
|
849
|
-
transpose: bool,
|
850
|
-
float_as_event: bool,
|
851
|
-
):
|
852
|
-
if jax.default_backend() == 'cpu':
|
853
|
-
return event_csrmv_p(
|
854
|
-
weights,
|
855
|
-
indices,
|
856
|
-
indptr,
|
857
|
-
v,
|
858
|
-
outs=[
|
859
|
-
jax.ShapeDtypeStruct([shape[1]], weights.dtype)
|
860
|
-
if transpose else
|
861
|
-
jax.ShapeDtypeStruct([shape[0]], weights.dtype),
|
862
|
-
],
|
863
|
-
# block_size=block_size,
|
864
|
-
float_as_event=float_as_event,
|
865
|
-
shape=shape,
|
866
|
-
transpose=transpose,
|
867
|
-
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
868
|
-
spike_info=jax.ShapeDtypeStruct(v.shape, v.dtype),
|
869
|
-
)
|
870
|
-
else:
|
871
|
-
return [
|
872
|
-
csr_matvec(
|
873
|
-
weights,
|
874
|
-
indices,
|
875
|
-
indptr,
|
876
|
-
v,
|
877
|
-
shape=shape,
|
878
|
-
transpose=transpose
|
879
|
-
)
|
880
|
-
]
|
881
|
-
|
882
|
-
|
883
|
-
def event_csrmm_batching(args, axes, **kwargs):
|
884
|
-
if tuple(axes) == (None, None, None, 0):
|
885
|
-
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
886
|
-
batch_size, m, n = args[3].shape
|
887
|
-
B = jnp.transpose(args[3], (1, 0, 2)).reshape(m, batch_size * n)
|
888
|
-
r = event_csrmm_p_call(
|
889
|
-
args[0],
|
890
|
-
args[1],
|
891
|
-
args[2],
|
892
|
-
B,
|
893
|
-
shape=kwargs['shape'],
|
894
|
-
transpose=kwargs['transpose'],
|
895
|
-
float_as_event=kwargs['float_as_event']
|
896
|
-
)
|
897
|
-
r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
|
898
|
-
return [r], [1]
|
899
|
-
|
900
|
-
elif tuple(axes) == (None, None, None, 1):
|
901
|
-
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
902
|
-
m, batch_size, n = args[3].shape
|
903
|
-
B = args[3].reshape(m, batch_size * n)
|
904
|
-
r = event_csrmm_p_call(
|
905
|
-
args[0],
|
906
|
-
args[1],
|
907
|
-
args[2],
|
908
|
-
B,
|
909
|
-
shape=kwargs['shape'],
|
910
|
-
transpose=kwargs['transpose'],
|
911
|
-
float_as_event=kwargs['float_as_event']
|
912
|
-
)
|
913
|
-
r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
|
914
|
-
return [r], [1]
|
915
|
-
|
916
|
-
elif tuple(axes) == (None, None, None, 2):
|
917
|
-
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
918
|
-
m, n, batch_size = args[3].shape
|
919
|
-
B = args[3].reshape(m, batch_size * n)
|
920
|
-
r = event_csrmm_p_call(
|
921
|
-
args[0],
|
922
|
-
args[1],
|
923
|
-
args[2],
|
924
|
-
B,
|
925
|
-
shape=kwargs['shape'],
|
926
|
-
transpose=kwargs['transpose'],
|
927
|
-
float_as_event=kwargs['float_as_event']
|
928
|
-
)
|
929
|
-
r = jnp.reshape(r[0], [r[0].shape[0], n, batch_size])
|
930
|
-
return [r], [2]
|
931
|
-
|
932
|
-
else:
|
933
|
-
raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
|
934
|
-
|
935
|
-
|
936
|
-
def event_csrmm_cpu_kernel_generator(
|
937
|
-
float_as_event: bool,
|
938
|
-
weight_info: jax.ShapeDtypeStruct,
|
939
|
-
spike_info: jax.ShapeDtypeStruct,
|
940
|
-
transpose: bool,
|
941
|
-
**kwargs
|
942
|
-
) -> Kernel:
|
943
|
-
import numba # pylint: disable=import-outside-toplevel
|
944
|
-
|
945
|
-
if weight_info.size == 1:
|
946
|
-
if transpose:
|
947
|
-
# csr.T @ B
|
948
|
-
|
949
|
-
if spike_info.dtype == jnp.bool_:
|
950
|
-
@numba.njit(fastmath=True, parallel=False)
|
951
|
-
def mv(weights, indices, indptr, B, posts):
|
952
|
-
posts[:] = 0.
|
953
|
-
w = weights[()]
|
954
|
-
for k in numba.prange(B.shape[1]):
|
955
|
-
for i in range(B.shape[0]):
|
956
|
-
if B[i, k]:
|
957
|
-
for j in range(indptr[i], indptr[i + 1]):
|
958
|
-
posts[indices[j], k] += w
|
959
|
-
|
960
|
-
elif float_as_event:
|
961
|
-
@numba.njit(fastmath=True, parallel=False)
|
962
|
-
def mv(weights, indices, indptr, B, posts):
|
963
|
-
posts[:] = 0.
|
964
|
-
B = B != 0.
|
965
|
-
w = weights[()]
|
966
|
-
for k in numba.prange(B.shape[1]):
|
967
|
-
for i in range(B.shape[0]):
|
968
|
-
if B[i, k]:
|
969
|
-
for j in range(indptr[i], indptr[i + 1]):
|
970
|
-
posts[indices[j], k] += w
|
971
|
-
|
972
|
-
else:
|
973
|
-
@numba.njit(fastmath=True, parallel=False)
|
974
|
-
def mv(weights, indices, indptr, B, posts):
|
975
|
-
posts[:] = 0.
|
976
|
-
w = weights[()]
|
977
|
-
for k in numba.prange(B.shape[1]):
|
978
|
-
for i in range(B.shape[0]):
|
979
|
-
sp = B[i, k]
|
980
|
-
if sp != 0.:
|
981
|
-
wsp = w * sp
|
982
|
-
for j in range(indptr[i], indptr[i + 1]):
|
983
|
-
posts[indices[j], k] += wsp
|
984
|
-
|
985
|
-
else:
|
986
|
-
# csr @ B
|
987
|
-
if spike_info.dtype == jnp.bool_:
|
988
|
-
@numba.njit(fastmath=True)
|
989
|
-
def mv(weights, indices, indptr, B, posts):
|
990
|
-
w = weights[()]
|
991
|
-
for i in range(indptr.shape[0] - 1):
|
992
|
-
r = np.zeros(B.shape[1], dtype=weights.dtype)
|
993
|
-
for j in range(indptr[i], indptr[i + 1]):
|
994
|
-
index = indices[j]
|
995
|
-
for k in range(B.shape[1]):
|
996
|
-
if B[index, k]:
|
997
|
-
r[k] += w
|
998
|
-
posts[i] = r
|
999
|
-
|
1000
|
-
elif float_as_event:
|
1001
|
-
@numba.njit(fastmath=True)
|
1002
|
-
def mv(weights, indices, indptr, B, posts):
|
1003
|
-
w = weights[()]
|
1004
|
-
B = B != 0.
|
1005
|
-
for i in range(indptr.shape[0] - 1):
|
1006
|
-
r = np.zeros(B.shape[1], dtype=weights.dtype)
|
1007
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1008
|
-
index = indices[j]
|
1009
|
-
for k in range(B.shape[1]):
|
1010
|
-
if B[index, k]:
|
1011
|
-
r[k] += w
|
1012
|
-
posts[i] = r
|
1013
|
-
|
1014
|
-
else:
|
1015
|
-
@numba.njit(fastmath=True)
|
1016
|
-
def mv(weights, indices, indptr, B, posts):
|
1017
|
-
w = weights[()]
|
1018
|
-
for i in range(indptr.shape[0] - 1):
|
1019
|
-
for k in range(B.shape[1]):
|
1020
|
-
r = 0.
|
1021
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1022
|
-
c = B[indices[j], k]
|
1023
|
-
if c != 0.:
|
1024
|
-
r += w * c
|
1025
|
-
posts[i, k] = r
|
1026
|
-
|
1027
|
-
else:
|
1028
|
-
if transpose:
|
1029
|
-
# csr.T @ B
|
1030
|
-
|
1031
|
-
if spike_info.dtype == jnp.bool_:
|
1032
|
-
@numba.njit(fastmath=True, parallel=False)
|
1033
|
-
def mv(weights, indices, indptr, B, posts):
|
1034
|
-
posts[:] = 0.
|
1035
|
-
for k in numba.prange(B.shape[1]):
|
1036
|
-
for i in range(B.shape[0]):
|
1037
|
-
if B[i, k]:
|
1038
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1039
|
-
posts[indices[j], k] += weights[j]
|
1040
|
-
|
1041
|
-
elif float_as_event:
|
1042
|
-
@numba.njit(fastmath=True, parallel=False)
|
1043
|
-
def mv(weights, indices, indptr, B, posts):
|
1044
|
-
posts[:] = 0.
|
1045
|
-
B = B != 0.
|
1046
|
-
for k in numba.prange(B.shape[1]):
|
1047
|
-
for i in range(B.shape[0]):
|
1048
|
-
if B[i, k]:
|
1049
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1050
|
-
posts[indices[j], k] += weights[j]
|
1051
|
-
|
1052
|
-
else:
|
1053
|
-
@numba.njit(fastmath=True, parallel=False)
|
1054
|
-
def mv(weights, indices, indptr, B, posts):
|
1055
|
-
posts[:] = 0.
|
1056
|
-
for k in numba.prange(B.shape[1]):
|
1057
|
-
for i in range(B.shape[0]):
|
1058
|
-
sp = B[i, k]
|
1059
|
-
if sp != 0.:
|
1060
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1061
|
-
posts[indices[j], k] += weights[j] * sp
|
1062
|
-
|
1063
|
-
else:
|
1064
|
-
# csr @ B
|
1065
|
-
|
1066
|
-
if spike_info.dtype == jnp.bool_:
|
1067
|
-
@numba.njit(fastmath=True)
|
1068
|
-
def mv(weights, indices, indptr, B, posts):
|
1069
|
-
for i in range(indptr.shape[0] - 1):
|
1070
|
-
for k in range(B.shape[1]):
|
1071
|
-
r = 0.
|
1072
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1073
|
-
if B[indices[j], k]:
|
1074
|
-
r += weights[j]
|
1075
|
-
posts[i, k] = r
|
1076
|
-
|
1077
|
-
elif float_as_event:
|
1078
|
-
@numba.njit(fastmath=True)
|
1079
|
-
def mv(weights, indices, indptr, B, posts):
|
1080
|
-
B = B != 0.
|
1081
|
-
for i in range(indptr.shape[0] - 1):
|
1082
|
-
for k in range(B.shape[1]):
|
1083
|
-
r = 0.
|
1084
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1085
|
-
if B[indices[j], k]:
|
1086
|
-
r += weights[j]
|
1087
|
-
posts[i, k] = r
|
1088
|
-
|
1089
|
-
else:
|
1090
|
-
@numba.njit(fastmath=True)
|
1091
|
-
def mv(weights, indices, indptr, B, posts):
|
1092
|
-
for i in range(indptr.shape[0] - 1):
|
1093
|
-
for k in range(B.shape[1]):
|
1094
|
-
r = 0.
|
1095
|
-
for j in range(indptr[i], indptr[i + 1]):
|
1096
|
-
c = B[indices[j], k]
|
1097
|
-
if c != 0.:
|
1098
|
-
r += weights[j] * c
|
1099
|
-
posts[i, k] = r
|
1100
|
-
|
1101
|
-
return mv
|
1102
|
-
|
1103
|
-
|
1104
|
-
event_csrmm_p = XLACustomOp(
|
1105
|
-
'event_csrmm',
|
1106
|
-
cpu_kernel_or_generator=event_csrmm_cpu_kernel_generator,
|
1107
|
-
)
|
1108
|
-
event_csrmm_p.def_batching_rule(event_csrmm_batching)
|
1109
|
-
|
1110
|
-
|
1111
|
-
def event_csrmm_p_call(
|
1112
|
-
weights,
|
1113
|
-
indices,
|
1114
|
-
indptr,
|
1115
|
-
B,
|
1116
|
-
*,
|
1117
|
-
shape: Shape,
|
1118
|
-
transpose: bool,
|
1119
|
-
float_as_event: bool,
|
1120
|
-
):
|
1121
|
-
if jax.default_backend() == 'cpu':
|
1122
|
-
return event_csrmm_p(
|
1123
|
-
weights,
|
1124
|
-
indices,
|
1125
|
-
indptr,
|
1126
|
-
B,
|
1127
|
-
outs=[
|
1128
|
-
jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype)
|
1129
|
-
if transpose else
|
1130
|
-
jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype),
|
1131
|
-
],
|
1132
|
-
# block_size=block_size,
|
1133
|
-
shape=shape,
|
1134
|
-
transpose=transpose,
|
1135
|
-
float_as_event=float_as_event,
|
1136
|
-
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
1137
|
-
spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),
|
1138
|
-
)
|
1139
|
-
else:
|
1140
|
-
return [
|
1141
|
-
csr_matmat(
|
1142
|
-
weights,
|
1143
|
-
indices,
|
1144
|
-
indptr,
|
1145
|
-
B,
|
1146
|
-
shape=shape,
|
1147
|
-
transpose=transpose
|
1148
|
-
)
|
1149
|
-
]
|