brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_jit.py +37 -28
  9. brainstate/compile/_loop_collect_return.py +6 -3
  10. brainstate/compile/_loop_no_collection.py +2 -0
  11. brainstate/compile/_make_jaxpr.py +7 -3
  12. brainstate/compile/_progress_bar.py +68 -40
  13. brainstate/compile/_unvmap.py +6 -3
  14. brainstate/event/__init__.py +0 -2
  15. brainstate/event/_csr.py +266 -23
  16. brainstate/event/_csr_test.py +187 -0
  17. brainstate/event/_xla_custom_op.py +7 -3
  18. brainstate/graph/__init__.py +8 -12
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_interaction/_conv.py +4 -2
  24. brainstate/nn/_interaction/_linear.py +84 -10
  25. brainstate/random/_rand_funs.py +9 -2
  26. brainstate/random/_rand_seed.py +12 -2
  27. brainstate/random/_rand_state.py +50 -179
  28. brainstate/surrogate.py +5 -1
  29. brainstate/util/__init__.py +0 -4
  30. brainstate/util/_caller.py +1 -1
  31. brainstate/util/_dict.py +4 -1
  32. brainstate/util/_filter.py +1 -1
  33. brainstate/util/_pretty_repr.py +1 -1
  34. brainstate/util/_struct.py +1 -1
  35. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +40 -46
  37. brainstate/event/_csr_mv_test.py +0 -118
  38. brainstate/graph/_graph_context.py +0 -443
  39. brainstate/graph/_graph_context_test.py +0 -65
  40. brainstate/graph/_graph_convert.py +0 -246
  41. brainstate/util/_tracers.py +0 -68
  42. brainstate/util/_visualization.py +0 -47
  43. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  44. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  45. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  46. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -16,34 +16,59 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import copy
19
- from typing import Optional
19
+ import importlib.util
20
+ from typing import Optional, Callable, Any, Tuple, Dict
20
21
 
21
22
  import jax
22
23
 
23
- try:
24
- from tqdm.auto import tqdm
25
- except (ImportError, ModuleNotFoundError):
26
- tqdm = None
24
+ tqdm_installed = importlib.util.find_spec('tqdm') is not None
27
25
 
28
26
  __all__ = [
29
27
  'ProgressBar',
30
28
  ]
31
29
 
30
+ Index = int
31
+ Carray = Any
32
+ Output = Any
33
+
32
34
 
33
35
  class ProgressBar(object):
36
+ """
37
+ A progress bar for tracking the progress of a jitted for-loop computation.
38
+ """
34
39
  __module__ = "brainstate.compile"
35
40
 
36
- def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
41
+ def __init__(
42
+ self,
43
+ freq: Optional[int] = None,
44
+ count: Optional[int] = None,
45
+ desc: Optional[Tuple[str, Callable[[Dict], Dict]]] = None,
46
+ **kwargs
47
+ ):
48
+ # print rate
37
49
  self.print_freq = freq
38
50
  if isinstance(freq, int):
39
51
  assert freq > 0, "Print rate should be > 0."
52
+
53
+ # print count
40
54
  self.print_count = count
41
55
  if self.print_freq is not None and self.print_count is not None:
42
56
  raise ValueError("Cannot specify both count and freq.")
57
+
58
+ # other parameters
43
59
  for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
44
60
  kwargs.pop(kwarg, None)
45
61
  self.kwargs = kwargs
46
- if tqdm is None:
62
+
63
+ # description
64
+ if desc is not None:
65
+ assert isinstance(desc, (tuple, list)), 'Description should be a tuple or list.'
66
+ assert isinstance(desc[0], str), 'Description should be a string.'
67
+ assert callable(desc[1]), 'Description should be a callable.'
68
+ self.desc = desc
69
+
70
+ # check if tqdm is installed
71
+ if not tqdm_installed:
47
72
  raise ImportError("tqdm is not installed.")
48
73
 
49
74
  def init(self, n: int):
@@ -67,15 +92,22 @@ class ProgressBar(object):
67
92
  raise ValueError("Print rate should be less than the "
68
93
  f"number of steps {n}, got {freq}")
