brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_error_if_test.py +1 -0
  9. brainstate/compile/_jit.py +37 -28
  10. brainstate/compile/_loop_collect_return.py +8 -5
  11. brainstate/compile/_loop_no_collection.py +2 -0
  12. brainstate/compile/_make_jaxpr.py +7 -3
  13. brainstate/compile/_make_jaxpr_test.py +2 -1
  14. brainstate/compile/_progress_bar.py +68 -40
  15. brainstate/compile/_unvmap.py +6 -2
  16. brainstate/environ.py +28 -18
  17. brainstate/environ_test.py +4 -0
  18. brainstate/event/__init__.py +0 -2
  19. brainstate/event/_csr.py +266 -23
  20. brainstate/event/_csr_test.py +187 -0
  21. brainstate/event/_fixedprob_mv.py +4 -2
  22. brainstate/event/_fixedprob_mv_test.py +2 -1
  23. brainstate/event/_xla_custom_op.py +16 -5
  24. brainstate/graph/__init__.py +8 -12
  25. brainstate/graph/_graph_node.py +1 -23
  26. brainstate/graph/_graph_operation.py +1 -1
  27. brainstate/graph/_graph_operation_test.py +0 -159
  28. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  29. brainstate/nn/_interaction/_conv.py +4 -2
  30. brainstate/nn/_interaction/_linear.py +84 -10
  31. brainstate/random/_rand_funs.py +9 -2
  32. brainstate/random/_rand_seed.py +12 -2
  33. brainstate/random/_rand_state.py +50 -179
  34. brainstate/surrogate.py +5 -1
  35. brainstate/util/__init__.py +0 -4
  36. brainstate/util/_caller.py +1 -1
  37. brainstate/util/_dict.py +4 -1
  38. brainstate/util/_filter.py +1 -1
  39. brainstate/util/_pretty_repr.py +1 -1
  40. brainstate/util/_struct.py +1 -1
  41. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  42. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
  43. brainstate/event/_csr_mv_test.py +0 -118
  44. brainstate/graph/_graph_context.py +0 -443
  45. brainstate/graph/_graph_context_test.py +0 -65
  46. brainstate/graph/_graph_convert.py +0 -246
  47. brainstate/util/_tracers.py +0 -68
  48. brainstate/util/_visualization.py +0 -47
  49. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  50. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  51. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  52. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
brainstate/event/_csr.py CHANGED
@@ -73,7 +73,7 @@ class CSR(u.sparse.SparseMatrix):
73
73
  return u.sparse.csr_todense(self)
74
74
 
75
75
  def transpose(self, axes=None):
76
- assert axes is None
76
+ assert axes is None, "transpose does not support axes argument."
77
77
  return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
78
78
 
79
79
  def __abs__(self):
@@ -103,6 +103,7 @@ class CSR(u.sparse.SparseMatrix):
103
103
  (op(self.data, other), self.indices, self.indptr),
104
104
  shape=self.shape
105
105
  )
106
+
106
107
  elif other.ndim == 2 and other.shape == self.shape:
107
108
  rows, cols = csr_to_coo(self.indices, self.indptr)
108
109
  other = other[rows, cols]
@@ -112,6 +113,7 @@ class CSR(u.sparse.SparseMatrix):
112
113
  self.indptr),
113
114
  shape=self.shape
114
115
  )
116
+
115
117
  else:
116
118
  raise NotImplementedError(f"mul with object of shape {other.shape}")
117
119
 
@@ -184,10 +186,12 @@ class CSR(u.sparse.SparseMatrix):
184
186
  return self._binary_rop(other, operator.mod)
185
187
 
186
188
  def __matmul__(self, other):
189
+ # csr @ other
187
190
  if isinstance(other, JAXSparse):
188
191
  raise NotImplementedError("matmul between two sparse objects.")
189
192
  other = u.math.asarray(other)
190
- data, other = u.math.promote_dtypes(self.data, other)
193
+ data = self.data
194
+ # data, other = u.math.promote_dtypes(self.data, other)
191
195
  if other.ndim == 1:
