brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
brainstate/event/_csr.py DELETED
@@ -1,906 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from __future__ import annotations
18
-
19
- import operator
20
- from typing import Callable
21
-
22
- import brainunit as u
23
- import jax
24
- import jax.numpy as jnp
25
- import numpy as np
26
- from brainunit.sparse._csr import (
27
- _csr_matvec as csr_matvec,
28
- _csr_matmat as csr_matmat,
29
- _csr_to_coo as csr_to_coo
30
- )
31
- from jax.experimental.sparse import JAXSparse
32
- from jax.interpreters import ad
33
-
34
- from brainstate.typing import Shape
35
- from ._xla_custom_op import XLACustomOp
36
-
37
- __all__ = [
38
- 'CSR',
39
- 'CSC',
40
- ]
41
-
42
-
43
- @jax.tree_util.register_pytree_node_class
44
- class CSR(u.sparse.SparseMatrix):
45
- """
46
- Event-driven and Unit-aware CSR matrix.
47
- """
48
- data: jax.Array | u.Quantity
49
- indices: jax.Array
50
- indptr: jax.Array
51
- shape: tuple[int, int]
52
- nse = property(lambda self: self.data.size)
53
- dtype = property(lambda self: self.data.dtype)
54
- _bufs = property(lambda self: (self.data, self.indices, self.indptr))
55
-
56
- def __init__(self, args, *, shape):
57
- self.data, self.indices, self.indptr = map(u.math.asarray, args)
58
- super().__init__(args, shape=shape)
59
-
60
- @classmethod
61
- def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
62
- if nse is None:
63
- nse = (u.get_mantissa(mat) != 0).sum()
64
- return u.sparse.csr_fromdense(mat, nse=nse, index_dtype=index_dtype)
65
-
66
- def with_data(self, data: jax.Array | u.Quantity) -> CSR:
67
- assert data.shape == self.data.shape
68
- assert data.dtype == self.data.dtype
69
- assert u.get_unit(data) == u.get_unit(self.data)
70
- return CSR((data, self.indices, self.indptr), shape=self.shape)
71
-
72
- def todense(self):
73
- return u.sparse.csr_todense(self)
74
-
75
- def transpose(self, axes=None):
76
- assert axes is None
77
- return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
78
-
79
- def __abs__(self):
80
- return CSR((abs(self.data), self.indices, self.indptr), shape=self.shape)
81
-
82
- def __neg__(self):
83
- return CSR((-self.data, self.indices, self.indptr), shape=self.shape)
84
-
85
- def __pos__(self):
86
- return CSR((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
87
-
88
- def _binary_op(self, other, op):
89
- if isinstance(other, CSR):
90
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
91
- return CSR(
92
- (op(self.data, other.data),
93
- self.indices,
94
- self.indptr),
95
- shape=self.shape
96
- )
97
- if isinstance(other, JAXSparse):
98
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
99
-
100
- other = u.math.asarray(other)
101
- if other.size == 1:
102
- return CSR(
103
- (op(self.data, other), self.indices, self.indptr),
104
- shape=self.shape
105
- )
106
- elif other.ndim == 2 and other.shape == self.shape:
107
- rows, cols = csr_to_coo(self.indices, self.indptr)
108
- other = other[rows, cols]
109
- return CSR(
110
- (op(self.data, other),
111
- self.indices,
112
- self.indptr),
113
- shape=self.shape
114
- )
115
- else:
116
- raise NotImplementedError(f"mul with object of shape {other.shape}")
117
-
118
- def _binary_rop(self, other, op):
119
- if isinstance(other, CSR):
120
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
121
- return CSR(
122
- (op(other.data, self.data),
123
- self.indices,
124
- self.indptr),
125
- shape=self.shape
126
- )
127
- if isinstance(other, JAXSparse):
128
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
129
-
130
- other = u.math.asarray(other)
131
- if other.size == 1:
132
- return CSR(
133
- (op(other, self.data),
134
- self.indices,
135
- self.indptr),
136
- shape=self.shape
137
- )
138
- elif other.ndim == 2 and other.shape == self.shape:
139
- rows, cols = csr_to_coo(self.indices, self.indptr)
140
- other = other[rows, cols]
141
- return CSR(
142
- (op(other, self.data),
143
- self.indices,
144
- self.indptr),
145
- shape=self.shape
146
- )
147
- else:
148
- raise NotImplementedError(f"mul with object of shape {other.shape}")
149
-
150
- def __mul__(self, other: jax.Array | u.Quantity) -> CSR:
151
- return self._binary_op(other, operator.mul)
152
-
153
- def __rmul__(self, other: jax.Array | u.Quantity) -> CSR:
154
- return self._binary_rop(other, operator.mul)
155
-
156
- def __div__(self, other: jax.Array | u.Quantity) -> CSR:
157
- return self._binary_op(other, operator.truediv)
158
-
159
- def __rdiv__(self, other: jax.Array | u.Quantity) -> CSR:
160
- return self._binary_rop(other, operator.truediv)
161
-
162
- def __truediv__(self, other) -> CSR:
163
- return self.__div__(other)
164
-
165
- def __rtruediv__(self, other) -> CSR:
166
- return self.__rdiv__(other)
167
-
168
- def __add__(self, other) -> CSR:
169
- return self._binary_op(other, operator.add)
170
-
171
- def __radd__(self, other) -> CSR:
172
- return self._binary_rop(other, operator.add)
173
-
174
- def __sub__(self, other) -> CSR:
175
- return self._binary_op(other, operator.sub)
176
-
177
- def __rsub__(self, other) -> CSR:
178
- return self._binary_rop(other, operator.sub)
179
-
180
- def __mod__(self, other) -> CSR:
181
- return self._binary_op(other, operator.mod)
182
-
183
- def __rmod__(self, other) -> CSR:
184
- return self._binary_rop(other, operator.mod)
185
-
186
- def __matmul__(self, other):
187
- if isinstance(other, JAXSparse):
188
- raise NotImplementedError("matmul between two sparse objects.")
189
- other = u.math.asarray(other)
190
- data, other = u.math.promote_dtypes(self.data, other)
191
- if other.ndim == 1:
192
- return _csr_matvec(
193
- data,
194
- self.indices,
195
- self.indptr,
196
- other,
197
- shape=self.shape
198
- )
199
- elif other.ndim == 2:
200
- return _csr_matmat(
201
- data,
202
- self.indices,
203
- self.indptr,
204
- other,
205
- shape=self.shape
206
- )
207
- else:
208
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
209
-
210
- def __rmatmul__(self, other):
211
- if isinstance(other, JAXSparse):
212
- raise NotImplementedError("matmul between two sparse objects.")
213
- other = u.math.asarray(other)
214
- data, other = u.math.promote_dtypes(self.data, other)
215
- if other.ndim == 1:
216
- return _csr_matvec(
217
- data,
218
- self.indices,
219
- self.indptr,
220
- other,
221
- shape=self.shape,
222
- transpose=True
223
- )
224
- elif other.ndim == 2:
225
- other = other.T
226
- r = _csr_matmat(
227
- data,
228
- self.indices,
229
- self.indptr,
230
- other,
231
- shape=self.shape,
232
- transpose=True
233
- )
234
- return r.T
235
- else:
236
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
237
-
238
- def tree_flatten(self):
239
- return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
240
-
241
- @classmethod
242
- def tree_unflatten(cls, aux_data, children):
243
- obj = object.__new__(cls)
244
- obj.data, = children
245
- if aux_data.keys() != {'shape', 'indices', 'indptr'}:
246
- raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
247
- obj.__dict__.update(**aux_data)
248
- return obj
249
-
250
-
251
- @jax.tree_util.register_pytree_node_class
252
- class CSC(u.sparse.SparseMatrix):
253
- """
254
- Event-driven and Unit-aware CSC matrix.
255
- """
256
- data: jax.Array
257
- indices: jax.Array
258
- indptr: jax.Array
259
- shape: tuple[int, int]
260
- nse = property(lambda self: self.data.size)
261
- dtype = property(lambda self: self.data.dtype)
262
-
263
- def __init__(self, args, *, shape):
264
- self.data, self.indices, self.indptr = map(u.math.asarray, args)
265
- super().__init__(args, shape=shape)
266
-
267
- @classmethod
268
- def fromdense(cls, mat, *, nse=None, index_dtype=np.int32):
269
- if nse is None:
270
- nse = (u.get_mantissa(mat) != 0).sum()
271
- return u.sparse.csr_fromdense(mat.T, nse=nse, index_dtype=index_dtype).T
272
-
273
- @classmethod
274
- def _empty(cls, shape, *, dtype=None, index_dtype='int32'):
275
- """Create an empty CSC instance. Public method is sparse.empty()."""
276
- shape = tuple(shape)
277
- if len(shape) != 2:
278
- raise ValueError(f"CSC must have ndim=2; got {shape=}")
279
- data = jnp.empty(0, dtype)
280
- indices = jnp.empty(0, index_dtype)
281
- indptr = jnp.zeros(shape[1] + 1, index_dtype)
282
- return cls((data, indices, indptr), shape=shape)
283
-
284
- @classmethod
285
- def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
286
- return CSR._eye(M, N, -k, dtype=dtype, index_dtype=index_dtype).T
287
-
288
- def with_data(self, data: jax.Array | u.Quantity) -> CSC:
289
- assert data.shape == self.data.shape
290
- assert data.dtype == self.data.dtype
291
- assert u.get_unit(data) == u.get_unit(self.data)
292
- return CSC((data, self.indices, self.indptr), shape=self.shape)
293
-
294
- def todense(self):
295
- return u.sparse.csr_todense(self.T).T
296
-
297
- def transpose(self, axes=None):
298
- assert axes is None
299
- return CSR((self.data, self.indices, self.indptr), shape=self.shape[::-1])
300
-
301
- def __abs__(self):
302
- return CSC((abs(self.data), self.indices, self.indptr), shape=self.shape)
303
-
304
- def __neg__(self):
305
- return CSC((-self.data, self.indices, self.indptr), shape=self.shape)
306
-
307
- def __pos__(self):
308
- return CSC((self.data.__pos__(), self.indices, self.indptr), shape=self.shape)
309
-
310
- def _binary_op(self, other, op):
311
- if isinstance(other, CSC):
312
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
313
- return CSC(
314
- (op(self.data, other.data),
315
- self.indices,
316
- self.indptr),
317
- shape=self.shape
318
- )
319
- if isinstance(other, JAXSparse):
320
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
321
-
322
- other = u.math.asarray(other)
323
- if other.size == 1:
324
- return CSC(
325
- (op(self.data, other),
326
- self.indices,
327
- self.indptr),
328
- shape=self.shape
329
- )
330
- elif other.ndim == 2 and other.shape == self.shape:
331
- cols, rows = csr_to_coo(self.indices, self.indptr)
332
- other = other[rows, cols]
333
- return CSC(
334
- (op(self.data, other),
335
- self.indices,
336
- self.indptr),
337
- shape=self.shape
338
- )
339
- else:
340
- raise NotImplementedError(f"mul with object of shape {other.shape}")
341
-
342
- def _binary_rop(self, other, op):
343
- if isinstance(other, CSC):
344
- if id(other.indices) == id(self.indices) and id(other.indptr) == id(self.indptr):
345
- return CSC(
346
- (op(other.data, self.data),
347
- self.indices,
348
- self.indptr),
349
- shape=self.shape
350
- )
351
- if isinstance(other, JAXSparse):
352
- raise NotImplementedError(f"binary operation {op} between two sparse objects.")
353
-
354
- other = u.math.asarray(other)
355
- if other.size == 1:
356
- return CSC(
357
- (op(other, self.data),
358
- self.indices,
359
- self.indptr),
360
- shape=self.shape
361
- )
362
- elif other.ndim == 2 and other.shape == self.shape:
363
- cols, rows = csr_to_coo(self.indices, self.indptr)
364
- other = other[rows, cols]
365
- return CSC(
366
- (op(other, self.data),
367
- self.indices,
368
- self.indptr),
369
- shape=self.shape
370
- )
371
- else:
372
- raise NotImplementedError(f"mul with object of shape {other.shape}")
373
-
374
- def __mul__(self, other: jax.Array | u.Quantity) -> 'CSC':
375
- return self._binary_op(other, operator.mul)
376
-
377
- def __rmul__(self, other: jax.Array | u.Quantity) -> 'CSC':
378
- return self._binary_rop(other, operator.mul)
379
-
380
- def __div__(self, other: jax.Array | u.Quantity) -> CSC:
381
- return self._binary_op(other, operator.truediv)
382
-
383
- def __rdiv__(self, other: jax.Array | u.Quantity) -> CSC:
384
- return self._binary_rop(other, operator.truediv)
385
-
386
- def __truediv__(self, other) -> CSC:
387
- return self.__div__(other)
388
-
389
- def __rtruediv__(self, other) -> CSC:
390
- return self.__rdiv__(other)
391
-
392
- def __add__(self, other) -> CSC:
393
- return self._binary_op(other, operator.add)
394
-
395
- def __radd__(self, other) -> CSC:
396
- return self._binary_rop(other, operator.add)
397
-
398
- def __sub__(self, other) -> CSC:
399
- return self._binary_op(other, operator.sub)
400
-
401
- def __rsub__(self, other) -> CSC:
402
- return self._binary_rop(other, operator.sub)
403
-
404
- def __mod__(self, other) -> CSC:
405
- return self._binary_op(other, operator.mod)
406
-
407
- def __rmod__(self, other) -> CSC:
408
- return self._binary_rop(other, operator.mod)
409
-
410
- def __matmul__(self, other):
411
- if isinstance(other, JAXSparse):
412
- raise NotImplementedError("matmul between two sparse objects.")
413
- other = u.math.asarray(other)
414
- data, other = u.math.promote_dtypes(self.data, other)
415
- if other.ndim == 1:
416
- return _csr_matvec(
417
- data,
418
- self.indices,
419
- self.indptr,
420
- other,
421
- shape=self.shape[::-1],
422
- transpose=True
423
- )
424
- elif other.ndim == 2:
425
- return _csr_matmat(
426
- data,
427
- self.indices,
428
- self.indptr,
429
- other,
430
- shape=self.shape[::-1],
431
- transpose=True
432
- )
433
- else:
434
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
435
-
436
- def __rmatmul__(self, other):
437
- if isinstance(other, JAXSparse):
438
- raise NotImplementedError("matmul between two sparse objects.")
439
- other = u.math.asarray(other)
440
- data, other = u.math.promote_dtypes(self.data, other)
441
- if other.ndim == 1:
442
- return _csr_matvec(
443
- data,
444
- self.indices,
445
- self.indptr,
446
- other,
447
- shape=self.shape[::-1],
448
- transpose=False
449
- )
450
- elif other.ndim == 2:
451
- other = other.T
452
- r = _csr_matmat(
453
- data,
454
- self.indices,
455
- self.indptr, other,
456
- shape=self.shape[::-1],
457
- transpose=False
458
- )
459
- return r.T
460
- else:
461
- raise NotImplementedError(f"matmul with object of shape {other.shape}")
462
-
463
- def tree_flatten(self):
464
- return (self.data,), {"shape": self.shape, "indices": self.indices, "indptr": self.indptr}
465
-
466
- @classmethod
467
- def tree_unflatten(cls, aux_data, children):
468
- obj = object.__new__(cls)
469
- obj.data, = children
470
- if aux_data.keys() != {'shape', 'indices', 'indptr'}:
471
- raise ValueError(f"CSR.tree_unflatten: invalid {aux_data=}")
472
- obj.__dict__.update(**aux_data)
473
- return obj
474
-
475
-
476
- def _csr_matvec(
477
- data: jax.Array | u.Quantity,
478
- indices: jax.Array,
479
- indptr: jax.Array,
480
- v: jax.Array | u.Quantity,
481
- *,
482
- shape: Shape,
483
- transpose: bool = False,
484
- float_as_event: bool = True,
485
- ) -> jax.Array | u.Quantity:
486
- """Product of CSR sparse matrix and a dense vector.
487
-
488
- Args:
489
- data : array of shape ``(nse,)``.
490
- indices : array of shape ``(nse,)``
491
- indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
492
- v : array of shape ``(shape[0] if transpose else shape[1],)``
493
- and dtype ``data.dtype``
494
- shape : length-2 tuple representing the matrix shape
495
- transpose : boolean specifying whether to transpose the sparse matrix
496
- before computing.
497
-
498
- Returns:
499
- y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
500
- the matrix vector product.
501
- """
502
- data, unitd = u.split_mantissa_unit(data)
503
- v, unitv = u.split_mantissa_unit(v)
504
- res = event_csrmv_p_call(
505
- data, indices, indptr, v,
506
- shape=shape,
507
- transpose=transpose,
508
- float_as_event=float_as_event
509
- )[0]
510
- return u.maybe_decimal(res * (unitd * unitv))
511
-
512
-
513
- def _csr_matmat(
514
- data: jax.Array | u.Quantity,
515
- indices: jax.Array,
516
- indptr: jax.Array,
517
- B: jax.Array | u.Quantity,
518
- *,
519
- shape: Shape,
520
- transpose: bool = False,
521
- float_as_event: bool = True,
522
- ) -> jax.Array | u.Quantity:
523
- """
524
- Product of CSR sparse matrix and a dense matrix.
525
-
526
- Args:
527
- data : array of shape ``(nse,)``.
528
- indices : array of shape ``(nse,)``
529
- indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
530
- B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
531
- dtype ``data.dtype``
532
- shape : length-2 tuple representing the matrix shape
533
- transpose : boolean specifying whether to transpose the sparse matrix
534
- before computing.
535
-
536
- Returns:
537
- C : array of shape ``(shape[1] if transpose else shape[0], cols)``
538
- representing the matrix-matrix product.
539
- """
540
- data, unitd = u.split_mantissa_unit(data)
541
- B, unitb = u.split_mantissa_unit(B)
542
- res = event_csrmm_p_call(
543
- data,
544
- indices,
545
- indptr,
546
- B,
547
- shape=shape,
548
- transpose=transpose,
549
- float_as_event=float_as_event,
550
- )[0]
551
- return u.maybe_decimal(res * (unitd * unitb))
552
-
553
-
554
- Kernel = Callable
555
-
556
-
557
- def event_csrmv_cpu_kernel_generator(
558
- float_as_event: bool,
559
- weight_info: jax.ShapeDtypeStruct,
560
- spike_info: jax.ShapeDtypeStruct,
561
- transpose: bool,
562
- **kwargs
563
- ) -> Kernel:
564
- import numba # pylint: disable=import-outside-toplevel
565
-
566
- if weight_info.size == 1:
567
- if transpose:
568
- if spike_info.dtype == jnp.bool_:
569
- @numba.njit
570
- def mv(weights, indices, indptr, v, posts):
571
- posts[:] = 0.
572
- w = weights[()]
573
- for i in range(v.shape[0]):
574
- if v[i]:
575
- for j in range(indptr[i], indptr[i + 1]):
576
- posts[indices[j]] += w
577
-
578
- elif float_as_event:
579
- @numba.njit
580
- def mv(weights, indices, indptr, v, posts):
581
- posts[:] = 0.
582
- w = weights[()]
583
- for i in range(v.shape[0]):
584
- if v[i] != 0.:
585
- for j in range(indptr[i], indptr[i + 1]):
586
- posts[indices[j]] += w
587
-
588
- else:
589
- @numba.njit
590
- def mv(weights, indices, indptr, v, posts):
591
- posts[:] = 0.
592
- w = weights[()]
593
- for i in range(v.shape[0]):
594
- sp = v[i]
595
- if sp != 0.:
596
- wsp = w * sp
597
- for j in range(indptr[i], indptr[i + 1]):
598
- posts[indices[j]] += wsp
599
-
600
- else:
601
- if spike_info.dtype == jnp.bool_:
602
- @numba.njit
603
- def mv(weights, indices, indptr, v, posts):
604
- w = weights[()]
605
- for i in range(indptr.shape[0] - 1):
606
- r = 0.
607
- for j in range(indptr[i], indptr[i + 1]):
608
- if v[indices[j]]:
609
- r += w
610
- posts[i] = r
611
-
612
- elif float_as_event:
613
- @numba.njit
614
- def mv(weights, indices, indptr, v, posts):
615
- w = weights[()]
616
- for i in range(indptr.shape[0] - 1):
617
- r = 0.
618
- for j in range(indptr[i], indptr[i + 1]):
619
- if v[indices[j]] != 0.:
620
- r += w
621
- posts[i] = r
622
-
623
- else:
624
- @numba.njit
625
- def mv(weights, indices, indptr, v, posts):
626
- w = weights[()]
627
- for i in range(indptr.shape[0] - 1):
628
- r = 0.
629
- for j in range(indptr[i], indptr[i + 1]):
630
- c = v[indices[j]]
631
- if c != 0.:
632
- r += w * c
633
- posts[i] = r
634
-
635
- else:
636
- if transpose:
637
- if spike_info.dtype == jnp.bool_:
638
- @numba.njit
639
- def mv(weights, indices, indptr, v, posts):
640
- posts[:] = 0.
641
- for i in range(v.shape[0]):
642
- if v[i]:
643
- for j in range(indptr[i], indptr[i + 1]):
644
- posts[indices[j]] += weights[j]
645
-
646
- elif float_as_event:
647
- @numba.njit
648
- def mv(weights, indices, indptr, v, posts):
649
- posts[:] = 0.
650
- for i in range(v.shape[0]):
651
- if v[i] != 0.:
652
- for j in range(indptr[i], indptr[i + 1]):
653
- posts[indices[j]] += weights[j]
654
-
655
- else:
656
- @numba.njit
657
- def mv(weights, indices, indptr, v, posts):
658
- posts[:] = 0.
659
- for i in range(v.shape[0]):
660
- sp = v[i]
661
- if sp != 0.:
662
- for j in range(indptr[i], indptr[i + 1]):
663
- posts[indices[j]] += weights[j] * sp
664
-
665
- else:
666
- if spike_info.dtype == jnp.bool_:
667
- @numba.njit
668
- def mv(weights, indices, indptr, v, posts):
669
- for i in range(indptr.shape[0] - 1):
670
- r = 0.
671
- for j in range(indptr[i], indptr[i + 1]):
672
- if v[indices[j]]:
673
- r += weights[j]
674
- posts[i] = r
675
-
676
- elif float_as_event:
677
- @numba.njit
678
- def mv(weights, indices, indptr, v, posts):
679
- for i in range(indptr.shape[0] - 1):
680
- r = 0.
681
- for j in range(indptr[i], indptr[i + 1]):
682
- if v[indices[j]] != 0.:
683
- r += weights[j]
684
- posts[i] = r
685
-
686
- else:
687
- @numba.njit
688
- def mv(weights, indices, indptr, v, posts):
689
- for i in range(indptr.shape[0] - 1):
690
- r = 0.
691
- for j in range(indptr[i], indptr[i + 1]):
692
- c = v[indices[j]]
693
- if c != 0.:
694
- r += weights[j] * c
695
- posts[i] = r
696
-
697
- return mv
698
-
699
-
700
- def event_csrmv_jvp_v(
701
- v_dot,
702
- data,
703
- indices,
704
- indptr,
705
- v,
706
- *,
707
- shape,
708
- transpose,
709
- **kwargs
710
- ):
711
- return [
712
- csr_matvec(
713
- data,
714
- indices,
715
- indptr,
716
- v_dot,
717
- shape=shape,
718
- transpose=transpose
719
- )
720
- ]
721
-
722
-
723
- def event_csrmv_jvp_weights(
724
- data_dot,
725
- data,
726
- indices,
727
- indptr,
728
- v,
729
- *,
730
- shape,
731
- transpose,
732
- float_as_event,
733
- **kwargs
734
- ):
735
- return event_csrmv_p_call(
736
- data_dot,
737
- indices,
738
- indptr,
739
- v,
740
- shape=shape,
741
- transpose=transpose,
742
- float_as_event=float_as_event,
743
- )
744
-
745
-
746
- def event_csrmv_transpose_rule(
747
- ct,
748
- data,
749
- indices,
750
- indptr,
751
- events,
752
- *,
753
- shape,
754
- float_as_event,
755
- transpose,
756
- **kwargs
757
- ):
758
- if ad.is_undefined_primal(indices):
759
- raise ValueError("Cannot transpose with respect to sparse indices.")
760
-
761
- ct = ct[0]
762
-
763
- if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
764
- raise ValueError("Cannot transpose with respect to sparse indices.")
765
- if ad.is_undefined_primal(events):
766
- ct_events = csr_matvec(
767
- data,
768
- indices,
769
- indptr,
770
- ct,
771
- shape=shape,
772
- transpose=not transpose
773
- )[0]
774
- return data, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events)
775
- else:
776
- if type(ct[0]) is ad.Zero:
777
- ct_values = ad.Zero(data)
778
- else:
779
- if data.aval.shape[0] == 1: # scalar
780
- ct_values = event_csrmv_p_call(
781
- jnp.ones(1, dtype=data.dtype),
782
- indices,
783
- indptr,
784
- events,
785
- shape=shape,
786
- transpose=transpose,
787
- float_as_event=float_as_event,
788
- )[0]
789
- ct_values = jnp.inner(ct, ct_values)
790
- else: # heterogeneous values
791
- row, col = csr_to_coo(indices, indptr)
792
- ct_values = events[row] * ct[col] if transpose else events[col] * ct[row]
793
- return ct_values, indices, indptr, events
794
-
795
-
796
- def event_csrmv_batching(args, axes, **kwargs):
797
- if tuple(axes) == (None, None, None, 0):
798
- return 0, event_csrmm_p_call(*args, **kwargs)
799
- else:
800
- raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
801
-
802
-
803
- event_csrmv_p = XLACustomOp(
804
- 'event_csrmv',
805
- cpu_kernel_or_generator=event_csrmv_cpu_kernel_generator,
806
- )
807
- event_csrmv_p.defjvp(event_csrmv_jvp_weights, None, None, event_csrmv_jvp_v)
808
- event_csrmv_p.def_transpose_rule(event_csrmv_transpose_rule)
809
- event_csrmv_p.def_batching_rule(event_csrmv_batching)
810
-
811
-
812
- def event_csrmv_p_call(
813
- weights,
814
- indices,
815
- indptr,
816
- v,
817
- *,
818
- shape: Shape,
819
- transpose: bool,
820
- float_as_event: bool,
821
- ):
822
- if jax.default_backend() == 'cpu':
823
- return event_csrmv_p(
824
- weights,
825
- indices,
826
- indptr,
827
- v,
828
- outs=[
829
- jax.ShapeDtypeStruct([shape[1]], weights.dtype)
830
- if transpose else
831
- jax.ShapeDtypeStruct([shape[0]], weights.dtype),
832
- ],
833
- # block_size=block_size,
834
- float_as_event=float_as_event,
835
- shape=shape,
836
- transpose=transpose,
837
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
838
- spike_info=jax.ShapeDtypeStruct(v.shape, v.dtype),
839
- )
840
- else:
841
- return [
842
- csr_matvec(
843
- weights,
844
- indices,
845
- indptr,
846
- v,
847
- shape=shape,
848
- transpose=transpose
849
- )
850
- ]
851
-
852
-
853
- def event_csrmm_batching(args, axes, **kwargs):
854
- if tuple(axes) == (None, None, None, 0):
855
- batch_shape = args[3].shape[:-1]
856
- B = jnp.reshape(args[3], (-1, args[3].shape[-1:]))
857
- r = event_csrmm_p_call(args[0], args[1], args[2], B, **kwargs)
858
- return 0, [jnp.reshape(r[0], batch_shape + r.shape[-1:])]
859
- else:
860
- raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
861
-
862
-
863
- event_csrmm_p = XLACustomOp(
864
- 'event_csrmm',
865
- cpu_kernel_or_generator=event_csrmv_cpu_kernel_generator,
866
- )
867
- event_csrmm_p.def_batching_rule(event_csrmm_batching)
868
-
869
-
870
- def event_csrmm_p_call(
871
- weights,
872
- indices,
873
- indptr,
874
- B,
875
- *,
876
- shape: Shape,
877
- transpose: bool,
878
- float_as_event: bool,
879
- ):
880
- if jax.default_backend() == 'cpu':
881
- return event_csrmm_p(
882
- weights,
883
- indices,
884
- indptr,
885
- B,
886
- outs=[
887
- jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype)
888
- if transpose else
889
- jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype),
890
- ],
891
- # block_size=block_size,
892
- float_as_event=float_as_event,
893
- weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
894
- spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),
895
- )
896
- else:
897
- return [
898
- csr_matmat(
899
- weights,
900
- indices,
901
- indptr,
902
- B,
903
- shape=shape,
904
- transpose=transpose
905
- )
906
- ]