69
94
  remainder = n % freq
70
- desc = kwargs.pop("desc", f"Running for {n:,} iterations")
71
- message = kwargs.pop("message", desc)
72
- return ProgressBarRunner(n, message, freq, remainder, **kwargs)
95
+
96
+ message = f"Running for {n:,} iterations" if self.desc is None else self.desc
97
+ return ProgressBarRunner(n, freq, remainder, message, **kwargs)
73
98
 
74
99
 
75
100
  class ProgressBarRunner(object):
76
101
  __module__ = "brainstate.compile"
77
102
 
78
- def __init__(self, n: int, message, print_freq: int, remainder: int, **kwargs):
103
+ def __init__(
104
+ self,
105
+ n: int,
106
+ print_freq: int,
107
+ remainder: int,
108
+ message: str | Tuple[str, Callable[[Dict], Dict]],
109
+ **kwargs
110
+ ):
79
111
  self.tqdm_bars = {}
80
112
  self.kwargs = kwargs
81
113
  self.n = n
@@ -83,50 +115,46 @@ class ProgressBarRunner(object):
83
115
  self.remainder = remainder
84
116
  self.message = message
85
117
 
86
- def _define_tqdm(self):
118
+ def _define_tqdm(self, x: dict):
119
+ from tqdm.auto import tqdm
87
120
  self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
88
- self.tqdm_bars[0].set_description(self.message, refresh=False)
121
+ if isinstance(self.message, str):
122
+ self.tqdm_bars[0].set_description(self.message, refresh=False)
123
+ else:
124
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
89
125
 
90
- def _update_tqdm(self):
126
+ def _update_tqdm(self, x: dict):
91
127
  self.tqdm_bars[0].update(self.print_freq)
128
+ if not isinstance(self.message, str):
129
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
92
130
 
93
- def _close_tqdm(self):
131
+ def _close_tqdm(self, x: dict):
94
132
  if self.remainder > 0:
95
133
  self.tqdm_bars[0].update(self.remainder)
134
+ if not isinstance(self.message, str):
135
+ self.tqdm_bars[0].set_description(self.message[0].format(**x), refresh=True)
96
136
  self.tqdm_bars[0].close()
97
137
 
98
- def _tqdm(self, is_init, is_print, is_final):
99
- if is_init:
100
- self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
101
- self.tqdm_bars[0].set_description(self.message, refresh=False)
102
- if is_print:
103
- self.tqdm_bars[0].update(self.print_freq)
104
- if is_final:
105
- if self.remainder > 0:
106
- self.tqdm_bars[0].update(self.remainder)
107
- self.tqdm_bars[0].close()
108
-
109
- def __call__(self, iter_num, *args, **kwargs):
110
- # jax.debug.callback(
111
- # self._tqdm,
112
- # iter_num == 0,
113
- # (iter_num + 1) % self.print_freq == 0,
114
- # iter_num == self.n - 1
115
- # )
138
+ def __call__(self, iter_num, **kwargs):
139
+ data = dict(i=iter_num, **kwargs)
140
+ data = dict() if isinstance(self.message, str) else self.message[1](data)
141
+ assert isinstance(data, dict), 'Description function should return a dictionary.'
116
142
 
117
143
  _ = jax.lax.cond(
118
144
  iter_num == 0,
119
- lambda: jax.debug.callback(self._define_tqdm, ordered=True),
120
- lambda: None,
145
+ lambda x: jax.debug.callback(self._define_tqdm, x, ordered=True),
146
+ lambda x: None,
147
+ data
121
148
  )
122
149
  _ = jax.lax.cond(
123
150
  iter_num % self.print_freq == (self.print_freq - 1),
124
- lambda: jax.debug.callback(self._update_tqdm, ordered=True),
125
- lambda: None,
151
+ lambda x: jax.debug.callback(self._update_tqdm, x, ordered=True),
152
+ lambda x: None,
153
+ data
126
154
  )
