mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,2572 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <algorithm>
4
+ #include <cassert>
5
+ #include <numeric>
6
+ #include <sstream>
7
+
8
+ #include "mlx/backend/common/broadcasting.h"
9
+ #include "mlx/backend/common/matmul.h"
10
+ #include "mlx/backend/gpu/copy.h"
11
+ #include "mlx/backend/metal/binary.h"
12
+ #include "mlx/backend/metal/device.h"
13
+ #include "mlx/backend/metal/kernels.h"
14
+ #include "mlx/backend/metal/kernels/defines.h"
15
+ #include "mlx/backend/metal/kernels/steel/gemm/params.h"
16
+ #include "mlx/backend/metal/matmul.h"
17
+ #include "mlx/backend/metal/utils.h"
18
+ #include "mlx/primitives.h"
19
+ #include "mlx/utils.h"
20
+
21
+ namespace mlx::core {
22
+
23
+ namespace {
24
+
25
+ std::tuple<bool, int64_t, array> check_transpose(
26
+ std::vector<array>& copies,
27
+ const Stream& s,
28
+ const array& arr,
29
+ bool is_vector) {
30
+ auto stx = arr.strides()[arr.ndim() - 2];
31
+ auto sty = arr.strides()[arr.ndim() - 1];
32
+ if (sty == 1 && (!is_vector || stx == arr.shape(-1))) {
33
+ return std::make_tuple(false, stx, arr);
34
+ } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
35
+ return std::make_tuple(true, sty, arr);
36
+ } else {
37
+ array arr_copy = contiguous_copy_gpu(arr, s);
38
+ copies.push_back(arr_copy);
39
+ return std::make_tuple(false, arr.shape(-1), arr_copy);
40
+ }
41
+ };
42
+
43
+ inline array
44
+ ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
45
+ if (!x.flags().row_contiguous) {
46
+ array x_copy = contiguous_copy_gpu(x, s);
47
+ d.add_temporary(x_copy, s.index);
48
+ return x_copy;
49
+ } else {
50
+ return x;
51
+ }
52
+ }
53
+
54
+ inline std::tuple<bool, int64_t, array>
55
+ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
56
+ if (x.flags().row_contiguous) {
57
+ return std::make_tuple(false, x.strides()[x.ndim() - 2], x);
58
+ }
59
+
60
+ bool rc = true;
61
+ for (int i = 0; i < x.ndim() - 3; i++) {
62
+ rc &= x.strides()[i + 1] * x.shape(i) == x.strides()[i];
63
+ }
64
+ if (rc) {
65
+ auto stx = x.strides()[x.ndim() - 2];
66
+ auto sty = x.strides()[x.ndim() - 1];
67
+ auto K = x.shape(-2);
68
+ auto N = x.shape(-1);
69
+ if (sty == 1 && (N != 1 || stx == N)) {
70
+ return std::make_tuple(false, stx, x);
71
+ }
72
+ if (stx == 1 && (N != 1 || sty == K)) {
73
+ return std::make_tuple(true, sty, x);
74
+ }
75
+ }
76
+
77
+ array x_copy = contiguous_copy_gpu(x, s);
78
+ d.add_temporary(x_copy, s.index);
79
+ return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
80
+ }
81
+
82
+ } // namespace
83
+
84
+ ///////////////////////////////////////////////////////////////////////////////
85
+ // Steel matmul fallback
86
+ ///////////////////////////////////////////////////////////////////////////////
87
+
88
+ #define GEMM_TPARAM_MACRO(devc) \
89
+ if (devc == 'g' || devc == 'p') { /* Small device */ \
90
+ if (out.dtype() == complex64) { \
91
+ bm = 64; \
92
+ bn = 32; \
93
+ bk = 8; \
94
+ wm = 4; \
95
+ wn = 1; \
96
+ } else if (!transpose_a && transpose_b) { /* nt */ \
97
+ bm = 64; \
98
+ bn = 32; \
99
+ bk = 32; \
100
+ wm = 2; \
101
+ wn = 2; \
102
+ } else if (out.dtype() != float32) { /* half and bfloat */ \
103
+ bm = 64; \
104
+ bn = 64; \
105
+ bk = 16; \
106
+ wm = 1; \
107
+ wn = 2; \
108
+ } \
109
+ } else if (devc == 'd') { /* Large device */ \
110
+ if ((size_t)batch_size_out * M * N >= 1ul << 20) { /* large matmul */ \
111
+ if (out.dtype() != float32) { /* half and bfloat */ \
112
+ if (2 * std::max(M, N) > K) { /* Reasonable K */ \
113
+ bm = 64; \
114
+ bn = 64; \
115
+ bk = 16; \
116
+ wm = 1; \
117
+ wn = 2; \
118
+ } else if (!transpose_a && transpose_b) { /* nt with large k */ \
119
+ bm = 64; \
120
+ bn = 32; \
121
+ bk = 32; \
122
+ wm = 2; \
123
+ wn = 2; \
124
+ } else { /* nn with large K */ \
125
+ bm = 32; \
126
+ bn = 64; \
127
+ bk = 16; \
128
+ wm = 1; \
129
+ wn = 2; \
130
+ } \
131
+ } /* float takes default */ \
132
+ } else { /* smaller matmul */ \
133
+ if (out.dtype() != float32) { /* half and bfloat */ \
134
+ if (!transpose_a && transpose_b) { /* nt */ \
135
+ bm = 64; \
136
+ bn = 32; \
137
+ bk = 32; \
138
+ wm = 2; \
139
+ wn = 2; \
140
+ } else { /* nn */ \
141
+ bm = 64; \
142
+ bn = 64; \
143
+ bk = 16; \
144
+ wm = 1; \
145
+ wn = 2; \
146
+ } \
147
+ } else { /* floats */ \
148
+ if (!transpose_a && transpose_b) { /* nt */ \
149
+ bm = 32; \
150
+ bn = 64; \
151
+ bk = 16; \
152
+ wm = 1; \
153
+ wn = 2; \
154
+ } else { /* nn */ \
155
+ bm = 64; \
156
+ bn = 32; \
157
+ bk = 32; \
158
+ wm = 2; \
159
+ wn = 2; \
160
+ } \
161
+ } \
162
+ } \
163
+ } else { /* Medium device */ \
164
+ bm = 64; \
165
+ bn = 64; \
166
+ bk = 16; \
167
+ wm = 2; \
168
+ wn = 2; \
169
+ }
170
+
171
+ ///////////////////////////////////////////////////////////////////////////////
172
+ // Regular steel matmul dispatch
173
+ ///////////////////////////////////////////////////////////////////////////////
174
+
175
+ template <bool CHECK_AB>
176
+ void steel_matmul_regular_axpby_nax(
177
+ const Stream& s,
178
+ metal::Device& d,
179
+ const array& a,
180
+ const array& b,
181
+ const array& c,
182
+ array& out,
183
+ int M,
184
+ int N,
185
+ int K,
186
+ int batch_size_out,
187
+ int lda,
188
+ int ldb,
189
+ int ldd,
190
+ bool transpose_a,
191
+ bool transpose_b,
192
+ std::vector<array>& copies,
193
+ Shape batch_shape,
194
+ Strides batch_strides,
195
+ int64_t A_batch_stride,
196
+ int64_t B_batch_stride,
197
+ int64_t matrix_stride_out,
198
+ int64_t C_batch_stride /* = 0*/,
199
+ float alpha /* = 1.0f */,
200
+ float beta /* = 0.0f */) {
201
+ using namespace mlx::steel;
202
+
203
+ // Determine dispatch kernel
204
+ int bm = 128, bn = 128, bk = 512;
205
+ int wm = 4, wn = 4;
206
+
207
+ // Prepare kernel name
208
+ std::ostringstream kname;
209
+
210
+ // clang-format off
211
+ kname << "steel_gemm_fused_nax_"
212
+ << (transpose_a ? 't' : 'n')
213
+ << (transpose_b ? 't' : 'n')
214
+ << "_" << type_to_name(a)
215
+ << "_" << type_to_name(out)
216
+ << "_bm" << bm << "_bn" << bn << "_bk" << bk
217
+ << "_wm" << wm << "_wn" << wn; // clang-format on
218
+
219
+ std::string base_name = kname.str();
220
+
221
+ const bool has_batch = (batch_shape.size() > 1);
222
+ const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
223
+ const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
224
+ const bool align_M = (M % bm) == 0;
225
+ const bool align_N = (N % bn) == 0;
226
+ const bool align_K = (K % bk) == 0;
227
+
228
+ metal::MTLFCList func_consts = {
229
+ {&has_batch, MTL::DataType::DataTypeBool, 10},
230
+ {&use_out_source, MTL::DataType::DataTypeBool, 100},
231
+ {&do_axpby, MTL::DataType::DataTypeBool, 110},
232
+ {&align_M, MTL::DataType::DataTypeBool, 200},
233
+ {&align_N, MTL::DataType::DataTypeBool, 201},
234
+ {&align_K, MTL::DataType::DataTypeBool, 202},
235
+ };
236
+
237
+ // clang-format off
238
+ kname << "_has_batch_" << (has_batch ? 't' : 'n')
239
+ << "_use_out_source_" << (use_out_source ? 't' : 'n')
240
+ << "_do_axpby_" << (do_axpby ? 't' : 'n')
241
+ << "_align_M_" << (align_M ? 't' : 'n')
242
+ << "_align_N_" << (align_N ? 't' : 'n')
243
+ << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
244
+
245
+ std::string hash_name = kname.str();
246
+
247
+ // Encode and dispatch kernel
248
+ auto& compute_encoder = d.get_command_encoder(s.index);
249
+ auto kernel = get_steel_gemm_fused_nax_kernel(
250
+ /* metal::Device& d = */ d,
251
+ /* const std::string& kernel_name = */ base_name,
252
+ /* const std::string& hash_name = */ hash_name,
253
+ /* const metal::MTLFCList& func_consts = */ func_consts,
254
+ /* const array& out = */ out,
255
+ /* bool transpose_a = */ transpose_a,
256
+ /* bool transpose_b = */ transpose_b,
257
+ /* int bm = */ bm,
258
+ /* int bn = */ bn,
259
+ /* int bk = */ bk,
260
+ /* int wm = */ wm,
261
+ /* int wn = */ wn);
262
+
263
+ compute_encoder.set_compute_pipeline_state(kernel);
264
+
265
+ // Use problem size to determine threadblock swizzle
266
+ int tn = (N + bn - 1) / bn;
267
+ int tm = (M + bm - 1) / bm;
268
+
269
+ // TODO: Explore device-based tuning for swizzle
270
+ int swizzle_log = tm <= 3 ? 0 : 1;
271
+
272
+ // Prepare steel matmul params
273
+ GEMMParams params{/* const int M = */ M,
274
+ /* const int N = */ N,
275
+ /* const int K = */ K,
276
+ /* const int lda = */ lda,
277
+ /* const int ldb = */ ldb,
278
+ /* const int ldd = */ ldd,
279
+ /* const int tiles_n = */ tn,
280
+ /* const int tiles_m = */ tm,
281
+ /* const int64_t batch_stride_a = */ A_batch_stride,
282
+ /* const int64_t batch_stride_b = */ B_batch_stride,
283
+ /* const int64_t batch_stride_d = */ matrix_stride_out,
284
+ /* const int swizzle_log = */ swizzle_log,
285
+ /* const int gemm_k_iterations_aligned = */ (K / bk),
286
+ /* const int batch_ndim = */ int(batch_shape.size())};
287
+
288
+ // Prepare launch grid params
289
+ int tile = 1 << swizzle_log;
290
+ tm = (tm + tile - 1) / tile;
291
+ tn = tn * tile;
292
+
293
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
294
+ MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
295
+
296
+ // Launch kernel
297
+ compute_encoder.set_input_array(a, 0);
298
+ compute_encoder.set_input_array(b, 1);
299
+ compute_encoder.set_output_array(out, 3);
300
+
301
+ compute_encoder.set_bytes(params, 4);
302
+
303
+ if (has_batch) {
304
+ compute_encoder.set_vector_bytes(batch_shape, 6);
305
+ compute_encoder.set_vector_bytes(batch_strides, 7);
306
+ }
307
+
308
+ if (use_out_source) {
309
+ int ldc = c.strides()[c.ndim() - 2];
310
+ int fdc = c.strides()[c.ndim() - 1];
311
+
312
+ GEMMAddMMParams params{/* const int ldc = */ ldc,
313
+ /* const int fdc = */ fdc,
314
+ /* const int64_t batch_stride_c = */ C_batch_stride,
315
+ /* const float alpha = */ alpha,
316
+ /* const float beta = */ beta};
317
+
318
+ compute_encoder.set_input_array(c, 2);
319
+ compute_encoder.set_bytes(params, 5);
320
+ }
321
+
322
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
323
+
324
+ // Record copies
325
+ d.add_temporaries(std::move(copies), s.index);
326
+ }
327
+
328
+ template <bool CHECK_AB>
329
+ void steel_matmul_regular_axpby(
330
+ const Stream& s,
331
+ metal::Device& d,
332
+ const array& a,
333
+ const array& b,
334
+ const array& c,
335
+ array& out,
336
+ int M,
337
+ int N,
338
+ int K,
339
+ int batch_size_out,
340
+ int lda,
341
+ int ldb,
342
+ int ldd,
343
+ bool transpose_a,
344
+ bool transpose_b,
345
+ std::vector<array>& copies,
346
+ Shape batch_shape,
347
+ Strides batch_strides,
348
+ int64_t A_batch_stride,
349
+ int64_t B_batch_stride,
350
+ int64_t matrix_stride_out,
351
+ int64_t C_batch_stride /* = 0*/,
352
+ float alpha /* = 1.0f */,
353
+ float beta /* = 0.0f */) {
354
+ if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) &&
355
+ (env::enable_tf32() || a.dtype() != float32)) {
356
+ return steel_matmul_regular_axpby_nax<CHECK_AB>(
357
+ /* const Stream& s = */ s,
358
+ /* metal::Device& d = */ d,
359
+ /* const array& a = */ a,
360
+ /* const array& b = */ b,
361
+ /* const array& c = */ c,
362
+ /* array& out = */ out,
363
+ /* int M = */ M,
364
+ /* int N = */ N,
365
+ /* int K = */ K,
366
+ /* int batch_size_out = */ batch_size_out,
367
+ /* int lda = */ lda,
368
+ /* int ldb = */ ldb,
369
+ /* int ldd = */ ldd,
370
+ /* bool transpose_a = */ transpose_a,
371
+ /* bool transpose_b = */ transpose_b,
372
+ /* std::vector<array>& copies = */ copies,
373
+ /* Shape batch_shape = */ batch_shape,
374
+ /* Strides batch_strides = */ batch_strides,
375
+ /* int64_t A_batch_stride = */ A_batch_stride,
376
+ /* int64_t B_batch_stride = */ B_batch_stride,
377
+ /* int64_t matrix_stride_out = */ matrix_stride_out,
378
+ /* int64_t C_batch_stride = */ C_batch_stride,
379
+ /* float alpha = */ alpha,
380
+ /* float beta = */ beta);
381
+ }
382
+
383
+ using namespace mlx::steel;
384
+
385
+ // Determine dispatch kernel
386
+ int bm = 64, bn = 64, bk = 16;
387
+ int wm = 2, wn = 2;
388
+
389
+ char devc = d.get_architecture().back();
390
+ GEMM_TPARAM_MACRO(devc)
391
+
392
+ // Prepare kernel name
393
+ std::ostringstream kname;
394
+
395
+ // clang-format off
396
+ kname << "steel_gemm_fused_"
397
+ << (transpose_a ? 't' : 'n')
398
+ << (transpose_b ? 't' : 'n')
399
+ << "_" << type_to_name(a)
400
+ << "_" << type_to_name(out)
401
+ << "_bm" << bm << "_bn" << bn << "_bk" << bk
402
+ << "_wm" << wm << "_wn" << wn; // clang-format on
403
+
404
+ std::string base_name = kname.str();
405
+
406
+ const bool has_batch = (batch_shape.size() > 1);
407
+ const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f);
408
+ const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f);
409
+ const bool align_M = (M % bm) == 0;
410
+ const bool align_N = (N % bn) == 0;
411
+ const bool align_K = (K % bk) == 0;
412
+
413
+ metal::MTLFCList func_consts = {
414
+ {&has_batch, MTL::DataType::DataTypeBool, 10},
415
+ {&use_out_source, MTL::DataType::DataTypeBool, 100},
416
+ {&do_axpby, MTL::DataType::DataTypeBool, 110},
417
+ {&align_M, MTL::DataType::DataTypeBool, 200},
418
+ {&align_N, MTL::DataType::DataTypeBool, 201},
419
+ {&align_K, MTL::DataType::DataTypeBool, 202},
420
+ };
421
+
422
+ // clang-format off
423
+ kname << "_has_batch_" << (has_batch ? 't' : 'n')
424
+ << "_use_out_source_" << (use_out_source ? 't' : 'n')
425
+ << "_do_axpby_" << (do_axpby ? 't' : 'n')
426
+ << "_align_M_" << (align_M ? 't' : 'n')
427
+ << "_align_N_" << (align_N ? 't' : 'n')
428
+ << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
429
+
430
+ std::string hash_name = kname.str();
431
+
432
+ // Encode and dispatch kernel
433
+ auto& compute_encoder = d.get_command_encoder(s.index);
434
+ auto kernel = get_steel_gemm_fused_kernel(
435
+ /* metal::Device& d = */ d,
436
+ /* const std::string& kernel_name = */ base_name,
437
+ /* const std::string& hash_name = */ hash_name,
438
+ /* const metal::MTLFCList& func_consts = */ func_consts,
439
+ /* const array& out = */ out,
440
+ /* bool transpose_a = */ transpose_a,
441
+ /* bool transpose_b = */ transpose_b,
442
+ /* int bm = */ bm,
443
+ /* int bn = */ bn,
444
+ /* int bk = */ bk,
445
+ /* int wm = */ wm,
446
+ /* int wn = */ wn);
447
+
448
+ compute_encoder.set_compute_pipeline_state(kernel);
449
+
450
+ // Use problem size to determine threadblock swizzle
451
+ int tn = (N + bn - 1) / bn;
452
+ int tm = (M + bm - 1) / bm;
453
+
454
+ // TODO: Explore device-based tuning for swizzle
455
+ int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
456
+
457
+ // Prepare steel matmul params
458
+ GEMMParams params{/* const int M = */ M,
459
+ /* const int N = */ N,
460
+ /* const int K = */ K,
461
+ /* const int lda = */ lda,
462
+ /* const int ldb = */ ldb,
463
+ /* const int ldd = */ ldd,
464
+ /* const int tiles_n = */ tn,
465
+ /* const int tiles_m = */ tm,
466
+ /* const int64_t batch_stride_a = */ A_batch_stride,
467
+ /* const int64_t batch_stride_b = */ B_batch_stride,
468
+ /* const int64_t batch_stride_d = */ matrix_stride_out,
469
+ /* const int swizzle_log = */ swizzle_log,
470
+ /* const int gemm_k_iterations_aligned = */ (K / bk),
471
+ /* const int batch_ndim = */ int(batch_shape.size())};
472
+
473
+ // Prepare launch grid params
474
+ int tile = 1 << swizzle_log;
475
+ tm = (tm + tile - 1) / tile;
476
+ tn = tn * tile;
477
+
478
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
479
+ MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
480
+
481
+ // Launch kernel
482
+ compute_encoder.set_input_array(a, 0);
483
+ compute_encoder.set_input_array(b, 1);
484
+ compute_encoder.set_output_array(out, 3);
485
+
486
+ compute_encoder.set_bytes(params, 4);
487
+
488
+ if (has_batch) {
489
+ compute_encoder.set_vector_bytes(batch_shape, 6);
490
+ compute_encoder.set_vector_bytes(batch_strides, 7);
491
+ }
492
+
493
+ if (use_out_source) {
494
+ int ldc = c.strides()[c.ndim() - 2];
495
+ int fdc = c.strides()[c.ndim() - 1];
496
+
497
+ GEMMAddMMParams params{/* const int ldc = */ ldc,
498
+ /* const int fdc = */ fdc,
499
+ /* const int64_t batch_stride_c = */ C_batch_stride,
500
+ /* const float alpha = */ alpha,
501
+ /* const float beta = */ beta};
502
+
503
+ compute_encoder.set_input_array(c, 2);
504
+ compute_encoder.set_bytes(params, 5);
505
+ }
506
+
507
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
508
+
509
+ // Record copies
510
+ d.add_temporaries(std::move(copies), s.index);
511
+ }
512
+
513
+ ///////////////////////////////////////////////////////////////////////////////
514
+ // Split k steel matmul
515
+ ///////////////////////////////////////////////////////////////////////////////
516
+
517
+ template <bool CHECK_AB = true>
518
+ void steel_gemm_splitk_axpby(
519
+ const Stream& s,
520
+ metal::Device& d,
521
+ const array& a,
522
+ const array& b,
523
+ const array& c,
524
+ array& out,
525
+ int M,
526
+ int N,
527
+ int K,
528
+ int batch_size_out,
529
+ int lda,
530
+ int ldb,
531
+ bool transpose_a,
532
+ bool transpose_b,
533
+ std::vector<array>& copies,
534
+ float alpha = 1.0f,
535
+ float beta = 0.0f) {
536
+ using namespace mlx::steel;
537
+
538
+ int _tm = (M + 32 - 1) / 32;
539
+ int _tn = (N + 32 - 1) / 32;
540
+ int _tk = K / 16;
541
+
542
+ int bm = M < 40 ? 16 : 32;
543
+ int bn = N < 40 ? 16 : 32;
544
+ int bk = 16;
545
+ int wm = 2, wn = 2;
546
+
547
+ // As _tk grows use more partitions, as _tm * _tn grow use fewer partitions
548
+ int split_k_partitions =
549
+ std::min(std::max(2, next_power_of_2(_tk / (_tm * _tn))), 32);
550
+ int split_k_partition_stride = M * N;
551
+ int gemm_k_iterations = (K / bk) / split_k_partitions;
552
+ int split_k_partition_size = gemm_k_iterations * bk;
553
+
554
+ array C_split(
555
+ {split_k_partitions, M, N},
556
+ issubdtype(out.dtype(), complexfloating) ? complex64 : float32,
557
+ nullptr,
558
+ {});
559
+ C_split.set_data(allocator::malloc(C_split.nbytes()));
560
+ copies.push_back(C_split);
561
+
562
+ bool mn_aligned = M % bm == 0 && N % bn == 0;
563
+ bool k_aligned = K % bk == 0;
564
+ std::ostringstream kname;
565
+
566
+ // clang-format off
567
+ kname << "steel_gemm_splitk_"
568
+ << (transpose_a ? 't' : 'n')
569
+ << (transpose_b ? 't' : 'n')
570
+ << "_" << type_to_name(a)
571
+ << "_" << type_to_name(C_split)
572
+ << "_bm" << bm << "_bn" << bn << "_bk" << bk
573
+ << "_wm" << wm << "_wn" << wn
574
+ << "_MN_" << (mn_aligned ? "t" : "n") << "aligned"
575
+ << "_K_" << (k_aligned ? "t" : "n") << "aligned"; // clang-format on
576
+
577
+ // Encode and dispatch gemm kernel
578
+ auto& compute_encoder = d.get_command_encoder(s.index);
579
+ auto kernel = get_steel_gemm_splitk_kernel(
580
+ /* metal::Device& d = */ d,
581
+ /* const std::string& kernel_name = */ kname.str(),
582
+ /* const array& in = */ a,
583
+ /* const array& out = */ C_split,
584
+ /* bool transpose_a = */ transpose_a,
585
+ /* bool transpose_b = */ transpose_b,
586
+ /* int bm = */ bm,
587
+ /* int bn = */ bn,
588
+ /* int bk = */ bk,
589
+ /* int wm = */ wm,
590
+ /* int wn = */ wn,
591
+ /* bool mn_aligned = */ mn_aligned,
592
+ /* bool k_aligned = */ k_aligned);
593
+
594
+ compute_encoder.set_compute_pipeline_state(kernel);
595
+
596
+ int tn = (N + bn - 1) / bn;
597
+ int tm = (M + bm - 1) / bm;
598
+
599
+ GEMMSpiltKParams params{
600
+ /* const int M = */ M,
601
+ /* const int N = */ N,
602
+ /* const int K = */ K,
603
+ /* const int lda = */ lda,
604
+ /* const int ldb = */ ldb,
605
+ /* const int ldc = */ N,
606
+ /* const int tiles_n = */ tn,
607
+ /* const int tiles_m = */ tm,
608
+ /* const int split_k_partitions = */ split_k_partitions,
609
+ /* const int split_k_partition_stride = */ split_k_partition_stride,
610
+ /* const int split_k_partition_size = */ split_k_partition_size,
611
+ /* const int swizzle_log = */ 0, // no swizzle
612
+ /* const int gemm_k_iterations_aligned = */ gemm_k_iterations};
613
+
614
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
615
+ MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions);
616
+
617
+ compute_encoder.set_input_array(a, 0);
618
+ compute_encoder.set_input_array(b, 1);
619
+ compute_encoder.set_output_array(C_split, 2);
620
+
621
+ compute_encoder.set_bytes(params, 3);
622
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
623
+
624
+ // Do accum kernel
625
+ {
626
+ const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
627
+
628
+ auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
629
+ type_to_name(C_split);
630
+
631
+ if (do_axpby) {
632
+ kernel_name = kernel_name + "_axbpy";
633
+ }
634
+
635
+ auto kernel = get_steel_gemm_splitk_accum_kernel(
636
+ /* metal::Device& d = */ d,
637
+ /* const std::string& kernel_name = */ kernel_name,
638
+ /* const array& in = */ C_split,
639
+ /* const array& out = */ out,
640
+ /* bool axbpy = */ do_axpby);
641
+ compute_encoder.set_compute_pipeline_state(kernel);
642
+
643
+ // Set the arguments for the kernel
644
+ compute_encoder.set_input_array(C_split, 0);
645
+ compute_encoder.set_output_array(out, 1);
646
+ compute_encoder.set_bytes(split_k_partitions, 2);
647
+ compute_encoder.set_bytes(split_k_partition_stride, 3);
648
+ compute_encoder.set_bytes(N, 4);
649
+
650
+ if (do_axpby) {
651
+ int ldc = c.strides()[c.ndim() - 2];
652
+ int fdc = c.strides()[c.ndim() - 1];
653
+
654
+ compute_encoder.set_input_array(c, 5);
655
+ compute_encoder.set_bytes(ldc, 6);
656
+ compute_encoder.set_bytes(fdc, 7);
657
+ compute_encoder.set_bytes(alpha, 8);
658
+ compute_encoder.set_bytes(beta, 9);
659
+ }
660
+
661
+ // Launch enough thread groups for each output
662
+ MTL::Size grid_dims = MTL::Size(N, M, 1);
663
+ auto group_dims = get_block_dims(N, M, 1);
664
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
665
+ }
666
+
667
+ d.add_temporaries(std::move(copies), s.index);
668
+ }
669
+
670
+ ///////////////////////////////////////////////////////////////////////////////
671
+ // NAX Split k steel matmul
672
+ ///////////////////////////////////////////////////////////////////////////////
673
+
674
+ template <bool CHECK_AB = true>
675
+ void steel_gemm_splitk_axpby_nax(
676
+ const Stream& s,
677
+ metal::Device& d,
678
+ const array& a,
679
+ const array& b,
680
+ const array& c,
681
+ array& out,
682
+ int M,
683
+ int N,
684
+ int K,
685
+ int batch_size_out,
686
+ int lda,
687
+ int ldb,
688
+ bool transpose_a,
689
+ bool transpose_b,
690
+ std::vector<array>& copies,
691
+ float alpha = 1.0f,
692
+ float beta = 0.0f) {
693
+ using namespace mlx::steel;
694
+
695
+ constexpr int bm = 128, bn = 128, bk = 512;
696
+ constexpr int wm = 4, wn = 4;
697
+
698
+ // Determine how many partitions to split K into
699
+ constexpr int split_k_partition_size = 3072;
700
+ int split_k_partitions =
701
+ (K + split_k_partition_size - 1) / split_k_partition_size;
702
+
703
+ const int bk_iters_per_partition = split_k_partition_size / bk;
704
+ const int split_k_partition_stride = M * N;
705
+
706
+ array C_split({split_k_partitions, M, N}, float32, nullptr, {});
707
+ C_split.set_data(allocator::malloc(C_split.nbytes()));
708
+ copies.push_back(C_split);
709
+
710
+ const bool align_M = (M % bm) == 0;
711
+ const bool align_N = (N % bn) == 0;
712
+ const bool align_K = (K % bk) == 0;
713
+
714
+ // Per-tile align_K is checked at runtime; only the last tile can be unaligned
715
+ metal::MTLFCList func_consts = {
716
+ {&align_M, MTL::DataType::DataTypeBool, 200},
717
+ {&align_N, MTL::DataType::DataTypeBool, 201}};
718
+
719
+ std::ostringstream kname;
720
+
721
+ // clang-format off
722
+ kname << "steel_gemm_splitk_nax_"
723
+ << (transpose_a ? 't' : 'n')
724
+ << (transpose_b ? 't' : 'n')
725
+ << "_" << type_to_name(a)
726
+ << "_" << type_to_name(C_split)
727
+ << "_bm" << bm << "_bn" << bn << "_bk" << bk
728
+ << "_wm" << wm << "_wn" << wn; // clang-format on
729
+
730
+ std::string base_name = kname.str();
731
+
732
+ // clang-format off
733
+ kname << "_align_M_" << (align_M ? 't' : 'n')
734
+ << "_align_N_" << (align_N ? 't' : 'n')
735
+ << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
736
+
737
+ std::string hash_name = kname.str();
738
+
739
+ auto& compute_encoder = d.get_command_encoder(s.index);
740
+ auto kernel = get_steel_gemm_splitk_nax_kernel(
741
+ /* metal::Device& d = */ d,
742
+ /* const std::string& kernel_name = */ base_name,
743
+ /* const std::string& hash_name = */ hash_name,
744
+ /* const metal::MTLFCList& func_consts = */ func_consts,
745
+ /* const array& out = */ C_split,
746
+ /* bool transpose_a = */ transpose_a,
747
+ /* bool transpose_b = */ transpose_b,
748
+ /* int bm = */ bm,
749
+ /* int bn = */ bn,
750
+ /* int bk = */ bk,
751
+ /* int wm = */ wm,
752
+ /* int wn = */ wn);
753
+
754
+ compute_encoder.set_compute_pipeline_state(kernel);
755
+
756
+ int tn = (N + bn - 1) / bn;
757
+ int tm = (M + bm - 1) / bm;
758
+
759
+ int swizzle_log = tm <= 3 ? 0 : 1;
760
+
761
+ // Compute swizzled tile counts
762
+ int tile = 1 << swizzle_log;
763
+ int tm_swizzled = (tm + tile - 1) / tile;
764
+ int tn_swizzled = tn * tile;
765
+
766
+ GEMMSpiltKParams params{
767
+ /* const int M = */ M,
768
+ /* const int N = */ N,
769
+ /* const int K = */ K,
770
+ /* const int lda = */ lda,
771
+ /* const int ldb = */ ldb,
772
+ /* const int ldc = */ N,
773
+ /* const int tiles_n = */ tn,
774
+ /* const int tiles_m = */ tm,
775
+ /* const int split_k_partitions = */ split_k_partitions,
776
+ /* const int split_k_partition_stride = */ split_k_partition_stride,
777
+ /* const int split_k_partition_size = */ split_k_partition_size,
778
+ /* const int swizzle_log = */ swizzle_log,
779
+ /* const int gemm_k_iterations_aligned = */ bk_iters_per_partition};
780
+
781
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
782
+ // Use 1D grid with K-partition-major layout: [Partition0: M×N
783
+ // tiles][Partition1: M×N tiles]... Grid size is 1D to prevent driver/HW from
784
+ // using its own heuristic to exploit 2D locality by launching threadgroups in
785
+ // a non-linear order
786
+ MTL::Size grid_dims =
787
+ MTL::Size(tn_swizzled * tm_swizzled * split_k_partitions, 1, 1);
788
+
789
+ compute_encoder.set_input_array(a, 0);
790
+ compute_encoder.set_input_array(b, 1);
791
+ compute_encoder.set_output_array(C_split, 2);
792
+
793
+ compute_encoder.set_bytes(params, 3);
794
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
795
+
796
+ // Do accum kernel
797
+ {
798
+ const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
799
+
800
+ auto kernel_name = "steel_gemm_splitk_accum_" + type_to_name(out) + "_" +
801
+ type_to_name(C_split);
802
+
803
+ if (do_axpby) {
804
+ kernel_name = kernel_name + "_axbpy";
805
+ }
806
+
807
+ auto kernel = get_steel_gemm_splitk_accum_kernel(
808
+ /* metal::Device& d = */ d,
809
+ /* const std::string& kernel_name = */ kernel_name,
810
+ /* const array& in = */ C_split,
811
+ /* const array& out = */ out,
812
+ /* bool axbpy = */ do_axpby);
813
+ compute_encoder.set_compute_pipeline_state(kernel);
814
+
815
+ // Set the arguments for the kernel
816
+ compute_encoder.set_input_array(C_split, 0);
817
+ compute_encoder.set_output_array(out, 1);
818
+ compute_encoder.set_bytes(split_k_partitions, 2);
819
+ compute_encoder.set_bytes(split_k_partition_stride, 3);
820
+ compute_encoder.set_bytes(N, 4);
821
+
822
+ if (do_axpby) {
823
+ int ldc = c.strides()[c.ndim() - 2];
824
+ int fdc = c.strides()[c.ndim() - 1];
825
+
826
+ compute_encoder.set_input_array(c, 5);
827
+ compute_encoder.set_bytes(ldc, 6);
828
+ compute_encoder.set_bytes(fdc, 7);
829
+ compute_encoder.set_bytes(alpha, 8);
830
+ compute_encoder.set_bytes(beta, 9);
831
+ }
832
+
833
+ // Launch enough thread groups for each output
834
+ MTL::Size grid_dims = MTL::Size(N, M, 1);
835
+ auto group_dims = get_block_dims(N, M, 1);
836
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
837
+ }
838
+
839
+ d.add_temporaries(std::move(copies), s.index);
840
+ }
841
+
842
+ ///////////////////////////////////////////////////////////////////////////////
843
+ // Split matmul routing
844
+ ///////////////////////////////////////////////////////////////////////////////
845
+
846
+ template <bool CHECK_AB>
847
+ void steel_matmul_axpby(
848
+ const Stream& s,
849
+ metal::Device& d,
850
+ const array& a,
851
+ const array& b,
852
+ const array& c,
853
+ array& out,
854
+ int M,
855
+ int N,
856
+ int K,
857
+ int batch_size_out,
858
+ int lda,
859
+ int ldb,
860
+ bool transpose_a,
861
+ bool transpose_b,
862
+ std::vector<array>& copies,
863
+ Shape batch_shape /* = {} */,
864
+ Strides A_batch_stride /* = {} */,
865
+ Strides B_batch_stride /* = {} */,
866
+ Strides C_batch_stride /* = {} */,
867
+ float alpha /* = 1.0f */,
868
+ float beta /* = 0.0f */) {
869
+ if (batch_shape.empty()) {
870
+ /////////////////////////////////////////////////////////////////////////////
871
+ // Check and collapse batch dimensions
872
+ if constexpr (CHECK_AB) {
873
+ auto [batch_shape_, A_bstride_, B_bstride_, C_bstride_] =
874
+ collapse_batches(a, b, c);
875
+
876
+ batch_shape = batch_shape_;
877
+ A_batch_stride = A_bstride_;
878
+ B_batch_stride = B_bstride_;
879
+ C_batch_stride = C_bstride_;
880
+ // Collapse batches into M if needed
881
+ if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
882
+ a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
883
+ C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
884
+ B_batch_stride.back() == 0) {
885
+ M *= batch_shape.back();
886
+ batch_size_out = 1;
887
+
888
+ A_batch_stride = {0};
889
+ B_batch_stride = {0};
890
+ C_batch_stride = {0};
891
+ batch_shape = {1};
892
+ }
893
+ } else {
894
+ auto [batch_shape_, A_bstride_, B_bstride_] = collapse_batches(a, b);
895
+
896
+ batch_shape = batch_shape_;
897
+ A_batch_stride = A_bstride_;
898
+ B_batch_stride = B_bstride_;
899
+ // Collapse batches into M if needed
900
+ if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
901
+ a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
902
+ B_batch_stride.back() == 0) {
903
+ M *= batch_shape.back();
904
+ batch_size_out = 1;
905
+
906
+ A_batch_stride = {0};
907
+ B_batch_stride = {0};
908
+ batch_shape = {1};
909
+ }
910
+ }
911
+ }
912
+
913
+ /////////////////////////////////////////////////////////////////////////////
914
+ // Split K specialization
915
+
916
+ int _tm = (M + 16 - 1) / 16;
917
+ int _tn = (N + 16 - 1) / 16;
918
+ int _tk = K / 16;
919
+
920
+ // Case 1: Small M×N with large K, use SIMD split-K
921
+ char devc = d.get_architecture().back();
922
+ // Max and Ultra dispatch larger sizes to splitk
923
+ int min_tmn_threshold = (devc == 's' || devc == 'd') ? 2048 : 1024;
924
+ if (batch_size_out == 1 && (_tm * _tn) <= min_tmn_threshold && _tk >= 8 &&
925
+ K >= std::max(M, N)) {
926
+ return steel_gemm_splitk_axpby<CHECK_AB>(
927
+ /* const Stream& s = */ s,
928
+ /* metal::Device& d = */ d,
929
+ /* const array& a = */ a,
930
+ /* const array& b = */ b,
931
+ /* const array& c = */ c,
932
+ /* array& out = */ out,
933
+ /* int M = */ M,
934
+ /* int N = */ N,
935
+ /* int K = */ K,
936
+ /* int batch_size_out = */ batch_size_out,
937
+ /* int lda = */ lda,
938
+ /* int ldb = */ ldb,
939
+ /* bool transpose_a = */ transpose_a,
940
+ /* bool transpose_b = */ transpose_b,
941
+ /* std::vector<array>& copies = */ copies,
942
+ /* float alpha = */ alpha,
943
+ /* float beta = */ beta);
944
+ }
945
+
946
+ // Case 2: Large K with sufficient M, N, and NAX is available, use NAX split-K
947
+ // TODO: Add device-specific tuning for more NAX GPUs in the future
948
+ constexpr int min_mn_threshold = 2048 * 2048;
949
+ constexpr int min_k_threshold = 10240;
950
+ if (batch_size_out == 1 && metal::is_nax_available() &&
951
+ !issubdtype(a.dtype(), complexfloating) &&
952
+ (env::enable_tf32() || a.dtype() != float32) &&
953
+ int64_t(M) * N >= min_mn_threshold && K >= min_k_threshold &&
954
+ K >= (3 * std::max(M, N))) {
955
+ return steel_gemm_splitk_axpby_nax<CHECK_AB>(
956
+ /* const Stream& s = */ s,
957
+ /* metal::Device& d = */ d,
958
+ /* const array& a = */ a,
959
+ /* const array& b = */ b,
960
+ /* const array& c = */ c,
961
+ /* array& out = */ out,
962
+ /* int M = */ M,
963
+ /* int N = */ N,
964
+ /* int K = */ K,
965
+ /* int batch_size_out = */ batch_size_out,
966
+ /* int lda = */ lda,
967
+ /* int ldb = */ ldb,
968
+ /* bool transpose_a = */ transpose_a,
969
+ /* bool transpose_b = */ transpose_b,
970
+ /* std::vector<array>& copies = */ copies,
971
+ /* float alpha = */ alpha,
972
+ /* float beta = */ beta);
973
+ }
974
+
975
+ /////////////////////////////////////////////////////////////////////////////
976
+ // Regular kernel dispatch
977
+ auto batch_strides = A_batch_stride;
978
+ batch_strides.insert(
979
+ batch_strides.end(), B_batch_stride.begin(), B_batch_stride.end());
980
+ if (CHECK_AB && !C_batch_stride.empty()) {
981
+ batch_strides.insert(
982
+ batch_strides.end(), C_batch_stride.begin(), C_batch_stride.end());
983
+ }
984
+
985
+ int64_t A_batch_stride_ = A_batch_stride.empty() ? 0 : A_batch_stride.back();
986
+ int64_t B_batch_stride_ = B_batch_stride.empty() ? 0 : B_batch_stride.back();
987
+ int64_t C_batch_stride_ = C_batch_stride.empty() ? 0 : C_batch_stride.back();
988
+
989
+ return steel_matmul_regular_axpby<CHECK_AB>(
990
+ /* const Stream& s = */ s,
991
+ /* metal::Device& d = */ d,
992
+ /* const array& a = */ a,
993
+ /* const array& b = */ b,
994
+ /* const array& c = */ c,
995
+ /* array& out = */ out,
996
+ /* int M = */ M,
997
+ /* int N = */ N,
998
+ /* int K = */ K,
999
+ /* int batch_size_out = */ batch_size_out,
1000
+ /* int lda = */ lda,
1001
+ /* int ldb = */ ldb,
1002
+ /* int ldd = */ N,
1003
+ /* bool transpose_a = */ transpose_a,
1004
+ /* bool transpose_b = */ transpose_b,
1005
+ /* std::vector<array>& copies = */ copies,
1006
+ /* Shape batch_shape = */ std::move(batch_shape),
1007
+ /* Strides batch_strides = */ std::move(batch_strides),
1008
+ /* int64_t A_batch_stride = */ A_batch_stride_,
1009
+ /* int64_t B_batch_stride = */ B_batch_stride_,
1010
+ /* int64_t matrix_stride_out = */ int64_t(M) * N,
1011
+ /* int64_t C_batch_stride = */ C_batch_stride_,
1012
+ /* float alpha = */ alpha,
1013
+ /* float beta = */ beta);
1014
+ }
1015
+
1016
+ ///////////////////////////////////////////////////////////////////////////////
1017
+ // GEMV dispatch
1018
+ ///////////////////////////////////////////////////////////////////////////////
1019
+
1020
+ template <bool CHECK_AB = true>
1021
+ void gemv_axbpy(
1022
+ const Stream& s,
1023
+ metal::Device& d,
1024
+ const array& a,
1025
+ const array& b,
1026
+ const array& c,
1027
+ array& out,
1028
+ int M,
1029
+ int N,
1030
+ int K,
1031
+ int batch_size_out,
1032
+ int lda,
1033
+ int ldb,
1034
+ bool transpose_a,
1035
+ bool transpose_b,
1036
+ std::vector<array>& copies,
1037
+ Shape batch_shape = {},
1038
+ Strides A_batch_stride = {},
1039
+ Strides B_batch_stride = {},
1040
+ Strides C_batch_stride = {},
1041
+ float alpha = 1.0f,
1042
+ float beta = 0.0f) {
1043
+ // Collect problem info
1044
+ bool is_b_matrix = N != 1;
1045
+
1046
+ auto& mat = is_b_matrix ? b : a;
1047
+ auto& vec = is_b_matrix ? a : b;
1048
+ bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
1049
+ int in_vector_len = K;
1050
+ int out_vector_len = is_b_matrix ? N : M;
1051
+
1052
+ int mat_ld = is_b_matrix ? ldb : lda;
1053
+
1054
+ auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
1055
+ auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
1056
+
1057
+ // Determine if inputs have simple batching / broadcasting
1058
+ bool contiguous_kernel = (batch_shape.size() == 1);
1059
+
1060
+ int batch_ndim = batch_shape.size();
1061
+
1062
+ // Determine dispatch kernel
1063
+ int tm = 4, tn = 4;
1064
+ int sm = 1, sn = 32;
1065
+ int bm = 1, bn = 1;
1066
+ int n_out_per_tgp;
1067
+ std::ostringstream kname;
1068
+
1069
+ if (transpose_mat) {
1070
+ if (in_vector_len >= 8192 && out_vector_len >= 2048) {
1071
+ sm = 4;
1072
+ sn = 8;
1073
+ } else {
1074
+ sm = 8;
1075
+ sn = 4;
1076
+ }
1077
+
1078
+ if (out_vector_len >= 2048) {
1079
+ bn = 16;
1080
+ } else if (out_vector_len >= 512) {
1081
+ bn = 4;
1082
+ } else {
1083
+ bn = 2;
1084
+ }
1085
+
1086
+ // Specialized kernel for very small outputs
1087
+ tn = out_vector_len < tn ? 1 : tn;
1088
+
1089
+ n_out_per_tgp = bn * sn * tn;
1090
+ kname << "gemv_t_" << type_to_name(out);
1091
+
1092
+ } else {
1093
+ bm = out_vector_len >= 4096 ? 8 : 4;
1094
+ sn = 32;
1095
+
1096
+ if (K <= 64) {
1097
+ bm = 1;
1098
+ sm = 8;
1099
+ sn = 4;
1100
+ } else if (K >= 16 * out_vector_len) {
1101
+ bm = 1;
1102
+ bn = 8;
1103
+ }
1104
+
1105
+ // Specialized kernel for very small outputs
1106
+ tm = out_vector_len < tm ? 1 : tm;
1107
+
1108
+ n_out_per_tgp = bm * sm * tm;
1109
+ kname << "gemv_" << type_to_name(out);
1110
+ }
1111
+
1112
+ const bool do_axpby = CHECK_AB && (alpha != 1.0f || beta != 0.0f);
1113
+
1114
+ // clang-format off
1115
+ kname << "_bm" << bm << "_bn" << bn
1116
+ << "_sm" << sm << "_sn" << sn
1117
+ << "_tm" << tm << "_tn" << tn
1118
+ << "_nc" << !contiguous_kernel
1119
+ << "_axpby" << do_axpby; // clang-format on
1120
+
1121
+ // Encode and dispatch kernel
1122
+ auto& compute_encoder = d.get_command_encoder(s.index);
1123
+ auto kernel = d.get_kernel(kname.str());
1124
+ compute_encoder.set_compute_pipeline_state(kernel);
1125
+
1126
+ int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
1127
+ MTL::Size group_dims = MTL::Size(32, bn, bm);
1128
+ MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
1129
+
1130
+ compute_encoder.set_input_array(mat, 0);
1131
+ compute_encoder.set_input_array(vec, 1);
1132
+ compute_encoder.set_output_array(out, 3);
1133
+
1134
+ compute_encoder.set_bytes(in_vector_len, 4);
1135
+ compute_encoder.set_bytes(out_vector_len, 5);
1136
+ compute_encoder.set_bytes(mat_ld, 6);
1137
+
1138
+ compute_encoder.set_bytes(batch_ndim, 9);
1139
+ compute_encoder.set_vector_bytes(batch_shape, 10);
1140
+ compute_encoder.set_vector_bytes(batch_strides_vec, 11);
1141
+ compute_encoder.set_vector_bytes(batch_strides_mat, 12);
1142
+
1143
+ if (do_axpby) {
1144
+ compute_encoder.set_input_array(c, 2);
1145
+
1146
+ compute_encoder.set_bytes(alpha, 7);
1147
+ compute_encoder.set_bytes(beta, 8);
1148
+
1149
+ compute_encoder.set_vector_bytes(C_batch_stride, 13);
1150
+
1151
+ int bias_stride = c.strides()[c.ndim() - 1];
1152
+ compute_encoder.set_bytes(bias_stride, 14);
1153
+ }
1154
+
1155
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
1156
+
1157
+ d.add_temporaries(std::move(copies), s.index);
1158
+ }
1159
+
1160
+ inline void gemv(
1161
+ const Stream& s,
1162
+ metal::Device& d,
1163
+ const array& a,
1164
+ const array& b,
1165
+ array& out,
1166
+ int M,
1167
+ int N,
1168
+ int K,
1169
+ int batch_size_out,
1170
+ int lda,
1171
+ int ldb,
1172
+ bool transpose_a,
1173
+ bool transpose_b,
1174
+ std::vector<array>& copies,
1175
+ Shape batch_shape = {},
1176
+ Strides A_batch_stride = {},
1177
+ Strides B_batch_stride = {}) {
1178
+ return gemv_axbpy<false>(
1179
+ /* const Stream& s = */ s,
1180
+ /* metal::Device& d = */ d,
1181
+ /* const array& a = */ a,
1182
+ /* const array& b = */ b,
1183
+ /* const array& c = */ b,
1184
+ /* array& out = */ out,
1185
+ /* int M = */ M,
1186
+ /* int N = */ N,
1187
+ /* int K = */ K,
1188
+ /* int batch_size_out = */ batch_size_out,
1189
+ /* int lda = */ lda,
1190
+ /* int ldb = */ ldb,
1191
+ /* bool transpose_a = */ transpose_a,
1192
+ /* bool transpose_b = */ transpose_b,
1193
+ /* std::vector<array>& copies = */ copies,
1194
+ /* Shape batch_shape = */ batch_shape,
1195
+ /* Strides A_batch_stride = */ A_batch_stride,
1196
+ /* Strides B_batch_stride = */ B_batch_stride);
1197
+ }
1198
+
1199
+ ///////////////////////////////////////////////////////////////////////////////
1200
+ // Matmul implementation
1201
+ ///////////////////////////////////////////////////////////////////////////////
1202
+
1203
+ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
1204
+ assert(inputs.size() == 2);
1205
+ if (!issubdtype(out.dtype(), inexact)) {
1206
+ throw std::runtime_error("[matmul] dtype must be inexact.");
1207
+ }
1208
+ auto& s = stream();
1209
+ auto& d = metal::device(s.device);
1210
+
1211
+ auto& a_pre = inputs[0];
1212
+ auto& b_pre = inputs[1];
1213
+ // Return 0s if either input is empty
1214
+ if (a_pre.size() == 0 || b_pre.size() == 0) {
1215
+ array zero = array(0, a_pre.dtype());
1216
+ fill_gpu(zero, out, s);
1217
+ d.add_temporary(std::move(zero), s.index);
1218
+ return;
1219
+ }
1220
+
1221
+ out.set_data(allocator::malloc(out.nbytes()));
1222
+
1223
+ /////////////////////////////////////////////////////////////////////////////
1224
+ // Init checks and prep
1225
+
1226
+ int M = a_pre.shape(-2);
1227
+ int N = b_pre.shape(-1);
1228
+ int K = a_pre.shape(-1);
1229
+
1230
+ // Keep a vector with copies to be cleared in the completed buffer to release
1231
+ // the arrays
1232
+ std::vector<array> copies;
1233
+ auto [a_transposed, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
1234
+ auto [b_transposed, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
1235
+
1236
+ /////////////////////////////////////////////////////////////////////////////
1237
+ // Check and collapse batch dimensions
1238
+
1239
+ auto [batch_shape, A_batch_stride, B_batch_stride] = collapse_batches(a, b);
1240
+
1241
+ auto batch_size_out = out.size() / (size_t(M) * size_t(N));
1242
+
1243
+ // Collapse batches into M if needed
1244
+ if (batch_size_out > 1 && !a_transposed && batch_shape.size() == 1 &&
1245
+ a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
1246
+ B_batch_stride.back() == 0) {
1247
+ M *= batch_shape.back();
1248
+ batch_size_out = 1;
1249
+
1250
+ A_batch_stride = {0};
1251
+ B_batch_stride = {0};
1252
+ batch_shape = {1};
1253
+ }
1254
+
1255
+ /////////////////////////////////////////////////////////////////////////////
1256
+ // Gemv specialization
1257
+
1258
+ // Route to gemv if needed
1259
+ if (std::min(M, N) == 1) {
1260
+ return gemv(
1261
+ /* const Stream& s = */ s,
1262
+ /* metal::Device& d = */ d,
1263
+ /* const array& a = */ a,
1264
+ /* const array& b = */ b,
1265
+ /* array& out = */ out,
1266
+ /* int M = */ M,
1267
+ /* int N = */ N,
1268
+ /* int K = */ K,
1269
+ /* int batch_size_out = */ batch_size_out,
1270
+ /* int lda = */ a_cols,
1271
+ /* int ldb = */ b_cols,
1272
+ /* bool transpose_a = */ a_transposed,
1273
+ /* bool transpose_b = */ b_transposed,
1274
+ /* std::vector<array>& copies = */ copies,
1275
+ /* Shape batch_shape = */ std::move(batch_shape),
1276
+ /* Strides A_batch_stride = */ std::move(A_batch_stride),
1277
+ /* Strides B_batch_stride = */ std::move(B_batch_stride));
1278
+ }
1279
+
1280
+ /////////////////////////////////////////////////////////////////////////////
1281
+ // Gemm specialization
1282
+
1283
+ return steel_matmul(
1284
+ /* const Stream& s = */ s,
1285
+ /* metal::Device& d = */ d,
1286
+ /* const array& a = */ a,
1287
+ /* const array& b = */ b,
1288
+ /* array& out = */ out,
1289
+ /* int M = */ M,
1290
+ /* int N = */ N,
1291
+ /* int K = */ K,
1292
+ /* int batch_size_out = */ batch_size_out,
1293
+ /* int lda = */ a_cols,
1294
+ /* int ldb = */ b_cols,
1295
+ /* bool transpose_a = */ a_transposed,
1296
+ /* bool transpose_b = */ b_transposed,
1297
+ /* std::vector<array>& copies = */ copies,
1298
+ /* Shape batch_shape = */ std::move(batch_shape),
1299
+ /* Strides A_batch_stride = */ std::move(A_batch_stride),
1300
+ /* Strides B_batch_stride = */ std::move(B_batch_stride));
1301
+ }
1302
+
1303
+ ///////////////////////////////////////////////////////////////////////////////
1304
+ // AddMM implementation
1305
+ ///////////////////////////////////////////////////////////////////////////////
1306
+
1307
+ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
1308
+ assert(inputs.size() == 3);
1309
+ if (!issubdtype(out.dtype(), floating)) {
1310
+ throw std::runtime_error(
1311
+ "[matmul] Does not yet support non-floating point types.");
1312
+ }
1313
+
1314
+ // Return 0s if either input is empty
1315
+ if (out.size() == 0) {
1316
+ out.set_data(allocator::malloc(out.nbytes()));
1317
+ return;
1318
+ }
1319
+
1320
+ auto& s = stream();
1321
+ auto& d = metal::device(s.device);
1322
+
1323
+ // Handle empty matrix case (K=0)
1324
+ if (inputs[0].shape(-1) == 0) {
1325
+ auto& c = inputs[2];
1326
+ if (beta_ == 1.0f) {
1327
+ copy_gpu(
1328
+ c,
1329
+ out,
1330
+ c.flags().row_contiguous ? CopyType::Vector : CopyType::General,
1331
+ s);
1332
+ } else {
1333
+ array beta_scalar = array(beta_, c.dtype());
1334
+ binary_op_gpu({c, beta_scalar}, out, "Multiply", s);
1335
+ d.add_temporary(std::move(beta_scalar), s.index);
1336
+ }
1337
+ return;
1338
+ }
1339
+
1340
+ out.set_data(allocator::malloc(out.nbytes()));
1341
+
1342
+ auto& a_pre = inputs[0];
1343
+ auto& b_pre = inputs[1];
1344
+ auto& c_pre = inputs[2];
1345
+
1346
+ /////////////////////////////////////////////////////////////////////////////
1347
+ // Init checks and prep
1348
+
1349
+ int M = a_pre.shape(-2);
1350
+ int N = b_pre.shape(-1);
1351
+ int K = a_pre.shape(-1);
1352
+
1353
+ // Keep a vector with copies to be cleared in the completed buffer to release
1354
+ // the arrays
1355
+ std::vector<array> copies;
1356
+ auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
1357
+ auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
1358
+
1359
+ array c = c_pre;
1360
+
1361
+ int lda = a_cols;
1362
+ int ldb = b_cols;
1363
+
1364
+ /////////////////////////////////////////////////////////////////////////////
1365
+ // Check and collapse batch dimensions
1366
+ auto [batch_shape, A_batch_stride, B_batch_stride, C_batch_stride] =
1367
+ collapse_batches(a, b, c);
1368
+
1369
+ int64_t matrix_stride_out = M * static_cast<int64_t>(N);
1370
+ auto batch_size_out = out.size() / (matrix_stride_out);
1371
+
1372
+ // Collapse batches into M if needed
1373
+ if (batch_size_out > 1 && !transpose_a && batch_shape.size() == 1 &&
1374
+ a.strides()[a.ndim() - 2] == K && A_batch_stride.back() == M * K &&
1375
+ C_batch_stride.back() == M * c.strides()[c.ndim() - 2] &&
1376
+ B_batch_stride.back() == 0) {
1377
+ M *= batch_shape.back();
1378
+ batch_size_out = 1;
1379
+
1380
+ A_batch_stride = {0};
1381
+ B_batch_stride = {0};
1382
+ C_batch_stride = {0};
1383
+ batch_shape = {1};
1384
+ }
1385
+
1386
+ /////////////////////////////////////////////////////////////////////////////
1387
+ // Gemv specialization
1388
+
1389
+ // Route to gemv if needed
1390
+ if (std::min(M, N) == 1) {
1391
+ return gemv_axbpy(
1392
+ /* const Stream& s = */ s,
1393
+ /* metal::Device& d = */ d,
1394
+ /* const array& a = */ a,
1395
+ /* const array& b = */ b,
1396
+ /* const array& c = */ c,
1397
+ /* array& out = */ out,
1398
+ /* int M = */ M,
1399
+ /* int N = */ N,
1400
+ /* int K = */ K,
1401
+ /* int batch_size_out = */ batch_size_out,
1402
+ /* int lda = */ lda,
1403
+ /* int ldb = */ ldb,
1404
+ /* bool transpose_a = */ transpose_a,
1405
+ /* bool transpose_b = */ transpose_b,
1406
+ /* std::vector<array>& copies = */ copies,
1407
+ /* Shape batch_shape = */ batch_shape,
1408
+ /* Strides A_batch_stride = */ A_batch_stride,
1409
+ /* Strides B_batch_stride = */ B_batch_stride,
1410
+ /* Strides C_batch_stride = */ C_batch_stride,
1411
+ /* float alpha = */ alpha_,
1412
+ /* float beta = */ beta_);
1413
+ }
1414
+
1415
+ /////////////////////////////////////////////////////////////////////////////
1416
+ // Regular addmm dispatch
1417
+
1418
+ return steel_matmul_axpby(
1419
+ /* const Stream& s = */ s,
1420
+ /* metal::Device& d = */ d,
1421
+ /* const array& a = */ a,
1422
+ /* const array& b = */ b,
1423
+ /* const array& c = */ c,
1424
+ /* array& out = */ out,
1425
+ /* int M = */ M,
1426
+ /* int N = */ N,
1427
+ /* int K = */ K,
1428
+ /* int batch_size_out = */ batch_size_out,
1429
+ /* int lda = */ lda,
1430
+ /* int ldb = */ ldb,
1431
+ /* bool transpose_a = */ transpose_a,
1432
+ /* bool transpose_b = */ transpose_b,
1433
+ /* std::vector<array>& copies = */ copies,
1434
+ /* Shape batch_shape = */ batch_shape,
1435
+ /* Strides A_batch_stride = */ A_batch_stride,
1436
+ /* Strides B_batch_stride = */ B_batch_stride,
1437
+ /* Strides B_batch_stride = */ C_batch_stride,
1438
+ /* float alpha = */ alpha_,
1439
+ /* float beta = */ beta_);
1440
+ }
1441
+
1442
+ ///////////////////////////////////////////////////////////////////////////////
1443
+ // BlockMaskedMM implementation
1444
+ ///////////////////////////////////////////////////////////////////////////////
1445
+
1446
+ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
1447
+ using namespace mlx::steel;
1448
+ // assert(inputs.size() == 2);
1449
+ if (!issubdtype(out.dtype(), floating)) {
1450
+ throw std::runtime_error(
1451
+ "[matmul] Does not yet support non-floating point types.");
1452
+ }
1453
+ auto& s = stream();
1454
+ auto& d = metal::device(s.device);
1455
+
1456
+ auto& a_pre = inputs[0];
1457
+ auto& b_pre = inputs[1];
1458
+ // Return 0s if either input is empty
1459
+ if (a_pre.size() == 0 || b_pre.size() == 0) {
1460
+ array zero = array(0, a_pre.dtype());
1461
+ fill_gpu(zero, out, s);
1462
+ d.add_temporary(std::move(zero), s.index);
1463
+ return;
1464
+ }
1465
+
1466
+ out.set_data(allocator::malloc(out.nbytes()));
1467
+
1468
+ /////////////////////////////////////////////////////////////////////////////
1469
+ // Init checks and prep
1470
+
1471
+ int M = a_pre.shape(-2);
1472
+ int N = b_pre.shape(-1);
1473
+ int K = a_pre.shape(-1);
1474
+
1475
+ // Keep a vector with copies to be cleared in the completed buffer to release
1476
+ // the arrays
1477
+ std::vector<array> copies;
1478
+ auto [transpose_a, a_cols, a] = check_transpose(copies, s, a_pre, M == 1);
1479
+ auto [transpose_b, b_cols, b] = check_transpose(copies, s, b_pre, N == 1);
1480
+
1481
+ int lda = a_cols;
1482
+ int ldb = b_cols;
1483
+
1484
+ /////////////////////////////////////////////////////////////////////////////
1485
+ // Check and collapse batch dimensions
1486
+
1487
+ bool has_op_mask = inputs.size() > 3;
1488
+ bool has_out_mask = inputs.size() == 3 || inputs.size() == 5;
1489
+
1490
+ // Prepare kernel name
1491
+ std::string out_mask_nm = has_out_mask ? type_to_name(inputs[2]) : "nomask";
1492
+ std::string op_mask_nm = has_op_mask ? type_to_name(inputs.back()) : "nomask";
1493
+
1494
+ Shape batch_shape{1};
1495
+ Strides A_batch_stride{0};
1496
+ Strides B_batch_stride{0};
1497
+ Strides outmask_bstride{0};
1498
+ Strides Amask_bstride{0};
1499
+ Strides Bmask_bstride{0};
1500
+ int64_t A_batch_str = 0;
1501
+ int64_t B_batch_str = 0;
1502
+
1503
+ Strides batch_strides;
1504
+
1505
+ if (out.ndim() > 2) {
1506
+ Shape bshape{out.shape().begin(), out.shape().end() - 2};
1507
+ std::vector<Strides> bstrides;
1508
+
1509
+ for (auto& arr : inputs) {
1510
+ bstrides.emplace_back(arr.strides().begin(), arr.strides().end() - 2);
1511
+ }
1512
+
1513
+ // auto [bshape_c, bstrides_c] = collapse_contiguous_dims(bshape, bstrides);
1514
+ batch_shape = bshape;
1515
+ A_batch_str = bstrides[0].back();
1516
+ B_batch_str = bstrides[1].back();
1517
+
1518
+ for (auto& bstr : bstrides) {
1519
+ batch_strides.insert(batch_strides.end(), bstr.begin(), bstr.end());
1520
+ }
1521
+
1522
+ A_batch_stride = bstrides[0];
1523
+ B_batch_stride = bstrides[1];
1524
+
1525
+ if (has_out_mask) {
1526
+ outmask_bstride = bstrides[2];
1527
+ }
1528
+ if (has_op_mask) {
1529
+ Amask_bstride = bstrides[has_out_mask + 2];
1530
+ Bmask_bstride = bstrides[has_out_mask + 3];
1531
+ }
1532
+
1533
+ } else {
1534
+ batch_strides = Strides(inputs.size(), 0);
1535
+ }
1536
+
1537
+ int64_t matrix_stride_out = static_cast<int64_t>(M) * N;
1538
+ size_t batch_size_out = out.size() / (matrix_stride_out);
1539
+
1540
+ /////////////////////////////////////////////////////////////////////////////
1541
+ // Gemv specialization
1542
+
1543
+ // Route to gemv if needed
1544
+ if (std::min(M, N) == 1) {
1545
+ // Collect problem info
1546
+ bool is_b_matrix = N != 1;
1547
+
1548
+ auto& mat = is_b_matrix ? b : a;
1549
+ auto& vec = is_b_matrix ? a : b;
1550
+ bool transpose_mat = is_b_matrix ? !transpose_b : transpose_a;
1551
+ int in_vector_len = K;
1552
+ int out_vector_len = is_b_matrix ? N : M;
1553
+
1554
+ int mat_ld = is_b_matrix ? b_cols : a_cols;
1555
+
1556
+ auto batch_strides_mat = is_b_matrix ? B_batch_stride : A_batch_stride;
1557
+ auto batch_strides_vec = is_b_matrix ? A_batch_stride : B_batch_stride;
1558
+
1559
+ auto mask_bstrides_mat = is_b_matrix ? Bmask_bstride : Amask_bstride;
1560
+ auto mask_bstrides_vec = is_b_matrix ? Amask_bstride : Bmask_bstride;
1561
+
1562
+ auto mat_mask_idx = int(has_out_mask) + (is_b_matrix ? 3 : 2);
1563
+ auto vec_mask_idx = int(has_out_mask) + (is_b_matrix ? 2 : 3);
1564
+
1565
+ // Determine if inputs have simple batching / broadcasting
1566
+ bool contiguous_kernel = (batch_shape.size() == 1);
1567
+
1568
+ int batch_ndim = batch_shape.size();
1569
+
1570
+ // Determine dispatch kernel
1571
+ int tm = 4, tn = 4;
1572
+ int sm = 1, sn = 32;
1573
+ int bm = 1, bn = 1;
1574
+ int n_out_per_tgp;
1575
+ std::ostringstream kname;
1576
+
1577
+ if (transpose_mat) {
1578
+ sm = 8;
1579
+ sn = 4;
1580
+ bm = 1;
1581
+ bn = (block_size_ == 64 && out_vector_len >= 2048) ? 4 : 2;
1582
+ tm = block_size_ == 32 ? 4 : 8;
1583
+ tn = 4;
1584
+
1585
+ // Specialized kernel for very small outputs
1586
+ tn = out_vector_len < tn ? 1 : tn;
1587
+
1588
+ n_out_per_tgp = bn * sn * tn;
1589
+ kname << "gemv_t";
1590
+
1591
+ } else {
1592
+ if (block_size_ == 32) {
1593
+ sm = 4;
1594
+ sn = 8;
1595
+ bm = 2;
1596
+ } else {
1597
+ sm = 2;
1598
+ sn = 16;
1599
+ bm = out_vector_len >= 512 ? 4 : 2;
1600
+ }
1601
+
1602
+ // Specialized kernel for very small outputs
1603
+ tm = out_vector_len < tm ? 1 : tm;
1604
+
1605
+ n_out_per_tgp = bm * sm * tm;
1606
+ kname << "gemv";
1607
+ }
1608
+
1609
+ kname << "_outmask_" << out_mask_nm;
1610
+ kname << "_opmask_" << op_mask_nm;
1611
+ kname << "_" << type_to_name(out);
1612
+ kname << "_bm" << bm << "_bn" << bn;
1613
+ kname << "_sm" << sm << "_sn" << sn;
1614
+ kname << "_tm" << tm << "_tn" << tn;
1615
+ kname << "_nc" << !contiguous_kernel;
1616
+
1617
+ // Encode and dispatch kernel
1618
+ auto kernel = get_gemv_masked_kernel(
1619
+ d,
1620
+ kname.str(),
1621
+ out,
1622
+ has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
1623
+ has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
1624
+ transpose_mat,
1625
+ bm,
1626
+ bn,
1627
+ sm,
1628
+ sn,
1629
+ tm,
1630
+ tn,
1631
+ contiguous_kernel);
1632
+
1633
+ auto& compute_encoder = d.get_command_encoder(s.index);
1634
+ compute_encoder.set_compute_pipeline_state(kernel);
1635
+
1636
+ int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
1637
+ MTL::Size group_dims = MTL::Size(32, bn, bm);
1638
+ MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
1639
+
1640
+ // Get mask params
1641
+ std::vector<int> mask_strides;
1642
+ Strides mask_batch_strides;
1643
+ if (has_out_mask) {
1644
+ auto& out_mask = inputs[2];
1645
+
1646
+ if (transpose_mat) {
1647
+ mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -1 : -2));
1648
+ mask_strides.push_back(out_mask.strides(out.shape(-2) == 1 ? -2 : -1));
1649
+ } else {
1650
+ mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -1 : -2));
1651
+ mask_strides.push_back(out_mask.strides(out.shape(-1) == 1 ? -2 : -1));
1652
+ }
1653
+
1654
+ mask_batch_strides.insert(
1655
+ mask_batch_strides.end(),
1656
+ outmask_bstride.begin(),
1657
+ outmask_bstride.end());
1658
+
1659
+ compute_encoder.set_input_array(out_mask, 20);
1660
+ }
1661
+
1662
+ if (has_op_mask) {
1663
+ auto& mat_mask = inputs[mat_mask_idx];
1664
+
1665
+ if (transpose_mat) {
1666
+ mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -2 : -1));
1667
+ mask_strides.push_back(mat_mask.strides(!is_b_matrix ? -1 : -2));
1668
+ } else {
1669
+ mask_strides.push_back(mat_mask.strides(is_b_matrix ? -2 : -1));
1670
+ mask_strides.push_back(mat_mask.strides(is_b_matrix ? -1 : -2));
1671
+ }
1672
+
1673
+ mask_batch_strides.insert(
1674
+ mask_batch_strides.end(),
1675
+ mask_bstrides_mat.begin(),
1676
+ mask_bstrides_mat.end());
1677
+
1678
+ compute_encoder.set_input_array(mat_mask, 21);
1679
+
1680
+ auto& vec_mask = inputs[vec_mask_idx];
1681
+ if (transpose_mat) {
1682
+ mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -1 : -2));
1683
+ mask_strides.push_back(vec_mask.strides(vec.shape(-2) == 1 ? -2 : -1));
1684
+ } else {
1685
+ mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -1 : -2));
1686
+ mask_strides.push_back(vec_mask.strides(vec.shape(-1) == 1 ? -2 : -1));
1687
+ }
1688
+
1689
+ mask_batch_strides.insert(
1690
+ mask_batch_strides.end(),
1691
+ mask_bstrides_vec.begin(),
1692
+ mask_bstrides_vec.end());
1693
+
1694
+ compute_encoder.set_input_array(vec_mask, 22);
1695
+ }
1696
+
1697
+ // Get gemv params
1698
+ compute_encoder.set_input_array(mat, 0);
1699
+ compute_encoder.set_input_array(vec, 1);
1700
+ compute_encoder.set_output_array(out, 3);
1701
+
1702
+ compute_encoder.set_bytes(in_vector_len, 4);
1703
+ compute_encoder.set_bytes(out_vector_len, 5);
1704
+ compute_encoder.set_bytes(mat_ld, 6);
1705
+ compute_encoder.set_bytes(batch_ndim, 9);
1706
+ compute_encoder.set_vector_bytes(batch_shape, 10);
1707
+ compute_encoder.set_vector_bytes(batch_strides_vec, 11);
1708
+ compute_encoder.set_vector_bytes(batch_strides_mat, 12);
1709
+
1710
+ compute_encoder.set_vector_bytes(mask_strides, 23);
1711
+ compute_encoder.set_vector_bytes(mask_batch_strides, 24);
1712
+
1713
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
1714
+
1715
+ d.add_temporaries(std::move(copies), s.index);
1716
+ return;
1717
+ }
1718
+
1719
+ /////////////////////////////////////////////////////////////////////////////
1720
+ // Regular kernel dispatch
1721
+
1722
+ // Determine dispatch kernel
1723
+ int bm = block_size_, bn = block_size_, bk = 16;
1724
+ int wm = 2, wn = 2;
1725
+ bool mn_aligned = M % bm == 0 && N % bn == 0;
1726
+ bool k_aligned = K % bk == 0;
1727
+
1728
+ std::ostringstream kname;
1729
+ kname << "steel_gemm_block_outmask_" << out_mask_nm << "_opmask_"
1730
+ << op_mask_nm << "_" << (transpose_a ? 't' : 'n')
1731
+ << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
1732
+ << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
1733
+ << "_wm" << wm << "_wn" << wn << "_MN_" << (mn_aligned ? "t" : "n")
1734
+ << "aligned"
1735
+ << "_K_" << (k_aligned ? "t" : "n") << "aligned";
1736
+
1737
+ // Encode and dispatch kernel
1738
+ auto& compute_encoder = d.get_command_encoder(s.index);
1739
+ auto kernel = get_steel_gemm_masked_kernel(
1740
+ d,
1741
+ kname.str(),
1742
+ out,
1743
+ has_out_mask ? std::optional<array>{inputs[2]} : std::nullopt,
1744
+ has_op_mask ? std::optional<array>{inputs.back()} : std::nullopt,
1745
+ transpose_a,
1746
+ transpose_b,
1747
+ bm,
1748
+ bn,
1749
+ bk,
1750
+ wm,
1751
+ wn,
1752
+ mn_aligned,
1753
+ k_aligned);
1754
+ compute_encoder.set_compute_pipeline_state(kernel);
1755
+
1756
+ // Use problem size to determine threadblock swizzle
1757
+ int tn = (N + bn - 1) / bn;
1758
+ int tm = (M + bm - 1) / bm;
1759
+
1760
+ // TODO: Explore device-based tuning for swizzle
1761
+ int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2);
1762
+
1763
+ // Prepare steel matmul params
1764
+ GEMMParams params{/* const int M = */ M,
1765
+ /* const int N = */ N,
1766
+ /* const int K = */ K,
1767
+ /* const int lda = */ lda,
1768
+ /* const int ldb = */ ldb,
1769
+ /* const int ldd = */ N,
1770
+ /* const int tiles_n = */ tn,
1771
+ /* const int tiles_m = */ tm,
1772
+ /* const int64_t batch_stride_a = */ A_batch_str,
1773
+ /* const int64_t batch_stride_b = */ B_batch_str,
1774
+ /* const int64_t batch_stride_d = */ matrix_stride_out,
1775
+ /* const int swizzle_log = */ swizzle_log,
1776
+ /* const int gemm_k_iterations_aligned = */ (K / bk),
1777
+ /* const int batch_ndim = */ int(batch_shape.size())};
1778
+
1779
+ // Prepare launch grid params
1780
+ int tile = 1 << swizzle_log;
1781
+ tm = (tm + tile - 1) / tile;
1782
+ tn = tn * tile;
1783
+
1784
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
1785
+ MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out);
1786
+
1787
+ std::vector<int> mask_strides;
1788
+
1789
+ if (has_out_mask) {
1790
+ auto& out_mask = inputs[2];
1791
+ mask_strides.push_back(*(out_mask.strides().end() - 1));
1792
+ mask_strides.push_back(*(out_mask.strides().end() - 2));
1793
+
1794
+ compute_encoder.set_input_array(out_mask, 10);
1795
+ }
1796
+
1797
+ if (has_op_mask) {
1798
+ auto& lhs_mask = inputs[2 + has_out_mask];
1799
+ mask_strides.push_back(*(lhs_mask.strides().end() - 1));
1800
+ mask_strides.push_back(*(lhs_mask.strides().end() - 2));
1801
+
1802
+ compute_encoder.set_input_array(lhs_mask, 11);
1803
+
1804
+ auto& rhs_mask = inputs[3 + has_out_mask];
1805
+ mask_strides.push_back(*(rhs_mask.strides().end() - 1));
1806
+ mask_strides.push_back(*(rhs_mask.strides().end() - 2));
1807
+
1808
+ compute_encoder.set_input_array(rhs_mask, 12);
1809
+ }
1810
+
1811
+ // Launch kernel
1812
+ compute_encoder.set_input_array(a, 0);
1813
+ compute_encoder.set_input_array(b, 1);
1814
+ compute_encoder.set_output_array(out, 3);
1815
+
1816
+ compute_encoder.set_bytes(params, 4);
1817
+
1818
+ compute_encoder.set_vector_bytes(batch_shape, 6);
1819
+ compute_encoder.set_vector_bytes(batch_strides, 7);
1820
+
1821
+ compute_encoder.set_vector_bytes(mask_strides, 13);
1822
+
1823
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
1824
+
1825
+ d.add_temporaries(std::move(copies), s.index);
1826
+ }
1827
+
1828
+ ///////////////////////////////////////////////////////////////////////////////
1829
+ // GatherMM implementation
1830
+ ///////////////////////////////////////////////////////////////////////////////
1831
+
1832
+ void gather_mm_rhs(
1833
+ const array& a_,
1834
+ const array& b_,
1835
+ const array& indices_,
1836
+ array& out,
1837
+ metal::Device& d,
1838
+ const Stream& s) {
1839
+ array indices = ensure_row_contiguous(indices_, d, s);
1840
+ auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
1841
+
1842
+ // Broadcast a with indices. If we are here that means lhs_indices were not
1843
+ // provided so the lhs_indices are implied to be the shape of a broadcasted
1844
+ // with rhs_indices. We need only broadcast a and copy it as if applying the
1845
+ // lhs_indices.
1846
+ auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
1847
+ if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
1848
+ return ensure_row_contiguous(x, d, s);
1849
+ }
1850
+
1851
+ auto x_shape = indices.shape();
1852
+ x_shape.push_back(x.shape(-2));
1853
+ x_shape.push_back(x.shape(-1));
1854
+ array new_x(std::move(x_shape), x.dtype(), nullptr, {});
1855
+ broadcast(x, new_x);
1856
+ return ensure_row_contiguous(new_x, d, s);
1857
+ };
1858
+ array a = broadcast_with_indices(a_);
1859
+
1860
+ // Extract the matmul shapes
1861
+ int K = a.shape(-1);
1862
+ int M = a.size() / K;
1863
+ int N = b.shape(-1);
1864
+ int lda = a.strides()[a.ndim() - 2]; // should be K
1865
+
1866
+ // Define the dispatch blocks
1867
+ int bm = 16, bn = 64, bk = 16;
1868
+ int wm = 1, wn = 2;
1869
+
1870
+ const bool align_M = (M % bm) == 0;
1871
+ const bool align_N = (N % bn) == 0;
1872
+ const bool align_K = (K % bk) == 0;
1873
+
1874
+ // Define the kernel name
1875
+ std::string base_name;
1876
+ base_name.reserve(64);
1877
+ concatenate(
1878
+ base_name,
1879
+ "steel_gather_mm_rhs_n",
1880
+ transpose_b ? 't' : 'n',
1881
+ '_',
1882
+ type_to_name(a),
1883
+ '_',
1884
+ type_to_name(out),
1885
+ "_bm",
1886
+ bm,
1887
+ "_bn",
1888
+ bn,
1889
+ "_bk",
1890
+ bk,
1891
+ "_wm",
1892
+ wm,
1893
+ "_wn",
1894
+ wn);
1895
+
1896
+ metal::MTLFCList func_consts = {
1897
+ {&align_M, MTL::DataType::DataTypeBool, 200},
1898
+ {&align_N, MTL::DataType::DataTypeBool, 201},
1899
+ {&align_K, MTL::DataType::DataTypeBool, 202},
1900
+ };
1901
+
1902
+ // And the kernel hash that includes the function constants
1903
+ std::string hash_name;
1904
+ hash_name.reserve(128);
1905
+ concatenate(
1906
+ hash_name,
1907
+ base_name,
1908
+ "_align_M_",
1909
+ align_M ? 't' : 'n',
1910
+ "_align_N_",
1911
+ align_N ? 't' : 'n',
1912
+ "_align_K_",
1913
+ align_K ? 't' : 'n');
1914
+
1915
+ // Get and set the kernel
1916
+ auto& compute_encoder = d.get_command_encoder(s.index);
1917
+ auto kernel = get_steel_gemm_gather_kernel(
1918
+ d,
1919
+ base_name,
1920
+ hash_name,
1921
+ func_consts,
1922
+ out,
1923
+ false,
1924
+ transpose_b,
1925
+ bm,
1926
+ bn,
1927
+ bk,
1928
+ wm,
1929
+ wn,
1930
+ true);
1931
+ compute_encoder.set_compute_pipeline_state(kernel);
1932
+
1933
+ // Prepare the matmul params
1934
+ auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
1935
+ steel::GEMMParams params{
1936
+ /* const int M = */ M,
1937
+ /* const int N = */ N,
1938
+ /* const int K = */ K,
1939
+ /* const int lda = */ lda,
1940
+ /* const int ldb = */ static_cast<int>(ldb),
1941
+ /* const int ldd = */ N,
1942
+ /* const int tiles_n = */ (N + bn - 1) / bn,
1943
+ /* const int tiles_m = */ (M + bm - 1) / bm,
1944
+ /* const int64_t batch_stride_a = */ 0,
1945
+ /* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
1946
+ /* const int64_t batch_stride_d = */ 0,
1947
+ /* const int swizzle_log = */ 0,
1948
+ /* const int gemm_k_iterations_aligned = */ (K / bk),
1949
+ /* const int batch_ndim = */ 0};
1950
+
1951
+ // Prepare the grid
1952
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
1953
+ MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
1954
+
1955
+ // Launch kernel
1956
+ compute_encoder.set_input_array(a, 0);
1957
+ compute_encoder.set_input_array(b, 1);
1958
+ compute_encoder.set_input_array(indices, 2);
1959
+ compute_encoder.set_output_array(out, 3);
1960
+ compute_encoder.set_bytes(params, 4);
1961
+
1962
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
1963
+ }
1964
+
1965
+ void gather_mm_rhs_nax(
1966
+ const array& a_,
1967
+ const array& b_,
1968
+ const array& indices_,
1969
+ array& out,
1970
+ metal::Device& d,
1971
+ const Stream& s) {
1972
+ array indices = ensure_row_contiguous(indices_, d, s);
1973
+ auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s);
1974
+
1975
+ // Broadcast a with indices. If we are here that means lhs_indices were not
1976
+ // provided so the lhs_indices are implied to be the shape of a broadcasted
1977
+ // with rhs_indices. We need only broadcast a and copy it as if applying the
1978
+ // lhs_indices.
1979
+ auto broadcast_with_indices = [&d, &s, &indices](const array& x) {
1980
+ if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) {
1981
+ return ensure_row_contiguous(x, d, s);
1982
+ }
1983
+
1984
+ auto x_shape = indices.shape();
1985
+ x_shape.push_back(x.shape(-2));
1986
+ x_shape.push_back(x.shape(-1));
1987
+ array new_x(std::move(x_shape), x.dtype(), nullptr, {});
1988
+ broadcast(x, new_x);
1989
+ return ensure_row_contiguous(new_x, d, s);
1990
+ };
1991
+ array a = broadcast_with_indices(a_);
1992
+
1993
+ // Extract the matmul shapes
1994
+ int K = a.shape(-1);
1995
+ int M = a.size() / K;
1996
+ int N = b.shape(-1);
1997
+ int lda = a.strides()[a.ndim() - 2]; // should be K
1998
+ int E = b.shape(0);
1999
+
2000
+ // Define the dispatch blocks
2001
+ int bm, bn = 128, bk = 128, wm, wn = 4;
2002
+ if (M / E > 48) {
2003
+ bm = 64;
2004
+ wm = 2;
2005
+ } else if (M / E > 24) {
2006
+ bm = 32l;
2007
+ wm = 1;
2008
+ } else {
2009
+ bm = 16;
2010
+ wm = 1;
2011
+ }
2012
+
2013
+ const bool align_M = (M % bm) == 0;
2014
+ const bool align_N = (N % bn) == 0;
2015
+ const bool align_K = (K % bk) == 0;
2016
+
2017
+ // Define the kernel name
2018
+ std::string base_name;
2019
+ base_name.reserve(64);
2020
+ concatenate(
2021
+ base_name,
2022
+ "steel_gather_mm_rhs_nax_n",
2023
+ transpose_b ? 't' : 'n',
2024
+ '_',
2025
+ type_to_name(a),
2026
+ '_',
2027
+ type_to_name(out),
2028
+ "_bm",
2029
+ bm,
2030
+ "_bn",
2031
+ bn,
2032
+ "_bk",
2033
+ bk,
2034
+ "_wm",
2035
+ wm,
2036
+ "_wn",
2037
+ wn);
2038
+
2039
+ metal::MTLFCList func_consts = {
2040
+ {&align_M, MTL::DataType::DataTypeBool, 200},
2041
+ {&align_N, MTL::DataType::DataTypeBool, 201},
2042
+ {&align_K, MTL::DataType::DataTypeBool, 202},
2043
+ };
2044
+
2045
+ // And the kernel hash that includes the function constants
2046
+ std::string hash_name;
2047
+ hash_name.reserve(128);
2048
+ concatenate(
2049
+ hash_name,
2050
+ base_name,
2051
+ "_align_M_",
2052
+ align_M ? 't' : 'n',
2053
+ "_align_N_",
2054
+ align_N ? 't' : 'n',
2055
+ "_align_K_",
2056
+ align_K ? 't' : 'n');
2057
+
2058
+ // Get and set the kernel
2059
+ auto& compute_encoder = d.get_command_encoder(s.index);
2060
+ auto kernel = get_steel_gemm_gather_nax_kernel(
2061
+ d,
2062
+ base_name,
2063
+ hash_name,
2064
+ func_consts,
2065
+ out,
2066
+ false,
2067
+ transpose_b,
2068
+ bm,
2069
+ bn,
2070
+ bk,
2071
+ wm,
2072
+ wn,
2073
+ true);
2074
+ compute_encoder.set_compute_pipeline_state(kernel);
2075
+
2076
+ // Prepare the matmul params
2077
+ auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size();
2078
+ steel::GEMMParams params{
2079
+ /* const int M = */ M,
2080
+ /* const int N = */ N,
2081
+ /* const int K = */ K,
2082
+ /* const int lda = */ lda,
2083
+ /* const int ldb = */ static_cast<int>(ldb),
2084
+ /* const int ldd = */ N,
2085
+ /* const int tiles_n = */ (N + bn - 1) / bn,
2086
+ /* const int tiles_m = */ (M + bm - 1) / bm,
2087
+ /* const int64_t batch_stride_a = */ 0,
2088
+ /* const int64_t batch_stride_b = */ static_cast<int64_t>(batch_stride_b),
2089
+ /* const int64_t batch_stride_d = */ 0,
2090
+ /* const int swizzle_log = */ 0,
2091
+ /* const int gemm_k_iterations_aligned = */ (K / bk),
2092
+ /* const int batch_ndim = */ 0};
2093
+
2094
+ // Prepare the grid
2095
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
2096
+ MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1);
2097
+
2098
+ // Launch kernel
2099
+ compute_encoder.set_input_array(a, 0);
2100
+ compute_encoder.set_input_array(b, 1);
2101
+ compute_encoder.set_input_array(indices, 2);
2102
+ compute_encoder.set_output_array(out, 3);
2103
+ compute_encoder.set_bytes(params, 4);
2104
+
2105
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
2106
+ }
2107
+
2108
+ void gather_mv(
2109
+ const array& mat_,
2110
+ const array& vec_,
2111
+ const array& mat_indices_,
2112
+ const array& vec_indices_,
2113
+ array& out,
2114
+ int N,
2115
+ int K,
2116
+ bool is_mv,
2117
+ metal::Device& d,
2118
+ const Stream& s) {
2119
+ // Copy if needed
2120
+ std::vector<array> copies;
2121
+ auto [transpose_mat, mat_cols, mat] =
2122
+ check_transpose(copies, s, mat_, N == 1);
2123
+ auto [transpose_vec, vec_cols, vec] = check_transpose(copies, s, vec_, true);
2124
+ d.add_temporaries(std::move(copies), s.index);
2125
+
2126
+ // If we are doing vector matrix instead of matrix vector we need to flip the
2127
+ // matrix transposition. Basically m @ v = v @ m.T assuming that v is treated
2128
+ // as a one dimensional array.
2129
+ transpose_mat = (!is_mv) ^ transpose_mat;
2130
+
2131
+ // Define some shapes
2132
+ int in_vector_len = K;
2133
+ int out_vector_len = N;
2134
+ int mat_ld = mat_cols;
2135
+
2136
+ int batch_size_out = out.size() / N;
2137
+ int batch_ndim = out.ndim() - 2;
2138
+ int batch_ndim_mat = mat.ndim() - 2;
2139
+ int batch_ndim_vec = vec.ndim() - 2;
2140
+ Strides index_strides = vec_indices_.strides();
2141
+ index_strides.insert(
2142
+ index_strides.end(),
2143
+ mat_indices_.strides().begin(),
2144
+ mat_indices_.strides().end());
2145
+
2146
+ // Determine dispatch kernel
2147
+ int tm = 4, tn = 4;
2148
+ int sm = 1, sn = 32;
2149
+ int bm = 1, bn = 1;
2150
+ int n_out_per_tgp;
2151
+ std::ostringstream kname;
2152
+
2153
+ if (transpose_mat) {
2154
+ if (in_vector_len >= 8192 && out_vector_len >= 2048) {
2155
+ sm = 4;
2156
+ sn = 8;
2157
+ } else {
2158
+ sm = 8;
2159
+ sn = 4;
2160
+ }
2161
+
2162
+ if (out_vector_len >= 2048) {
2163
+ bn = 16;
2164
+ } else if (out_vector_len >= 512) {
2165
+ bn = 4;
2166
+ } else {
2167
+ bn = 2;
2168
+ }
2169
+
2170
+ // Specialized kernel for very small outputs
2171
+ tn = out_vector_len < tn ? 1 : tn;
2172
+
2173
+ n_out_per_tgp = bn * sn * tn;
2174
+ kname << "gemv_t_gather_" << type_to_name(out);
2175
+
2176
+ } else {
2177
+ bm = out_vector_len >= 4096 ? 8 : 4;
2178
+ sn = 32;
2179
+
2180
+ // Specialized kernel for very small outputs
2181
+ tm = out_vector_len < tm ? 1 : tm;
2182
+
2183
+ n_out_per_tgp = bm * sm * tm;
2184
+ kname << "gemv_gather_" << type_to_name(out);
2185
+ }
2186
+
2187
+ kname << "_bm" << bm << "_bn" << bn << "_sm" << sm << "_sn" << sn << "_tm"
2188
+ << tm << "_tn" << tn;
2189
+
2190
+ // Encode and dispatch kernel
2191
+ auto& compute_encoder = d.get_command_encoder(s.index);
2192
+ auto kernel = d.get_kernel(kname.str());
2193
+ compute_encoder.set_compute_pipeline_state(kernel);
2194
+
2195
+ int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp;
2196
+ MTL::Size group_dims = MTL::Size(32, bn, bm);
2197
+ MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out);
2198
+
2199
+ compute_encoder.set_input_array(mat, 0);
2200
+ compute_encoder.set_input_array(vec, 1);
2201
+ compute_encoder.set_output_array(out, 3);
2202
+
2203
+ compute_encoder.set_bytes(in_vector_len, 4);
2204
+ compute_encoder.set_bytes(out_vector_len, 5);
2205
+ compute_encoder.set_bytes(mat_ld, 6);
2206
+
2207
+ compute_encoder.set_bytes(batch_ndim, 9);
2208
+ compute_encoder.set_vector_bytes(out.shape(), 10);
2209
+ compute_encoder.set_vector_bytes(index_strides, 11);
2210
+
2211
+ compute_encoder.set_bytes(batch_ndim_vec, 12);
2212
+ compute_encoder.set_vector_bytes(vec.shape(), 13);
2213
+ compute_encoder.set_vector_bytes(vec.strides(), 14);
2214
+
2215
+ compute_encoder.set_bytes(batch_ndim_mat, 15);
2216
+ compute_encoder.set_vector_bytes(mat.shape(), 16);
2217
+ compute_encoder.set_vector_bytes(mat.strides(), 17);
2218
+
2219
+ compute_encoder.set_input_array(vec_indices_, 18);
2220
+ compute_encoder.set_input_array(mat_indices_, 19);
2221
+
2222
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
2223
+ }
2224
+
2225
+ void gather_mm(
2226
+ const array& a_,
2227
+ const array& b_,
2228
+ const array& lhs_indices,
2229
+ const array& rhs_indices,
2230
+ array& out,
2231
+ int M,
2232
+ int N,
2233
+ int K,
2234
+ metal::Device& d,
2235
+ const Stream& s) {
2236
+ // Copy if needed
2237
+ std::vector<array> copies;
2238
+ auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
2239
+ auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
2240
+ d.add_temporaries(std::move(copies), s.index);
2241
+
2242
+ // Determine dispatch kernel
2243
+ int bm = 64, bn = 64, bk = 16;
2244
+ int wm = 2, wn = 2;
2245
+ size_t batch_size_out = out.size() / M / N;
2246
+ int batch_ndim = out.ndim() - 2;
2247
+ int batch_ndim_a = a.ndim() - 2;
2248
+ int batch_ndim_b = b.ndim() - 2;
2249
+
2250
+ char devc = d.get_architecture().back();
2251
+ GEMM_TPARAM_MACRO(devc)
2252
+
2253
+ const bool has_batch = batch_ndim > 1;
2254
+ const bool align_M = (M % bm) == 0;
2255
+ const bool align_N = (N % bn) == 0;
2256
+ const bool align_K = (K % bk) == 0;
2257
+
2258
+ // Define the kernel name
2259
+ std::string base_name;
2260
+ base_name.reserve(128);
2261
+ concatenate(
2262
+ base_name,
2263
+ "steel_gather_mm_",
2264
+ transpose_a ? 't' : 'n',
2265
+ transpose_b ? 't' : 'n',
2266
+ "_",
2267
+ type_to_name(a),
2268
+ "_",
2269
+ type_to_name(out),
2270
+ "_bm",
2271
+ bm,
2272
+ "_bn",
2273
+ bn,
2274
+ "_bk",
2275
+ bk,
2276
+ "_wm",
2277
+ wm,
2278
+ "_wn",
2279
+ wn);
2280
+
2281
+ metal::MTLFCList func_consts = {
2282
+ {&has_batch, MTL::DataType::DataTypeBool, 10},
2283
+ {&align_M, MTL::DataType::DataTypeBool, 200},
2284
+ {&align_N, MTL::DataType::DataTypeBool, 201},
2285
+ {&align_K, MTL::DataType::DataTypeBool, 202},
2286
+ };
2287
+
2288
+ // And the kernel hash that includes the function constants
2289
+ std::string hash_name;
2290
+ hash_name.reserve(128);
2291
+ concatenate(
2292
+ hash_name,
2293
+ base_name,
2294
+ "_has_batch_",
2295
+ has_batch ? 't' : 'n',
2296
+ "_align_M_",
2297
+ align_M ? 't' : 'n',
2298
+ "_align_N_",
2299
+ align_N ? 't' : 'n',
2300
+ "_align_K_",
2301
+ align_K ? 't' : 'n');
2302
+
2303
+ // Get and set the kernel
2304
+ auto& compute_encoder = d.get_command_encoder(s.index);
2305
+ auto kernel = get_steel_gemm_gather_kernel(
2306
+ d,
2307
+ base_name,
2308
+ hash_name,
2309
+ func_consts,
2310
+ out,
2311
+ transpose_a,
2312
+ transpose_b,
2313
+ bm,
2314
+ bn,
2315
+ bk,
2316
+ wm,
2317
+ wn,
2318
+ false);
2319
+ compute_encoder.set_compute_pipeline_state(kernel);
2320
+
2321
+ // Prepare the matmul params
2322
+ steel::GEMMParams params{/* const int M = */ M,
2323
+ /* const int N = */ N,
2324
+ /* const int K = */ K,
2325
+ /* const int lda = */ static_cast<int>(lda),
2326
+ /* const int ldb = */ static_cast<int>(ldb),
2327
+ /* const int ldd = */ N,
2328
+ /* const int tiles_n = */ (N + bn - 1) / bn,
2329
+ /* const int tiles_m = */ (M + bm - 1) / bm,
2330
+ /* const int64_t batch_stride_a = */
2331
+ (batch_ndim > 0) ? lhs_indices.strides()[0] : 0,
2332
+ /* const int64_t batch_stride_b = */
2333
+ (batch_ndim > 0) ? rhs_indices.strides()[0] : 0,
2334
+ /* const int64_t batch_stride_d = */ M * N,
2335
+ /* const int swizzle_log = */ 0,
2336
+ /* const int gemm_k_iterations_aligned = */ (K / bk),
2337
+ /* const int batch_ndim = */ batch_ndim};
2338
+
2339
+ // Prepare the grid
2340
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
2341
+ MTL::Size grid_dims =
2342
+ MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
2343
+
2344
+ // Launch kernel
2345
+ compute_encoder.set_input_array(a, 0);
2346
+ compute_encoder.set_input_array(b, 1);
2347
+ compute_encoder.set_input_array(lhs_indices, 2);
2348
+ compute_encoder.set_input_array(rhs_indices, 3);
2349
+ compute_encoder.set_output_array(out, 4);
2350
+ compute_encoder.set_bytes(params, 5);
2351
+ compute_encoder.set_vector_bytes(lhs_indices.shape(), 6);
2352
+ compute_encoder.set_vector_bytes(lhs_indices.strides(), 7);
2353
+ compute_encoder.set_vector_bytes(rhs_indices.strides(), 8);
2354
+ compute_encoder.set_bytes(batch_ndim_a, 9);
2355
+ compute_encoder.set_vector_bytes(a.shape(), 10);
2356
+ compute_encoder.set_vector_bytes(a.strides(), 11);
2357
+ compute_encoder.set_bytes(batch_ndim_b, 12);
2358
+ compute_encoder.set_vector_bytes(b.shape(), 13);
2359
+ compute_encoder.set_vector_bytes(b.strides(), 14);
2360
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
2361
+ }
2362
+
2363
+ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
2364
+ auto& s = stream();
2365
+ auto& d = metal::device(s.device);
2366
+
2367
+ auto& a = inputs[0];
2368
+ auto& b = inputs[1];
2369
+ auto& lhs_indices = inputs[2];
2370
+ auto& rhs_indices = inputs[3];
2371
+
2372
+ // Return 0s if either input is empty
2373
+ if (a.size() == 0 || b.size() == 0) {
2374
+ array zero = array(0, a.dtype());
2375
+ fill_gpu(zero, out, s);
2376
+ d.add_temporary(std::move(zero), s.index);
2377
+ return;
2378
+ }
2379
+
2380
+ out.set_data(allocator::malloc(out.nbytes()));
2381
+
2382
+ // Extract shapes from inputs.
2383
+ int M = a.shape(-2);
2384
+ int N = b.shape(-1);
2385
+ int K = a.shape(-1);
2386
+
2387
+ // We are walking a in order and b is also in order so we can batch up the
2388
+ // matmuls and reuse reading a and b.
2389
+ if (M == 1 && right_sorted_ == true) {
2390
+ if (metal::is_nax_available() &&
2391
+ (env::enable_tf32() || a.dtype() != float32)) {
2392
+ return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s);
2393
+ }
2394
+ gather_mm_rhs(a, b, rhs_indices, out, d, s);
2395
+ return;
2396
+ }
2397
+
2398
+ // Route to gather gemv if any of a or b are vectors
2399
+ if (M == 1) {
2400
+ gather_mv(b, a, rhs_indices, lhs_indices, out, N, K, false, d, s);
2401
+ return;
2402
+ }
2403
+ if (N == 1) {
2404
+ gather_mv(a, b, lhs_indices, rhs_indices, out, M, K, true, d, s);
2405
+ return;
2406
+ }
2407
+
2408
+ // Route to non specialized gather mm
2409
+ gather_mm(a, b, lhs_indices, rhs_indices, out, M, N, K, d, s);
2410
+ }
2411
+
2412
+ void segmented_mm(
2413
+ const array& a_,
2414
+ const array& b_,
2415
+ const array& segments_,
2416
+ array& out,
2417
+ int M,
2418
+ int N,
2419
+ int K,
2420
+ metal::Device& d,
2421
+ const Stream& s) {
2422
+ auto check_segments_layout = [&d, &s](const array& x) {
2423
+ // Contiguous so return early
2424
+ if (x.flags().row_contiguous) {
2425
+ return std::make_tuple(true, x);
2426
+ }
2427
+
2428
+ bool rc = true;
2429
+ for (int i = 0; i < x.ndim() - 2; i++) {
2430
+ rc &=
2431
+ (x.strides(i + 1) * x.shape(i) == x.strides(i)) || (x.shape(i) == 1);
2432
+ }
2433
+ rc &= x.strides(x.ndim() - 1) == 1;
2434
+ if (x.ndim() > 1) {
2435
+ rc &= x.strides(x.ndim() - 2) == 1;
2436
+ }
2437
+
2438
+ if (rc) {
2439
+ return std::make_tuple(false, x);
2440
+ }
2441
+
2442
+ array x_copy = contiguous_copy_gpu(x, s);
2443
+ d.add_temporary(x_copy, s.index);
2444
+ return std::make_tuple(true, x_copy);
2445
+ };
2446
+
2447
+ // Copy if needed
2448
+ std::vector<array> copies;
2449
+ auto [transpose_a, lda, a] = check_transpose(copies, s, a_, false);
2450
+ auto [transpose_b, ldb, b] = check_transpose(copies, s, b_, false);
2451
+ auto [segments_contiguous, segments] = check_segments_layout(segments_);
2452
+ d.add_temporaries(std::move(copies), s.index);
2453
+
2454
+ // Determine dispatch kernel
2455
+ int bm = 64, bn = 64, bk = 16;
2456
+ int wm = 2, wn = 2;
2457
+ size_t batch_size_out = out.size() / M / N;
2458
+
2459
+ char devc = d.get_architecture().back();
2460
+ GEMM_TPARAM_MACRO(devc)
2461
+
2462
+ const bool align_M = (M % bm) == 0;
2463
+ const bool align_N = (N % bn) == 0;
2464
+
2465
+ // Define the kernel name
2466
+ std::string base_name;
2467
+ base_name.reserve(128);
2468
+ concatenate(
2469
+ base_name,
2470
+ "steel_segmented_mm_",
2471
+ transpose_a ? 't' : 'n',
2472
+ transpose_b ? 't' : 'n',
2473
+ "_",
2474
+ type_to_name(a),
2475
+ "_",
2476
+ type_to_name(out),
2477
+ "_bm",
2478
+ bm,
2479
+ "_bn",
2480
+ bn,
2481
+ "_bk",
2482
+ bk,
2483
+ "_wm",
2484
+ wm,
2485
+ "_wn",
2486
+ wn);
2487
+
2488
+ metal::MTLFCList func_consts = {
2489
+ {&segments_contiguous, MTL::DataType::DataTypeBool, 199},
2490
+ {&align_M, MTL::DataType::DataTypeBool, 200},
2491
+ {&align_N, MTL::DataType::DataTypeBool, 201},
2492
+ };
2493
+
2494
+ // And the kernel hash that includes the function constants
2495
+ std::string hash_name;
2496
+ hash_name.reserve(128);
2497
+ concatenate(
2498
+ hash_name,
2499
+ base_name,
2500
+ "_segments_contiguous_",
2501
+ segments_contiguous ? 't' : 'n',
2502
+ "_align_M_",
2503
+ align_M ? 't' : 'n',
2504
+ "_align_N_",
2505
+ align_N ? 't' : 'n');
2506
+
2507
+ // Get and set the kernel
2508
+ auto& compute_encoder = d.get_command_encoder(s.index);
2509
+ auto kernel = get_steel_gemm_segmented_kernel(
2510
+ d,
2511
+ base_name,
2512
+ hash_name,
2513
+ func_consts,
2514
+ out,
2515
+ transpose_a,
2516
+ transpose_b,
2517
+ bm,
2518
+ bn,
2519
+ bk,
2520
+ wm,
2521
+ wn);
2522
+ compute_encoder.set_compute_pipeline_state(kernel);
2523
+
2524
+ // Prepare the matmul params
2525
+ steel::GEMMParams params{/* const int M = */ M,
2526
+ /* const int N = */ N,
2527
+ /* const int K = */ K,
2528
+ /* const int lda = */ static_cast<int>(lda),
2529
+ /* const int ldb = */ static_cast<int>(ldb),
2530
+ /* const int ldd = */ N,
2531
+ /* const int tiles_n = */ (N + bn - 1) / bn,
2532
+ /* const int tiles_m = */ (M + bm - 1) / bm,
2533
+ /* const int64_t batch_stride_a = */ 0,
2534
+ /* const int64_t batch_stride_b = */ 0,
2535
+ /* const int64_t batch_stride_d = */ M * N,
2536
+ /* const int swizzle_log = */ 0,
2537
+ /* const int gemm_k_iterations_aligned = */ 0,
2538
+ /* const int batch_ndim = */ 0};
2539
+
2540
+ // Prepare the grid
2541
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
2542
+ MTL::Size grid_dims =
2543
+ MTL::Size(params.tiles_n, params.tiles_m, batch_size_out);
2544
+
2545
+ // Launch kernel
2546
+ compute_encoder.set_input_array(a, 0);
2547
+ compute_encoder.set_input_array(b, 1);
2548
+ compute_encoder.set_input_array(segments, 2);
2549
+ compute_encoder.set_output_array(out, 3);
2550
+ compute_encoder.set_bytes(params, 4);
2551
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
2552
+ }
2553
+
2554
+ void SegmentedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
2555
+ auto& s = stream();
2556
+ auto& d = metal::device(s.device);
2557
+
2558
+ auto& a = inputs[0];
2559
+ auto& b = inputs[1];
2560
+ auto& segments = inputs[2];
2561
+
2562
+ out.set_data(allocator::malloc(out.nbytes()));
2563
+
2564
+ // Extract shapes from inputs.
2565
+ int M = a.shape(-2);
2566
+ int N = b.shape(-1);
2567
+ int K = a.shape(-1);
2568
+
2569
+ segmented_mm(a, b, segments, out, M, N, K, d, s);
2570
+ }
2571
+
2572
+ } // namespace mlx::core