mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,44 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ ///////////////////////////////////////////////////////////////////////////////
6
+ // Attn param classes
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+
9
+ namespace mlx {
10
+ namespace steel {
11
+
12
+ struct AttnParams {
13
+ int B; ///< Batch Size
14
+ int H; ///< Heads
15
+ int D; ///< Head Dim
16
+
17
+ int qL; ///< Query Sequence Length
18
+ int kL; ///< Key Sequence Length
19
+
20
+ int gqa_factor; ///< Group Query factor
21
+ float scale; ///< Attention scale
22
+
23
+ int NQ; ///< Number of query blocks
24
+ int NK; ///< Number of key/value blocks
25
+
26
+ int NQ_aligned; ///< Number of full query blocks
27
+ int NK_aligned; ///< Number of full key/value blocks
28
+
29
+ int qL_rem; ///< Remainder in last query block
30
+ int kL_rem; ///< Remainder in last key/value block
31
+ int qL_off; ///< Offset in query sequence start
32
+
33
+ int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
34
+ int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
35
+ int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
36
+ int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
37
+ };
38
+
39
+ struct AttnMaskParams {
40
+ int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
41
+ };
42
+
43
+ } // namespace steel
44
+ } // namespace mlx
@@ -0,0 +1,71 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/utils.h"
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // Transforms and Epilogues
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ namespace mlx {
12
+ namespace steel {
13
+
14
+ template <typename OutT, typename InT>
15
+ struct TransformNone {
16
+ static METAL_FUNC OutT apply(InT x) {
17
+ return static_cast<OutT>(x);
18
+ }
19
+
20
+ static METAL_FUNC OutT apply(InT x, OutT) {
21
+ return static_cast<OutT>(x);
22
+ }
23
+ };
24
+
25
+ template <typename OutT, typename InT>
26
+ struct TransformAdd {
27
+ TransformAdd(const float, const float) {}
28
+
29
+ static METAL_FUNC OutT apply(InT x) {
30
+ return static_cast<OutT>(x);
31
+ }
32
+
33
+ static METAL_FUNC OutT apply(InT x, OutT c) {
34
+ return static_cast<OutT>(x) + c;
35
+ }
36
+ };
37
+
38
+ template <typename OutT, typename InT>
39
+ struct TransformAxpby {
40
+ const float alpha;
41
+ const float beta;
42
+
43
+ TransformAxpby(const float alpha_, const float beta_)
44
+ : alpha(alpha_), beta(beta_) {}
45
+
46
+ static METAL_FUNC OutT apply(InT x) {
47
+ return static_cast<OutT>(x);
48
+ }
49
+
50
+ METAL_FUNC OutT apply(InT x, OutT c) const {
51
+ return static_cast<OutT>(x * alpha + (beta * c));
52
+ }
53
+ };
54
+
55
+ template <typename T>
56
+ struct AccumHelper {
57
+ typedef float accum_type;
58
+ };
59
+
60
+ struct BlockSwizzle {
61
+ static METAL_FUNC int2
62
+ swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
63
+ const int tid_x = (tid.x) >> swizzle_log;
64
+ const int tid_y =
65
+ ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
66
+ return int2(tid_x, tid_y);
67
+ }
68
+ };
69
+
70
+ } // namespace steel
71
+ } // namespace mlx
@@ -0,0 +1,13 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/defines.h"
6
+ #include "mlx/backend/metal/kernels/steel/utils.h"
7
+
8
+ #include "mlx/backend/metal/kernels/steel/conv/loader.h"
9
+ #include "mlx/backend/metal/kernels/steel/conv/params.h"
10
+ #include "mlx/backend/metal/kernels/steel/gemm/mma.h"
11
+
12
+ using namespace metal;
13
+ using namespace mlx::steel;
@@ -0,0 +1,176 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include <metal_stdlib>
4
+
5
+ using namespace metal;
6
+
7
+ template <
8
+ typename T,
9
+ int BM,
10
+ int BN,
11
+ int BK,
12
+ int WM,
13
+ int WN,
14
+ int N_CHANNELS = 0,
15
+ bool SMALL_FILTER = false>
16
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
17
+ implicit_gemm_conv_2d(
18
+ const device T* A [[buffer(0)]],
19
+ const device T* B [[buffer(1)]],
20
+ device T* C [[buffer(2)]],
21
+ const constant MLXConvParams<2>* params [[buffer(3)]],
22
+ const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
23
+ uint3 tid [[threadgroup_position_in_grid]],
24
+ uint3 lid [[thread_position_in_threadgroup]],
25
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
26
+ uint simd_lid [[thread_index_in_simdgroup]]) {
27
+ using namespace mlx::steel;
28
+
29
+ (void)lid;
30
+
31
+ constexpr bool transpose_a = false;
32
+ constexpr bool transpose_b = true;
33
+ constexpr short tgp_padding_a = 16 / sizeof(T);
34
+ constexpr short tgp_padding_b = 16 / sizeof(T);
35
+
36
+ constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
37
+ constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
38
+ constexpr short shape_a_rows = (transpose_a ? BK : BM);
39
+ constexpr short shape_b_rows = (transpose_b ? BN : BK);
40
+ constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
41
+ constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
42
+
43
+ constexpr short tgp_size = WM * WN * 32;
44
+
45
+ // Input loader
46
+
47
+ using loader_a_t = typename metal::conditional_t<
48
+ // Check for small channel specialization
49
+ N_CHANNELS != 0 && N_CHANNELS <= 4,
50
+
51
+ // Go to small channel specialization
52
+ Conv2DInputBlockLoaderSmallChannels<
53
+ T,
54
+ BM,
55
+ BN,
56
+ BK,
57
+ tgp_size,
58
+ N_CHANNELS,
59
+ tgp_padding_a>,
60
+
61
+ // Else go to general loader
62
+ typename metal::conditional_t<
63
+ // Check if filter size is small enough
64
+ SMALL_FILTER,
65
+
66
+ // Go to small filter specialization
67
+ Conv2DInputBlockLoaderSmallFilter<
68
+ T,
69
+ BM,
70
+ BN,
71
+ BK,
72
+ tgp_size,
73
+ tgp_padding_a>,
74
+
75
+ // Else go to large filter generalization
76
+ Conv2DInputBlockLoaderLargeFilter<
77
+ T,
78
+ BM,
79
+ BN,
80
+ BK,
81
+ tgp_size,
82
+ tgp_padding_a>>>;
83
+
84
+ // Weight loader
85
+ using loader_b_t = typename metal::conditional_t<
86
+ // Check for small channel specialization
87
+ N_CHANNELS != 0 && N_CHANNELS <= 4,
88
+
89
+ // Go to small channel specialization
90
+ Conv2DWeightBlockLoaderSmallChannels<
91
+ T,
92
+ BM,
93
+ BN,
94
+ BK,
95
+ tgp_size,
96
+ N_CHANNELS,
97
+ tgp_padding_b>,
98
+
99
+ // Else go to general loader
100
+ Conv2DWeightBlockLoader<T, BM, BN, BK, tgp_size, tgp_padding_b>>;
101
+
102
+ using mma_t = BlockMMA<
103
+ T,
104
+ T,
105
+ BM,
106
+ BN,
107
+ BK,
108
+ WM,
109
+ WN,
110
+ transpose_a,
111
+ transpose_b,
112
+ shape_a_cols,
113
+ shape_b_cols>;
114
+
115
+ threadgroup T As[tgp_mem_size_a];
116
+ threadgroup T Bs[tgp_mem_size_b];
117
+
118
+ const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
119
+ ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
120
+ const int tid_x = (tid.x) >> gemm_params->swizzle_log;
121
+
122
+ if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
123
+ return;
124
+ }
125
+
126
+ const int c_row = tid_y * BM;
127
+ const int c_col = tid_x * BN;
128
+ const int K = gemm_params->K;
129
+ const int N = gemm_params->N;
130
+ const int C_per_group = params->C / params->groups;
131
+
132
+ // Groups
133
+ A += tid.z * C_per_group;
134
+ B += tid.z * N * K;
135
+ C += tid.z * N;
136
+
137
+ B += c_col * K;
138
+ C += c_row * (N * params->groups) + c_col;
139
+
140
+ const int2 offsets_a(0, c_row);
141
+ const int2 offsets_b(0, c_col);
142
+
143
+ // Prepare threadgroup loading operations
144
+ loader_a_t loader_a(
145
+ A, As, offsets_a, params, gemm_params, simd_gid, simd_lid);
146
+ loader_b_t loader_b(
147
+ B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid);
148
+
149
+ // Prepare threadgroup mma operation
150
+ mma_t mma_op(simd_gid, simd_lid);
151
+
152
+ int gemm_k_iterations = gemm_params->gemm_k_iterations;
153
+ for (int k = 0; k < gemm_k_iterations; k++) {
154
+ threadgroup_barrier(mem_flags::mem_threadgroup);
155
+ // Load elements into threadgroup
156
+ loader_a.load_unsafe();
157
+ loader_b.load_unsafe();
158
+
159
+ threadgroup_barrier(mem_flags::mem_threadgroup);
160
+
161
+ // Multiply and accumulate threadgroup elements
162
+ mma_op.mma(As, Bs);
163
+
164
+ // Prepare for next iteration
165
+ loader_a.next();
166
+ loader_b.next();
167
+ }
168
+
169
+ threadgroup_barrier(mem_flags::mem_none);
170
+
171
+ // Store results to device memory
172
+ short tgp_bm = min(BM, gemm_params->M - c_row);
173
+ short tgp_bn = min(BN, gemm_params->N - c_col);
174
+ const int ldc = N * params->groups;
175
+ mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm));
176
+ }
@@ -0,0 +1,225 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h"
4
+
5
+ constant bool align_C [[function_constant(200)]];
6
+
7
+ template <
8
+ typename T,
9
+ int BM,
10
+ int BN,
11
+ int BK,
12
+ int WM,
13
+ int WN,
14
+ typename AccumType = float,
15
+ typename Epilogue = TransformNone<T, AccumType>>
16
+ [[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
17
+ implicit_gemm_conv_2d_general(
18
+ const device T* A [[buffer(0)]],
19
+ const device T* B [[buffer(1)]],
20
+ device T* C [[buffer(2)]],
21
+ const constant MLXConvParams<2>* params [[buffer(3)]],
22
+ const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
23
+ const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
24
+ const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
25
+ const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
26
+ uint3 tid [[threadgroup_position_in_grid]],
27
+ uint3 lid [[thread_position_in_threadgroup]],
28
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
29
+ uint simd_lid [[thread_index_in_simdgroup]]) {
30
+ (void)lid;
31
+
32
+ constexpr bool transpose_a = false;
33
+ constexpr bool transpose_b = true;
34
+ constexpr short tgp_padding_a = 16 / sizeof(T);
35
+ constexpr short tgp_padding_b = 16 / sizeof(T);
36
+
37
+ constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a;
38
+ constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b;
39
+ constexpr short shape_a_rows = (transpose_a ? BK : BM);
40
+ constexpr short shape_b_rows = (transpose_b ? BN : BK);
41
+ constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows;
42
+ constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows;
43
+
44
+ constexpr short tgp_size = WM * WN * 32;
45
+
46
+ // Input loader
47
+ using loader_a_t =
48
+ Conv2DInputBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_a>;
49
+
50
+ // Weight loader
51
+ using loader_b_t =
52
+ Conv2DWeightBlockLoaderGeneral<T, BM, BN, BK, tgp_size, tgp_padding_b>;
53
+
54
+ using mma_t = BlockMMA<
55
+ T,
56
+ T,
57
+ BM,
58
+ BN,
59
+ BK,
60
+ WM,
61
+ WN,
62
+ transpose_a,
63
+ transpose_b,
64
+ shape_a_cols,
65
+ shape_b_cols>;
66
+
67
+ threadgroup T As[tgp_mem_size_a];
68
+ threadgroup T Bs[tgp_mem_size_b];
69
+
70
+ const int tid_y = ((tid.y) << gemm_params->swizzle_log) +
71
+ ((tid.x) & ((1 << gemm_params->swizzle_log) - 1));
72
+ const int tid_x = (tid.x) >> gemm_params->swizzle_log;
73
+
74
+ if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) {
75
+ return;
76
+ }
77
+
78
+ const int tid_z = tid.z;
79
+
80
+ const int base_oh = tid_z / jump_params->f_out_jump_w;
81
+ const int base_ow = tid_z % jump_params->f_out_jump_w;
82
+
83
+ const int base_wh = base_h[base_oh].weight_base;
84
+ const int base_ww = base_w[base_ow].weight_base;
85
+
86
+ const int base_wh_size = base_h[base_oh].weight_size;
87
+ const int base_ww_size = base_w[base_ow].weight_size;
88
+
89
+ const int c_row = tid_y * BM;
90
+ const int c_col = tid_x * BN;
91
+ const int K = gemm_params->K;
92
+
93
+ B += c_col * K;
94
+
95
+ const int4 offsets_a(0, c_row, base_oh, base_ow);
96
+ const int2 offsets_b(0, c_col);
97
+
98
+ // Prepare threadgroup loading operations
99
+ loader_a_t loader_a(
100
+ A,
101
+ As,
102
+ offsets_a,
103
+ params,
104
+ jump_params,
105
+ base_wh,
106
+ base_ww,
107
+ simd_gid,
108
+ simd_lid);
109
+ loader_b_t loader_b(
110
+ B,
111
+ Bs,
112
+ offsets_b,
113
+ params,
114
+ jump_params,
115
+ base_wh,
116
+ base_ww,
117
+ simd_gid,
118
+ simd_lid);
119
+
120
+ // Prepare threadgroup mma operation
121
+ mma_t mma_op(simd_gid, simd_lid);
122
+
123
+ if (align_C) {
124
+ int gemm_k_iterations =
125
+ base_wh_size * base_ww_size * gemm_params->gemm_k_iterations;
126
+
127
+ for (int k = 0; k < gemm_k_iterations; k++) {
128
+ threadgroup_barrier(mem_flags::mem_threadgroup);
129
+ // Load elements into threadgroup
130
+ loader_a.load_unsafe();
131
+ loader_b.load_unsafe();
132
+
133
+ threadgroup_barrier(mem_flags::mem_threadgroup);
134
+
135
+ // Multiply and accumulate threadgroup elements
136
+ mma_op.mma(As, Bs);
137
+
138
+ // Prepare for next iteration
139
+ loader_a.next();
140
+ loader_b.next();
141
+ }
142
+ }
143
+
144
+ else {
145
+ for (int k = 1; k < gemm_params->gemm_k_iterations; k++) {
146
+ for (int j = 0; j < base_wh_size * base_ww_size; j++) {
147
+ threadgroup_barrier(mem_flags::mem_threadgroup);
148
+ // Load elements into threadgroup
149
+ loader_a.load_unsafe();
150
+ loader_b.load_unsafe();
151
+
152
+ threadgroup_barrier(mem_flags::mem_threadgroup);
153
+
154
+ // Multiply and accumulate threadgroup elements
155
+ mma_op.mma(As, Bs);
156
+
157
+ // Prepare for next iteration
158
+ loader_a.next();
159
+ loader_b.next();
160
+ }
161
+ }
162
+ const short remaining_k = params->C % BK;
163
+ for (int j = 0; j < base_wh_size * base_ww_size; j++) {
164
+ // Load elements into threadgroup
165
+ threadgroup_barrier(mem_flags::mem_threadgroup);
166
+ loader_a.load_safe(remaining_k);
167
+ loader_b.load_safe(remaining_k);
168
+ threadgroup_barrier(mem_flags::mem_threadgroup);
169
+ // Multiply and accumulate threadgroup elements
170
+ mma_op.mma(As, Bs);
171
+ // Prepare for next iteration
172
+ loader_a.next();
173
+ loader_b.next();
174
+ }
175
+ }
176
+
177
+ threadgroup_barrier(mem_flags::mem_none);
178
+
179
+ // Store results to device memory
180
+ {
181
+ // Adjust for simdgroup and thread location
182
+ int offset_m = c_row + mma_op.sm;
183
+ int offset_n = c_col + mma_op.sn;
184
+ C += offset_n;
185
+
186
+ if (offset_n >= gemm_params->N)
187
+ return;
188
+
189
+ short diff = gemm_params->N - offset_n;
190
+
191
+ STEEL_PRAGMA_UNROLL
192
+ for (int i = 0; i < mma_t::TM; i++) {
193
+ int cm = offset_m + i * mma_t::TM_stride;
194
+
195
+ int n = cm / jump_params->adj_out_hw;
196
+ int hw = cm % jump_params->adj_out_hw;
197
+ int oh =
198
+ (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh;
199
+ int ow =
200
+ (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow;
201
+
202
+ if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) {
203
+ int offset_cm = n * params->out_strides[0] +
204
+ oh * params->out_strides[1] + ow * params->out_strides[2];
205
+
206
+ STEEL_PRAGMA_UNROLL
207
+ for (int j = 0; j < mma_t::TN; j++) {
208
+ // Get accumulated result and associated offset in C
209
+ thread const auto& accum = mma_op.Ctile.frag_at(i, j);
210
+ int offset = offset_cm + (j * mma_t::TN_stride);
211
+
212
+ constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;
213
+
214
+ // Apply epilogue and output C
215
+ STEEL_PRAGMA_UNROLL
216
+ for (short k = 0; k < kelems; k++) {
217
+ if ((j * mma_t::TN_stride + k) < diff) {
218
+ C[offset + k] = Epilogue::apply(accum[k]);
219
+ }
220
+ }
221
+ }
222
+ }
223
+ }
224
+ }
225
+ }
@@ -0,0 +1,6 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h"
6
+ #include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h"