brainstate 0.1.0.post20250120__py2.py3-none-any.whl → 0.1.0.post20250127__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/augment/__init__.py +10 -20
  3. brainstate/compile/__init__.py +18 -37
  4. brainstate/compile/_make_jaxpr.py +9 -2
  5. brainstate/compile/_make_jaxpr_test.py +10 -6
  6. brainstate/compile/_progress_bar.py +49 -6
  7. brainstate/compile/_unvmap.py +3 -3
  8. brainstate/graph/__init__.py +12 -12
  9. brainstate/nn/_dyn_impl/_inputs.py +4 -2
  10. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  11. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/METADATA +1 -1
  12. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/RECORD +15 -29
  13. brainstate/event/__init__.py +0 -27
  14. brainstate/event/_csr.py +0 -1149
  15. brainstate/event/_csr_benchmark.py +0 -14
  16. brainstate/event/_csr_mv.py +0 -303
  17. brainstate/event/_csr_test.py +0 -277
  18. brainstate/event/_fixedprob_mv.py +0 -730
  19. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  20. brainstate/event/_fixedprob_mv_test.py +0 -132
  21. brainstate/event/_linear_mv.py +0 -359
  22. brainstate/event/_linear_mv_benckmark.py +0 -82
  23. brainstate/event/_linear_mv_test.py +0 -117
  24. brainstate/event/_misc.py +0 -34
  25. brainstate/event/_xla_custom_op.py +0 -317
  26. brainstate/event/_xla_custom_op_test.py +0 -55
  27. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/LICENSE +0 -0
  28. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/WHEEL +0 -0
  29. {brainstate-0.1.0.post20250120.dist-info → brainstate-0.1.0.post20250127.dist-info}/top_level.txt +0 -0