127
155
  _ = jax.lax.cond(
128
156
  iter_num == self.n - 1,
129
- lambda: jax.debug.callback(self._close_tqdm, ordered=True),
130
- lambda: None,
157
+ lambda x: jax.debug.callback(self._close_tqdm, x, ordered=True),
158
+ lambda x: None,
159
+ data
131
160
  )
132
-
@@ -16,13 +16,16 @@ from __future__ import annotations
16
16
 
17
17
  import jax
18
18
  import jax.core
19
- import jax.extend as je
20
19
  import jax.interpreters.batching as batching
21
20
  import jax.interpreters.mlir as mlir
22
21
  import jax.numpy as jnp
23
-
24
22
  from brainstate._utils import set_module_as
25
23
 
24
+ if jax.__version_info__ < (0, 4, 38):
25
+ from jax.core import Primitive
26
+ else:
27
+ from jax.extend.core import Primitive
28
+
26
29
  __all__ = [
27
30
  "unvmap",
28
31
  ]
@@ -44,7 +47,7 @@ def unvmap(x, op: str = 'any'):
44
47
 
45
48
  # unvmap_all
46
49
 
47
- unvmap_all_p = je.core.Primitive("unvmap_all")
50
+ unvmap_all_p = Primitive("unvmap_all")
48
51
 
49
52
 
50
53
  def unvmap_all(x):
@@ -15,13 +15,11 @@
15
15
 
16
16
 
17
17
  from ._csr import *
18
- from ._csr_mv import *
19
18
  from ._fixedprob_mv import *
20
19
  from ._linear_mv import *
21
20
  from ._xla_custom_op import *
22
21
 
23
22
  __all__ = [
24
- 'CSRLinear',
25
23
  'FixedProb',
26
24
  'XLACustomOp',
27
25
  'CSR',
brainstate/event/_csr.py CHANGED
@@ -73,7 +73,7 @@ class CSR(u.sparse.SparseMatrix):
73
73
  return u.sparse.csr_todense(self)
74
74
 
75
75
  def transpose(self, axes=None):
76
- assert axes is None
76
+ assert axes is None, "transpose does not support axes argument."
77
77
  return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
78
78
 
79
79
  def __abs__(self):
@@ -103,6 +103,7 @@ class CSR(u.sparse.SparseMatrix):
103
103
  (op(self.data, other), self.indices, self.indptr),
104
104
  shape=self.shape
105
105
  )
106
+
106
107
  elif other.ndim == 2 and other.shape == self.shape:
107
108
  rows, cols = csr_to_coo(self.indices, self.indptr)
108
109
  other = other[rows, cols]
@@ -112,6 +113,7 @@ class CSR(u.sparse.SparseMatrix):
112
113
  self.indptr),
113
114
  shape=self.shape
114
115
  )
116
+
115
117
  else:
116
118
  raise NotImplementedError(f"mul with object of shape {other.shape}")
117
119
 
@@ -184,10 +186,12 @@ class CSR(u.sparse.SparseMatrix):
184
186
  return self._binary_rop(other, operator.mod)
185
187
 
186
188
  def __matmul__(self, other):
189
+ # csr @ other
187
190
  if isinstance(other, JAXSparse):
188
191
  raise NotImplementedError("matmul between two sparse objects.")
189
192
  other = u.math.asarray(other)
190
- data, other = u.math.promote_dtypes(self.data, other)
193
+ data = self.data
194
+ # data, other = u.math.promote_dtypes(self.data, other)
191
195
  if other.ndim == 1:
