mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,514 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/binary_ops.h"
6
+
7
+ #define DEFINE_SIMD_SCAN() \
8
+ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
9
+ T simd_scan(T val) { \
10
+ return simd_scan_impl(val); \
11
+ } \
12
+ \
13
+ template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
14
+ T simd_scan(T val) { \
15
+ for (int i = 1; i <= 16; i *= 2) { \
16
+ val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \
17
+ } \
18
+ return val; \
19
+ }
20
+
21
+ #define DEFINE_SIMD_EXCLUSIVE_SCAN() \
22
+ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
23
+ T simd_exclusive_scan(T val) { \
24
+ return simd_exclusive_scan_impl(val); \
25
+ } \
26
+ \
27
+ template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
28
+ T simd_exclusive_scan(T val) { \
29
+ val = simd_scan(val); \
30
+ return simd_shuffle_and_fill_up(val, init, 1); \
31
+ }
32
+
33
+ template <typename U>
34
+ struct CumSum {
35
+ DEFINE_SIMD_SCAN()
36
+ DEFINE_SIMD_EXCLUSIVE_SCAN()
37
+
38
+ static constexpr constant U init = static_cast<U>(0);
39
+
40
+ template <typename T>
41
+ U operator()(U a, T b) {
42
+ return a + b;
43
+ }
44
+
45
+ U simd_scan_impl(U x) {
46
+ return simd_prefix_inclusive_sum(x);
47
+ }
48
+
49
+ U simd_exclusive_scan_impl(U x) {
50
+ return simd_prefix_exclusive_sum(x);
51
+ }
52
+ };
53
+
54
+ template <typename U>
55
+ struct CumProd {
56
+ DEFINE_SIMD_SCAN()
57
+ DEFINE_SIMD_EXCLUSIVE_SCAN()
58
+
59
+ static constexpr constant U init = static_cast<U>(1.0f);
60
+
61
+ template <typename T>
62
+ U operator()(U a, T b) {
63
+ return a * b;
64
+ }
65
+
66
+ U simd_scan_impl(U x) {
67
+ return simd_prefix_inclusive_product(x);
68
+ }
69
+
70
+ U simd_exclusive_scan_impl(U x) {
71
+ return simd_prefix_exclusive_product(x);
72
+ }
73
+ };
74
+
75
+ template <>
76
+ struct CumProd<bool> {
77
+ static constexpr constant bool init = true;
78
+
79
+ template <typename T>
80
+ bool operator()(bool a, T b) {
81
+ return a & static_cast<bool>(b);
82
+ }
83
+
84
+ bool simd_scan(bool x) {
85
+ for (int i = 1; i <= 16; i *= 2) {
86
+ bool other = simd_shuffle_and_fill_up(x, init, i);
87
+ x &= other;
88
+ }
89
+ return x;
90
+ }
91
+
92
+ bool simd_exclusive_scan(bool x) {
93
+ x = simd_scan(x);
94
+ return simd_shuffle_and_fill_up(x, init, 1);
95
+ }
96
+ };
97
+
98
+ template <typename U>
99
+ struct CumMax {
100
+ static constexpr constant U init = Limits<U>::min;
101
+
102
+ template <typename T>
103
+ U operator()(U a, T b) {
104
+ return (a >= b) ? a : b;
105
+ }
106
+
107
+ U simd_scan(U x) {
108
+ for (int i = 1; i <= 16; i *= 2) {
109
+ U other = simd_shuffle_and_fill_up(x, init, i);
110
+ x = (x >= other) ? x : other;
111
+ }
112
+ return x;
113
+ }
114
+
115
+ U simd_exclusive_scan(U x) {
116
+ x = simd_scan(x);
117
+ return simd_shuffle_and_fill_up(x, init, 1);
118
+ }
119
+ };
120
+
121
+ template <typename U>
122
+ struct CumMin {
123
+ static constexpr constant U init = Limits<U>::max;
124
+
125
+ template <typename T>
126
+ U operator()(U a, T b) {
127
+ return (a <= b) ? a : b;
128
+ }
129
+
130
+ U simd_scan(U x) {
131
+ for (int i = 1; i <= 16; i *= 2) {
132
+ U other = simd_shuffle_and_fill_up(x, init, i);
133
+ x = (x <= other) ? x : other;
134
+ }
135
+ return x;
136
+ }
137
+
138
+ U simd_exclusive_scan(U x) {
139
+ x = simd_scan(x);
140
+ return simd_shuffle_and_fill_up(x, init, 1);
141
+ }
142
+ };
143
+
144
+ template <typename U>
145
+ struct CumLogaddexp {
146
+ static constexpr constant U init = Limits<U>::min;
147
+
148
+ template <typename T>
149
+ U operator()(U a, T b) {
150
+ return LogAddExp{}(a, static_cast<U>(b));
151
+ }
152
+
153
+ U simd_scan(U x) {
154
+ for (int i = 1; i <= 16; i *= 2) {
155
+ U other = simd_shuffle_and_fill_up(x, init, i);
156
+ x = LogAddExp{}(x, other);
157
+ }
158
+ return x;
159
+ }
160
+
161
+ U simd_exclusive_scan(U x) {
162
+ x = simd_scan(x);
163
+ return simd_shuffle_and_fill_up(x, init, 1);
164
+ }
165
+ };
166
+
167
+ template <typename T, typename U, int N_READS, bool reverse>
168
+ inline void load_unsafe(U values[N_READS], const device T* input) {
169
+ if (reverse) {
170
+ for (int i = 0; i < N_READS; i++) {
171
+ values[N_READS - i - 1] = input[i];
172
+ }
173
+ } else {
174
+ for (int i = 0; i < N_READS; i++) {
175
+ values[i] = input[i];
176
+ }
177
+ }
178
+ }
179
+
180
+ template <typename T, typename U, int N_READS, bool reverse>
181
+ inline void load_safe(
182
+ U values[N_READS],
183
+ const device T* input,
184
+ int start,
185
+ int total,
186
+ U init) {
187
+ if (reverse) {
188
+ for (int i = 0; i < N_READS; i++) {
189
+ values[N_READS - i - 1] =
190
+ (start + N_READS - i - 1 < total) ? input[i] : init;
191
+ }
192
+ } else {
193
+ for (int i = 0; i < N_READS; i++) {
194
+ values[i] = (start + i < total) ? input[i] : init;
195
+ }
196
+ }
197
+ }
198
+
199
+ template <typename U, int N_READS, bool reverse>
200
+ inline void write_unsafe(U values[N_READS], device U* out) {
201
+ if (reverse) {
202
+ for (int i = 0; i < N_READS; i++) {
203
+ out[i] = values[N_READS - i - 1];
204
+ }
205
+ } else {
206
+ for (int i = 0; i < N_READS; i++) {
207
+ out[i] = values[i];
208
+ }
209
+ }
210
+ }
211
+
212
+ template <typename U, int N_READS, bool reverse>
213
+ inline void write_safe(U values[N_READS], device U* out, int start, int total) {
214
+ if (reverse) {
215
+ for (int i = 0; i < N_READS; i++) {
216
+ if (start + N_READS - i - 1 < total) {
217
+ out[i] = values[N_READS - i - 1];
218
+ }
219
+ }
220
+ } else {
221
+ for (int i = 0; i < N_READS; i++) {
222
+ if (start + i < total) {
223
+ out[i] = values[i];
224
+ }
225
+ }
226
+ }
227
+ }
228
+
229
+ template <
230
+ typename T,
231
+ typename U,
232
+ typename Op,
233
+ int N_READS,
234
+ bool inclusive,
235
+ bool reverse>
236
+ [[kernel]] void contiguous_scan(
237
+ const device T* in [[buffer(0)]],
238
+ device U* out [[buffer(1)]],
239
+ const constant size_t& axis_size [[buffer(2)]],
240
+ uint3 gid [[threadgroup_position_in_grid]],
241
+ uint3 gsize [[threadgroups_per_grid]],
242
+ uint3 lid [[thread_position_in_threadgroup]],
243
+ uint3 lsize [[threads_per_threadgroup]],
244
+ uint simd_lane_id [[thread_index_in_simdgroup]],
245
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
246
+ constexpr int simd_size = 32;
247
+ Op op;
248
+
249
+ // Position the pointers
250
+ size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size;
251
+ in += offset;
252
+ out += offset;
253
+
254
+ // Compute the number of simd_groups
255
+ uint simd_groups = lsize.x / simd_size;
256
+
257
+ // Allocate memory
258
+ U prefix = Op::init;
259
+ U values[N_READS];
260
+ threadgroup U simdgroup_sums[32];
261
+
262
+ // Loop over the reduced axis in blocks of size ceildiv(axis_size,
263
+ // N_READS*lsize)
264
+ // Read block
265
+ // Compute inclusive scan of the block
266
+ // Compute inclusive scan per thread
267
+ // Compute exclusive scan of thread sums in simdgroup
268
+ // Write simdgroup sums in SM
269
+ // Compute exclusive scan of simdgroup sums
270
+ // Compute the output by scanning prefix, prev_simdgroup, prev_thread,
271
+ // value
272
+ // Write block
273
+
274
+ for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) {
275
+ // Compute the block offset
276
+ uint offset = r * lsize.x * N_READS + lid.x * N_READS;
277
+
278
+ // Read the values
279
+ if (reverse) {
280
+ if ((offset + N_READS) < axis_size) {
281
+ load_unsafe<T, U, N_READS, reverse>(
282
+ values, in + axis_size - offset - N_READS);
283
+ } else {
284
+ load_safe<T, U, N_READS, reverse>(
285
+ values,
286
+ in + axis_size - offset - N_READS,
287
+ offset,
288
+ axis_size,
289
+ Op::init);
290
+ }
291
+ } else {
292
+ if ((offset + N_READS) < axis_size) {
293
+ load_unsafe<T, U, N_READS, reverse>(values, in + offset);
294
+ } else {
295
+ load_safe<T, U, N_READS, reverse>(
296
+ values, in + offset, offset, axis_size, Op::init);
297
+ }
298
+ }
299
+
300
+ // Compute an inclusive scan per thread
301
+ for (int i = 1; i < N_READS; i++) {
302
+ values[i] = op(values[i], values[i - 1]);
303
+ }
304
+
305
+ // Compute exclusive scan of thread sums
306
+ U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]);
307
+
308
+ // Write simdgroup_sums to SM
309
+ threadgroup_barrier(mem_flags::mem_threadgroup);
310
+ if (simd_lane_id == simd_size - 1) {
311
+ simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]);
312
+ }
313
+ threadgroup_barrier(mem_flags::mem_threadgroup);
314
+
315
+ // Compute exclusive scan of simdgroup_sums
316
+ if (simd_group_id == 0) {
317
+ U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]);
318
+ simdgroup_sums[simd_lane_id] = prev_simdgroup;
319
+ }
320
+ threadgroup_barrier(mem_flags::mem_threadgroup);
321
+
322
+ // Compute the output
323
+ for (int i = 0; i < N_READS; i++) {
324
+ values[i] = op(values[i], prefix);
325
+ values[i] = op(values[i], simdgroup_sums[simd_group_id]);
326
+ values[i] = op(values[i], prev_thread);
327
+ }
328
+
329
+ // Write the values
330
+ if (reverse) {
331
+ if (inclusive) {
332
+ if ((offset + N_READS) < axis_size) {
333
+ write_unsafe<U, N_READS, reverse>(
334
+ values, out + axis_size - offset - N_READS);
335
+ } else {
336
+ write_safe<U, N_READS, reverse>(
337
+ values, out + axis_size - offset - N_READS, offset, axis_size);
338
+ }
339
+ } else {
340
+ if (lid.x == 0 && offset == 0) {
341
+ out[axis_size - 1] = Op::init;
342
+ }
343
+ if ((offset + N_READS + 1) < axis_size) {
344
+ write_unsafe<U, N_READS, reverse>(
345
+ values, out + axis_size - offset - 1 - N_READS);
346
+ } else {
347
+ write_safe<U, N_READS, reverse>(
348
+ values,
349
+ out + axis_size - offset - 1 - N_READS,
350
+ offset + 1,
351
+ axis_size);
352
+ }
353
+ }
354
+ } else {
355
+ if (inclusive) {
356
+ if ((offset + N_READS) < axis_size) {
357
+ write_unsafe<U, N_READS, reverse>(values, out + offset);
358
+ } else {
359
+ write_safe<U, N_READS, reverse>(
360
+ values, out + offset, offset, axis_size);
361
+ }
362
+ } else {
363
+ if (lid.x == 0 && offset == 0) {
364
+ out[0] = Op::init;
365
+ }
366
+ if ((offset + N_READS + 1) < axis_size) {
367
+ write_unsafe<U, N_READS, reverse>(values, out + offset + 1);
368
+ } else {
369
+ write_safe<U, N_READS, reverse>(
370
+ values, out + offset + 1, offset + 1, axis_size);
371
+ }
372
+ }
373
+ }
374
+ threadgroup_barrier(mem_flags::mem_threadgroup);
375
+
376
+ // Share the prefix
377
+ if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) {
378
+ simdgroup_sums[0] = values[N_READS - 1];
379
+ }
380
+ threadgroup_barrier(mem_flags::mem_threadgroup);
381
+ prefix = simdgroup_sums[0];
382
+ }
383
+ }
384
+
385
+ template <
386
+ typename T,
387
+ typename U,
388
+ typename Op,
389
+ int N_READS,
390
+ bool inclusive,
391
+ bool reverse>
392
+ [[kernel]] void strided_scan(
393
+ const device T* in [[buffer(0)]],
394
+ device U* out [[buffer(1)]],
395
+ const constant size_t& axis_size [[buffer(2)]],
396
+ const constant size_t& stride [[buffer(3)]],
397
+ const constant size_t& stride_blocks [[buffer(4)]],
398
+ uint3 gid [[threadgroup_position_in_grid]],
399
+ uint3 gsize [[threadgroups_per_grid]],
400
+ uint3 lid [[thread_position_in_threadgroup]],
401
+ uint simd_lane_id [[thread_index_in_simdgroup]],
402
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
403
+ constexpr int simd_size = 32;
404
+ constexpr int BM = 32;
405
+ constexpr int BN = 32;
406
+ constexpr int BN_pad = 32 + 16 / sizeof(U);
407
+ constexpr int n_simds = BN / N_READS;
408
+ constexpr int n_scans = BN / n_simds;
409
+ Op op;
410
+
411
+ threadgroup U read_buffer[BM * BN_pad];
412
+ U values[n_scans];
413
+ U prefix[n_scans];
414
+ for (int i = 0; i < n_scans; i++) {
415
+ prefix[i] = Op::init;
416
+ }
417
+
418
+ // Compute offsets
419
+ size_t full_gid = gid.y + gsize.y * size_t(gid.z);
420
+ size_t offset = full_gid / stride_blocks * axis_size * stride;
421
+ size_t global_index_x = full_gid % stride_blocks * BN;
422
+ uint read_offset_y = (lid.x * N_READS) / BN;
423
+ uint read_offset_x = (lid.x * N_READS) % BN;
424
+ uint scan_offset_y = simd_lane_id;
425
+ uint scan_offset_x = simd_group_id * n_scans;
426
+
427
+ uint stride_limit = stride - global_index_x;
428
+ in += offset + global_index_x + read_offset_x;
429
+ out += offset + global_index_x + read_offset_x;
430
+ threadgroup U* read_into =
431
+ read_buffer + read_offset_y * BN_pad + read_offset_x;
432
+ threadgroup U* read_from =
433
+ read_buffer + scan_offset_y * BN_pad + scan_offset_x;
434
+
435
+ for (uint j = 0; j < axis_size; j += BM) {
436
+ // Calculate the indices for the current thread
437
+ uint index_y = j + read_offset_y;
438
+ uint check_index_y = index_y;
439
+ if (reverse) {
440
+ index_y = axis_size - 1 - index_y;
441
+ }
442
+
443
+ // Read in SM
444
+ threadgroup_barrier(mem_flags::mem_threadgroup);
445
+ if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
446
+ for (int i = 0; i < N_READS; i++) {
447
+ read_into[i] = in[index_y * stride + i];
448
+ }
449
+ } else {
450
+ for (int i = 0; i < N_READS; i++) {
451
+ if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
452
+ read_into[i] = in[index_y * stride + i];
453
+ } else {
454
+ read_into[i] = Op::init;
455
+ }
456
+ }
457
+ }
458
+ threadgroup_barrier(mem_flags::mem_threadgroup);
459
+
460
+ // Read strided into registers
461
+ for (int i = 0; i < n_scans; i++) {
462
+ values[i] = read_from[i];
463
+ }
464
+ simdgroup_barrier(mem_flags::mem_threadgroup);
465
+
466
+ // Perform the scan
467
+ for (int i = 0; i < n_scans; i++) {
468
+ values[i] = op.simd_scan(values[i]);
469
+ values[i] = op(values[i], prefix[i]);
470
+ prefix[i] = simd_shuffle(values[i], simd_size - 1);
471
+ }
472
+
473
+ // Write to SM
474
+ for (int i = 0; i < n_scans; i++) {
475
+ read_from[i] = values[i];
476
+ }
477
+ threadgroup_barrier(mem_flags::mem_threadgroup);
478
+
479
+ // Write to device memory
480
+ if (!inclusive) {
481
+ if (check_index_y == 0) {
482
+ if ((read_offset_x + N_READS) < stride_limit) {
483
+ for (int i = 0; i < N_READS; i++) {
484
+ out[index_y * stride + i] = Op::init;
485
+ }
486
+ } else {
487
+ for (int i = 0; i < N_READS; i++) {
488
+ if ((read_offset_x + i) < stride_limit) {
489
+ out[index_y * stride + i] = Op::init;
490
+ }
491
+ }
492
+ }
493
+ }
494
+ if (reverse) {
495
+ index_y -= 1;
496
+ check_index_y += 1;
497
+ } else {
498
+ index_y += 1;
499
+ check_index_y += 1;
500
+ }
501
+ }
502
+ if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) {
503
+ for (int i = 0; i < N_READS; i++) {
504
+ out[index_y * stride + i] = read_into[i];
505
+ }
506
+ } else {
507
+ for (int i = 0; i < N_READS; i++) {
508
+ if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) {
509
+ out[index_y * stride + i] = read_into[i];
510
+ }
511
+ }
512
+ }
513
+ }
514
+ }