brainstate/event/_csr.py DELETED
@@ -1,1149 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from __future__ import annotations
18
-
19
- import operator
20
- from typing import Callable
21
-
22
- import brainunit as u
23
- import jax
24
- import jax.numpy as jnp
25
- import numpy as np
26
- from brainunit.sparse._csr import (
27
- _csr_matvec as csr_matvec,
28
- _csr_matmat as csr_matmat,
29
- _csr_to_coo as csr_to_coo
30
- )
31
- from jax.experimental.sparse import JAXSparse
32
- from jax.interpreters import ad
33
-
34
- from brainstate.typing import Shape
35
- from ._xla_custom_op import XLACustomOp
36
-
37
- __all__ = [
38
- 'CSR',
39
- 'CSC',
40
- ]
41
-
42
-
43
- @jax.tree_util.register_pytree_node_class
44
- class CSR(u.sparse.SparseMatrix):
45
- """
46
- Event-driven and Unit-aware CSR matrix.
47
- """
48
- data: jax.Array | u.Quantity
49
- indices: jax.Array
50
- indptr: jax.Array
51
- shape: tuple[int, int]
52
- nse = property(lambda self: self.data.size)
53
- dtype = property(lambda self: self.data.dtype)
54
- _bufs = property(lambda self: (self.data, self.indices, self.indptr))
55
-
56
- def __init__(self, args, *, shape):
57
- self.data, self.indices, self.indptr = map(u.math.asarray, args)
58
- super().__init__(args, shape=shape)
59
-
60
- @classmethod
61
- def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
62
- if nse is None:
63
- nse = (u.get_mantissa(mat) != 0).sum()
64
- return u.sparse.csr_fromdense(mat, nse=nse, index_dtype=index_dtype)
65
-
66
- def with_data(self, data: jax.Array | u.Quantity) -> CSR:
67
- assert data.shape == self.data.shape
68
- assert data.dtype == self.data.dtype
69
- assert u.get_unit(data) == u.get_unit(self.data)
70
- return CSR((data, self.indices, self.indptr), shape=self.shape)
71
-
72
- def todense(self):
73
- return u.sparse.csr_todense(self)
74
-
75
- def transpose(self, axes=None):
76
- assert axes is None, "transpose does not support axes argument."
77
- return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
78
-
79
- def __abs__(self):
80
- return CSR((abs(self.data), self.indices, self.indptr), shape=self.shape)
81
-
82
- def __neg__(self):
83
- return CSR((-self.data, self.indices, self.indptr), shape=self.shape)
84
-
85
- def __pos__(self):
86
- return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
87
-
88
- def _binary_op(self, other, op):
89
- if isinstance(other, CSR):
90
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
91
- return CSR(
92
- (op(self.data, other.data),
93
- self.indices,
94
- self.indptr),
95
- shape=self.shape
96
- )
97
- if isinstance(other, JAXSparse):
98
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
99
-
100
- other = u.math.asarray(other)
101
- if other.size == 1:
102
- return CSR(
103
- (op(self.data, other), self.indices, self.indptr),
104
- shape=self.shape
105
- )
106
-
107
- elif other.ndim == 2 and other.shape == self.shape:
108
- rows, cols = csr_to_coo(self.indices, self.indptr)
109
- other = other[rows, cols]
110
- return CSR(
111
- (op(self.data, other),
112
- self.indices,
113
- self.indptr),
114
- shape=self.shape
115
- )
116
-
117
- else:
118
- raise NotImplementedError(f"mul with object of shape {other.shape}")
119
-
120
- def _binary_rop(self, other, op):
121
- if isinstance(other, CSR):
122
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
123
- return CSR(
124
- (op(other.data, self.data),
125
- self.indices,
126
- self.indptr),
127
- shape=self.shape
128
- )
129
- if isinstance(other, JAXSparse):
130
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
131
-
132
- other = u.math.asarray(other)
133
- if other.size == 1:
134
- return CSR(
135
- (op(other, self.data),
136
- self.indices,
137
- self.indptr),
138
- shape=self.shape
139
- )
140
- elif other.ndim == 2 and other.shape == self.shape:
141
- rows, cols = csr_to_coo(self.indices, self.indptr)
142
- other = other[rows, cols]
143
- return CSR(
144
- (op(other, self.data),
145
- self.indices,
146
- self.indptr),
147
- shape=self.shape
148
- )
149
- else:
150
- raise NotImplementedError(f"mul with object of shape {other.shape}")
151
-
152
- def __mul__(self, other: jax.Array | u.Quantity) -> CSR:
153
- return self._binary_op(other, operator.mul)
154
-
155
- def __rmul__(self, other: jax.Array | u.Quantity) -> CSR:
156
- return self._binary_rop(other, operator.mul)
157
-
158
- def __div__(self, other: jax.Array | u.Quantity) -> CSR:
159
- return self._binary_op(other, operator.truediv)
160
-
161
- def __rdiv__(self, other: jax.Array | u.Quantity) -> CSR:
162
- return self._binary_rop(other, operator.truediv)
163
-
164
- def __truediv__(self, other) -> CSR:
165
- return self.__div__(other)
166
-
167
- def __rtruediv__(self, other) -> CSR:
168
- return self.__rdiv__(other)
169
-
170
- def __add__(self, other) -> CSR:
171
- return self._binary_op(other, operator.add)
172
-
173
- def __radd__(self, other) -> CSR:
174
- return self._binary_rop(other, operator.add)
175
-
176
- def __sub__(self, other) -> CSR:
177
- return self._binary_op(other, operator.sub)
178
-
179
- def __rsub__(self, other) -> CSR:
180
- return self._binary_rop(other, operator.sub)
181
-
182
- def __mod__(self, other) -> CSR:
183
- return self._binary_op(other, operator.mod)
184
-
185
- def __rmod__(self, other) -> CSR:
186
- return self._binary_rop(other, operator.mod)
187
-
188
- def __matmul__(self, other):
189
- # csr @ other
190
- if isinstance(other, JAXSparse):
191
- raise NotImplementedError("matmul between two sparse objects.")
192
- other = u.math.asarray(other)
193
- data = self.data
194
- # data, other = u.math.promote_dtypes(self.data, other)
195
- if other.ndim == 1:
196
- return _csr_matvec(
197
- data,
198
- self.indices,
199
- self.indptr,
200
- other,
201
- shape=self.shape
202
- )
203
- elif other.ndim == 2:
204
- return _csr_matmat(
205
- data,
206
- self.indices,
207
- self.indptr,
208
- other,
209
- shape=self.shape
210
- )
211
- else:
212
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
213
-
214
- def __rmatmul__(self, other):
215
- # other @ csr
216
- if isinstance(other, JAXSparse):
217
- raise NotImplementedError("matmul between two sparse objects.")
218
- other = u.math.asarray(other)
219
- data = self.data
220
- # data, other = u.math.promote_dtypes(self.data, other)
221
- if other.ndim == 1:
222
- return _csr_matvec(
223
- data,
224
- self.indices,
225
- self.indptr,
226
- other,
227
- shape=self.shape,
228
- transpose=True
229
- )
230
- elif other.ndim == 2:
231
- other = other.T
232
- r = _csr_matmat(
233
- data,
234
- self.indices,
235
- self.indptr,
236
- other,
237
- shape=self.shape,
238
- transpose=True
239
- )
240
- return r.T
241
- else:
242
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
243
-
244
- def tree_flatten(self):
245
- return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
246
-
247
- @classmethod
248
- def tree_unflatten(cls, aux_data, children):
249
- obj = object.__new__(cls)
250
- obj.data, = children
251
- if aux_data.keys() != {'shape', 'indices', 'indptr'}:
252
- raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
253
- obj.__dict__.update(**aux_data)
254
- return obj
255
-
256
-
257
- @jax.tree_util.register_pytree_node_class
258
- class CSC(u.sparse.SparseMatrix):
259
- """
260
- Event-driven and Unit-aware CSC matrix.
261
- """
262
- data: jax.Array
263
- indices: jax.Array
264
- indptr: jax.Array
265
- shape: tuple[int, int]
266
- nse = property(lambda self: self.data.size)
267
- dtype = property(lambda self: self.data.dtype)
268
-
269
- def __init__(self, args, *, shape):
270
- self.data, self.indices, self.indptr = map(u.math.asarray, args)
271
- super().__init__(args, shape=shape)
272
-
273
- @classmethod
274
- def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
275
- if nse is None:
276
- nse = (u.get_mantissa(mat) != 0).sum()
277
- return u.sparse.csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T
278
-
279
- @classmethod
280
- def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
281
- """Create an empty CSC instance. Public method is sparse.empty()."""
282
- shape = tuple(shape)
283
- if len(shape) != 2:
284
- raise ValueError(f"CSC must have ndim=2; got {shape=}")
285
- data = jnp.empty(0, dtype)
286
- indices = jnp.empty(0, index_dtype)
287
- indptr = jnp.zeros(shape[1] + 1, index_dtype)
288
- return cls((data, indices, indptr), shape=shape)
289
-
290
- @classmethod
291
- def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
292
- return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T
293
-
294
- def with_data(self, data: jax.Array | u.Quantity) -> CSC:
295
- assert data.shape == self.data.shape
296
- assert data.dtype == self.data.dtype
297
- assert u.get_unit(data) == u.get_unit(self.data)
298
- return CSC((data, self.indices, self.indptr), shape=self.shape)
299
-
300
- def todense(self):
301
- return u.sparse.csr_todense(self.T).T
302
-
303
- def transpose(self, axes=None):
304
- assert axes is None
305
- return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
306
-
307
- def __abs__(self):
308
- return CSC((abs(self.data), self.indices, self.indptr), shape=self.shape)
309
-
310
- def __neg__(self):
311
- return CSC((-self.data, self.indices, self.indptr), shape=self.shape)
312
-
313
- def __pos__(self):
314
- return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
315
-
316
- def _binary_op(self, other, op):
317
- if isinstance(other, CSC):
318
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
319
- return CSC(
320
- (op(self.data, other.data),
321
- self.indices,
322
- self.indptr),
323
- shape=self.shape
324
- )
325
- if isinstance(other, JAXSparse):
326
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
327
-
328
- other = u.math.asarray(other)
329
- if other.size == 1:
330
- return CSC(
331
- (op(self.data, other),
332
- self.indices,
333
- self.indptr),
334
- shape=self.shape
335
- )
336
- elif other.ndim == 2 and other.shape == self.shape:
337
- cols, rows = csr_to_coo(self.indices, self.indptr)
338
- other = other[rows, cols]
339
- return CSC(
340
- (op(self.data, other),
341
- self.indices,
342
- self.indptr),
343
- shape=self.shape
344
- )
345
- else:
346
- raise NotImplementedError(f"mul with object of shape {other.shape}")
347
-
348
- def _binary_rop(self, other, op):
349
- if isinstance(other, CSC):
350
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
351
- return CSC(
352
- (op(other.data, self.data),
353
- self.indices,
354
- self.indptr),
355
- shape=self.shape
356
- )
357
- if isinstance(other, JAXSparse):
358
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
359
-
360
- other = u.math.asarray(other)
361
- if other.size == 1:
362
- return CSC(
363
- (op(other, self.data),
364
- self.indices,
365
- self.indptr),
366
- shape=self.shape
367
- )
368
- elif other.ndim == 2 and other.shape == self.shape:
369
- cols, rows = csr_to_coo(self.indices, self.indptr)
370
- other = other[rows, cols]
371
- return CSC(
372
- (op(other, self.data),
373
- self.indices,
374
- self.indptr),
375
- shape=self.shape
376
- )
377
- else:
378
- raise NotImplementedError(f"mul with object of shape {other.shape}")
379
-
380
- def __mul__(self, other: jax.Array | u.Quantity) -> 'CSC':
381
- return self._binary_op(other, operator.mul)
382
-
383
- def __rmul__(self, other: jax.Array | u.Quantity) -> 'CSC':
384
- return self._binary_rop(other, operator.mul)
385
-
386
- def __div__(self, other: jax.Array | u.Quantity) -> CSC:
387
- return self._binary_op(other, operator.truediv)
388
-
389
- def __rdiv__(self, other: jax.Array | u.Quantity) -> CSC:
390
- return self._binary_rop(other, operator.truediv)
391
-
392
- def __truediv__(self, other) -> CSC:
393
- return self.__div__(other)
394
-
395
- def __rtruediv__(self, other) -> CSC:
396
- return self.__rdiv__(other)
397
-
398
- def __add__(self, other) -> CSC:
399
- return self._binary_op(other, operator.add)
400
-
401
- def __radd__(self, other) -> CSC:
402
- return self._binary_rop(other, operator.add)
403
-
404
- def __sub__(self, other) -> CSC:
405
- return self._binary_op(other, operator.sub)
406
-
407
- def __rsub__(self, other) -> CSC:
408
- return self._binary_rop(other, operator.sub)
409
-
410
- def __mod__(self, other) -> CSC:
411
- return self._binary_op(other, operator.mod)
412
-
413
- def __rmod__(self, other) -> CSC:
414
- return self._binary_rop(other, operator.mod)
415
-
416
- def __matmul__(self, other):
417
- if isinstance(other, JAXSparse):
418
- raise NotImplementedError("matmul between two sparse objects.")
419
- other = u.math.asarray(other)
420
- data, other = u.math.promote_dtypes(self.data, other)
421
- if other.ndim == 1:
422
- return _csr_matvec(
423
- data,
424
- self.indices,
425
- self.indptr,
426
- other,
427
- shape=self.shape[::-1],
428
- transpose=True
429
- )
430
- elif other.ndim == 2:
431
- return _csr_matmat(
432
- data,
433
- self.indices,
434
- self.indptr,
435
- other,
436
- shape=self.shape[::-1],
437
- transpose=True
438
- )
439
- else:
440
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
441
-
442
- def __rmatmul__(self, other):
443
- if isinstance(other, JAXSparse):
444
- raise NotImplementedError("matmul between two sparse objects.")
445
- other = u.math.asarray(other)
446
- data, other = u.math.promote_dtypes(self.data, other)
447
- if other.ndim == 1:
448
- return _csr_matvec(
449
- data,
450
- self.indices,
451
- self.indptr,
452
- other,
453
- shape=self.shape[::-1],
454
- transpose=False
455
- )
456
- elif other.ndim == 2:
457
- other = other.T
458
- r = _csr_matmat(
459
- data,
460
- self.indices,
461
- self.indptr, other,
462
- shape=self.shape[::-1],
463
- transpose=False
464
- )
465
- return r.T
466
- else:
467
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
468
-
469
- def tree_flatten(self):
470
- return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
471
-
472
- @classmethod
473
- def tree_unflatten(cls, aux_data, children):
474
- obj = object.__new__(cls)
475
- obj.data, = children
476
- if aux_data.keys() != {'shape', 'indices', 'indptr'}:
477
- raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
478
- obj.__dict__.update(**aux_data)
479
- return obj
480
-
481
-
482
- def _csr_matvec(
483
- data: jax.Array | u.Quantity,
484
- indices: jax.Array,
485
- indptr: jax.Array,
486
- v: jax.Array | u.Quantity,
487
- *,
488
- shape: Shape,
489
- transpose: bool = False,
490
- float_as_event: bool = True,
491
- ) -> jax.Array | u.Quantity:
492
- """Product of CSR sparse matrix and a dense vector.
493
-
494
- Args:
495
- data : array of shape ``(nse,)``.
496
- indices : array of shape ``(nse,)``
497
- indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
498
- v : array of shape ``(shape[0] if transpose else shape[1],)``
499
- and dtype ``data.dtype``
500
- shape : length-2 tuple representing the matrix shape
501
- transpose : boolean specifying whether to transpose the sparse matrix
502
- before computing.
503
-
504
- Returns:
505
- y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
506
- the matrix vector product.
507
- """
508
- data, unitd = u.split_mantissa_unit(data)
509
- v, unitv = u.split_mantissa_unit(v)
510
- res = event_csrmv_p_call(
511
- data, indices, indptr, v,
512
- shape=shape,
513
- transpose=transpose,
514
- float_as_event=float_as_event
515
- )[0]
516
- return u.maybe_decimal(res * (unitd * unitv))
517
-
518
-
519
- def _csr_matmat(
520
- data: jax.Array | u.Quantity,
521
- indices: jax.Array,
522
- indptr: jax.Array,
523
- B: jax.Array | u.Quantity,
524
- *,
525
- shape: Shape,
526
- transpose: bool = False,
527
- float_as_event: bool = True,
528
- ) -> jax.Array | u.Quantity:
529
- """
530
- Product of CSR sparse matrix and a dense matrix.
531
-
532
- Args:
533
- data : array of shape ``(nse,)``.
534
- indices : array of shape ``(nse,)``
535
- indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
536
- B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
537
- dtype ``data.dtype``
538
- shape : length-2 tuple representing the matrix shape
539
- transpose : boolean specifying whether to transpose the sparse matrix
540
- before computing.
541
-
542
- Returns:
543
- C : array of shape ``(shape[1] if transpose else shape[0], cols)``
544
- representing the matrix-matrix product.
545
- """
546
- data, unitd = u.split_mantissa_unit(data)
547
- B, unitb = u.split_mantissa_unit(B)
548
- res = event_csrmm_p_call(
549
- data,
550
- indices,
551
- indptr,
552
- B,
553
- shape=shape,
554
- transpose=transpose,
555
- float_as_event=float_as_event,
556
- )[0]
557
- return u.maybe_decimal(res * (unitd * unitb))
558
-
559
-
560
- Kernel = Callable
561
-
562
-
563
- def event_csrmv_cpu_kernel_generator(
564
- float_as_event: bool,
565
- weight_info: jax.ShapeDtypeStruct,
566
- spike_info: jax.ShapeDtypeStruct,
567
- transpose: bool,
568
- **kwargs
569
- ) -> Kernel:
570
- import numba # pylint: disable=import-outside-toplevel
571
-
572
- if weight_info.size == 1:
573
- if transpose:
574
- if spike_info.dtype == jnp.bool_:
575
- @numba.njit(fastmath=True)
576
- def mv(weights, indices, indptr, v, posts):
577
- posts[:] = 0.
578
- w = weights[()]
579
- for i in range(v.shape[0]):
580
- if v[i]:
581
- for j in range(indptr[i], indptr[i + 1]):
582
- posts[indices[j]] += w
583
-
584
- elif float_as_event:
585
- @numba.njit(fastmath=True)
586
- def mv(weights, indices, indptr, v, posts):
587
- posts[:] = 0.
588
- w = weights[()]
589
- for i in range(v.shape[0]):
590
- if v[i] != 0.:
591
- for j in range(indptr[i], indptr[i + 1]):
592
- posts[indices[j]] += w
593
-
594
- else:
595
- @numba.njit(fastmath=True)
596
- def mv(weights, indices, indptr, v, posts):
597
- posts[:] = 0.
598
- w = weights[()]
599
- for i in range(v.shape[0]):
600
- sp = v[i]
601
- if sp != 0.:
602
- wsp = w * sp
603
- for j in range(indptr[i], indptr[i + 1]):
604
- posts[indices[j]] += wsp
605
-
606
- else:
607
- if spike_info.dtype == jnp.bool_:
608
- @numba.njit(fastmath=True)
609
- def mv(weights, indices, indptr, v, posts):
610
- w = weights[()]
611
- for i in range(indptr.shape[0] - 1):
612
- r = 0.
613
- for j in range(indptr[i], indptr[i + 1]):
614
- if v[indices[j]]:
615
- r += w
616
- posts[i] = r
617
-
618
- elif float_as_event:
619
- @numba.njit(fastmath=True)
620
- def mv(weights, indices, indptr, v, posts):
621
- w = weights[()]
622
- for i in range(indptr.shape[0] - 1):
623
- r = 0.
624
- for j in range(indptr[i], indptr[i + 1]):
625
- if v[indices[j]] != 0.:
626
- r += w
627
- posts[i] = r
628
-
629
- else:
630
- @numba.njit(fastmath=True)
631
- def mv(weights, indices, indptr, v, posts):
632
- w = weights[()]
633
- for i in range(indptr.shape[0] - 1):
634
- r = 0.
635
- for j in range(indptr[i], indptr[i + 1]):
636
- c = v[indices[j]]
637
- if c != 0.:
638
- r += w * c
639
- posts[i] = r
640
-
641
- else:
642
- if transpose:
643
- if spike_info.dtype == jnp.bool_:
644
- @numba.njit(fastmath=True)
645
- def mv(weights, indices, indptr, v, posts):
646
- posts[:] = 0.
647
- for i in range(v.shape[0]):
648
- if v[i]:
649
- for j in range(indptr[i], indptr[i + 1]):
650
- posts[indices[j]] += weights[j]
651
-
652
- elif float_as_event:
653
- @numba.njit(fastmath=True)
654
- def mv(weights, indices, indptr, v, posts):
655
- posts[:] = 0.
656
- for i in range(v.shape[0]):
657
- if v[i] != 0.:
658
- for j in range(indptr[i], indptr[i + 1]):
659
- posts[indices[j]] += weights[j]
660
-
661
- else:
662
- @numba.njit(fastmath=True)
663
- def mv(weights, indices, indptr, v, posts):
664
- posts[:] = 0.
665
- for i in range(v.shape[0]):
666
- sp = v[i]
667
- if sp != 0.:
668
- for j in range(indptr[i], indptr[i + 1]):
669
- posts[indices[j]] += weights[j] * sp
670
-
671
- else:
672
- if spike_info.dtype == jnp.bool_:
673
- @numba.njit(fastmath=True)
674
- def mv(weights, indices, indptr, v, posts):
675
- for i in range(indptr.shape[0] - 1):
676
- r = 0.
677
- for j in range(indptr[i], indptr[i + 1]):
678
- if v[indices[j]]:
679
- r += weights[j]
680
- posts[i] = r
681
-
682
- elif float_as_event:
683
- @numba.njit(fastmath=True)
684
- def mv(weights, indices, indptr, v, posts):
685
- for i in range(indptr.shape[0] - 1):
686
- r = 0.
687
- for j in range(indptr[i], indptr[i + 1]):
688
- if v[indices[j]] != 0.:
689
- r += weights[j]
690
- posts[i] = r
691
-
692
- else:
693
- @numba.njit(fastmath=True)
694
- def mv(weights, indices, indptr, v, posts):
695
- for i in range(indptr.shape[0] - 1):
696
- r = 0.
697
- for j in range(indptr[i], indptr[i + 1]):
698
- c = v[indices[j]]
699
- if c != 0.:
700
- r += weights[j] * c
701
- posts[i] = r
702
-
703
- return mv
704
-
705
-
706
- def event_csrmv_jvp_v(
707
- v_dot,
708
- data,
709
- indices,
710
- indptr,
711
- v,
712
- *,
713
- shape,
714
- transpose,
715
- **kwargs
716
- ):
717
- return [
718
- csr_matvec(
719
- data,
720
- indices,
721
- indptr,
722
- v_dot,
723
- shape=shape,
724
- transpose=transpose
725
- )
726
- ]
727
-
728
-
729
- def event_csrmv_jvp_weights(
730
- data_dot,
731
- data,
732
- indices,
733
- indptr,
734
- v,
735
- *,
736
- shape,
737
- transpose,
738
- float_as_event,
739
- **kwargs
740
- ):
741
- return event_csrmv_p_call(
742
- data_dot,
743
- indices,
744
- indptr,
745
- v,
746
- shape=shape,
747
- transpose=transpose,
748
- float_as_event=float_as_event,
749
- )
750
-
751
-
752
- def event_csrmv_transpose_rule(
753
- ct,
754
- data,
755
- indices,
756
- indptr,
757
- events,
758
- *,
759
- shape,
760
- float_as_event,
761
- transpose,
762
- **kwargs
763
- ):
764
- if ad.is_undefined_primal(indices):
765
- raise ValueError("Cannot transpose with respect to sparse indices.")
766
-
767
- ct = ct[0]
768
-
769
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
770
- raise ValueError("Cannot transpose with respect to sparse indices.")
771
- if ad.is_undefined_primal(events):
772
- ct_events = csr_matvec(
773
- data,
774
- indices,
775
- indptr,
776
- ct,
777
- shape=shape,
778
- transpose=not transpose
779
- )[0]
780
- return data, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events)
781
- else:
782
- if type(ct[0]) is ad.Zero:
783
- ct_values = ad.Zero(data)
784
- else:
785
- if data.aval.shape[0] == 1: # scalar
786
- ct_values = event_csrmv_p_call(
787
- jnp.ones(1, dtype=data.dtype),
788
- indices,
789
- indptr,
790
- events,
791
- shape=shape,
792
- transpose=transpose,
793
- float_as_event=float_as_event,
794
- )[0]
795
- ct_values = jnp.inner(ct, ct_values)
796
- else: # heterogeneous values
797
- row, col = csr_to_coo(indices, indptr)
798
- ct_values = events[row] * ct[col] if transpose else events[col] * ct[row]
799
- return ct_values, indices, indptr, events
800
-
801
-
802
- def event_csrmv_batching(args, axes, **kwargs):
803
- if tuple(axes) == (None, None, None, 0):
804
- assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
805
- r = event_csrmm_p_call(
806
- args[0],
807
- args[1],
808
- args[2],
809
- args[3].T,
810
- shape=kwargs['shape'],
811
- transpose=kwargs['transpose'],
812
- float_as_event=kwargs['float_as_event']
813
- )
814
- return r, [1]
815
-
816
- elif tuple(axes) == (None, None, None, 1):
817
- assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
818
- r = event_csrmm_p_call(
819
- args[0],
820
- args[1],
821
- args[2],
822
- args[3],
823
- shape=kwargs['shape'],
824
- transpose=kwargs['transpose'],
825
- float_as_event=kwargs['float_as_event']
826
- )
827
- return r, [1]
828
-
829
- else:
830
- raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
831
-
832
-
833
- event_csrmv_p = XLACustomOp(
834
- 'event_csrmv',
835
- cpu_kernel_or_generator=event_csrmv_cpu_kernel_generator,
836
- )
837
- event_csrmv_p.defjvp(event_csrmv_jvp_weights, None, None, event_csrmv_jvp_v)
838
- event_csrmv_p.def_transpose_rule(event_csrmv_transpose_rule)
839
- event_csrmv_p.def_batching_rule(event_csrmv_batching)
840
-
841
-
842
- def event_csrmv_p_call(
843
- weights,
844
- indices,
845
- indptr,
846
- v,
847
- *,
848
- shape: Shape,
849
- transpose: bool,
850
- float_as_event: bool,
851
- ):
852
- if jax.default_backend() == 'cpu':
853
- return event_csrmv_p(
854
- weights,
855
- indices,
856
- indptr,
857
- v,
858
- outs=[
859
- jax.ShapeDtypeStruct([shape[1]], weights.dtype)
860
- if transpose else
861
- jax.ShapeDtypeStruct([shape[0]], weights.dtype),
862
- ],
863
- # block_size=block_size,
864
- float_as_event=float_as_event,
865
- shape=shape,
866
- transpose=transpose,
867
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
868
- spike_info=jax.ShapeDtypeStruct(v.shape, v.dtype),
869
- )
870
- else:
871
- return [
872
- csr_matvec(
873
- weights,
874
- indices,
875
- indptr,
876
- v,
877
- shape=shape,
878
- transpose=transpose
879
- )
880
- ]
881
-
882
-
883
- def event_csrmm_batching(args, axes, **kwargs):
884
- if tuple(axes) == (None, None, None, 0):
885
- assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
886
- batch_size, m, n = args[3].shape
887
- B = jnp.transpose(args[3], (1, 0, 2)).reshape(m, batch_size * n)
888
- r = event_csrmm_p_call(
889
- args[0],
890
- args[1],
891
- args[2],
892
- B,
893
- shape=kwargs['shape'],
894
- transpose=kwargs['transpose'],
895
- float_as_event=kwargs['float_as_event']
896
- )
897
- r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
898
- return [r], [1]
899
-
900
- elif tuple(axes) == (None, None, None, 1):
901
- assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
902
- m, batch_size, n = args[3].shape
903
- B = args[3].reshape(m, batch_size * n)
904
- r = event_csrmm_p_call(
905
- args[0],
906
- args[1],
907
- args[2],
908
- B,
909
- shape=kwargs['shape'],
910
- transpose=kwargs['transpose'],
911
- float_as_event=kwargs['float_as_event']
912
- )
913
- r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
914
- return [r], [1]
915
-
916
- elif tuple(axes) == (None, None, None, 2):
917
- assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
918
- m, n, batch_size = args[3].shape
919
- B = args[3].reshape(m, batch_size * n)
920
- r = event_csrmm_p_call(
921
- args[0],
922
- args[1],
923
- args[2],
924
- B,
925
- shape=kwargs['shape'],
926
- transpose=kwargs['transpose'],
927
- float_as_event=kwargs['float_as_event']
928
- )
929
- r = jnp.reshape(r[0], [r[0].shape[0], n, batch_size])
930
- return [r], [2]
931
-
932
- else:
933
- raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
934
-
935
-
936
- def event_csrmm_cpu_kernel_generator(
937
- float_as_event: bool,
938
- weight_info: jax.ShapeDtypeStruct,
939
- spike_info: jax.ShapeDtypeStruct,
940
- transpose: bool,
941
- **kwargs
942
- ) -> Kernel:
943
- import numba # pylint: disable=import-outside-toplevel
944
-
945
- if weight_info.size == 1:
946
- if transpose:
947
- # csr.T @ B
948
-
949
- if spike_info.dtype == jnp.bool_:
950
- @numba.njit(fastmath=True, parallel=False)
951
- def mv(weights, indices, indptr, B, posts):
952
- posts[:] = 0.
953
- w = weights[()]
954
- for k in numba.prange(B.shape[1]):
955
- for i in range(B.shape[0]):
956
- if B[i, k]:
957
- for j in range(indptr[i], indptr[i + 1]):
958
- posts[indices[j], k] += w
959
-
960
- elif float_as_event:
961
- @numba.njit(fastmath=True, parallel=False)
962
- def mv(weights, indices, indptr, B, posts):
963
- posts[:] = 0.
964
- B = B != 0.
965
- w = weights[()]
966
- for k in numba.prange(B.shape[1]):
967
- for i in range(B.shape[0]):
968
- if B[i, k]:
969
- for j in range(indptr[i], indptr[i + 1]):
970
- posts[indices[j], k] += w
971
-
972
- else:
973
- @numba.njit(fastmath=True, parallel=False)
974
- def mv(weights, indices, indptr, B, posts):
975
- posts[:] = 0.
976
- w = weights[()]
977
- for k in numba.prange(B.shape[1]):
978
- for i in range(B.shape[0]):
979
- sp = B[i, k]
980
- if sp != 0.:
981
- wsp = w * sp
982
- for j in range(indptr[i], indptr[i + 1]):
983
- posts[indices[j], k] += wsp
984
-
985
- else:
986
- # csr @ B
987
- if spike_info.dtype == jnp.bool_:
988
- @numba.njit(fastmath=True)
989
- def mv(weights, indices, indptr, B, posts):
990
- w = weights[()]
991
- for i in range(indptr.shape[0] - 1):
992
- r = np.zeros(B.shape[1], dtype=weights.dtype)
993
- for j in range(indptr[i], indptr[i + 1]):
994
- index = indices[j]
995
- for k in range(B.shape[1]):
996
- if B[index, k]:
997
- r[k] += w
998
- posts[i] = r
999
-
1000
- elif float_as_event:
1001
- @numba.njit(fastmath=True)
1002
- def mv(weights, indices, indptr, B, posts):
1003
- w = weights[()]
1004
- B = B != 0.
1005
- for i in range(indptr.shape[0] - 1):
1006
- r = np.zeros(B.shape[1], dtype=weights.dtype)
1007
- for j in range(indptr[i], indptr[i + 1]):
1008
- index = indices[j]
1009
- for k in range(B.shape[1]):
1010
- if B[index, k]:
1011
- r[k] += w
1012
- posts[i] = r
1013
-
1014
- else:
1015
- @numba.njit(fastmath=True)
1016
- def mv(weights, indices, indptr, B, posts):
1017
- w = weights[()]
1018
- for i in range(indptr.shape[0] - 1):
1019
- for k in range(B.shape[1]):
1020
- r = 0.
1021
- for j in range(indptr[i], indptr[i + 1]):
1022
- c = B[indices[j], k]
1023
- if c != 0.:
1024
- r += w * c
1025
- posts[i, k] = r
1026
-
1027
- else:
1028
- if transpose:
1029
- # csr.T @ B
1030
-
1031
- if spike_info.dtype == jnp.bool_:
1032
- @numba.njit(fastmath=True, parallel=False)
1033
- def mv(weights, indices, indptr, B, posts):
1034
- posts[:] = 0.
1035
- for k in numba.prange(B.shape[1]):
1036
- for i in range(B.shape[0]):
1037
- if B[i, k]:
1038
- for j in range(indptr[i], indptr[i + 1]):
1039
- posts[indices[j], k] += weights[j]
1040
-
1041
- elif float_as_event:
1042
- @numba.njit(fastmath=True, parallel=False)
1043
- def mv(weights, indices, indptr, B, posts):
1044
- posts[:] = 0.
1045
- B = B != 0.
1046
- for k in numba.prange(B.shape[1]):
1047
- for i in range(B.shape[0]):
1048
- if B[i, k]:
1049
- for j in range(indptr[i], indptr[i + 1]):
1050
- posts[indices[j], k] += weights[j]
1051
-
1052
- else:
1053
- @numba.njit(fastmath=True, parallel=False)
1054
- def mv(weights, indices, indptr, B, posts):
1055
- posts[:] = 0.
1056
- for k in numba.prange(B.shape[1]):
1057
- for i in range(B.shape[0]):
1058
- sp = B[i, k]
1059
- if sp != 0.:
1060
- for j in range(indptr[i], indptr[i + 1]):
1061
- posts[indices[j], k] += weights[j] * sp
1062
-
1063
- else:
1064
- # csr @ B
1065
-
1066
- if spike_info.dtype == jnp.bool_:
1067
- @numba.njit(fastmath=True)
1068
- def mv(weights, indices, indptr, B, posts):
1069
- for i in range(indptr.shape[0] - 1):
1070
- for k in range(B.shape[1]):
1071
- r = 0.
1072
- for j in range(indptr[i], indptr[i + 1]):
1073
- if B[indices[j], k]:
1074
- r += weights[j]
1075
- posts[i, k] = r
1076
-
1077
- elif float_as_event:
1078
- @numba.njit(fastmath=True)
1079
- def mv(weights, indices, indptr, B, posts):
1080
- B = B != 0.
1081
- for i in range(indptr.shape[0] - 1):
1082
- for k in range(B.shape[1]):
1083
- r = 0.
1084
- for j in range(indptr[i], indptr[i + 1]):
1085
- if B[indices[j], k]:
1086
- r += weights[j]
1087
- posts[i, k] = r
1088
-
1089
- else:
1090
- @numba.njit(fastmath=True)
1091
- def mv(weights, indices, indptr, B, posts):
1092
- for i in range(indptr.shape[0] - 1):
1093
- for k in range(B.shape[1]):
1094
- r = 0.
1095
- for j in range(indptr[i], indptr[i + 1]):
1096
- c = B[indices[j], k]
1097
- if c != 0.:
1098
- r += weights[j] * c
1099
- posts[i, k] = r
1100
-
1101
- return mv
1102
-
1103
-
1104
- event_csrmm_p = XLACustomOp(
1105
- 'event_csrmm',
1106
- cpu_kernel_or_generator=event_csrmm_cpu_kernel_generator,
1107
- )
1108
- event_csrmm_p.def_batching_rule(event_csrmm_batching)
1109
-
1110
-
1111
- def event_csrmm_p_call(
1112
- weights,
1113
- indices,
1114
- indptr,
1115
- B,
1116
- *,
1117
- shape: Shape,
1118
- transpose: bool,
1119
- float_as_event: bool,
1120
- ):
1121
- if jax.default_backend() == 'cpu':
1122
- return event_csrmm_p(
1123
- weights,
1124
- indices,
1125
- indptr,
1126
- B,
1127
- outs=[
1128
- jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype)
1129
- if transpose else
1130
- jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype),
1131
- ],
1132
- # block_size=block_size,
1133
- shape=shape,
1134
- transpose=transpose,
1135
- float_as_event=float_as_event,
1136
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
1137
- spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),
1138
- )
1139
- else:
1140
- return [
1141
- csr_matmat(
1142
- weights,
1143
- indices,
1144
- indptr,
1145
- B,
1146
- shape=shape,
1147
- transpose=transpose
1148
- )
1149
- ]