192
196
  return _csr_matvec(
193
197
  data,
@@ -208,10 +212,12 @@ class CSR(u.sparse.SparseMatrix):
208
212
  raise NotImplementedError(f"matmul with object of shape {other.shape}")
209
213
 
210
214
  def __rmatmul__(self, other):
215
+ # other @ csr
211
216
  if isinstance(other, JAXSparse):
212
217
  raise NotImplementedError("matmul between two sparse objects.")
213
218
  other = u.math.asarray(other)
214
- data, other = u.math.promote_dtypes(self.data, other)
219
+ data = self.data
220
+ # data, other = u.math.promote_dtypes(self.data, other)
215
221
  if other.ndim == 1:
216
222
  return _csr_matvec(
217
223
  data,
@@ -566,7 +572,7 @@ def event_csrmv_cpu_kernel_generator(
566
572
  if weight_info.size == 1:
567
573
  if transpose:
568
574
  if spike_info.dtype == jnp.bool_:
569
- @numba.njit
575
+ @numba.njit(fastmath=True)
570
576
  def mv(weights, indices, indptr, v, posts):
571
577
  posts[:] = 0.
572
578
  w = weights[()]
@@ -576,7 +582,7 @@ def event_csrmv_cpu_kernel_generator(
576
582
  posts[indices[j]] += w
577
583
 
578
584
  elif float_as_event:
579
- @numba.njit
585
+ @numba.njit(fastmath=True)
580
586
  def mv(weights, indices, indptr, v, posts):
581
587
  posts[:] = 0.
582
588
  w = weights[()]
@@ -586,7 +592,7 @@ def event_csrmv_cpu_kernel_generator(
586
592
  posts[indices[j]] += w
587
593
 
588
594
  else:
589
- @numba.njit
595
+ @numba.njit(fastmath=True)
590
596
  def mv(weights, indices, indptr, v, posts):
591
597
  posts[:] = 0.
592
598
  w = weights[()]
@@ -599,7 +605,7 @@ def event_csrmv_cpu_kernel_generator(
599
605
 
600
606
  else:
601
607
  if spike_info.dtype == jnp.bool_:
602
- @numba.njit
608
+ @numba.njit(fastmath=True)
603
609
  def mv(weights, indices, indptr, v, posts):
604
610
  w = weights[()]
605
611
  for i in range(indptr.shape[0] - 1):
@@ -610,7 +616,7 @@ def event_csrmv_cpu_kernel_generator(
610
616
  posts[i] = r
611
617
 
612
618
  elif float_as_event:
613
- @numba.njit
619
+ @numba.njit(fastmath=True)
614
620
  def mv(weights, indices, indptr, v, posts):
615
621
  w = weights[()]
616
622
  for i in range(indptr.shape[0] - 1):
@@ -621,7 +627,7 @@ def event_csrmv_cpu_kernel_generator(
621
627
  posts[i] = r
622
628
 
623
629
  else:
624
- @numba.njit
630
+ @numba.njit(fastmath=True)
625
631
  def mv(weights, indices, indptr, v, posts):
626
632
  w = weights[()]
627
633
  for i in range(indptr.shape[0] - 1):
@@ -635,7 +641,7 @@ def event_csrmv_cpu_kernel_generator(
635
641
  else:
636
642
  if transpose:
637
643
  if spike_info.dtype == jnp.bool_:
638
- @numba.njit
644
+ @numba.njit(fastmath=True)
639
645
  def mv(weights, indices, indptr, v, posts):
640
646
  posts[:] = 0.
641
647
  for i in range(v.shape[0]):
@@ -644,7 +650,7 @@ def event_csrmv_cpu_kernel_generator(
644
650
  posts[indices[j]] += weights[j]
645
651
 
646
652
  elif float_as_event:
647
- @numba.njit
653
+ @numba.njit(fastmath=True)
648
654
  def mv(weights, indices, indptr, v, posts):
649
655
  posts[:] = 0.
650
656
  for i in range(v.shape[0]):
@@ -653,7 +659,7 @@ def event_csrmv_cpu_kernel_generator(
653
659
  posts[indices[j]] += weights[j]
654
660
 
655
661
  else:
656
- @numba.njit
662
+ @numba.njit(fastmath=True)
657
663
  def mv(weights, indices, indptr, v, posts):
658
664
  posts[:] = 0.
659
665
  for i in range(v.shape[0]):
@@ -664,7 +670,7 @@ def event_csrmv_cpu_kernel_generator(
664
670
 
665
671
  else:
666
672
  if spike_info.dtype == jnp.bool_:
667
- @numba.njit
673
+ @numba.njit(fastmath=True)
668
674
  def mv(weights, indices, indptr, v, posts):
669
675
  for i in range(indptr.shape[0] - 1):
670
676
  r = 0.
@@ -674,7 +680,7 @@ def event_csrmv_cpu_kernel_generator(
674
680
  posts[i] = r
675
681
 
676
682
  elif float_as_event:
677
- @numba.njit
683
+ @numba.njit(fastmath=True)
678
684
  def mv(weights, indices, indptr, v, posts):
679
685
  for i in range(indptr.shape[0] - 1):
680
686
  r = 0.
@@ -684,7 +690,7 @@ def event_csrmv_cpu_kernel_generator(
684
690
  posts[i] = r
685
691
 
686
692
  else:
687
- @numba.njit
693
+ @numba.njit(fastmath=True)
688
694
  def mv(weights, indices, indptr, v, posts):
689
695
  for i in range(indptr.shape[0] - 1):
690
696
  r = 0.
@@ -795,7 +801,31 @@ def event_csrmv_transpose_rule(
795
801
 
796
802
  def event_csrmv_batching(args, axes, **kwargs):
797
803
  if tuple(axes) == (None, None, None, 0):
798
- return 0, event_csrmm_p_call(*args, **kwargs)
804
+ assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
805
+ r = event_csrmm_p_call(
806
+ args[0],
807
+ args[1],
808
+ args[2],
809
+ args[3].T,
810
+ shape=kwargs['shape'],
811
+ transpose=kwargs['transpose'],
812
+ float_as_event=kwargs['float_as_event']
813
+ )
814
+ return r, [1]
815
+
816
+ elif tuple(axes) == (None, None, None, 1):
817
+ assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
818
+ r = event_csrmm_p_call(
819
+ args[0],
820
+ args[1],
821
+ args[2],
822
+ args[3],
823
+ shape=kwargs['shape'],
824
+ transpose=kwargs['transpose'],
825
+ float_as_event=kwargs['float_as_event']
826
+ )
827
+ return r, [1]
828
+
799
829
  else:
800
830
  raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
801
831
 
@@ -852,17 +882,228 @@ def event_csrmv_p_call(
852
882
 
853
883
  def event_csrmm_batching(args, axes, **kwargs):
854
884
  if tuple(axes) == (None, None, None, 0):
855
- batch_shape = args[3].shape[:-1]
856
- B = jnp.reshape(args[3], (-1, args[3].shape[-1:]))
857
- r = event_csrmm_p_call(args[0], args[1], args[2], B, **kwargs)
858
- return 0, [jnp.reshape(r[0], batch_shape + r.shape[-1:])]
885
+ assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
886
+ batch_size, m, n = args[3].shape
887
+ B = jnp.transpose(args[3], (1, 0, 2)).reshape(m, batch_size * n)
888
+ r = event_csrmm_p_call(
889
+ args[0],
890
+ args[1],
891
+ args[2],
892
+ B,
893
+ shape=kwargs['shape'],
894
+ transpose=kwargs['transpose'],
895
+ float_as_event=kwargs['float_as_event']
896
+ )
897
+ r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
898
+ return [r], [1]
899
+
900
+ elif tuple(axes) == (None, None, None, 1):
901
+ assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
902
+ m, batch_size, n = args[3].shape
903
+ B = args[3].reshape(m, batch_size * n)
904
+ r = event_csrmm_p_call(
905
+ args[0],
906
+ args[1],
907
+ args[2],
908
+ B,
909
+ shape=kwargs['shape'],
910
+ transpose=kwargs['transpose'],
911
+ float_as_event=kwargs['float_as_event']
912
+ )
913
+ r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
914
+ return [r], [1]
915
+
916
+ elif tuple(axes) == (None, None, None, 2):
917
+ assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
918
+ m, n, batch_size = args[3].shape
919
+ B = args[3].reshape(m, batch_size * n)
920
+ r = event_csrmm_p_call(
921
+ args[0],
922
+ args[1],
923
+ args[2],
924
+ B,
925
+ shape=kwargs['shape'],
926
+ transpose=kwargs['transpose'],
927
+ float_as_event=kwargs['float_as_event']
928
+ )
929
+ r = jnp.reshape(r[0], [r[0].shape[0], n, batch_size])
930
+ return [r], [2]
931
+
859
932
  else:
860
933
  raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
861
934
 
862
935
 
936
+ def event_csrmm_cpu_kernel_generator(
937
+ float_as_event: bool,
938
+ weight_info: jax.ShapeDtypeStruct,
939
+ spike_info: jax.ShapeDtypeStruct,
940
+ transpose: bool,
941
+ **kwargs
942
+ ) -> Kernel:
943
+ import numba # pylint: disable=import-outside-toplevel
944
+
945
+ if weight_info.size == 1:
946
+ if transpose:
947
+ # csr.T @ B
948
+
949
+ if spike_info.dtype == jnp.bool_:
950
+ @numba.njit(fastmath=True, parallel=False)
951
+ def mv(weights, indices, indptr, B, posts):
952
+ posts[:] = 0.
953
+ w = weights[()]
954
+ for k in numba.prange(B.shape[1]):
955
+ for i in range(B.shape[0]):
956
+ if B[i, k]:
957
+ for j in range(indptr[i], indptr[i + 1]):
958
+ posts[indices[j], k] += w
959
+
960
+ elif float_as_event:
961
+ @numba.njit(fastmath=True, parallel=False)
962
+ def mv(weights, indices, indptr, B, posts):
963
+ posts[:] = 0.
964
+ B = B != 0.
965
+ w = weights[()]
966
+ for k in numba.prange(B.shape[1]):
967
+ for i in range(B.shape[0]):
968
+ if B[i, k]:
969
+ for j in range(indptr[i], indptr[i + 1]):
970
+ posts[indices[j], k] += w
971
+
972
+ else:
973
+ @numba.njit(fastmath=True, parallel=False)
974
+ def mv(weights, indices, indptr, B, posts):
975
+ posts[:] = 0.
976
+ w = weights[()]
977
+ for k in numba.prange(B.shape[1]):
978
+ for i in range(B.shape[0]):
979
+ sp = B[i, k]
980
+ if sp != 0.:
981
+ wsp = w * sp
982
+ for j in range(indptr[i], indptr[i + 1]):
983
+ posts[indices[j], k] += wsp
984
+
985
+ else:
986
+ # csr @ B
987
+ if spike_info.dtype == jnp.bool_:
988
+ @numba.njit(fastmath=True)
989
+ def mv(weights, indices, indptr, B, posts):
990
+ w = weights[()]
991
+ for i in range(indptr.shape[0] - 1):
992
+ r = np.zeros(B.shape[1], dtype=weights.dtype)
993
+ for j in range(indptr[i], indptr[i + 1]):
994
+ index = indices[j]
995
+ for k in range(B.shape[1]):
996
+ if B[index, k]:
997
+ r[k] += w
998
+ posts[i] = r
999
+
1000
+ elif float_as_event:
1001
+ @numba.njit(fastmath=True)
1002
+ def mv(weights, indices, indptr, B, posts):
1003
+ w = weights[()]
1004
+ B = B != 0.
1005
+ for i in range(indptr.shape[0] - 1):
1006
+ r = np.zeros(B.shape[1], dtype=weights.dtype)
1007
+ for j in range(indptr[i], indptr[i + 1]):
1008
+ index = indices[j]
1009
+ for k in range(B.shape[1]):
1010
+ if B[index, k]:
1011
+ r[k] += w
1012
+ posts[i] = r
1013
+
1014
+ else:
1015
+ @numba.njit(fastmath=True)
1016
+ def mv(weights, indices, indptr, B, posts):
1017
+ w = weights[()]
1018
+ for i in range(indptr.shape[0] - 1):
1019
+ for k in range(B.shape[1]):
1020
+ r = 0.
1021
+ for j in range(indptr[i], indptr[i + 1]):
1022
+ c = B[indices[j], k]
1023
+ if c != 0.:
1024
+ r += w * c
1025
+ posts[i, k] = r
1026
+
1027
+ else:
1028
+ if transpose:
1029
+ # csr.T @ B
1030
+
1031
+ if spike_info.dtype == jnp.bool_:
1032
+ @numba.njit(fastmath=True, parallel=False)
1033
+ def mv(weights, indices, indptr, B, posts):
1034
+ posts[:] = 0.
1035
+ for k in numba.prange(B.shape[1]):
1036
+ for i in range(B.shape[0]):
1037
+ if B[i, k]:
1038
+ for j in range(indptr[i], indptr[i + 1]):
1039
+ posts[indices[j], k] += weights[j]
1040
+
1041
+ elif float_as_event:
1042
+ @numba.njit(fastmath=True, parallel=False)
1043
+ def mv(weights, indices, indptr, B, posts):
1044
+ posts[:] = 0.
1045
+ B = B != 0.
1046
+ for k in numba.prange(B.shape[1]):
1047
+ for i in range(B.shape[0]):
1048
+ if B[i, k]:
1049
+ for j in range(indptr[i], indptr[i + 1]):
1050
+ posts[indices[j], k] += weights[j]
1051
+
1052
+ else:
1053
+ @numba.njit(fastmath=True, parallel=False)
1054
+ def mv(weights, indices, indptr, B, posts):
1055
+ posts[:] = 0.
1056
+ for k in numba.prange(B.shape[1]):
1057
+ for i in range(B.shape[0]):
1058
+ sp = B[i, k]
1059
+ if sp != 0.:
1060
+ for j in range(indptr[i], indptr[i + 1]):
1061
+ posts[indices[j], k] += weights[j] * sp
1062
+
1063
+ else:
1064
+ # csr @ B
1065
+
1066
+ if spike_info.dtype == jnp.bool_:
1067
+ @numba.njit(fastmath=True)
1068
+ def mv(weights, indices, indptr, B, posts):
1069
+ for i in range(indptr.shape[0] - 1):
1070
+ for k in range(B.shape[1]):
1071
+ r = 0.
1072
+ for j in range(indptr[i], indptr[i + 1]):
1073
+ if B[indices[j], k]:
1074
+ r += weights[j]
1075
+ posts[i, k] = r
1076
+
1077
+ elif float_as_event:
1078
+ @numba.njit(fastmath=True)
1079
+ def mv(weights, indices, indptr, B, posts):
1080
+ B = B != 0.
1081
+ for i in range(indptr.shape[0] - 1):
1082
+ for k in range(B.shape[1]):
1083
+ r = 0.
1084
+ for j in range(indptr[i], indptr[i + 1]):
1085
+ if B[indices[j], k]:
1086
+ r += weights[j]
1087
+ posts[i, k] = r
1088
+
1089
+ else:
1090
+ @numba.njit(fastmath=True)
1091
+ def mv(weights, indices, indptr, B, posts):
1092
+ for i in range(indptr.shape[0] - 1):
1093
+ for k in range(B.shape[1]):
1094
+ r = 0.
1095
+ for j in range(indptr[i], indptr[i + 1]):
1096
+ c = B[indices[j], k]
1097
+ if c != 0.:
1098
+ r += weights[j] * c
1099
+ posts[i, k] = r
1100
+
1101
+ return mv
1102
+
1103
+
863
1104
  event_csrmm_p = XLACustomOp(
864
1105
  'event_csrmm',
865
- cpu_kernel_or_generator=event_csrmv_cpu_kernel_generator,
1106
+ cpu_kernel_or_generator=event_csrmm_cpu_kernel_generator,
866
1107
  )
867
1108
  event_csrmm_p.def_batching_rule(event_csrmm_batching)
868
1109
 
@@ -884,11 +1125,13 @@ def event_csrmm_p_call(
884
1125
  indptr,
885
1126
  B,
886
1127
  outs=[
887
- jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype)
1128
+ jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype)
888
1129
  if transpose else
889
- jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype),
1130
+ jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype),
890
1131
  ],
891
1132
  # block_size=block_size,
1133
+ shape=shape,
1134
+ transpose=transpose,
892
1135
  float_as_event=float_as_event,
893
1136
  weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
894
1137
  spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),