mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,1705 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <metal_simdgroup>
4
+ #include <metal_stdlib>
5
+
6
+ using namespace metal;
7
+ using namespace mlx::steel;
8
+
9
+ constant bool align_M [[function_constant(200)]];
10
+ constant bool align_N [[function_constant(201)]];
11
+ constant bool align_K [[function_constant(202)]];
12
+
13
+ using namespace metal;
14
+
15
+ #define MLX_MTL_CONST static constant constexpr const
16
+
17
+ MLX_MTL_CONST int SIMD_SIZE = 32;
18
+ MLX_MTL_CONST int QUAD_SIZE = 4;
19
+
20
+ template <int bits, int wsize = 8>
21
+ inline constexpr short get_pack_factor() {
22
+ return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
23
+ }
24
+
25
+ template <int bits, int wsize = 8>
26
+ inline constexpr short get_bytes_per_pack() {
27
+ constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
28
+ return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
29
+ }
30
+
31
+ template <typename T, typename U, int values_per_thread, int bits>
32
+ inline U load_vector(const device T* x, thread U* x_thread) {
33
+ static_assert(
34
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
35
+ bits == 8,
36
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
37
+
38
+ U sum = 0;
39
+
40
+ if (bits == 2) {
41
+ for (int i = 0; i < values_per_thread; i += 4) {
42
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
43
+ x_thread[i] = x[i];
44
+ x_thread[i + 1] = x[i + 1] / 4.0f;
45
+ x_thread[i + 2] = x[i + 2] / 16.0f;
46
+ x_thread[i + 3] = x[i + 3] / 64.0f;
47
+ }
48
+ }
49
+
50
+ else if (bits == 3) {
51
+ for (int i = 0; i < values_per_thread; i += 8) {
52
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
53
+ x[i + 6] + x[i + 7];
54
+ x_thread[i] = x[i];
55
+ x_thread[i + 1] = x[i + 1] / 8.0f;
56
+ x_thread[i + 2] = x[i + 2] / 64.0f;
57
+ x_thread[i + 3] = x[i + 3] / 2.0f;
58
+ x_thread[i + 4] = x[i + 4] / 16.0f;
59
+ x_thread[i + 5] = x[i + 5] / 128.0f;
60
+ x_thread[i + 6] = x[i + 6] / 4.0f;
61
+ x_thread[i + 7] = x[i + 7] / 32.0f;
62
+ }
63
+ }
64
+
65
+ else if (bits == 4) {
66
+ for (int i = 0; i < values_per_thread; i += 4) {
67
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
68
+ x_thread[i] = x[i];
69
+ x_thread[i + 1] = x[i + 1] / 16.0f;
70
+ x_thread[i + 2] = x[i + 2] / 256.0f;
71
+ x_thread[i + 3] = x[i + 3] / 4096.0f;
72
+ }
73
+ }
74
+
75
+ else if (bits == 5) {
76
+ for (int i = 0; i < values_per_thread; i += 8) {
77
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
78
+ x[i + 6] + x[i + 7];
79
+ x_thread[i] = x[i];
80
+ x_thread[i + 1] = x[i + 1] / 32.0f;
81
+ x_thread[i + 2] = x[i + 2] / 4.0f;
82
+ x_thread[i + 3] = x[i + 3] / 128.0f;
83
+ x_thread[i + 4] = x[i + 4] / 16.0f;
84
+ x_thread[i + 5] = x[i + 5] / 2.0f;
85
+ x_thread[i + 6] = x[i + 6] / 64.0f;
86
+ x_thread[i + 7] = x[i + 7] / 8.0f;
87
+ }
88
+ }
89
+
90
+ else if (bits == 6) {
91
+ for (int i = 0; i < values_per_thread; i += 4) {
92
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
93
+ x_thread[i] = x[i];
94
+ x_thread[i + 1] = x[i + 1] / 64.0f;
95
+ x_thread[i + 2] = x[i + 2] / 16.0f;
96
+ x_thread[i + 3] = x[i + 3] / 4.0f;
97
+ }
98
+ }
99
+
100
+ else if (bits == 8) {
101
+ for (int i = 0; i < values_per_thread; i++) {
102
+ sum += x[i];
103
+ x_thread[i] = x[i];
104
+ }
105
+ }
106
+
107
+ return sum;
108
+ }
109
+
110
+ template <typename T, typename U, int values_per_thread, int bits>
111
+ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
112
+ static_assert(
113
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
114
+ bits == 8,
115
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
116
+
117
+ U sum = 0;
118
+
119
+ if (bits == 2) {
120
+ for (int i = 0; i < N; i += 4) {
121
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
122
+ x_thread[i] = x[i];
123
+ x_thread[i + 1] = x[i + 1] / 4.0f;
124
+ x_thread[i + 2] = x[i + 2] / 16.0f;
125
+ x_thread[i + 3] = x[i + 3] / 64.0f;
126
+ }
127
+ }
128
+
129
+ else if (bits == 3) {
130
+ for (int i = 0; i < N; i += 8) {
131
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
132
+ x[i + 6] + x[i + 7];
133
+
134
+ x_thread[i] = x[i];
135
+ x_thread[i + 1] = x[i + 1] / 8.0f;
136
+ x_thread[i + 2] = x[i + 2] / 64.0f;
137
+ x_thread[i + 3] = x[i + 3] / 2.0f;
138
+ x_thread[i + 4] = x[i + 4] / 16.0f;
139
+ x_thread[i + 5] = x[i + 5] / 128.0f;
140
+ x_thread[i + 6] = x[i + 6] / 4.0f;
141
+ x_thread[i + 7] = x[i + 7] / 32.0f;
142
+ }
143
+ }
144
+
145
+ else if (bits == 4) {
146
+ for (int i = 0; i < N; i += 4) {
147
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
148
+ x_thread[i] = x[i];
149
+ x_thread[i + 1] = x[i + 1] / 16.0f;
150
+ x_thread[i + 2] = x[i + 2] / 256.0f;
151
+ x_thread[i + 3] = x[i + 3] / 4096.0f;
152
+ }
153
+ }
154
+
155
+ else if (bits == 5) {
156
+ for (int i = 0; i < N; i += 8) {
157
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
158
+ x[i + 6] + x[i + 7];
159
+ x_thread[i] = x[i];
160
+ x_thread[i + 1] = x[i + 1] / 32.0f;
161
+ x_thread[i + 2] = x[i + 2] / 4.0f;
162
+ x_thread[i + 3] = x[i + 3] / 128.0f;
163
+ x_thread[i + 4] = x[i + 4] / 16.0f;
164
+ x_thread[i + 5] = x[i + 5] / 2.0f;
165
+ x_thread[i + 6] = x[i + 6] / 64.0f;
166
+ x_thread[i + 7] = x[i + 7] / 8.0f;
167
+ }
168
+ }
169
+
170
+ else if (bits == 6) {
171
+ for (int i = 0; i < N; i += 4) {
172
+ sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
173
+ x_thread[i] = x[i];
174
+ x_thread[i + 1] = x[i + 1] / 64.0f;
175
+ x_thread[i + 2] = x[i + 2] / 16.0f;
176
+ x_thread[i + 3] = x[i + 3] / 4.0f;
177
+ }
178
+ }
179
+
180
+ else if (bits == 8) {
181
+ for (int i = 0; i < N; i++) {
182
+ sum += x[i];
183
+ x_thread[i] = x[i];
184
+ }
185
+ }
186
+
187
+ for (int i = N; i < values_per_thread; i++) {
188
+ x_thread[i] = 0;
189
+ }
190
+
191
+ return sum;
192
+ }
193
+
194
+ template <typename U, int values_per_thread, int bits>
195
+ inline U qdot(
196
+ const device uint8_t* w,
197
+ const thread U* x_thread,
198
+ U scale,
199
+ U bias,
200
+ U sum) {
201
+ static_assert(
202
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
203
+ bits == 8,
204
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
205
+
206
+ U accum = 0;
207
+
208
+ if (bits == 2) {
209
+ for (int i = 0; i < (values_per_thread / 4); i++) {
210
+ accum +=
211
+ (x_thread[4 * i] * (w[i] & 0x03) +
212
+ x_thread[4 * i + 1] * (w[i] & 0x0c) +
213
+ x_thread[4 * i + 2] * (w[i] & 0x30) +
214
+ x_thread[4 * i + 3] * (w[i] & 0xc0));
215
+ }
216
+ }
217
+
218
+ else if (bits == 3) {
219
+ for (int i = 0; i < (values_per_thread / 8); i++) {
220
+ x_thread += 8 * i;
221
+ w += 3 * i;
222
+
223
+ accum += (w[0] & 0x07) * x_thread[0];
224
+ accum += (w[0] & 0x38) * x_thread[1];
225
+ accum += (w[0] & 0xc0) * x_thread[2];
226
+ accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
227
+
228
+ accum += (w[1] & 0x0e) * x_thread[3];
229
+ accum += (w[1] & 0x70) * x_thread[4];
230
+ accum += (w[1] & 0x80) * x_thread[5];
231
+ accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
232
+
233
+ accum += (w[2] & 0x1c) * x_thread[6];
234
+ accum += (w[2] & 0xe0) * x_thread[7];
235
+ }
236
+ }
237
+
238
+ else if (bits == 4) {
239
+ const device uint16_t* ws = (const device uint16_t*)w;
240
+ for (int i = 0; i < (values_per_thread / 4); i++) {
241
+ accum +=
242
+ (x_thread[4 * i] * (ws[i] & 0x000f) +
243
+ x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
244
+ x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
245
+ x_thread[4 * i + 3] * (ws[i] & 0xf000));
246
+ }
247
+ }
248
+
249
+ else if (bits == 5) {
250
+ for (int i = 0; i < (values_per_thread / 8); i++) {
251
+ x_thread += 8 * i;
252
+ w += 5 * i;
253
+
254
+ accum += (w[0] & 0x1f) * x_thread[0];
255
+ accum += (w[0] & 0xe0) * x_thread[1];
256
+ accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
257
+ accum += (w[1] & 0x7c) * x_thread[2];
258
+ accum += (w[1] & 0x80) * x_thread[3];
259
+ accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
260
+ accum += (w[2] & 0xf0) * x_thread[4];
261
+ accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
262
+ accum += (w[3] & 0x3e) * x_thread[5];
263
+ accum += (w[3] & 0xc0) * x_thread[6];
264
+ accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
265
+ accum += (w[4] & 0xf8) * x_thread[7];
266
+ }
267
+ }
268
+
269
+ else if (bits == 6) {
270
+ for (int i = 0; i < (values_per_thread / 4); i++) {
271
+ x_thread += 4 * i;
272
+ w += 3 * i;
273
+
274
+ accum += (w[0] & 0x3f) * x_thread[0];
275
+
276
+ accum += (w[0] & 0xc0) * x_thread[1];
277
+ accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
278
+
279
+ accum += (w[1] & 0xf0) * x_thread[2];
280
+ accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
281
+
282
+ accum += (w[2] & 0xfc) * x_thread[3];
283
+ }
284
+ }
285
+
286
+ else if (bits == 8) {
287
+ for (int i = 0; i < values_per_thread; i++) {
288
+ accum += x_thread[i] * w[i];
289
+ }
290
+ }
291
+
292
+ return scale * accum + sum * bias;
293
+ }
294
+
295
+ template <typename U, int values_per_thread, int bits>
296
+ inline U qdot_safe(
297
+ const device uint8_t* w,
298
+ const thread U* x_thread,
299
+ U scale,
300
+ U bias,
301
+ U sum,
302
+ int N) {
303
+ static_assert(
304
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
305
+ bits == 8,
306
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
307
+
308
+ U accum = 0;
309
+
310
+ if (bits == 2) {
311
+ for (int i = 0; i < (N / 4); i++) {
312
+ accum +=
313
+ (x_thread[4 * i] * (w[i] & 0x03) +
314
+ x_thread[4 * i + 1] * (w[i] & 0x0c) +
315
+ x_thread[4 * i + 2] * (w[i] & 0x30) +
316
+ x_thread[4 * i + 3] * (w[i] & 0xc0));
317
+ }
318
+ }
319
+
320
+ else if (bits == 3) {
321
+ for (int i = 0; i < (N / 8); i++) {
322
+ x_thread += 8 * i;
323
+ w += 3 * i;
324
+
325
+ accum += (w[0] & 0x07) * x_thread[0];
326
+ accum += (w[0] & 0x38) * x_thread[1];
327
+ accum += (w[0] & 0xc0) * x_thread[2];
328
+ accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
329
+
330
+ accum += (w[1] & 0x0e) * x_thread[3];
331
+ accum += (w[1] & 0x70) * x_thread[4];
332
+ accum += (w[1] & 0x80) * x_thread[5];
333
+ accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
334
+
335
+ accum += (w[2] & 0x1c) * x_thread[6];
336
+ accum += (w[2] & 0xe0) * x_thread[7];
337
+ }
338
+ }
339
+
340
+ else if (bits == 4) {
341
+ const device uint16_t* ws = (const device uint16_t*)w;
342
+ for (int i = 0; i < (N / 4); i++) {
343
+ accum +=
344
+ (x_thread[4 * i] * (ws[i] & 0x000f) +
345
+ x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
346
+ x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
347
+ x_thread[4 * i + 3] * (ws[i] & 0xf000));
348
+ }
349
+ }
350
+
351
+ else if (bits == 5) {
352
+ for (int i = 0; i < (N / 8); i++) {
353
+ x_thread += 8 * i;
354
+ w += 5 * i;
355
+
356
+ accum += (w[0] & 0x1f) * x_thread[0];
357
+ accum += (w[0] & 0xe0) * x_thread[1];
358
+ accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
359
+ accum += (w[1] & 0x7c) * x_thread[2];
360
+ accum += (w[1] & 0x80) * x_thread[3];
361
+ accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
362
+ accum += (w[2] & 0xf0) * x_thread[4];
363
+ accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
364
+ accum += (w[3] & 0x3e) * x_thread[5];
365
+ accum += (w[3] & 0xc0) * x_thread[6];
366
+ accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
367
+ accum += (w[4] & 0xf8) * x_thread[7];
368
+ }
369
+ }
370
+
371
+ else if (bits == 6) {
372
+ for (int i = 0; i < (N / 4); i++) {
373
+ x_thread += 4 * i;
374
+ w += 3 * i;
375
+
376
+ accum += (w[0] & 0x3f) * x_thread[0];
377
+
378
+ accum += (w[0] & 0xc0) * x_thread[1];
379
+ accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
380
+
381
+ accum += (w[1] & 0xf0) * x_thread[2];
382
+ accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
383
+
384
+ accum += (w[2] & 0xfc) * x_thread[3];
385
+ }
386
+ }
387
+
388
+ else if (bits == 8) {
389
+ for (int i = 0; i < N; i++) {
390
+ accum += x_thread[i] * w[i];
391
+ }
392
+ }
393
+
394
+ return scale * accum + sum * bias;
395
+ }
396
+
397
+ template <typename U, int values_per_thread, int bits>
398
+ inline void
399
+ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
400
+ static_assert(
401
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
402
+ bits == 8,
403
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
404
+
405
+ if (bits == 2) {
406
+ U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
407
+ for (int i = 0; i < (values_per_thread / 4); i++) {
408
+ result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
409
+ result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
410
+ result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
411
+ result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
412
+ }
413
+ }
414
+
415
+ else if (bits == 3) {
416
+ for (int i = 0; i < (values_per_thread / 8); i++) {
417
+ uint8_t w0 = w[3 * i];
418
+ uint8_t w1 = w[3 * i + 1];
419
+ uint8_t w2 = w[3 * i + 2];
420
+
421
+ result[8 * i] += x * ((w0 & 0x7) * scale + bias);
422
+ result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
423
+ result[8 * i + 2] +=
424
+ x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
425
+ result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
426
+ result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
427
+ result[8 * i + 5] +=
428
+ x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
429
+ result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
430
+ result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
431
+ }
432
+ }
433
+
434
+ else if (bits == 4) {
435
+ U s[2] = {scale, scale / 16.0f};
436
+ for (int i = 0; i < (values_per_thread / 2); i++) {
437
+ result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
438
+ result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
439
+ }
440
+ }
441
+
442
+ else if (bits == 5) {
443
+ for (int i = 0; i < (values_per_thread / 8); i++) {
444
+ uint8_t w0 = w[5 * i];
445
+ uint8_t w1 = w[5 * i + 1];
446
+ uint8_t w2 = w[5 * i + 2];
447
+ uint8_t w3 = w[5 * i + 3];
448
+ uint8_t w4 = w[5 * i + 4];
449
+ result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
450
+ result[8 * i + 1] +=
451
+ x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
452
+ result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
453
+ result[8 * i + 3] +=
454
+ x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
455
+ result[8 * i + 4] +=
456
+ x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
457
+ result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
458
+ result[8 * i + 6] +=
459
+ x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
460
+ result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
461
+ }
462
+ }
463
+
464
+ else if (bits == 6) {
465
+ for (int i = 0; i < (values_per_thread / 4); i++) {
466
+ uint8_t w0 = w[3 * i];
467
+ uint8_t w1 = w[3 * i + 1];
468
+ uint8_t w2 = w[3 * i + 2];
469
+
470
+ result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
471
+ result[4 * i + 1] +=
472
+ x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
473
+ result[4 * i + 2] +=
474
+ x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
475
+ result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
476
+ }
477
+ }
478
+
479
+ else if (bits == 8) {
480
+ for (int i = 0; i < values_per_thread; i++) {
481
+ result[i] += x * (scale * w[i] + bias);
482
+ }
483
+ }
484
+ }
485
+
486
+ template <typename U, int N, int bits>
487
+ inline void
488
+ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
489
+ static_assert(
490
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
491
+ bits == 8,
492
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
493
+
494
+ if (bits == 2) {
495
+ U s[4] = {
496
+ scale,
497
+ scale / static_cast<U>(4.0f),
498
+ scale / static_cast<U>(16.0f),
499
+ scale / static_cast<U>(64.0f)};
500
+ for (int i = 0; i < (N / 4); i++) {
501
+ w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
502
+ w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
503
+ w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
504
+ w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
505
+ }
506
+ }
507
+
508
+ else if (bits == 3) {
509
+ for (int i = 0; i < (N / 8); i++) {
510
+ w_local += 8 * i;
511
+ w += 3 * i;
512
+
513
+ w_local[0] = (w[0] & 0x7) * scale + bias;
514
+ w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
515
+ w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
516
+ w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
517
+ w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
518
+ w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
519
+ w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
520
+ w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
521
+ }
522
+ }
523
+
524
+ else if (bits == 4) {
525
+ U s[2] = {scale, scale / static_cast<U>(16.0f)};
526
+ for (int i = 0; i < (N / 2); i++) {
527
+ w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
528
+ w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
529
+ }
530
+ }
531
+
532
+ else if (bits == 5) {
533
+ for (int i = 0; i < (N / 8); i++) {
534
+ w_local += 8 * i;
535
+ w += 5 * i;
536
+
537
+ w_local[0] = (w[0] & 0x1f) * scale + bias;
538
+ w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
539
+ w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
540
+ w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
541
+ w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
542
+ w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
543
+ w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
544
+ w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
545
+ }
546
+ }
547
+
548
+ else if (bits == 6) {
549
+ for (int i = 0; i < (N / 4); i++) {
550
+ w_local += 4 * i;
551
+ w += 3 * i;
552
+ w_local[0] = (w[0] & 0x3f) * scale + bias;
553
+ w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
554
+ w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
555
+ w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
556
+ }
557
+ }
558
+
559
+ else if (bits == 8) {
560
+ for (int i = 0; i < N; i++) {
561
+ w_local[i] = scale * w[i] + bias;
562
+ }
563
+ }
564
+ }
565
+
566
+ template <
567
+ typename T,
568
+ short BROWS,
569
+ short BCOLS,
570
+ short dst_ld,
571
+ short reduction_dim,
572
+ short tgp_size,
573
+ short group_size,
574
+ short bits>
575
+ struct QuantizedBlockLoader {
576
+ static_assert(
577
+ BCOLS <= group_size,
578
+ "The group size should be larger than the columns");
579
+ static_assert(
580
+ group_size % BCOLS == 0,
581
+ "The group size should be divisible by the columns");
582
+ static_assert(
583
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
584
+ bits == 8,
585
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
586
+
587
+ MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
588
+ MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
589
+ MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
590
+ MLX_MTL_CONST short n_reads =
591
+ (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
592
+ MLX_MTL_CONST short group_steps = group_size / BCOLS;
593
+
594
+ const int src_ld;
595
+ const int tile_stride;
596
+ short group_step_cnt;
597
+ const int group_stride;
598
+
599
+ const short thread_idx;
600
+ const short bi;
601
+ const short bj;
602
+
603
+ threadgroup T* dst;
604
+ const device uint8_t* src;
605
+ const device T* scales;
606
+ const device T* biases;
607
+
608
+ QuantizedBlockLoader(
609
+ const device uint8_t* src_,
610
+ const device T* scales_,
611
+ const device T* biases_,
612
+ const int src_ld_,
613
+ threadgroup T* dst_,
614
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
615
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
616
+ : src_ld(src_ld_),
617
+ tile_stride(
618
+ reduction_dim ? BCOLS_PACKED * bytes_per_pack
619
+ : BROWS * src_ld * bytes_per_pack / pack_factor),
620
+ group_step_cnt(0),
621
+ group_stride(BROWS * src_ld / group_size),
622
+ thread_idx(simd_group_id * 32 + simd_lane_id),
623
+ bi(n_reads * thread_idx / BCOLS_PACKED),
624
+ bj((n_reads * thread_idx) % BCOLS_PACKED),
625
+ dst(dst_ + bi * dst_ld + bj * pack_factor),
626
+ src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
627
+ bj * bytes_per_pack),
628
+ scales(scales_ + bi * src_ld / group_size),
629
+ biases(biases_ + bi * src_ld / group_size) {}
630
+
631
+ void load_unsafe() const {
632
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
633
+ return;
634
+ }
635
+
636
+ T scale = *scales;
637
+ T bias = *biases;
638
+ for (int i = 0; i < n_reads; i++) {
639
+ dequantize<T, pack_factor, bits>(
640
+ src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
641
+ }
642
+ }
643
+
644
+ void load_safe(short2 src_tile_dim) const {
645
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
646
+ return;
647
+ }
648
+
649
+ if (reduction_dim == 1 && bi >= src_tile_dim.x) {
650
+ for (int i = 0; i < n_reads * pack_factor; i++) {
651
+ dst[i] = T(0);
652
+ }
653
+ return;
654
+ }
655
+
656
+ if (reduction_dim == 0 && bi >= src_tile_dim.y) {
657
+ for (int i = 0; i < n_reads * pack_factor; i++) {
658
+ dst[i] = T(0);
659
+ }
660
+ return;
661
+ }
662
+
663
+ T scale = *scales;
664
+ T bias = *biases;
665
+ for (int i = 0; i < n_reads; i++) {
666
+ dequantize<T, pack_factor, bits>(
667
+ (device uint8_t*)(src + i * bytes_per_pack),
668
+ scale,
669
+ bias,
670
+ dst + i * pack_factor);
671
+ }
672
+ }
673
+
674
+ void next() {
675
+ src += tile_stride;
676
+ if (reduction_dim == 1) {
677
+ if (group_steps > 1) {
678
+ group_step_cnt++;
679
+ if (group_step_cnt == group_steps) {
680
+ group_step_cnt = 0;
681
+ scales++;
682
+ biases++;
683
+ }
684
+ } else {
685
+ scales++;
686
+ biases++;
687
+ }
688
+ } else {
689
+ scales += group_stride;
690
+ biases += group_stride;
691
+ }
692
+ }
693
+ };
694
+
695
+ template <
696
+ typename T,
697
+ short BROWS,
698
+ short BCOLS,
699
+ short dst_ld,
700
+ short reduction_dim,
701
+ short tgp_size,
702
+ short bits>
703
+ struct QuantizedBlockLoader<
704
+ T,
705
+ BROWS,
706
+ BCOLS,
707
+ dst_ld,
708
+ reduction_dim,
709
+ tgp_size,
710
+ 32,
711
+ bits> {
712
+ MLX_MTL_CONST short group_size = 32;
713
+
714
+ static_assert(
715
+ BCOLS % group_size == 0,
716
+ "The group size should be divisible by the columns");
717
+ static_assert(
718
+ bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
719
+ bits == 8,
720
+ "Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
721
+
722
+ MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
723
+ MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
724
+ MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
725
+ MLX_MTL_CONST short n_reads =
726
+ (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
727
+ MLX_MTL_CONST short n_groups = BCOLS / group_size;
728
+
729
+ static_assert(
730
+ (BCOLS_PACKED / n_reads) == n_groups,
731
+ "Other configurations are not yet supported");
732
+
733
+ const int src_ld;
734
+ const int tile_stride;
735
+ const int group_stride;
736
+
737
+ const short thread_idx;
738
+ const short bi;
739
+ const short bj;
740
+
741
+ const short group_id;
742
+
743
+ threadgroup T* dst;
744
+ const device uint8_t* src;
745
+ const device T* scales;
746
+ const device T* biases;
747
+
748
+ QuantizedBlockLoader(
749
+ const device uint8_t* src_,
750
+ const device T* scales_,
751
+ const device T* biases_,
752
+ const int src_ld_,
753
+ threadgroup T* dst_,
754
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
755
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
756
+ : src_ld(src_ld_),
757
+ tile_stride(
758
+ reduction_dim ? BCOLS_PACKED * bytes_per_pack
759
+ : BROWS * src_ld * bytes_per_pack / pack_factor),
760
+ group_stride(BROWS * src_ld / group_size),
761
+ thread_idx(simd_group_id * 32 + simd_lane_id),
762
+ bi(n_reads * thread_idx / BCOLS_PACKED),
763
+ bj((n_reads * thread_idx) % BCOLS_PACKED),
764
+ group_id((bj * pack_factor) / group_size),
765
+ dst(dst_ + bi * dst_ld + bj * pack_factor),
766
+ src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
767
+ bj * bytes_per_pack),
768
+ scales(scales_ + bi * src_ld / group_size + group_id),
769
+ biases(biases_ + bi * src_ld / group_size + group_id) {}
770
+
771
+ void load_unsafe() const {
772
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
773
+ return;
774
+ }
775
+
776
+ T scale = *scales;
777
+ T bias = *biases;
778
+ for (int i = 0; i < n_reads; i++) {
779
+ dequantize<T, pack_factor, bits>(
780
+ src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
781
+ }
782
+ }
783
+
784
+ void load_safe(short2 src_tile_dim) const {
785
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
786
+ return;
787
+ }
788
+
789
+ if (reduction_dim == 1 && bi >= src_tile_dim.x) {
790
+ for (int i = 0; i < n_reads * pack_factor; i++) {
791
+ dst[i] = T(0);
792
+ }
793
+ return;
794
+ }
795
+
796
+ if (reduction_dim == 0 && bi >= src_tile_dim.y) {
797
+ for (int i = 0; i < n_reads * pack_factor; i++) {
798
+ dst[i] = T(0);
799
+ }
800
+ return;
801
+ }
802
+
803
+ T scale = *scales;
804
+ T bias = *biases;
805
+ for (int i = 0; i < n_reads; i++) {
806
+ dequantize<T, pack_factor, bits>(
807
+ (device uint8_t*)(src + i * bytes_per_pack),
808
+ scale,
809
+ bias,
810
+ dst + i * pack_factor);
811
+ }
812
+ }
813
+
814
+ void next() {
815
+ src += tile_stride;
816
+ if (reduction_dim == 1) {
817
+ // if (group_steps > 1) {
818
+ // group_step_cnt++;
819
+ // if (group_step_cnt == group_steps) {
820
+ // group_step_cnt = 0;
821
+ // scales++;
822
+ // biases++;
823
+ // }
824
+ // } else {
825
+ scales += n_groups;
826
+ biases += n_groups;
827
+ // }
828
+ } else {
829
+ scales += n_groups * group_stride;
830
+ biases += n_groups * group_stride;
831
+ }
832
+ }
833
+ };
834
+
835
+ template <typename T>
836
+ METAL_FUNC void adjust_matrix_offsets(
837
+ const device T*& x,
838
+ const device uint32_t*& w,
839
+ const device T*& scales,
840
+ const device T*& biases,
841
+ device T*& y,
842
+ int output_stride,
843
+ const constant int& x_batch_ndims,
844
+ const constant int* x_shape,
845
+ const constant int64_t* x_strides,
846
+ const constant int& w_batch_ndims,
847
+ const constant int* w_shape,
848
+ const constant int64_t* w_strides,
849
+ const constant int64_t* s_strides,
850
+ const constant int64_t* b_strides,
851
+ uint3 tid [[threadgroup_position_in_grid]]) {
852
+ // Set the input/output matrices
853
+ uint32_t x_idx = tid.z;
854
+ uint32_t w_idx = tid.z;
855
+ if (x_batch_ndims == 1) {
856
+ x += x_idx * x_strides[0];
857
+ } else {
858
+ x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
859
+ }
860
+ if (w_batch_ndims == 1) {
861
+ w += w_idx * w_strides[0];
862
+ scales += w_idx * s_strides[0];
863
+ biases += w_idx * b_strides[0];
864
+ } else {
865
+ ulong3 idx = elem_to_loc_broadcast(
866
+ w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
867
+ w += idx.x;
868
+ scales += idx.y;
869
+ biases += idx.z;
870
+ }
871
+ y += tid.z * output_stride;
872
+ }
873
+
874
+ template <typename T>
875
+ METAL_FUNC void adjust_matrix_offsets(
876
+ const device T*& x,
877
+ const device uint32_t*& w,
878
+ const device T*& scales,
879
+ const device T*& biases,
880
+ const device uint32_t* lhs_indices,
881
+ const device uint32_t* rhs_indices,
882
+ device T*& y,
883
+ int output_stride,
884
+ const constant int& batch_ndims,
885
+ const constant int* batch_shape,
886
+ const constant int64_t* lhs_strides,
887
+ const constant int64_t* rhs_strides,
888
+ const constant int& x_batch_ndims,
889
+ const constant int* x_shape,
890
+ const constant int64_t* x_strides,
891
+ const constant int& w_batch_ndims,
892
+ const constant int* w_shape,
893
+ const constant int64_t* w_strides,
894
+ const constant int64_t* s_strides,
895
+ const constant int64_t* b_strides,
896
+ uint3 tid [[threadgroup_position_in_grid]]) {
897
+ // Set the input/output matrices
898
+ uint32_t x_idx;
899
+ uint32_t w_idx;
900
+ if (batch_ndims == 1) {
901
+ x_idx = lhs_indices[tid.z * lhs_strides[0]];
902
+ w_idx = rhs_indices[tid.z * rhs_strides[0]];
903
+ } else {
904
+ ulong2 idx = elem_to_loc_broadcast(
905
+ tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
906
+ x_idx = lhs_indices[idx.x];
907
+ w_idx = rhs_indices[idx.y];
908
+ }
909
+ if (x_batch_ndims == 1) {
910
+ x += x_idx * x_strides[0];
911
+ } else {
912
+ x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
913
+ }
914
+ if (w_batch_ndims == 1) {
915
+ w += w_idx * w_strides[0];
916
+ scales += w_idx * s_strides[0];
917
+ biases += w_idx * b_strides[0];
918
+ } else {
919
+ ulong3 idx = elem_to_loc_broadcast(
920
+ w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
921
+ w += idx.x;
922
+ scales += idx.y;
923
+ biases += idx.z;
924
+ }
925
+ y += tid.z * output_stride;
926
+ }
927
+
928
+ template <
929
+ typename T,
930
+ const int group_size,
931
+ const int bits,
932
+ const bool aligned_N,
933
+ const int BM = 64,
934
+ const int BK = 64,
935
+ const int BN = 64,
936
+ const int WM = 2,
937
+ const int WN = 2>
938
+ METAL_FUNC void qmm_t_nax_tgp_impl(
939
+ const device uint32_t* w,
940
+ const device T* scales,
941
+ const device T* biases,
942
+ const device T* x,
943
+ device T* y,
944
+ threadgroup T* Ws,
945
+ const constant int& K,
946
+ const constant int& N,
947
+ const constant int& M,
948
+ uint3 tid [[threadgroup_position_in_grid]],
949
+ uint lid [[thread_index_in_threadgroup]],
950
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
951
+ uint simd_lid [[thread_index_in_simdgroup]]) {
952
+ static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
953
+ static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
954
+
955
+ (void)lid;
956
+
957
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
958
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
959
+
960
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
961
+
962
+ using loader_w_t = QuantizedBlockLoader<
963
+ T,
964
+ BN,
965
+ BK,
966
+ BK_padded,
967
+ 1,
968
+ WM * WN * SIMD_SIZE,
969
+ group_size,
970
+ bits>;
971
+
972
+ // Set the block
973
+ const int K_w = K * bytes_per_pack / pack_factor;
974
+ const int K_g = K / group_size;
975
+ const int y_row = tid.y * BM;
976
+ const int y_col = tid.x * BN;
977
+
978
+ auto wl = (const device uint8_t*)w;
979
+
980
+ x += y_row * static_cast<int64_t>(K);
981
+ wl += y_col * K_w;
982
+ scales += y_col * K_g;
983
+ biases += y_col * K_g;
984
+ y += y_row * static_cast<int64_t>(N) + y_col;
985
+
986
+ // Make the weight loader
987
+ loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
988
+
989
+ constexpr short UM = 16;
990
+ constexpr short UN = 32;
991
+ constexpr short UK = 16;
992
+ constexpr short SM = BM / WM;
993
+ constexpr short SN = BN / WN;
994
+ constexpr short SK = 32;
995
+
996
+ constexpr short TM = SM / UM;
997
+ constexpr short TN = SN / UN;
998
+ constexpr short TK = SK / UK;
999
+
1000
+ const short tm = SM * (simd_gid / WN);
1001
+ const short tn = SN * (simd_gid % WN);
1002
+
1003
+ constexpr bool transpose_a = false;
1004
+ constexpr bool transpose_b = true;
1005
+
1006
+ const short sgp_sm = min(SM, short(M - (y_row + tm)));
1007
+ const bool is_unaligned_sm = (sgp_sm != SM);
1008
+
1009
+ const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn)));
1010
+
1011
+ const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col)));
1012
+ const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN);
1013
+
1014
+ using AccumType = float;
1015
+
1016
+ using ASubTile = NAXSubTile<T, UM, UK>;
1017
+ using BSubTile = NAXSubTile<T, UN, UK>;
1018
+ using DSubTile = NAXSubTile<AccumType, UM, UN>;
1019
+
1020
+ NAXTile<AccumType, TM, TN, DSubTile> Dtile;
1021
+
1022
+ Dtile.clear();
1023
+
1024
+ x += tm * K;
1025
+
1026
+ dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) {
1027
+ dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) {
1028
+ for (int k = 0; k < K; k += BK) {
1029
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1030
+ if constexpr (kAlignedN.value) {
1031
+ loader_w.load_unsafe();
1032
+ } else {
1033
+ loader_w.load_safe(short2(BK, tgp_bn));
1034
+ }
1035
+
1036
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1037
+
1038
+ STEEL_PRAGMA_NO_UNROLL
1039
+ for (int kk1 = 0; kk1 < BK; kk1 += SK) {
1040
+ NAXTile<T, TM, TK, ASubTile> Atile;
1041
+ NAXTile<T, TN, TK, BSubTile> Btile;
1042
+
1043
+ volatile int compiler_barrier;
1044
+
1045
+ if constexpr (kAlignedM.value) {
1046
+ Atile.load(x + kk1, K);
1047
+ } else {
1048
+ Atile.load_safe(x + kk1, K, short2(SK, sgp_sm));
1049
+ }
1050
+
1051
+ Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);
1052
+
1053
+ tile_matmad_nax(
1054
+ Dtile,
1055
+ Atile,
1056
+ metal::bool_constant<transpose_a>{},
1057
+ Btile,
1058
+ metal::bool_constant<transpose_b>{});
1059
+
1060
+ (void)compiler_barrier;
1061
+ }
1062
+
1063
+ x += BK;
1064
+ loader_w.next();
1065
+ }
1066
+
1067
+ // Store results to device memory
1068
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1069
+
1070
+ if constexpr (kAlignedM.value && kAlignedN.value) {
1071
+ Dtile.store(y + tm * N + tn, N);
1072
+ } else if (kAlignedM.value && sgp_sn == SN) {
1073
+ Dtile.store(y + tm * N + tn, N);
1074
+ } else {
1075
+ Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm));
1076
+ }
1077
+ });
1078
+ });
1079
+ }
1080
+
1081
+ template <
1082
+ typename T,
1083
+ const int group_size,
1084
+ const int bits,
1085
+ const int BM = 64,
1086
+ const int BK = 64,
1087
+ const int BN = 64,
1088
+ const int WM = 2,
1089
+ const int WN = 2>
1090
+ METAL_FUNC void qmm_n_nax_tgp_impl(
1091
+ const device uint32_t* w,
1092
+ const device T* scales,
1093
+ const device T* biases,
1094
+ const device T* x,
1095
+ device T* y,
1096
+ threadgroup T* Ws,
1097
+ const constant int& K,
1098
+ const constant int& N,
1099
+ const constant int& M,
1100
+ uint3 tid [[threadgroup_position_in_grid]],
1101
+ uint lid [[thread_index_in_threadgroup]],
1102
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1103
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1104
+ (void)lid;
1105
+ (void)M;
1106
+
1107
+ static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
1108
+ static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
1109
+
1110
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
1111
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
1112
+
1113
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
1114
+
1115
+ using loader_w_t = QuantizedBlockLoader<
1116
+ T,
1117
+ BK,
1118
+ BN,
1119
+ BN_padded,
1120
+ 0,
1121
+ WM * WN * SIMD_SIZE,
1122
+ group_size,
1123
+ bits>;
1124
+
1125
+ // Set the block
1126
+ const int K_w = K * bytes_per_pack / pack_factor;
1127
+ const int K_g = K / group_size;
1128
+ const int y_row = tid.y * BM;
1129
+ const int y_col = tid.x * BN;
1130
+
1131
+ auto wl = (const device uint8_t*)w;
1132
+
1133
+ x += y_row * static_cast<int64_t>(K);
1134
+ wl += y_col * K_w;
1135
+ scales += y_col * K_g;
1136
+ biases += y_col * K_g;
1137
+ y += y_row * static_cast<int64_t>(N) + y_col;
1138
+
1139
+ // Make the x loader and mma operation
1140
+ // const short num_els = min(BM, M - y_row);
1141
+ // const short num_outs = min(BN, N - y_col);
1142
+ loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
1143
+
1144
+ constexpr short UM = 16;
1145
+ constexpr short UN = 32;
1146
+ constexpr short UK = 16;
1147
+ constexpr short SM = BM / WM;
1148
+ constexpr short SN = BN / WN;
1149
+ constexpr short SK = 32;
1150
+
1151
+ constexpr short TM = SM / UM;
1152
+ constexpr short TN = SN / UN;
1153
+ constexpr short TK = SK / UK;
1154
+
1155
+ const short tm = SM * (simd_gid / WN);
1156
+ const short tn = SN * (simd_gid % WN);
1157
+
1158
+ const short ldb_tgp = BN_padded;
1159
+
1160
+ constexpr bool transpose_a = false;
1161
+ constexpr bool transpose_b = false;
1162
+
1163
+ using AccumType = float;
1164
+
1165
+ using ASubTile = NAXSubTile<T, UM, UK>;
1166
+ using BSubTile = NAXSubTile<T, UK, UN>;
1167
+ using DSubTile = NAXSubTile<AccumType, UM, UN>;
1168
+
1169
+ NAXTile<AccumType, TM, TN, DSubTile> Dtile;
1170
+
1171
+ Dtile.clear();
1172
+
1173
+ x += tm * K;
1174
+
1175
+ for (int k = 0; k < K; k += BK) {
1176
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1177
+ loader_w.load_unsafe();
1178
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1179
+
1180
+ STEEL_PRAGMA_NO_UNROLL
1181
+ for (int kk1 = 0; kk1 < BK; kk1 += SK) {
1182
+ NAXTile<T, TM, TK, ASubTile> Atile;
1183
+ NAXTile<T, TK, TN, BSubTile> Btile;
1184
+
1185
+ volatile int compiler_barrier;
1186
+
1187
+ Atile.load(x + kk1, K);
1188
+ Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * ldb_tgp);
1189
+
1190
+ tile_matmad_nax(
1191
+ Dtile,
1192
+ Atile,
1193
+ metal::bool_constant<transpose_a>{},
1194
+ Btile,
1195
+ metal::bool_constant<transpose_b>{});
1196
+
1197
+ (void)compiler_barrier;
1198
+ }
1199
+
1200
+ x += BK;
1201
+ loader_w.next();
1202
+ }
1203
+
1204
+ // Store results to device memory
1205
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1206
+
1207
+ Dtile.store(y + tm * N + tn, N);
1208
+ }
1209
+
1210
+ template <
1211
+ typename T,
1212
+ const int group_size,
1213
+ const int bits,
1214
+ const bool aligned_N,
1215
+ const bool batched,
1216
+ const int BM = 64,
1217
+ const int BK = 32,
1218
+ const int BN = 64,
1219
+ const int WM = 2,
1220
+ const int WN = 2>
1221
+ [[kernel]] void affine_qmm_t_nax(
1222
+ const device uint32_t* w [[buffer(0)]],
1223
+ const device T* scales [[buffer(1)]],
1224
+ const device T* biases [[buffer(2)]],
1225
+ const device T* x [[buffer(3)]],
1226
+ device T* y [[buffer(4)]],
1227
+ const constant int& K [[buffer(5)]],
1228
+ const constant int& N [[buffer(6)]],
1229
+ const constant int& M [[buffer(7)]],
1230
+ const constant int& x_batch_ndims [[buffer(8)]],
1231
+ const constant int* x_shape [[buffer(9)]],
1232
+ const constant int64_t* x_strides [[buffer(10)]],
1233
+ const constant int& w_batch_ndims [[buffer(11)]],
1234
+ const constant int* w_shape [[buffer(12)]],
1235
+ const constant int64_t* w_strides [[buffer(13)]],
1236
+ const constant int64_t* s_strides [[buffer(14)]],
1237
+ const constant int64_t* b_strides [[buffer(15)]],
1238
+ uint3 tid [[threadgroup_position_in_grid]],
1239
+ uint lid [[thread_index_in_threadgroup]],
1240
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1241
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1242
+ (void)lid;
1243
+
1244
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1245
+
1246
+ threadgroup T Ws[BN * BK_padded];
1247
+
1248
+ if (batched) {
1249
+ adjust_matrix_offsets<T>(
1250
+ x,
1251
+ w,
1252
+ scales,
1253
+ biases,
1254
+ y,
1255
+ M * N,
1256
+ x_batch_ndims,
1257
+ x_shape,
1258
+ x_strides,
1259
+ w_batch_ndims,
1260
+ w_shape,
1261
+ w_strides,
1262
+ s_strides,
1263
+ b_strides,
1264
+ tid);
1265
+ }
1266
+ qmm_t_nax_tgp_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN>(
1267
+ w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1268
+ }
1269
+
1270
+ template <
1271
+ typename T,
1272
+ const int group_size,
1273
+ const int bits,
1274
+ const bool batched,
1275
+ const int BM = 64,
1276
+ const int BK = 64,
1277
+ const int BN = 64,
1278
+ const int WM = 2,
1279
+ const int WN = 2>
1280
+ [[kernel]] void affine_qmm_n_nax(
1281
+ const device uint32_t* w [[buffer(0)]],
1282
+ const device T* scales [[buffer(1)]],
1283
+ const device T* biases [[buffer(2)]],
1284
+ const device T* x [[buffer(3)]],
1285
+ device T* y [[buffer(4)]],
1286
+ const constant int& K [[buffer(5)]],
1287
+ const constant int& N [[buffer(6)]],
1288
+ const constant int& M [[buffer(7)]],
1289
+ const constant int& x_batch_ndims [[buffer(8)]],
1290
+ const constant int* x_shape [[buffer(9)]],
1291
+ const constant int64_t* x_strides [[buffer(10)]],
1292
+ const constant int& w_batch_ndims [[buffer(11)]],
1293
+ const constant int* w_shape [[buffer(12)]],
1294
+ const constant int64_t* w_strides [[buffer(13)]],
1295
+ const constant int64_t* s_strides [[buffer(14)]],
1296
+ const constant int64_t* b_strides [[buffer(15)]],
1297
+ uint3 tid [[threadgroup_position_in_grid]],
1298
+ uint lid [[thread_index_in_threadgroup]],
1299
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1300
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1301
+ (void)lid;
1302
+
1303
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
1304
+
1305
+ threadgroup T Ws[BK * BN_padded];
1306
+
1307
+ if (batched) {
1308
+ adjust_matrix_offsets<T>(
1309
+ x,
1310
+ w,
1311
+ scales,
1312
+ biases,
1313
+ y,
1314
+ M * N,
1315
+ x_batch_ndims,
1316
+ x_shape,
1317
+ x_strides,
1318
+ w_batch_ndims,
1319
+ w_shape,
1320
+ w_strides,
1321
+ s_strides,
1322
+ b_strides,
1323
+ tid);
1324
+ }
1325
+
1326
+ qmm_n_nax_tgp_impl<T, group_size, bits, BM, BK, BN, WM, WN>(
1327
+ w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1328
+ }
1329
+
1330
+ template <
1331
+ typename T,
1332
+ const int group_size,
1333
+ const int bits,
1334
+ const bool aligned_N,
1335
+ const int BM = 64,
1336
+ const int BK = 64,
1337
+ const int BN = 64,
1338
+ const int WM = 2,
1339
+ const int WN = 2>
1340
+ [[kernel]] void affine_gather_qmm_t_nax(
1341
+ const device uint32_t* w [[buffer(0)]],
1342
+ const device T* scales [[buffer(1)]],
1343
+ const device T* biases [[buffer(2)]],
1344
+ const device T* x [[buffer(3)]],
1345
+ const device uint32_t* lhs_indices [[buffer(4)]],
1346
+ const device uint32_t* rhs_indices [[buffer(5)]],
1347
+ device T* y [[buffer(6)]],
1348
+ const constant int& K [[buffer(7)]],
1349
+ const constant int& N [[buffer(8)]],
1350
+ const constant int& M [[buffer(9)]],
1351
+ const constant int& x_batch_ndims [[buffer(10)]],
1352
+ const constant int* x_shape [[buffer(11)]],
1353
+ const constant int64_t* x_strides [[buffer(12)]],
1354
+ const constant int& w_batch_ndims [[buffer(13)]],
1355
+ const constant int* w_shape [[buffer(14)]],
1356
+ const constant int64_t* w_strides [[buffer(15)]],
1357
+ const constant int64_t* s_strides [[buffer(16)]],
1358
+ const constant int64_t* b_strides [[buffer(17)]],
1359
+ const constant int& batch_ndims [[buffer(18)]],
1360
+ const constant int* batch_shape [[buffer(19)]],
1361
+ const constant int64_t* lhs_strides [[buffer(20)]],
1362
+ const constant int64_t* rhs_strides [[buffer(21)]],
1363
+ uint3 tid [[threadgroup_position_in_grid]],
1364
+ uint lid [[thread_index_in_threadgroup]],
1365
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1366
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1367
+ (void)lid;
1368
+
1369
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1370
+
1371
+ threadgroup T Ws[BN * BK_padded];
1372
+
1373
+ adjust_matrix_offsets<T>(
1374
+ x,
1375
+ w,
1376
+ scales,
1377
+ biases,
1378
+ lhs_indices,
1379
+ rhs_indices,
1380
+ y,
1381
+ M * N,
1382
+ batch_ndims,
1383
+ batch_shape,
1384
+ lhs_strides,
1385
+ rhs_strides,
1386
+ x_batch_ndims,
1387
+ x_shape,
1388
+ x_strides,
1389
+ w_batch_ndims,
1390
+ w_shape,
1391
+ w_strides,
1392
+ s_strides,
1393
+ b_strides,
1394
+ tid);
1395
+ qmm_t_nax_tgp_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN>(
1396
+ w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1397
+ }
1398
+
1399
+ template <
1400
+ typename T,
1401
+ const int group_size,
1402
+ const int bits,
1403
+ const int BM = 64,
1404
+ const int BK = 64,
1405
+ const int BN = 64,
1406
+ const int WM = 2,
1407
+ const int WN = 2>
1408
+ [[kernel]] void affine_gather_qmm_n_nax(
1409
+ const device uint32_t* w [[buffer(0)]],
1410
+ const device T* scales [[buffer(1)]],
1411
+ const device T* biases [[buffer(2)]],
1412
+ const device T* x [[buffer(3)]],
1413
+ const device uint32_t* lhs_indices [[buffer(4)]],
1414
+ const device uint32_t* rhs_indices [[buffer(5)]],
1415
+ device T* y [[buffer(6)]],
1416
+ const constant int& K [[buffer(7)]],
1417
+ const constant int& N [[buffer(8)]],
1418
+ const constant int& M [[buffer(9)]],
1419
+ const constant int& x_batch_ndims [[buffer(10)]],
1420
+ const constant int* x_shape [[buffer(11)]],
1421
+ const constant int64_t* x_strides [[buffer(12)]],
1422
+ const constant int& w_batch_ndims [[buffer(13)]],
1423
+ const constant int* w_shape [[buffer(14)]],
1424
+ const constant int64_t* w_strides [[buffer(15)]],
1425
+ const constant int64_t* s_strides [[buffer(16)]],
1426
+ const constant int64_t* b_strides [[buffer(17)]],
1427
+ const constant int& batch_ndims [[buffer(18)]],
1428
+ const constant int* batch_shape [[buffer(19)]],
1429
+ const constant int64_t* lhs_strides [[buffer(20)]],
1430
+ const constant int64_t* rhs_strides [[buffer(21)]],
1431
+ uint3 tid [[threadgroup_position_in_grid]],
1432
+ uint lid [[thread_index_in_threadgroup]],
1433
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
1434
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1435
+ (void)lid;
1436
+
1437
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
1438
+
1439
+ threadgroup T Ws[BK * BN_padded];
1440
+
1441
+ adjust_matrix_offsets<T>(
1442
+ x,
1443
+ w,
1444
+ scales,
1445
+ biases,
1446
+ lhs_indices,
1447
+ rhs_indices,
1448
+ y,
1449
+ M * N,
1450
+ batch_ndims,
1451
+ batch_shape,
1452
+ lhs_strides,
1453
+ rhs_strides,
1454
+ x_batch_ndims,
1455
+ x_shape,
1456
+ x_strides,
1457
+ w_batch_ndims,
1458
+ w_shape,
1459
+ w_strides,
1460
+ s_strides,
1461
+ b_strides,
1462
+ tid);
1463
+ qmm_n_nax_tgp_impl<T, group_size, bits, BM, BK, BN, WM, WN>(
1464
+ w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
1465
+ }
1466
+
1467
+ template <
1468
+ typename T,
1469
+ int group_size,
1470
+ int bits,
1471
+ int BM,
1472
+ int BN,
1473
+ int BK,
1474
+ int WM,
1475
+ int WN,
1476
+ bool transpose>
1477
+ [[kernel]] void affine_gather_qmm_rhs_nax(
1478
+ const device T* x [[buffer(0)]],
1479
+ const device uint32_t* w [[buffer(1)]],
1480
+ const device T* scales [[buffer(2)]],
1481
+ const device T* biases [[buffer(3)]],
1482
+ const device uint32_t* indices [[buffer(4)]],
1483
+ device T* y [[buffer(5)]],
1484
+ const constant int& M [[buffer(6)]],
1485
+ const constant int& N [[buffer(7)]],
1486
+ const constant int& K [[buffer(8)]],
1487
+ uint3 tid [[threadgroup_position_in_grid]],
1488
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
1489
+ uint simd_lane_id [[thread_index_in_simdgroup]]) {
1490
+ constexpr int pack_factor = get_pack_factor<bits, 8>();
1491
+ constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
1492
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
1493
+ constexpr int BN_padded = (BN + 16 / sizeof(T));
1494
+
1495
+ using loader_w_t = QuantizedBlockLoader<
1496
+ T,
1497
+ transpose ? BN : BK,
1498
+ transpose ? BK : BN,
1499
+ transpose ? BK_padded : BN_padded,
1500
+ transpose,
1501
+ WM * WN * SIMD_SIZE,
1502
+ group_size,
1503
+ bits>;
1504
+
1505
+ threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
1506
+
1507
+ // Compute the block
1508
+ const int K_w = K * bytes_per_pack / pack_factor;
1509
+ const int K_g = K / group_size;
1510
+ const int N_w = N * bytes_per_pack / pack_factor;
1511
+ const int N_g = N / group_size;
1512
+ const int K_it = K / BK;
1513
+ const size_t stride_w = transpose ? N * K_w : K * N_w;
1514
+ const size_t stride_s = transpose ? N * K_g : K * N_g;
1515
+ const int y_row = tid.y * BM;
1516
+ const int y_col = tid.x * BN;
1517
+ const size_t y_row_long = size_t(y_row);
1518
+ const size_t y_col_long = size_t(y_col);
1519
+
1520
+ // Prepare threadgroup bounds
1521
+ const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
1522
+ const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
1523
+
1524
+ // Calculate the final tiles in the case that K is not aligned
1525
+ const int k_remain = K - K_it * BK;
1526
+ const short2 tile_w =
1527
+ transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
1528
+
1529
+ // Move x and output to the correct block
1530
+ auto wl = (const device uint8_t*)w;
1531
+ x += y_row_long * K;
1532
+ y += y_row_long * N + y_col_long;
1533
+ wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
1534
+ scales += transpose ? y_col_long * K_g : y_col / group_size;
1535
+ biases += transpose ? y_col_long * K_g : y_col / group_size;
1536
+
1537
+ constexpr short UM = 16;
1538
+ constexpr short UN = 32;
1539
+ constexpr short UK = 16;
1540
+ constexpr short SM = BM / WM;
1541
+ constexpr short SN = BN / WN;
1542
+ constexpr short SK = 32;
1543
+
1544
+ constexpr short TM = SM / UM;
1545
+ constexpr short TN = SN / UN;
1546
+ constexpr short TK = SK / UK;
1547
+
1548
+ const short tm = SM * (simd_group_id / WN);
1549
+ const short tn = SN * (simd_group_id % WN);
1550
+
1551
+ const short sgp_sm =
1552
+ align_M ? SM : min(SM, short(max(0, (M - (y_row + tm)))));
1553
+ const short sgp_sn =
1554
+ align_N ? SN : min(SN, short(max(0, (N - (y_col + tn)))));
1555
+
1556
+ const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
1557
+ const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN);
1558
+
1559
+ constexpr short BR = transpose ? TN : TK;
1560
+ constexpr short BC = transpose ? TK : TN;
1561
+
1562
+ using AccumType = float;
1563
+
1564
+ using ASubTile = NAXSubTile<T, UM, UK>;
1565
+ using BSubTile = NAXSubTile<T, transpose ? UN : UK, transpose ? UK : UN>;
1566
+ using DSubTile = NAXSubTile<AccumType, UM, UN>;
1567
+
1568
+ // Do as many matmuls as necessary
1569
+ uint32_t index;
1570
+ short offset;
1571
+ uint32_t index_next = indices[y_row];
1572
+ short offset_next = 0;
1573
+ int n = 0;
1574
+ while (n < tgp_bm) {
1575
+ n++;
1576
+ offset = offset_next;
1577
+ index = index_next;
1578
+ offset_next = tgp_bm;
1579
+ for (; n < tgp_bm; n++) {
1580
+ if (indices[y_row + n] != index) {
1581
+ offset_next = n;
1582
+ index_next = indices[y_row + n];
1583
+ break;
1584
+ }
1585
+ }
1586
+ threadgroup_barrier(mem_flags::mem_none);
1587
+
1588
+ NAXTile<AccumType, TM, TN, DSubTile> Dtile;
1589
+
1590
+ Dtile.clear();
1591
+
1592
+ const device T* xn = x + tm * K;
1593
+
1594
+ // Prepare threadgroup loading operations
1595
+ thread loader_w_t loader_w(
1596
+ wl + index * stride_w,
1597
+ scales + index * stride_s,
1598
+ biases + index * stride_s,
1599
+ transpose ? K : N,
1600
+ Ws,
1601
+ simd_group_id,
1602
+ simd_lane_id);
1603
+
1604
+ dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
1605
+ dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) {
1606
+ for (int k = 0; k < K_it; k++) {
1607
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1608
+ if constexpr (kAlignedN.value) {
1609
+ loader_w.load_unsafe();
1610
+ } else {
1611
+ loader_w.load_safe(
1612
+ transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
1613
+ }
1614
+
1615
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1616
+
1617
+ STEEL_PRAGMA_NO_UNROLL
1618
+ for (int kk1 = 0; kk1 < BK; kk1 += SK) {
1619
+ NAXTile<T, TM, TK, ASubTile> Atile;
1620
+ NAXTile<T, BR, BC, BSubTile> Btile;
1621
+
1622
+ volatile int compiler_barrier;
1623
+
1624
+ if constexpr (kAlignedM.value) {
1625
+ Atile.load(xn + kk1, K);
1626
+ } else {
1627
+ Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm));
1628
+ }
1629
+
1630
+ if constexpr (transpose) {
1631
+ Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);
1632
+ } else {
1633
+ Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * BN_padded);
1634
+ }
1635
+
1636
+ tile_matmad_nax(
1637
+ Dtile,
1638
+ Atile,
1639
+ metal::bool_constant<false>{},
1640
+ Btile,
1641
+ metal::bool_constant<transpose>{});
1642
+
1643
+ (void)compiler_barrier;
1644
+ }
1645
+
1646
+ xn += BK;
1647
+ loader_w.next();
1648
+ }
1649
+
1650
+ if (!align_K) {
1651
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1652
+ loader_w.load_safe(tile_w);
1653
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1654
+
1655
+ STEEL_PRAGMA_NO_UNROLL
1656
+ for (int kk1 = 0; kk1 < BK; kk1 += SK) {
1657
+ NAXTile<T, TM, TK, ASubTile> Atile;
1658
+ NAXTile<T, BR, BC, BSubTile> Btile;
1659
+
1660
+ volatile int compiler_barrier;
1661
+
1662
+ const short psk = min(int(SK), max(0, (BK - kk1)));
1663
+ Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm));
1664
+
1665
+ if constexpr (transpose) {
1666
+ Btile.template load<T, BK_padded, 1>(Ws + tn * BK_padded + kk1);
1667
+ } else {
1668
+ Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * BN_padded);
1669
+ }
1670
+
1671
+ tile_matmad_nax(
1672
+ Dtile,
1673
+ Atile,
1674
+ metal::bool_constant<false>{},
1675
+ Btile,
1676
+ metal::bool_constant<transpose>{});
1677
+
1678
+ (void)compiler_barrier;
1679
+ }
1680
+ }
1681
+
1682
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1683
+
1684
+ const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm));
1685
+ const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm));
1686
+
1687
+ // Store results to device memory
1688
+ if constexpr (kAlignedN.value) {
1689
+ if (m_lo_lim == 0 && m_hi_lim == SM) {
1690
+ Dtile.store(y + tm * N + tn, N);
1691
+ } else {
1692
+ Dtile.store_slice(
1693
+ y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim));
1694
+ }
1695
+ } else {
1696
+ Dtile.store_slice(
1697
+ y + tm * N + tn,
1698
+ N,
1699
+ short2(0, m_lo_lim),
1700
+ short2(sgp_sn, m_hi_lim));
1701
+ }
1702
+ });
1703
+ });
1704
+ }
1705
+ }