brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
brainstate/event/_csr.py
CHANGED
@@ -73,7 +73,7 @@ class CSR(u.sparse.SparseMatrix):
|
|
73
73
|
return u.sparse.csr_todense(self)
|
74
74
|
|
75
75
|
def transpose(self, axes=None):
|
76
|
-
assert axes is None
|
76
|
+
assert axes is None, "transpose does not support axes argument."
|
77
77
|
return CSC((self.data, self.indices, self.indptr), shape=self.shape[::-1])
|
78
78
|
|
79
79
|
def __abs__(self):
|
@@ -103,6 +103,7 @@ class CSR(u.sparse.SparseMatrix):
|
|
103
103
|
(op(self.data, other), self.indices, self.indptr),
|
104
104
|
shape=self.shape
|
105
105
|
)
|
106
|
+
|
106
107
|
elif other.ndim == 2 and other.shape == self.shape:
|
107
108
|
rows, cols = csr_to_coo(self.indices, self.indptr)
|
108
109
|
other = other[rows, cols]
|
@@ -112,6 +113,7 @@ class CSR(u.sparse.SparseMatrix):
|
|
112
113
|
self.indptr),
|
113
114
|
shape=self.shape
|
114
115
|
)
|
116
|
+
|
115
117
|
else:
|
116
118
|
raise NotImplementedError(f"mul with object of shape {other.shape}")
|
117
119
|
|
@@ -184,10 +186,12 @@ class CSR(u.sparse.SparseMatrix):
|
|
184
186
|
return self._binary_rop(other, operator.mod)
|
185
187
|
|
186
188
|
def __matmul__(self, other):
|
189
|
+
# csr @ other
|
187
190
|
if isinstance(other, JAXSparse):
|
188
191
|
raise NotImplementedError("matmul between two sparse objects.")
|
189
192
|
other = u.math.asarray(other)
|
190
|
-
data
|
193
|
+
data = self.data
|
194
|
+
# data, other = u.math.promote_dtypes(self.data, other)
|
191
195
|
if other.ndim == 1:
|
192
196
|
return _csr_matvec(
|
193
197
|
data,
|
@@ -208,10 +212,12 @@ class CSR(u.sparse.SparseMatrix):
|
|
208
212
|
raise NotImplementedError(f"matmul with object of shape {other.shape}")
|
209
213
|
|
210
214
|
def __rmatmul__(self, other):
|
215
|
+
# other @ csr
|
211
216
|
if isinstance(other, JAXSparse):
|
212
217
|
raise NotImplementedError("matmul between two sparse objects.")
|
213
218
|
other = u.math.asarray(other)
|
214
|
-
data
|
219
|
+
data = self.data
|
220
|
+
# data, other = u.math.promote_dtypes(self.data, other)
|
215
221
|
if other.ndim == 1:
|
216
222
|
return _csr_matvec(
|
217
223
|
data,
|
@@ -566,7 +572,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
566
572
|
if weight_info.size == 1:
|
567
573
|
if transpose:
|
568
574
|
if spike_info.dtype == jnp.bool_:
|
569
|
-
@numba.njit
|
575
|
+
@numba.njit(fastmath=True)
|
570
576
|
def mv(weights, indices, indptr, v, posts):
|
571
577
|
posts[:] = 0.
|
572
578
|
w = weights[()]
|
@@ -576,7 +582,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
576
582
|
posts[indices[j]] += w
|
577
583
|
|
578
584
|
elif float_as_event:
|
579
|
-
@numba.njit
|
585
|
+
@numba.njit(fastmath=True)
|
580
586
|
def mv(weights, indices, indptr, v, posts):
|
581
587
|
posts[:] = 0.
|
582
588
|
w = weights[()]
|
@@ -586,7 +592,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
586
592
|
posts[indices[j]] += w
|
587
593
|
|
588
594
|
else:
|
589
|
-
@numba.njit
|
595
|
+
@numba.njit(fastmath=True)
|
590
596
|
def mv(weights, indices, indptr, v, posts):
|
591
597
|
posts[:] = 0.
|
592
598
|
w = weights[()]
|
@@ -599,7 +605,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
599
605
|
|
600
606
|
else:
|
601
607
|
if spike_info.dtype == jnp.bool_:
|
602
|
-
@numba.njit
|
608
|
+
@numba.njit(fastmath=True)
|
603
609
|
def mv(weights, indices, indptr, v, posts):
|
604
610
|
w = weights[()]
|
605
611
|
for i in range(indptr.shape[0] - 1):
|
@@ -610,7 +616,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
610
616
|
posts[i] = r
|
611
617
|
|
612
618
|
elif float_as_event:
|
613
|
-
@numba.njit
|
619
|
+
@numba.njit(fastmath=True)
|
614
620
|
def mv(weights, indices, indptr, v, posts):
|
615
621
|
w = weights[()]
|
616
622
|
for i in range(indptr.shape[0] - 1):
|
@@ -621,7 +627,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
621
627
|
posts[i] = r
|
622
628
|
|
623
629
|
else:
|
624
|
-
@numba.njit
|
630
|
+
@numba.njit(fastmath=True)
|
625
631
|
def mv(weights, indices, indptr, v, posts):
|
626
632
|
w = weights[()]
|
627
633
|
for i in range(indptr.shape[0] - 1):
|
@@ -635,7 +641,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
635
641
|
else:
|
636
642
|
if transpose:
|
637
643
|
if spike_info.dtype == jnp.bool_:
|
638
|
-
@numba.njit
|
644
|
+
@numba.njit(fastmath=True)
|
639
645
|
def mv(weights, indices, indptr, v, posts):
|
640
646
|
posts[:] = 0.
|
641
647
|
for i in range(v.shape[0]):
|
@@ -644,7 +650,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
644
650
|
posts[indices[j]] += weights[j]
|
645
651
|
|
646
652
|
elif float_as_event:
|
647
|
-
@numba.njit
|
653
|
+
@numba.njit(fastmath=True)
|
648
654
|
def mv(weights, indices, indptr, v, posts):
|
649
655
|
posts[:] = 0.
|
650
656
|
for i in range(v.shape[0]):
|
@@ -653,7 +659,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
653
659
|
posts[indices[j]] += weights[j]
|
654
660
|
|
655
661
|
else:
|
656
|
-
@numba.njit
|
662
|
+
@numba.njit(fastmath=True)
|
657
663
|
def mv(weights, indices, indptr, v, posts):
|
658
664
|
posts[:] = 0.
|
659
665
|
for i in range(v.shape[0]):
|
@@ -664,7 +670,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
664
670
|
|
665
671
|
else:
|
666
672
|
if spike_info.dtype == jnp.bool_:
|
667
|
-
@numba.njit
|
673
|
+
@numba.njit(fastmath=True)
|
668
674
|
def mv(weights, indices, indptr, v, posts):
|
669
675
|
for i in range(indptr.shape[0] - 1):
|
670
676
|
r = 0.
|
@@ -674,7 +680,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
674
680
|
posts[i] = r
|
675
681
|
|
676
682
|
elif float_as_event:
|
677
|
-
@numba.njit
|
683
|
+
@numba.njit(fastmath=True)
|
678
684
|
def mv(weights, indices, indptr, v, posts):
|
679
685
|
for i in range(indptr.shape[0] - 1):
|
680
686
|
r = 0.
|
@@ -684,7 +690,7 @@ def event_csrmv_cpu_kernel_generator(
|
|
684
690
|
posts[i] = r
|
685
691
|
|
686
692
|
else:
|
687
|
-
@numba.njit
|
693
|
+
@numba.njit(fastmath=True)
|
688
694
|
def mv(weights, indices, indptr, v, posts):
|
689
695
|
for i in range(indptr.shape[0] - 1):
|
690
696
|
r = 0.
|
@@ -795,7 +801,31 @@ def event_csrmv_transpose_rule(
|
|
795
801
|
|
796
802
|
def event_csrmv_batching(args, axes, **kwargs):
|
797
803
|
if tuple(axes) == (None, None, None, 0):
|
798
|
-
|
804
|
+
assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
|
805
|
+
r = event_csrmm_p_call(
|
806
|
+
args[0],
|
807
|
+
args[1],
|
808
|
+
args[2],
|
809
|
+
args[3].T,
|
810
|
+
shape=kwargs['shape'],
|
811
|
+
transpose=kwargs['transpose'],
|
812
|
+
float_as_event=kwargs['float_as_event']
|
813
|
+
)
|
814
|
+
return r, [1]
|
815
|
+
|
816
|
+
elif tuple(axes) == (None, None, None, 1):
|
817
|
+
assert args[3].ndim == 2, 'Batching axis 0 requires 2D input.'
|
818
|
+
r = event_csrmm_p_call(
|
819
|
+
args[0],
|
820
|
+
args[1],
|
821
|
+
args[2],
|
822
|
+
args[3],
|
823
|
+
shape=kwargs['shape'],
|
824
|
+
transpose=kwargs['transpose'],
|
825
|
+
float_as_event=kwargs['float_as_event']
|
826
|
+
)
|
827
|
+
return r, [1]
|
828
|
+
|
799
829
|
else:
|
800
830
|
raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
|
801
831
|
|
@@ -852,17 +882,228 @@ def event_csrmv_p_call(
|
|
852
882
|
|
853
883
|
def event_csrmm_batching(args, axes, **kwargs):
|
854
884
|
if tuple(axes) == (None, None, None, 0):
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
885
|
+
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
886
|
+
batch_size, m, n = args[3].shape
|
887
|
+
B = jnp.transpose(args[3], (1, 0, 2)).reshape(m, batch_size * n)
|
888
|
+
r = event_csrmm_p_call(
|
889
|
+
args[0],
|
890
|
+
args[1],
|
891
|
+
args[2],
|
892
|
+
B,
|
893
|
+
shape=kwargs['shape'],
|
894
|
+
transpose=kwargs['transpose'],
|
895
|
+
float_as_event=kwargs['float_as_event']
|
896
|
+
)
|
897
|
+
r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
|
898
|
+
return [r], [1]
|
899
|
+
|
900
|
+
elif tuple(axes) == (None, None, None, 1):
|
901
|
+
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
902
|
+
m, batch_size, n = args[3].shape
|
903
|
+
B = args[3].reshape(m, batch_size * n)
|
904
|
+
r = event_csrmm_p_call(
|
905
|
+
args[0],
|
906
|
+
args[1],
|
907
|
+
args[2],
|
908
|
+
B,
|
909
|
+
shape=kwargs['shape'],
|
910
|
+
transpose=kwargs['transpose'],
|
911
|
+
float_as_event=kwargs['float_as_event']
|
912
|
+
)
|
913
|
+
r = jnp.reshape(r[0], [r[0].shape[0], batch_size, n])
|
914
|
+
return [r], [1]
|
915
|
+
|
916
|
+
elif tuple(axes) == (None, None, None, 2):
|
917
|
+
assert args[3].ndim == 3, 'Batching axis 0 requires 3D input.'
|
918
|
+
m, n, batch_size = args[3].shape
|
919
|
+
B = args[3].reshape(m, batch_size * n)
|
920
|
+
r = event_csrmm_p_call(
|
921
|
+
args[0],
|
922
|
+
args[1],
|
923
|
+
args[2],
|
924
|
+
B,
|
925
|
+
shape=kwargs['shape'],
|
926
|
+
transpose=kwargs['transpose'],
|
927
|
+
float_as_event=kwargs['float_as_event']
|
928
|
+
)
|
929
|
+
r = jnp.reshape(r[0], [r[0].shape[0], n, batch_size])
|
930
|
+
return [r], [2]
|
931
|
+
|
859
932
|
else:
|
860
933
|
raise NotImplementedError(f"Batching axes {axes} not implemented for event-driven CSR matrix-vector product.")
|
861
934
|
|
862
935
|
|
936
|
+
def event_csrmm_cpu_kernel_generator(
|
937
|
+
float_as_event: bool,
|
938
|
+
weight_info: jax.ShapeDtypeStruct,
|
939
|
+
spike_info: jax.ShapeDtypeStruct,
|
940
|
+
transpose: bool,
|
941
|
+
**kwargs
|
942
|
+
) -> Kernel:
|
943
|
+
import numba # pylint: disable=import-outside-toplevel
|
944
|
+
|
945
|
+
if weight_info.size == 1:
|
946
|
+
if transpose:
|
947
|
+
# csr.T @ B
|
948
|
+
|
949
|
+
if spike_info.dtype == jnp.bool_:
|
950
|
+
@numba.njit(fastmath=True, parallel=False)
|
951
|
+
def mv(weights, indices, indptr, B, posts):
|
952
|
+
posts[:] = 0.
|
953
|
+
w = weights[()]
|
954
|
+
for k in numba.prange(B.shape[1]):
|
955
|
+
for i in range(B.shape[0]):
|
956
|
+
if B[i, k]:
|
957
|
+
for j in range(indptr[i], indptr[i + 1]):
|
958
|
+
posts[indices[j], k] += w
|
959
|
+
|
960
|
+
elif float_as_event:
|
961
|
+
@numba.njit(fastmath=True, parallel=False)
|
962
|
+
def mv(weights, indices, indptr, B, posts):
|
963
|
+
posts[:] = 0.
|
964
|
+
B = B != 0.
|
965
|
+
w = weights[()]
|
966
|
+
for k in numba.prange(B.shape[1]):
|
967
|
+
for i in range(B.shape[0]):
|
968
|
+
if B[i, k]:
|
969
|
+
for j in range(indptr[i], indptr[i + 1]):
|
970
|
+
posts[indices[j], k] += w
|
971
|
+
|
972
|
+
else:
|
973
|
+
@numba.njit(fastmath=True, parallel=False)
|
974
|
+
def mv(weights, indices, indptr, B, posts):
|
975
|
+
posts[:] = 0.
|
976
|
+
w = weights[()]
|
977
|
+
for k in numba.prange(B.shape[1]):
|
978
|
+
for i in range(B.shape[0]):
|
979
|
+
sp = B[i, k]
|
980
|
+
if sp != 0.:
|
981
|
+
wsp = w * sp
|
982
|
+
for j in range(indptr[i], indptr[i + 1]):
|
983
|
+
posts[indices[j], k] += wsp
|
984
|
+
|
985
|
+
else:
|
986
|
+
# csr @ B
|
987
|
+
if spike_info.dtype == jnp.bool_:
|
988
|
+
@numba.njit(fastmath=True)
|
989
|
+
def mv(weights, indices, indptr, B, posts):
|
990
|
+
w = weights[()]
|
991
|
+
for i in range(indptr.shape[0] - 1):
|
992
|
+
r = np.zeros(B.shape[1], dtype=weights.dtype)
|
993
|
+
for j in range(indptr[i], indptr[i + 1]):
|
994
|
+
index = indices[j]
|
995
|
+
for k in range(B.shape[1]):
|
996
|
+
if B[index, k]:
|
997
|
+
r[k] += w
|
998
|
+
posts[i] = r
|
999
|
+
|
1000
|
+
elif float_as_event:
|
1001
|
+
@numba.njit(fastmath=True)
|
1002
|
+
def mv(weights, indices, indptr, B, posts):
|
1003
|
+
w = weights[()]
|
1004
|
+
B = B != 0.
|
1005
|
+
for i in range(indptr.shape[0] - 1):
|
1006
|
+
r = np.zeros(B.shape[1], dtype=weights.dtype)
|
1007
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1008
|
+
index = indices[j]
|
1009
|
+
for k in range(B.shape[1]):
|
1010
|
+
if B[index, k]:
|
1011
|
+
r[k] += w
|
1012
|
+
posts[i] = r
|
1013
|
+
|
1014
|
+
else:
|
1015
|
+
@numba.njit(fastmath=True)
|
1016
|
+
def mv(weights, indices, indptr, B, posts):
|
1017
|
+
w = weights[()]
|
1018
|
+
for i in range(indptr.shape[0] - 1):
|
1019
|
+
for k in range(B.shape[1]):
|
1020
|
+
r = 0.
|
1021
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1022
|
+
c = B[indices[j], k]
|
1023
|
+
if c != 0.:
|
1024
|
+
r += w * c
|
1025
|
+
posts[i, k] = r
|
1026
|
+
|
1027
|
+
else:
|
1028
|
+
if transpose:
|
1029
|
+
# csr.T @ B
|
1030
|
+
|
1031
|
+
if spike_info.dtype == jnp.bool_:
|
1032
|
+
@numba.njit(fastmath=True, parallel=False)
|
1033
|
+
def mv(weights, indices, indptr, B, posts):
|
1034
|
+
posts[:] = 0.
|
1035
|
+
for k in numba.prange(B.shape[1]):
|
1036
|
+
for i in range(B.shape[0]):
|
1037
|
+
if B[i, k]:
|
1038
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1039
|
+
posts[indices[j], k] += weights[j]
|
1040
|
+
|
1041
|
+
elif float_as_event:
|
1042
|
+
@numba.njit(fastmath=True, parallel=False)
|
1043
|
+
def mv(weights, indices, indptr, B, posts):
|
1044
|
+
posts[:] = 0.
|
1045
|
+
B = B != 0.
|
1046
|
+
for k in numba.prange(B.shape[1]):
|
1047
|
+
for i in range(B.shape[0]):
|
1048
|
+
if B[i, k]:
|
1049
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1050
|
+
posts[indices[j], k] += weights[j]
|
1051
|
+
|
1052
|
+
else:
|
1053
|
+
@numba.njit(fastmath=True, parallel=False)
|
1054
|
+
def mv(weights, indices, indptr, B, posts):
|
1055
|
+
posts[:] = 0.
|
1056
|
+
for k in numba.prange(B.shape[1]):
|
1057
|
+
for i in range(B.shape[0]):
|
1058
|
+
sp = B[i, k]
|
1059
|
+
if sp != 0.:
|
1060
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1061
|
+
posts[indices[j], k] += weights[j] * sp
|
1062
|
+
|
1063
|
+
else:
|
1064
|
+
# csr @ B
|
1065
|
+
|
1066
|
+
if spike_info.dtype == jnp.bool_:
|
1067
|
+
@numba.njit(fastmath=True)
|
1068
|
+
def mv(weights, indices, indptr, B, posts):
|
1069
|
+
for i in range(indptr.shape[0] - 1):
|
1070
|
+
for k in range(B.shape[1]):
|
1071
|
+
r = 0.
|
1072
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1073
|
+
if B[indices[j], k]:
|
1074
|
+
r += weights[j]
|
1075
|
+
posts[i, k] = r
|
1076
|
+
|
1077
|
+
elif float_as_event:
|
1078
|
+
@numba.njit(fastmath=True)
|
1079
|
+
def mv(weights, indices, indptr, B, posts):
|
1080
|
+
B = B != 0.
|
1081
|
+
for i in range(indptr.shape[0] - 1):
|
1082
|
+
for k in range(B.shape[1]):
|
1083
|
+
r = 0.
|
1084
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1085
|
+
if B[indices[j], k]:
|
1086
|
+
r += weights[j]
|
1087
|
+
posts[i, k] = r
|
1088
|
+
|
1089
|
+
else:
|
1090
|
+
@numba.njit(fastmath=True)
|
1091
|
+
def mv(weights, indices, indptr, B, posts):
|
1092
|
+
for i in range(indptr.shape[0] - 1):
|
1093
|
+
for k in range(B.shape[1]):
|
1094
|
+
r = 0.
|
1095
|
+
for j in range(indptr[i], indptr[i + 1]):
|
1096
|
+
c = B[indices[j], k]
|
1097
|
+
if c != 0.:
|
1098
|
+
r += weights[j] * c
|
1099
|
+
posts[i, k] = r
|
1100
|
+
|
1101
|
+
return mv
|
1102
|
+
|
1103
|
+
|
863
1104
|
event_csrmm_p = XLACustomOp(
|
864
1105
|
'event_csrmm',
|
865
|
-
cpu_kernel_or_generator=
|
1106
|
+
cpu_kernel_or_generator=event_csrmm_cpu_kernel_generator,
|
866
1107
|
)
|
867
1108
|
event_csrmm_p.def_batching_rule(event_csrmm_batching)
|
868
1109
|
|
@@ -884,11 +1125,13 @@ def event_csrmm_p_call(
|
|
884
1125
|
indptr,
|
885
1126
|
B,
|
886
1127
|
outs=[
|
887
|
-
jax.ShapeDtypeStruct([shape[
|
1128
|
+
jax.ShapeDtypeStruct([shape[1], B.shape[1]], weights.dtype)
|
888
1129
|
if transpose else
|
889
|
-
jax.ShapeDtypeStruct([shape[
|
1130
|
+
jax.ShapeDtypeStruct([shape[0], B.shape[1]], weights.dtype),
|
890
1131
|
],
|
891
1132
|
# block_size=block_size,
|
1133
|
+
shape=shape,
|
1134
|
+
transpose=transpose,
|
892
1135
|
float_as_event=float_as_event,
|
893
1136
|
weight_info=jax.ShapeDtypeStruct(weights.shape, weights.dtype),
|
894
1137
|
spike_info=jax.ShapeDtypeStruct(B.shape, B.dtype),
|
brainstate/event/_csr_test.py
CHANGED
@@ -18,6 +18,9 @@
|
|
18
18
|
import unittest
|
19
19
|
|
20
20
|
import brainunit as u
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
21
24
|
|
22
25
|
import brainstate as bst
|
23
26
|
|
@@ -88,3 +91,187 @@ class TestCSR(unittest.TestCase):
|
|
88
91
|
v @ csr
|
89
92
|
)
|
90
93
|
)
|
94
|
+
|
95
|
+
|
96
|
+
def _get_csr(n_pre, n_post, prob):
|
97
|
+
n_conn = int(n_post * prob)
|
98
|
+
indptr = np.arange(n_pre + 1) * n_conn
|
99
|
+
indices = np.random.randint(0, n_post, (n_pre * n_conn,))
|
100
|
+
return indptr, indices
|
101
|
+
|
102
|
+
|
103
|
+
def vector_csr(x, w, indices, indptr, shape):
|
104
|
+
homo_w = jnp.size(w) == 1
|
105
|
+
post = jnp.zeros((shape[1],))
|
106
|
+
for i_pre in range(x.shape[0]):
|
107
|
+
ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
|
108
|
+
post = post.at[ids].add(w * x[i_pre] if homo_w else w[indptr[i_pre]: indptr[i_pre + 1]] * x[i_pre])
|
109
|
+
return post
|
110
|
+
|
111
|
+
|
112
|
+
def matrix_csr(xs, w, indices, indptr, shape):
|
113
|
+
homo_w = jnp.size(w) == 1
|
114
|
+
post = jnp.zeros((xs.shape[0], shape[1]))
|
115
|
+
for i_pre in range(xs.shape[1]):
|
116
|
+
ids = indices[indptr[i_pre]: indptr[i_pre + 1]]
|
117
|
+
post = post.at[:, ids].add(
|
118
|
+
w * xs[:, i_pre: i_pre + 1]
|
119
|
+
if homo_w else
|
120
|
+
(w[indptr[i_pre]: indptr[i_pre + 1]] * xs[:, i_pre: i_pre + 1])
|
121
|
+
)
|
122
|
+
return post
|
123
|
+
|
124
|
+
|
125
|
+
def csr_vector(x, w, indices, indptr, shape):
|
126
|
+
homo_w = jnp.size(w) == 1
|
127
|
+
out = jnp.zeros([shape[0]])
|
128
|
+
for i in range(shape[0]):
|
129
|
+
ids = indices[indptr[i]: indptr[i + 1]]
|
130
|
+
ws = w if homo_w else w[indptr[i]: indptr[i + 1]]
|
131
|
+
out = out.at[i].set(jnp.sum(x[ids] * ws))
|
132
|
+
return out
|
133
|
+
|
134
|
+
|
135
|
+
def csr_matrix(xs, w, indices, indptr, shape):
|
136
|
+
# CSR @ matrix
|
137
|
+
homo_w = jnp.size(w) == 1
|
138
|
+
out = jnp.zeros([shape[0], xs.shape[1]])
|
139
|
+
for i in range(shape[0]):
|
140
|
+
ids = indices[indptr[i]: indptr[i + 1]]
|
141
|
+
ws = w if homo_w else jnp.expand_dims(w[indptr[i]: indptr[i + 1]], axis=1)
|
142
|
+
out = out.at[i].set(jnp.sum(xs[ids] * ws, axis=0))
|
143
|
+
return out
|
144
|
+
|
145
|
+
|
146
|
+
class TestVectorCSR(unittest.TestCase):
|
147
|
+
def test_vector_csr(self, ):
|
148
|
+
m, n = 20, 40
|
149
|
+
x = bst.random.rand(m) < 0.1
|
150
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
151
|
+
|
152
|
+
for homo_w in [True, False]:
|
153
|
+
print(f'homo_w = {homo_w}')
|
154
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
155
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
156
|
+
y = x @ csr
|
157
|
+
y2 = vector_csr(x, csr.data, indices, indptr, [m, n])
|
158
|
+
self.assertTrue(jnp.allclose(y, y2))
|
159
|
+
|
160
|
+
def test_vector_csr_vmap_vector(self):
|
161
|
+
n_batch, m, n = 10, 20, 40
|
162
|
+
xs = bst.random.rand(n_batch, m) < 0.1
|
163
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
164
|
+
|
165
|
+
for homo_w in [True, False]:
|
166
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
167
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
168
|
+
y = jax.vmap(lambda x: x @ csr)(xs)
|
169
|
+
y2 = jax.vmap(lambda x: vector_csr(x, csr.data, indices, indptr, [m, n]))(xs)
|
170
|
+
self.assertTrue(jnp.allclose(y, y2))
|
171
|
+
|
172
|
+
|
173
|
+
class TestMatrixCSR(unittest.TestCase):
|
174
|
+
def test_matrix_csr(self):
|
175
|
+
k, m, n = 10, 20, 40
|
176
|
+
x = bst.random.rand(k, m) < 0.1
|
177
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
178
|
+
|
179
|
+
for homo_w in [True, False]:
|
180
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
181
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
182
|
+
y = x @ csr
|
183
|
+
y2 = matrix_csr(x, csr.data, indices, indptr, [m, n])
|
184
|
+
self.assertTrue(jnp.allclose(y, y2))
|
185
|
+
|
186
|
+
|
187
|
+
class TestCSRVector(unittest.TestCase):
|
188
|
+
def test_csr_vector(self):
|
189
|
+
m, n = 20, 40
|
190
|
+
v = bst.random.rand(n) < 0.1
|
191
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
192
|
+
|
193
|
+
for homo_w in [True, False]:
|
194
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
195
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
196
|
+
y = csr @ v
|
197
|
+
y2 = csr_vector(v, csr.data, indices, indptr, [m, n])
|
198
|
+
self.assertTrue(jnp.allclose(y, y2))
|
199
|
+
|
200
|
+
|
201
|
+
class TestCSRMatrix(unittest.TestCase):
|
202
|
+
def test_csr_matrix(self):
|
203
|
+
m, n, k = 20, 40, 10
|
204
|
+
matrix = bst.random.rand(n, k) < 0.1
|
205
|
+
indptr, indices = _get_csr(m, n, 0.1)
|
206
|
+
|
207
|
+
for homo_w in [True, False]:
|
208
|
+
data = 1.5 if homo_w else bst.init.Normal()(indices.shape)
|
209
|
+
csr = bst.event.CSR([data, indices, indptr], shape=(m, n))
|
210
|
+
y = csr @ matrix
|
211
|
+
y2 = csr_matrix(matrix, csr.data, indices, indptr, [m, n])
|
212
|
+
self.assertTrue(jnp.allclose(y, y2))
|
213
|
+
|
214
|
+
# @parameterized.product(
|
215
|
+
# bool_x=[True, False],
|
216
|
+
# homo_w=[True, False]
|
217
|
+
# )
|
218
|
+
# def test_vjp(self, bool_x, homo_w):
|
219
|
+
# n_in = 20
|
220
|
+
# n_out = 30
|
221
|
+
# if bool_x:
|
222
|
+
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
223
|
+
# else:
|
224
|
+
# x = bst.random.rand(n_in)
|
225
|
+
#
|
226
|
+
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
227
|
+
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices, 1.5 if homo_w else bst.init.Normal())
|
228
|
+
# w = fn.weight.value
|
229
|
+
#
|
230
|
+
# def f(x, w):
|
231
|
+
# fn.weight.value = w
|
232
|
+
# return fn(x).sum()
|
233
|
+
#
|
234
|
+
# r = jax.grad(f, argnums=(0, 1))(x, w)
|
235
|
+
#
|
236
|
+
# # -------------------
|
237
|
+
# # TRUE gradients
|
238
|
+
#
|
239
|
+
# def f2(x, w):
|
240
|
+
# return true_fn(x, w, indices, indptr, n_out).sum()
|
241
|
+
#
|
242
|
+
# r2 = jax.grad(f2, argnums=(0, 1))(x, w)
|
243
|
+
# self.assertTrue(jnp.allclose(r[0], r2[0]))
|
244
|
+
# self.assertTrue(jnp.allclose(r[1], r2[1]))
|
245
|
+
#
|
246
|
+
# @parameterized.product(
|
247
|
+
# bool_x=[True, False],
|
248
|
+
# homo_w=[True, False]
|
249
|
+
# )
|
250
|
+
# def test_jvp(self, bool_x, homo_w):
|
251
|
+
# n_in = 20
|
252
|
+
# n_out = 30
|
253
|
+
# if bool_x:
|
254
|
+
# x = jax.numpy.asarray(bst.random.rand(n_in) < 0.3, dtype=float)
|
255
|
+
# else:
|
256
|
+
# x = bst.random.rand(n_in)
|
257
|
+
#
|
258
|
+
# indptr, indices = _get_csr(n_in, n_out, 0.1)
|
259
|
+
# fn = bst.event.CSRLinear(n_in, n_out, indptr, indices,
|
260
|
+
# 1.5 if homo_w else bst.init.Normal(), grad_mode='jvp')
|
261
|
+
# w = fn.weight.value
|
262
|
+
#
|
263
|
+
# def f(x, w):
|
264
|
+
# fn.weight.value = w
|
265
|
+
# return fn(x)
|
266
|
+
#
|
267
|
+
# o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
268
|
+
#
|
269
|
+
# # -------------------
|
270
|
+
# # TRUE gradients
|
271
|
+
#
|
272
|
+
# def f2(x, w):
|
273
|
+
# return true_fn(x, w, indices, indptr, n_out)
|
274
|
+
#
|
275
|
+
# o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
276
|
+
# self.assertTrue(jnp.allclose(r1, r2))
|
277
|
+
# self.assertTrue(jnp.allclose(o1, o2))
|
@@ -24,6 +24,7 @@ import jax.numpy as jnp
|
|
24
24
|
import numpy as np
|
25
25
|
from jax.interpreters import ad
|
26
26
|
|
27
|
+
from brainstate import environ
|
27
28
|
from brainstate._state import ParamState
|
28
29
|
from brainstate.augment import vmap
|
29
30
|
from brainstate.init import param
|
@@ -111,7 +112,7 @@ class FixedProb(Module):
|
|
111
112
|
|
112
113
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
113
114
|
if self.n_conn > 1:
|
114
|
-
|
115
|
+
r = event_fixed_prob(
|
115
116
|
spk,
|
116
117
|
self.weight.value,
|
117
118
|
self.indices,
|
@@ -123,7 +124,8 @@ class FixedProb(Module):
|
|
123
124
|
weight = self.weight.value
|
124
125
|
unit = u.get_unit(weight)
|
125
126
|
r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
126
|
-
|
127
|
+
r = u.maybe_decimal(u.Quantity(r, unit=unit))
|
128
|
+
return u.math.asarray(r, dtype=environ.dftype())
|
127
129
|
|
128
130
|
|
129
131
|
def event_fixed_prob(
|
@@ -128,4 +128,5 @@ class TestFixedProbCSR(parameterized.TestCase):
|
|
128
128
|
|
129
129
|
o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
|
130
130
|
self.assertTrue(jnp.allclose(o1, o2))
|
131
|
-
|
131
|
+
# assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
|
132
|
+
self.assertTrue(jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
|