192
196
  return _csr_matvec(
193
197
  data,
@@ -208,10 +212,12 @@ class CSR(u.sparse.SparseMatrix):
208
212
  raise NotImplementedError(f"matmul with object of shape {other.shape}")
209
213
 
210
214
  def __rmatmul__(self, other):
215
+ # other @ csr
211
216
  if isinstance(other, JAXSparse):
212
217
  raise NotImplementedError("matmul between two sparse objects.")
213
218
  other = u.math.asarray(other)
214
- data, other = u.math.promote_dtypes(self.data, other)
219
+ data = self.data
220
+ # data, other = u.math.promote_dtypes(self.data, other)
215
221
  if other.ndim == 1:
216
222
  return _csr_matvec(
217
223
  data,
@@ -566,7 +572,7 @@ def event_csrmv_cpu_kernel_generator(
566
572
  if weight_info.size == 1:
567
573
  if transpose:
568
574
  if spike_info.dtype == jnp.bool_:
569
- @numba.njit
575
+ @numba.njit(fastmath=True)
570
576
  def mv(weights, indices, indptr, v, posts):
571
577
  posts[:] = 0.
572
578
  w = weights[()]
@@ -576,7 +582,7 @@ def event_csrmv_cpu_kernel_generator(
576
582
  posts[indices[j]] += w
577
583
 
578
584
  elif float_as_event:
579
- @numba.njit
585
+ @numba.njit(fastmath=True)
580
586
  def mv(weights, indices, indptr, v, posts):
581
587
  posts[:] = 0.
582
588
  w = weights[()]
@@ -586,7 +592,7 @@ def event_csrmv_cpu_kernel_generator(
586
592
  posts[indices[j]] += w
587
593
 
588
594
  else:
589
- @numba.njit
595
+ @numba.njit(fastmath=True)
590
596
  def mv(weights, indices, indptr, v, posts):
591
597
  posts[:] = 0.
592
598
  w = weights[()]
@@ -599,7 +605,7 @@ def event_csrmv_cpu_kernel_generator(
599
605
 
600
606
  else:
601
607
  if spike_info.dtype == jnp.bool_:
602
- @numba.njit
608
+ @numba.njit(fastmath=True)
603
609
  def mv(weights, indices, indptr, v, posts):
604
610
  w = weights[()]
605
611
  for i in range(indptr.shape[0] - 1):
@@ -610,7 +616,7 @@ def event_csrmv_cpu_kernel_generator(
610
616
  posts[i] = r
611
617
 
612
618
  elif float_as_event:
613
- @numba.njit
619
+ @numba.njit(fastmath=True)
614
620
  def mv(weights, indices, indptr, v, posts):
615
621
  w = weights[()]
616
622
  for i in range(indptr.shape[0] - 1):
@@ -621,7 +627,7 @@ def event_csrmv_cpu_kernel_generator(
621
627
  posts[i] = r
622
628
 
623
629
  else:
624
- @numba.njit
630
+ @numba.njit(fastmath=True)
625
631
  def mv(weights, indices, indptr, v, posts):
626
632
  w = weights[()]
627
633
  for i in range(indptr.shape[0] - 1):
@@ -635,7 +641,7 @@ def event_csrmv_cpu_kernel_generator(
635
641
  else:
636
642
  if transpose:
637
643
  if spike_info.dtype == jnp.bool_:
638
- @numba.njit
644
+ @numba.njit(fastmath=True)
639
645
  def mv(weights, indices, indptr, v, posts):
640
646
  posts[:] = 0.
641
647
  for i in range(v.shape[0]):
@@ -644,7 +650,7 @@ def event_csrmv_cpu_kernel_generator(
644
650
  posts[indices[j]] += weights[j]
645
651
 
646
652
  elif float_as_event:
647
- @numba.njit
653
+ @numba.njit(fastmath=True)
648
654
  def mv(weights, indices, indptr, v, posts):
649
655
  posts[:] = 0.
650
656
  for i in range(v.shape[0]):
@@ -653,7 +659,7 @@ def event_csrmv_cpu_kernel_generator(
653
659
  posts[indices[j]] += weights[j]
654
660
 
655
661
  else:
656
- @numba.njit
662
+ @numba.njit(fastmath=True)
657
663
  def mv(weights, indices, indptr, v, posts):
658
664
  posts[:] = 0.
659
665
  for i in range(v.shape[0]):
@@ -664,7 +670,7 @@ def event_csrmv_cpu_kernel_generator(
664
670
 
665
671
  else:
666
672
  if spike_info.dtype == jnp.bool_:
667
- @numba.njit
673
+ @numba.njit(fastmath=True)
668
674
  def mv(weights, indices, indptr, v, posts):
669
675
  for i in range(indptr.shape[0] - 1):
670
676
  r = 0.
@@ -674,7 +680,7 @@ def event_csrmv_cpu_kernel_generator(
674
680
  posts[i] = r
675
681
 
676
682
  elif float_as_event:
677
- @numba.njit
683
+ @numba.njit(fastmath=True)
678
684
  def mv(weights, indices, indptr, v, posts):
679
685
  for i in range(indptr.shape[0] - 1):
680
686
  r = 0.
@@ -684,7 +690,7 @@ def event_csrmv_cpu_kernel_generator(
684
690
  posts[i] = r
685
691
 
686
692
  else:
687
- @numba.njit
693
+ @numba.njit(fastmath=True)
688
694
  def mv(weights, indices, indptr, v, posts):
689
695
  for i in range(indptr.shape[0] - 1):
690
696
  r = 0.
@@ -795,7 +801,31 @@ def event_csrmv_transpose_rule(
795
801
 
796
802
  def event_csrmv_batching(args, axes, **kwargs):
797
803
  if tuple(axes) == (None, None, None, 0):
798
- return 0, event_csrmm_p_call(*args, **kwargs)
804
+ assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
805
+ r = event_csrmm_p_call(
806
+ args[0],
807
+ args[1],
808
+ args[2],
809
+ args[3].T,
810
+ shape=kwargs['shape'],
811
+ transpose=kwargs['transpose'],
812
+ float_as_event=kwargs['float_as_event']
813
+ )
814
+ return r, [1]
815
+
816
+ elif tuple(axes) == (None, None, None, 1):
817
+ assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
818
+ r = event_csrmm_p_call(
819
+ args[0],
820
+ args[1],
821
+ args[2],
822
+ args[3],
823
+ shape=kwargs['shape'],
824
+ transpose=kwargs['transpose'],
825
+ float_as_event=kwargs['float_as_event']
826
+ )
827
+ return r, [1]
828
+
799
829
  else:
800
830
  raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
801
831
 
@@ -852,17 +882,228 @@ def event_csrmv_p_call(
852
882
 
853
883
  def event_csrmm_batching(args, axes, **kwargs):
854
884
  if tuple(axes) == (None, None, None, 0):
855
- batch_shape = args[3].shape[:-1]
856
- B = jnp.reshape(args[3], (-1, args[3].shape[-1:]))
857
- r = event_csrmm_p_call(args[0], args[1], args[2], B, **kwargs)
858
- return 0, [jnp.reshape(r[0], batch_shape + r.shape[-1:])]
885
+ assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
886
+ batch_size, m, n = args[3].shape
887
+ B = jnp.transpose(args[3], (1, 0, 2)).reshape(m, batch_size * n)
888
+ r = event_csrmm_p_call(
889
+ args[0],
890
+ args[1],
891
+ args[2],
892
+ B,
893
+ shape=kwargs['shape'],
894
+ transpose=kwargs['transpose'],
895
+ float_as_event=kwargs['float_as_event']
896
+ )
897
+ r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
898
+ return [r], [1]
899
+
900
+ elif tuple(axes) == (None, None, None, 1):
901
+ assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
902
+ m, batch_size, n = args[3].shape
903
+ B = args[3].reshape(m, batch_size * n)
904
+ r = event_csrmm_p_call(
905
+ args[0],
906
+ args[1],
907
+ args[2],
908
+ B,
909
+ shape=kwargs['shape'],
910
+ transpose=kwargs['transpose'],
911
+ float_as_event=kwargs['float_as_event']
912
+ )
913
+ r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
914
+ return [r], [1]
915
+
916
+ elif tuple(axes) == (None, None, None, 2):
917
+ assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
918
+ m, n, batch_size = args[3].shape
919
+ B = args[3].reshape(m, batch_size * n)
920
+ r = event_csrmm_p_call(
921
+ args[0],
922
+ args[1],
923
+ args[2],
924
+ B,
925
+ shape=kwargs['shape'],
926
+ transpose=kwargs['transpose'],
927
+ float_as_event=kwargs['float_as_event']
928
+ )
929
+ r = jnp.reshape(r[0], [r[0].shape[0], n, batch_size])
930
+ return [r], [2]
931
+
859
932
  else:
860
933
  raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
861
934
 
862
935
 
936
+ def event_csrmm_cpu_kernel_generator(
937
+ float_as_event: bool,
938
+ weight_info: jax.ShapeDtypeStruct,
939
+ spike_info: jax.ShapeDtypeStruct,
940
+ transpose: bool,
941
+ **kwargs
942
+ ) -> Kernel:
943
+ import numba # pylint: disable=import-outside-toplevel
944
+
945
+ if weight_info.size == 1:
946
+ if transpose:
947
+ # csr.T @ B
948
+
949
+ if spike_info.dtype == jnp.bool_:
950
+ @numba.njit(fastmath=True, parallel=False)
951
+ def mv(weights, indices, indptr, B, posts):
952
+ posts[:] = 0.
953
+ w = weights[()]
954
+ for k in numba.prange(B.shape[1]):
955
+ for i in range(B.shape[0]):
956
+ if B[i, k]:
957
+ for j in range(indptr[i], indptr[i + 1]):
958
+ posts[indices[j], k] += w
959
+
960
+ elif float_as_event:
961
+ @numba.njit(fastmath=True, parallel=False)
962
+ def mv(weights, indices, indptr, B, posts):
963
+ posts[:] = 0.
964
+ B = B != 0.
965
+ w = weights[()]
966
+ for k in numba.prange(B.shape[1]):
967
+ for i in range(B.shape[0]):
968
+ if B[i, k]:
969
+ for j in range(indptr[i], indptr[i + 1]):
970
+ posts[indices[j], k] += w
971
+
972
+ else:
973
+ @numba.njit(fastmath=True, parallel=False)
974
+ def mv(weights, indices, indptr, B, posts):
975
+ posts[:] = 0.
976
+ w = weights[()]
977
+ for k in numba.prange(B.shape[1]):
978
+ for i in range(B.shape[0]):
979
+ sp = B[i, k]
980
+ if sp != 0.:
981
+ wsp = w * sp
982
+ for j in range(indptr[i], indptr[i + 1]):
983
+ posts[indices[j], k] += wsp
984
+
985
+ else:
986
+ # csr @ B
987
+ if spike_info.dtype == jnp.bool_:
988
+ @numba.njit(fastmath=True)
989
+ def mv(weights, indices, indptr, B, posts):
990
+ w = weights[()]
991
+ for i in range(indptr.shape[0] - 1):
992
+ r = np.zeros(B.shape[1], dtype=weights.dtype)
993
+ for j in range(indptr[i], indptr[i + 1]):
994
+ index = indices[j]
995
+ for k in range(B.shape[1]):
996
+ if B[index, k]:
997
+ r[k] += w
998
+ posts[i] = r
999
+
1000
+ elif float_as_event:
1001
+ @numba.njit(fastmath=True)
1002
+ def mv(weights, indices, indptr, B, posts):
1003
+ w = weights[()]
1004
+ B = B != 0.
1005
+ for i in range(indptr.shape[0] - 1):
1006
+ r = np.zeros(B.shape[1], dtype=weights.dtype)
1007
+ for j in range(indptr[i], indptr[i + 1]):
1008
+ index = indices[j]
1009
+ for k in range(B.shape[1]):
1010
+ if B[index, k]:
1011
+ r[k] += w
1012
+ posts[i] = r
1013
+
1014
+ else:
1015
+ @numba.njit(fastmath=True)
1016
+ def mv(weights, indices, indptr, B, posts):
1017
+ w = weights[()]
1018
+ for i in range(indptr.shape[0] - 1):
1019
+ for k in range(B.shape[1]):
1020
+ r = 0.
1021
+ for j in range(indptr[i], indptr[i + 1]):
1022
+ c = B[indices[j], k]
1023
+ if c != 0.:
1024
+ r += w * c
1025
+ posts[i, k] = r
1026
+
1027
+ else:
1028
+ if transpose:
1029
+ # csr.T @ B
1030
+
1031
+ if spike_info.dtype == jnp.bool_:
1032
+ @numba.njit(fastmath=True, parallel=False)
1033
+ def mv(weights, indices, indptr, B, posts):
1034
+ posts[:] = 0.
1035
+ for k in numba.prange(B.shape[1]):
1036
+ for i in range(B.shape[0]):
1037
+ if B[i, k]:
1038
+ for j in range(indptr[i], indptr[i + 1]):
1039
+ posts[indices[j], k] += weights[j]
1040
+
1041
+ elif float_as_event:
1042
+ @numba.njit(fastmath=True, parallel=False)
1043
+ def mv(weights, indices, indptr, B, posts):
1044
+ posts[:] = 0.
1045
+ B = B != 0.
1046
+ for k in numba.prange(B.shape[1]):
1047
+ for i in range(B.shape[0]):
1048
+ if B[i, k]:
1049
+ for j in range(indptr[i], indptr[i + 1]):
1050
+ posts[indices[j], k] += weights[j]
1051
+
1052
+ else:
1053
+ @numba.njit(fastmath=True, parallel=False)
1054
+ def mv(weights, indices, indptr, B, posts):
1055
+ posts[:] = 0.
1056
+ for k in numba.prange(B.shape[1]):
1057
+ for i in range(B.shape[0]):
1058
+ sp = B[i, k]
1059
+ if sp != 0.:
1060
+ for j in range(indptr[i], indptr[i + 1]):
1061
+ posts[indices[j], k] += weights[j] * sp
1062
+
1063
+ else:
1064
+ # csr @ B
1065
+
1066
+ if spike_info.dtype == jnp.bool_:
1067
+ @numba.njit(fastmath=True)
1068
+ def mv(weights, indices, indptr, B, posts):
1069
+ for i in range(indptr.shape[0] - 1):
1070
+ for k in range(B.shape[1]):
1071
+ r = 0.
1072
+ for j in range(indptr[i], indptr[i + 1]):
1073
+ if B[indices[j], k]:
1074
+ r += weights[j]
1075
+ posts[i, k] = r
1076
+
1077
+ elif float_as_event:
1078
+ @numba.njit(fastmath=True)
1079
+ def mv(weights, indices, indptr, B, posts):
1080
+ B = B != 0.
1081
+ for i in range(indptr.shape[0] - 1):
1082
+ for k in range(B.shape[1]):
1083
+ r = 0.
1084
+ for j in range(indptr[i], indptr[i + 1]):
1085
+ if B[indices[j], k]:
1086
+ r += weights[j]
1087
+ posts[i, k] = r
1088
+
1089
+ else:
1090
+ @numba.njit(fastmath=True)
1091
+ def mv(weights, indices, indptr, B, posts):
1092
+ for i in range(indptr.shape[0] - 1):
1093
+ for k in range(B.shape[1]):
1094
+ r = 0.
1095
+ for j in range(indptr[i], indptr[i + 1]):
1096
+ c = B[indices[j], k]
1097
+ if c != 0.:
1098
+ r += weights[j] * c
1099
+ posts[i, k] = r
1100
+
1101
+ return mv
1102
+
1103
+
863
1104
  event_csrmm_p = XLACustomOp(
864
1105
  'event_csrmm',
865
- cpu_kernel_or_generator=event_csrmv_cpu_kernel_generator,
1106
+ cpu_kernel_or_generator=event_csrmm_cpu_kernel_generator,
866
1107
  )
867
1108
  event_csrmm_p.def_batching_rule(event_csrmm_batching)
868
1109
 
@@ -884,11 +1125,13 @@ def event_csrmm_p_call(
884
1125
  indptr,
885
1126
  B,
886
1127
  outs=[
887
- jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype)
1128
+ jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype)
888
1129
  if transpose else
889
- jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype),
1130
+ jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype),
890
1131
  ],
891
1132
  # block_size=block_size,
1133
+ shape=shape,
1134
+ transpose=transpose,
892
1135
  float_as_event=float_as_event,
893
1136
  weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
894
1137
  spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),
@@ -18,6 +18,9 @@
18
18
  import unittest
19
19
 
20
20
  import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
21
24
 
22
25
  import brainstate as bst
23
26
 
@@ -88,3 +91,187 @@ class TestCSR(unittest.TestCase):
88
91
  v @ csr
89
92
  )
90
93
  )
94
+
95
+
96
+ def _get_csr(n_pre, n_post, prob):
97
+ n_conn = int(n_post * prob)
98
+ indptr = np.arange(n_pre + 1) * n_conn
99
+ indices = np.random.randint(0, n_post, (n_pre * n_conn,))
100
+ return indptr, indices
101
+
102
+
103
+ def vector_csr(x, w, indices, indptr, shape):
104
+ homo_w = jnp.size(w) == 1
105
+ post = jnp.zeros((shape[1],))
106
+ for i_pre in range(x.shape[0]):
107
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
108
+ post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
109
+ return post
110
+
111
+
112
+ def matrix_csr(xs, w, indices, indptr, shape):
113
+ homo_w = jnp.size(w) == 1
114
+ post = jnp.zeros((xs.shape[0], shape[1]))
115
+ for i_pre in range(xs.shape[1]):
116
+ ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
117
+ post = post.at[:, ids].add(
118
+ w * xs[:, i_pre: i_pre + 1]
119
+ if homo_w else
120
+ (w[indptr[i_pre]: indptr[i_pre + 1]] * xs[:, i_pre: i_pre + 1])
121
+ )
122
+ return post
123
+
124
+
125
+ def csr_vector(x, w, indices, indptr, shape):
126
+ homo_w = jnp.size(w) == 1
127
+ out = jnp.zeros([shape[0]])
128
+ for i in range(shape[0]):
129
+ ids = indices[indptr[i]: indptr[i + 1]]
130
+ ws = w if homo_w else w[indptr[i]: indptr[i + 1]]
131
+ out = out.at[i].set(jnp.sum(x[ids] * ws))
132
+ return out
133
+
134
+
135
+ def csr_matrix(xs, w, indices, indptr, shape):
136
+ # CSR @ matrix
137
+ homo_w = jnp.size(w) == 1
138
+ out = jnp.zeros([shape[0], xs.shape[1]])
139
+ for i in range(shape[0]):
140
+ ids = indices[indptr[i]: indptr[i + 1]]
141
+ ws = w if homo_w else jnp.expand_dims(w[indptr[i]: indptr[i + 1]], axis=1)
142
+ out = out.at[i].set(jnp.sum(xs[ids] * ws, axis=0))
143
+ return out
144
+
145
+
146
+ class TestVectorCSR(unittest.TestCase):
147
+ def test_vector_csr(self, ):
148
+ m, n = 20, 40
149
+ x = bst.random.rand(m) < 0.1
150
+ indptr, indices = _get_csr(m, n, 0.1)
151
+
152
+ for homo_w in [True, False]:
153
+ print(f'homo_w = {homo_w}')
154
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
155
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
156
+ y = x @ csr
157
+ y2 = vector_csr(x, csr.data, indices, indptr, [m, n])
158
+ self.assertTrue(jnp.allclose(y, y2))
159
+
160
+ def test_vector_csr_vmap_vector(self):
161
+ n_batch, m, n = 10, 20, 40
162
+ xs = bst.random.rand(n_batch, m) < 0.1
163
+ indptr, indices = _get_csr(m, n, 0.1)
164
+
165
+ for homo_w in [True, False]:
166
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
167
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
168
+ y = jax.vmap(lambda x: x @ csr)(xs)
169
+ y2 = jax.vmap(lambda x: vector_csr(x, csr.data, indices, indptr, [m, n]))(xs)
170
+ self.assertTrue(jnp.allclose(y, y2))
171
+
172
+
173
+ class TestMatrixCSR(unittest.TestCase):
174
+ def test_matrix_csr(self):
175
+ k, m, n = 10, 20, 40
176
+ x = bst.random.rand(k, m) < 0.1
177
+ indptr, indices = _get_csr(m, n, 0.1)
178
+
179
+ for homo_w in [True, False]:
180
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
181
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
182
+ y = x @ csr
183
+ y2 = matrix_csr(x, csr.data, indices, indptr, [m, n])
184
+ self.assertTrue(jnp.allclose(y, y2))
185
+
186
+
187
+ class TestCSRVector(unittest.TestCase):
188
+ def test_csr_vector(self):
189
+ m, n = 20, 40
190
+ v = bst.random.rand(n) < 0.1
191
+ indptr, indices = _get_csr(m, n, 0.1)
192
+
193
+ for homo_w in [True, False]:
194
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
195
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
196
+ y = csr @ v
197
+ y2 = csr_vector(v, csr.data, indices, indptr, [m, n])
198
+ self.assertTrue(jnp.allclose(y, y2))
199
+
200
+
201
+ class TestCSRMatrix(unittest.TestCase):
202
+ def test_csr_matrix(self):
203
+ m, n, k = 20, 40, 10
204
+ matrix = bst.random.rand(n, k) < 0.1
205
+ indptr, indices = _get_csr(m, n, 0.1)
206
+
207
+ for homo_w in [True, False]:
208
+ data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
209
+ csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
210
+ y = csr @ matrix
211
+ y2 = csr_matrix(matrix, csr.data, indices, indptr, [m, n])
212
+ self.assertTrue(jnp.allclose(y, y2))
213
+
214
+ # @parameterized.product(
215
+ # bool_x=[True, False],
216
+ # homo_w=[True, False]
217
+ # )
218
+ # def test_vjp(self, bool_x, homo_w):
219
+ # n_in = 20
220
+ # n_out = 30
221
+ # if bool_x:
222
+ # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
223
+ # else:
224
+ # x = bst.random.rand(n_in)
225
+ #
226
+ # indptr, indices = _get_csr(n_in, n_out, 0.1)
227
+ # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
228
+ # w = fn.weight.value
229
+ #
230
+ # def f(x, w):
231
+ # fn.weight.value = w
232
+ # return fn(x).sum()
233
+ #
234
+ # r = jax.grad(f, argnums=(0, 1))(x, w)
235
+ #
236
+ # # -------------------
237
+ # # TRUE gradients
238
+ #
239
+ # def f2(x, w):
240
+ # return true_fn(x, w, indices, indptr, n_out).sum()
241
+ #
242
+ # r2 = jax.grad(f2, argnums=(0, 1))(x, w)
243
+ # self.assertTrue(jnp.allclose(r[0], r2[0]))
244
+ # self.assertTrue(jnp.allclose(r[1], r2[1]))
245
+ #
246
+ # @parameterized.product(
247
+ # bool_x=[True, False],
248
+ # homo_w=[True, False]
249
+ # )
250
+ # def test_jvp(self, bool_x, homo_w):
251
+ # n_in = 20
252
+ # n_out = 30
253
+ # if bool_x:
254
+ # x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
255
+ # else:
256
+ # x = bst.random.rand(n_in)
257
+ #
258
+ # indptr, indices = _get_csr(n_in, n_out, 0.1)
259
+ # fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
260
+ # 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
261
+ # w = fn.weight.value
262
+ #
263
+ # def f(x, w):
264
+ # fn.weight.value = w
265
+ # return fn(x)
266
+ #
267
+ # o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
268
+ #
269
+ # # -------------------
270
+ # # TRUE gradients
271
+ #
272
+ # def f2(x, w):
273
+ # return true_fn(x, w, indices, indptr, n_out)
274
+ #
275
+ # o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
276
+ # self.assertTrue(jnp.allclose(r1, r2))
277
+ # self.assertTrue(jnp.allclose(o1, o2))
@@ -24,6 +24,7 @@ import jax.numpy as jnp
24
24
  import numpy as np
25
25
  from jax.interpreters import ad
26
26
 
27
+ from brainstate import environ
27
28
  from brainstate._state import ParamState
28
29
  from brainstate.augment import vmap
29
30
  from brainstate.init import param
@@ -111,7 +112,7 @@ class FixedProb(Module):
111
112
 
112
113
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
113
114
  if self.n_conn > 1:
114
- return event_fixed_prob(
115
+ r = event_fixed_prob(
115
116
  spk,
116
117
  self.weight.value,
117
118
  self.indices,
@@ -123,7 +124,8 @@ class FixedProb(Module):
123
124
  weight = self.weight.value
124
125
  unit = u.get_unit(weight)
125
126
  r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
126
- return u.maybe_decimal(u.Quantity(r, unit=unit))
127
+ r = u.maybe_decimal(u.Quantity(r, unit=unit))
128
+ return u.math.asarray(r, dtype=environ.dftype())
127
129
 
128
130
 
129
131
  def event_fixed_prob(
@@ -128,4 +128,5 @@ class TestFixedProbCSR(parameterized.TestCase):
128
128
 
129
129
  o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
130
130
  self.assertTrue(jnp.allclose(o1, o2))
131
- self.assertTrue(jnp.allclose(r1, r2))
131
+ # assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
132
+ self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))