mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,369 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ // Row reduction utilities
4
+ // - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
5
+ // - `threadgroup_reduce` collaborative reduction in the threadgroup such that
6
+ // lid.x == 0 holds the reduced value
7
+ // - `thread_reduce` simple loop and reduce the row
8
+
9
+ /**
10
+ * The thread group collaboratively reduces across the rows with bounds
11
+ * checking. In the end each thread holds a part of the reduction.
12
+ */
13
+ template <
14
+ typename T,
15
+ typename U,
16
+ typename Op,
17
+ int N_READS = REDUCE_N_READS,
18
+ int N_WRITES = REDUCE_N_WRITES>
19
+ METAL_FUNC void per_thread_row_reduce(
20
+ thread U totals[N_WRITES],
21
+ const device T* inputs[N_WRITES],
22
+ int blocks,
23
+ int extra,
24
+ uint lsize_x,
25
+ uint lid_x) {
26
+ Op op;
27
+
28
+ // Set up the accumulator registers
29
+ for (int i = 0; i < N_WRITES; i++) {
30
+ totals[i] = Op::init;
31
+ }
32
+
33
+ // Loop over the reduction size within thread group
34
+ for (int i = 0; i < blocks; i++) {
35
+ for (int j = 0; j < N_WRITES; j++) {
36
+ for (int i = 0; i < N_READS; i++) {
37
+ totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
38
+ }
39
+
40
+ inputs[j] += lsize_x * N_READS;
41
+ }
42
+ }
43
+
44
+ // Separate case for the last set as we close the reduction size
45
+ int index = lid_x * N_READS;
46
+ if (index + N_READS <= extra) {
47
+ for (int j = 0; j < N_WRITES; j++) {
48
+ for (int i = 0; i < N_READS; i++) {
49
+ totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
50
+ }
51
+ }
52
+ } else {
53
+ for (int j = 0; j < N_WRITES; j++) {
54
+ for (int i = 0; index + i < extra; i++) {
55
+ totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
56
+ }
57
+ }
58
+ }
59
+ }
60
+
61
+ /**
62
+ * Consecutive rows in a contiguous array.
63
+ */
64
+ template <
65
+ typename T,
66
+ typename U,
67
+ typename Op,
68
+ int N_READS = REDUCE_N_READS,
69
+ int N_WRITES = REDUCE_N_WRITES>
70
+ METAL_FUNC void per_thread_row_reduce(
71
+ thread U totals[N_WRITES],
72
+ const device T* in,
73
+ const constant size_t& reduction_size,
74
+ int blocks,
75
+ int extra,
76
+ uint lsize_x,
77
+ uint lid_x) {
78
+ // Set up the input pointers
79
+ const device T* inputs[N_WRITES];
80
+ inputs[0] = in + lid_x * N_READS;
81
+ for (int i = 1; i < N_READS; i++) {
82
+ inputs[i] = inputs[i - 1] + reduction_size;
83
+ }
84
+
85
+ per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
86
+ totals, inputs, blocks, extra, lsize_x, lid_x);
87
+ }
88
+
89
+ /**
90
+ * Consecutive rows in an arbitrarily ordered array.
91
+ */
92
+ template <
93
+ typename T,
94
+ typename U,
95
+ typename Op,
96
+ int N_READS = REDUCE_N_READS,
97
+ int N_WRITES = REDUCE_N_WRITES>
98
+ METAL_FUNC void per_thread_row_reduce(
99
+ thread U totals[N_WRITES],
100
+ const device T* in,
101
+ const int64_t row_idx,
102
+ int blocks,
103
+ int extra,
104
+ const constant int* shape,
105
+ const constant int64_t* strides,
106
+ const constant int& ndim,
107
+ uint lsize_x,
108
+ uint lid_x) {
109
+ // Set up the input pointers
110
+ const device T* inputs[N_WRITES];
111
+ in += lid_x * N_READS;
112
+ for (int i = 0; i < N_READS; i++) {
113
+ inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
114
+ }
115
+
116
+ per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
117
+ totals, inputs, blocks, extra, lsize_x, lid_x);
118
+ }
119
+
120
+ /**
121
+ * Reduce within the threadgroup.
122
+ */
123
+ template <
124
+ typename T,
125
+ typename U,
126
+ typename Op,
127
+ int N_READS = REDUCE_N_READS,
128
+ int N_WRITES = REDUCE_N_WRITES>
129
+ METAL_FUNC void threadgroup_reduce(
130
+ thread U totals[N_WRITES],
131
+ threadgroup U* shared_vals,
132
+ uint3 lid [[thread_position_in_threadgroup]],
133
+ uint simd_lane_id [[thread_index_in_simdgroup]],
134
+ uint simd_per_group [[simdgroups_per_threadgroup]],
135
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
136
+ Op op;
137
+
138
+ // Simdgroup first
139
+ for (int i = 0; i < N_WRITES; i++) {
140
+ totals[i] = op.simd_reduce(totals[i]);
141
+ }
142
+
143
+ // Across simdgroups
144
+ if (simd_per_group > 1) {
145
+ if (simd_lane_id == 0) {
146
+ for (int i = 0; i < N_WRITES; i++) {
147
+ shared_vals[simd_group_id * N_WRITES + i] = totals[i];
148
+ }
149
+ }
150
+ threadgroup_barrier(mem_flags::mem_threadgroup);
151
+
152
+ U values[N_WRITES];
153
+ for (int i = 0; i < N_WRITES; i++) {
154
+ values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
155
+ : op.init;
156
+ }
157
+
158
+ for (int i = 0; i < N_WRITES; i++) {
159
+ totals[i] = op.simd_reduce(values[i]);
160
+ }
161
+ }
162
+ }
163
+
164
+ template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
165
+ METAL_FUNC void
166
+ thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
167
+ Op op;
168
+ for (int i = 0; i < blocks; i++) {
169
+ U vals[N_READS];
170
+ for (int j = 0; j < N_READS; j++) {
171
+ vals[j] = row[j];
172
+ }
173
+ for (int j = 0; j < N_READS; j++) {
174
+ total = op(vals[j], total);
175
+ }
176
+ row += N_READS;
177
+ }
178
+ for (int i = 0; i < extra; i++) {
179
+ total = op(*row++, total);
180
+ }
181
+ }
182
+
183
+ // Reduction kernels
184
+ // - `row_reduce_small` depending on the non-row reductions and row size it
185
+ // either just loops over everything or a simd collaboratively reduces the
186
+ // non_row reductions. In the first case one thread is responsible for one
187
+ // output on the 2nd one simd is responsible for one output.
188
+ // - `row_reduce_simple` simple contiguous row reduction
189
+ // - `row_reduce_looped` simply loop and reduce each row for each non-row
190
+ // reduction. One threadgroup is responsible for one output.
191
+
192
+ template <
193
+ typename T,
194
+ typename U,
195
+ typename Op,
196
+ typename IdxT,
197
+ int NDIMS,
198
+ int N_READS = REDUCE_N_READS>
199
+ [[kernel]] void row_reduce_small(
200
+ const device T* in [[buffer(0)]],
201
+ device U* out [[buffer(1)]],
202
+ const constant int64_t& row_size [[buffer(2)]],
203
+ const constant int64_t& non_row_reductions [[buffer(3)]],
204
+ const constant int* shape [[buffer(4)]],
205
+ const constant int64_t* strides [[buffer(5)]],
206
+ const constant int& ndim [[buffer(6)]],
207
+ const constant int* reduce_shape [[buffer(7)]],
208
+ const constant int64_t* reduce_strides [[buffer(8)]],
209
+ const constant int& reduce_ndim [[buffer(9)]],
210
+ uint simd_lane_id [[thread_index_in_simdgroup]],
211
+ uint3 gid [[threadgroup_position_in_grid]],
212
+ uint3 gsize [[threadgroups_per_grid]],
213
+ uint3 tid [[thread_position_in_grid]],
214
+ uint3 tsize [[threads_per_grid]]) {
215
+ Op op;
216
+
217
+ U total_val = Op::init;
218
+ LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
219
+
220
+ // Precompute some row reduction numbers
221
+ const device T* row;
222
+ int blocks = IdxT(row_size) / N_READS;
223
+ int extra = IdxT(row_size) % N_READS;
224
+
225
+ if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
226
+ // Simple loop over non_row_reductions and reduce the row in the thread.
227
+ IdxT out_idx = tid.x + tsize.x * IdxT(tid.y);
228
+ in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
229
+
230
+ for (uint r = 0; r < non_row_reductions; r++) {
231
+ row = in + loop.location();
232
+ thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
233
+ loop.next(reduce_shape, reduce_strides);
234
+ }
235
+
236
+ out[out_idx] = total_val;
237
+ } else {
238
+ // Collaboratively reduce over non_row_reductions in the simdgroup. Each
239
+ // thread reduces every 32nd row and then a simple simd reduce.
240
+ IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
241
+ in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
242
+
243
+ loop.next(simd_lane_id, reduce_shape, reduce_strides);
244
+
245
+ for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
246
+ row = in + loop.location();
247
+ thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
248
+ loop.next(simd_size, reduce_shape, reduce_strides);
249
+ }
250
+
251
+ total_val = op.simd_reduce(total_val);
252
+
253
+ if (simd_lane_id == 0) {
254
+ out[out_idx] = total_val;
255
+ }
256
+ }
257
+ }
258
+
259
+ template <
260
+ typename T,
261
+ typename U,
262
+ typename Op,
263
+ typename IdxT = int64_t,
264
+ int N_READS = REDUCE_N_READS,
265
+ int N_WRITES = REDUCE_N_WRITES>
266
+ [[kernel]] void row_reduce_simple(
267
+ const device T* in [[buffer(0)]],
268
+ device U* out [[buffer(1)]],
269
+ const constant size_t& reduction_size [[buffer(2)]],
270
+ const constant int64_t& out_size [[buffer(3)]],
271
+ uint3 gid [[threadgroup_position_in_grid]],
272
+ uint3 gsize [[threadgroups_per_grid]],
273
+ uint3 lid [[thread_position_in_threadgroup]],
274
+ uint3 lsize [[threads_per_threadgroup]],
275
+ uint simd_lane_id [[thread_index_in_simdgroup]],
276
+ uint simd_per_group [[simdgroups_per_threadgroup]],
277
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
278
+ threadgroup U shared_vals[simd_size * N_WRITES];
279
+ U totals[N_WRITES];
280
+
281
+ // Move to the row
282
+ IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
283
+ if (out_idx + N_WRITES > out_size) {
284
+ out_idx = out_size - N_WRITES;
285
+ }
286
+ in += out_idx * IdxT(reduction_size);
287
+ out += out_idx;
288
+
289
+ // Each thread reduces across the row
290
+ int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
291
+ int extra = reduction_size - blocks * (lsize.x * N_READS);
292
+ per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
293
+ totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
294
+
295
+ // Reduce across the threadgroup
296
+ threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
297
+ totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
298
+
299
+ // Write the output
300
+ if (lid.x == 0) {
301
+ for (int i = 0; i < N_WRITES; i++) {
302
+ out[i] = totals[i];
303
+ }
304
+ }
305
+ }
306
+
307
+ template <
308
+ typename T,
309
+ typename U,
310
+ typename Op,
311
+ typename IdxT,
312
+ int NDIMS,
313
+ int N_READS = REDUCE_N_READS>
314
+ [[kernel]] void row_reduce_looped(
315
+ const device T* in [[buffer(0)]],
316
+ device U* out [[buffer(1)]],
317
+ const constant int64_t& row_size [[buffer(2)]],
318
+ const constant int64_t& non_row_reductions [[buffer(3)]],
319
+ const constant int* shape [[buffer(4)]],
320
+ const constant int64_t* strides [[buffer(5)]],
321
+ const constant int& ndim [[buffer(6)]],
322
+ const constant int* reduce_shape [[buffer(7)]],
323
+ const constant int64_t* reduce_strides [[buffer(8)]],
324
+ const constant int& reduce_ndim [[buffer(9)]],
325
+ uint3 gid [[threadgroup_position_in_grid]],
326
+ uint3 gsize [[threadgroups_per_grid]],
327
+ uint3 lid [[thread_position_in_threadgroup]],
328
+ uint3 lsize [[threads_per_threadgroup]],
329
+ uint simd_lane_id [[thread_index_in_simdgroup]],
330
+ uint simd_per_group [[simdgroups_per_threadgroup]],
331
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
332
+ Op op;
333
+ threadgroup U shared_vals[simd_size];
334
+ U total = Op::init;
335
+
336
+ IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
337
+
338
+ // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
339
+ // needs a small refactor.
340
+ in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
341
+
342
+ LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
343
+ const device T* row;
344
+ int blocks = IdxT(row_size) / (lsize.x * N_READS);
345
+ int extra = row_size - blocks * (lsize.x * N_READS);
346
+
347
+ for (IdxT i = 0; i < non_row_reductions; i++) {
348
+ row = in + loop.location();
349
+
350
+ // Each thread reduces across the row
351
+ U row_total;
352
+ per_thread_row_reduce<T, U, Op, N_READS, 1>(
353
+ &row_total, &row, blocks, extra, lsize.x, lid.x);
354
+
355
+ // Aggregate across rows
356
+ total = op(total, row_total);
357
+
358
+ loop.next(reduce_shape, reduce_strides);
359
+ }
360
+
361
+ // Reduce across the threadgroup
362
+ threadgroup_reduce<T, U, Op, N_READS, 1>(
363
+ &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
364
+
365
+ // Write the output
366
+ if (lid.x == 0) {
367
+ out[out_idx] = total;
368
+ }
369
+ }