mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,1118 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ #include <algorithm>
3
+ #include <cassert>
4
+ #include <numeric>
5
+
6
+ #include "mlx/backend/gpu/copy.h"
7
+ #include "mlx/backend/metal/device.h"
8
+ #include "mlx/backend/metal/kernels.h"
9
+ #include "mlx/backend/metal/kernels/defines.h"
10
+ #include "mlx/backend/metal/kernels/steel/conv/params.h"
11
+ #include "mlx/backend/metal/matmul.h"
12
+ #include "mlx/backend/metal/utils.h"
13
+ #include "mlx/primitives.h"
14
+ #include "mlx/utils.h"
15
+
16
+ using namespace mlx::steel;
17
+
18
+ namespace mlx::core {
19
+
20
+ namespace {
21
+
22
+ template <int N>
23
+ void explicit_gemm_conv_ND_gpu(
24
+ const Stream& s,
25
+ metal::Device& d,
26
+ const array& in,
27
+ const array& wt,
28
+ array out,
29
+ const MLXConvParams<N>& conv_params) {
30
+ // Get gemm shapes
31
+ int implicit_M = out.size() / conv_params.O;
32
+ int implicit_K = wt.size() / conv_params.O;
33
+ int implicit_N = conv_params.O;
34
+ // Prepare unfolding array
35
+ Shape unfolded_shape{implicit_M, implicit_K};
36
+ array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
37
+
38
+ in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
39
+
40
+ // Prepare unfolding kernel
41
+ std::string kname;
42
+ kname.reserve(32);
43
+ concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
44
+ auto& compute_encoder = d.get_command_encoder(s.index);
45
+ auto kernel = d.get_kernel(kname);
46
+ compute_encoder.set_compute_pipeline_state(kernel);
47
+
48
+ compute_encoder.set_input_array(in, 0);
49
+ compute_encoder.set_output_array(in_unfolded, 1);
50
+
51
+ compute_encoder.set_bytes(conv_params, 2);
52
+
53
+ // Launch unfolding kernel
54
+ size_t tgp_x = std::min(conv_params.C, 64);
55
+ tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
56
+ size_t tgp_y = 256 / tgp_x;
57
+
58
+ MTL::Size grid_dims = MTL::Size(
59
+ conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
60
+ MTL::Size group_dims = MTL::Size(
61
+ std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
62
+
63
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
64
+
65
+ // Reshape weight
66
+ Shape wt_reshape{implicit_K, implicit_N};
67
+ Strides wt_restride{1, implicit_K};
68
+ array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
69
+ auto wt_flags = wt.flags();
70
+ wt_flags.row_contiguous = false;
71
+ wt_flags.col_contiguous = true;
72
+ wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
73
+
74
+ // Perform gemm
75
+ std::vector<array> copies = {in_unfolded};
76
+ return steel_matmul(
77
+ s,
78
+ d,
79
+ /*a = */ in_unfolded,
80
+ /*b = */ wt_reshaped,
81
+ /*c = */ out,
82
+ /*M = */ implicit_M,
83
+ /*N = */ implicit_N,
84
+ /*K = */ implicit_K,
85
+ /*batch_size_out = */ 1,
86
+ /*a_cols = */ implicit_K,
87
+ /*b_cols = */ implicit_K,
88
+ /*a_transposed = */ false,
89
+ /*b_transposed = */ true,
90
+ /*copies = */ copies);
91
+ }
92
+
93
+ template <int N>
94
+ void explicit_gemm_conv_group_ND_gpu(
95
+ const Stream& s,
96
+ metal::Device& d,
97
+ const array& in,
98
+ const array& wt,
99
+ array out,
100
+ const MLXConvParams<N>& conv_params) {
101
+ const int groups = conv_params.groups;
102
+ const int C_per_group = conv_params.C / conv_params.groups;
103
+ const int O_per_group = conv_params.O / conv_params.groups;
104
+ // Get gemm shapes
105
+ const int implicit_M = out.size() / conv_params.O;
106
+ const int implicit_K = wt.size() / conv_params.O;
107
+ const int implicit_N = O_per_group;
108
+
109
+ int kernel_size = 1;
110
+ for (int i = 0; i < N; ++i) {
111
+ kernel_size *= conv_params.wS[i];
112
+ }
113
+
114
+ // Prepare unfolding array
115
+ Shape unfolded_shape{implicit_M, implicit_K * groups};
116
+ array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
117
+ in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
118
+
119
+ // Prepare unfolding kernel
120
+ std::string kname;
121
+ kname.reserve(32);
122
+ concatenate(
123
+ kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
124
+ auto& compute_encoder = d.get_command_encoder(s.index);
125
+ auto kernel = d.get_kernel(kname);
126
+ compute_encoder.set_compute_pipeline_state(kernel);
127
+
128
+ compute_encoder.set_input_array(in, 0);
129
+ compute_encoder.set_output_array(in_unfolded, 1);
130
+
131
+ compute_encoder.set_bytes(conv_params, 2);
132
+
133
+ // Launch unfolding kernel
134
+ size_t tgp_x = std::min(conv_params.C, 64);
135
+ tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
136
+ size_t tgp_y = 256 / tgp_x;
137
+
138
+ MTL::Size grid_dims = MTL::Size(
139
+ conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
140
+ MTL::Size group_dims = MTL::Size(
141
+ std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
142
+
143
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
144
+
145
+ // Transpose kernel weights so that we can slice them by contiguous chunks
146
+ // of channel groups.
147
+ array wt_view(
148
+ {wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
149
+ wt_view.copy_shared_buffer(
150
+ wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
151
+
152
+ // Materialize
153
+ array wt_transpose = contiguous_copy_gpu(wt_view, s);
154
+
155
+ // Perform gemm
156
+ std::vector<array> copies = {in_unfolded, wt_transpose};
157
+ return steel_matmul_regular(
158
+ /* const Stream& s = */ s,
159
+ /* Device& d = */ d,
160
+ /* const array& a = */ in_unfolded,
161
+ /* const array& b = */ wt_transpose,
162
+ /* array& c = */ out,
163
+ /* int M = */ implicit_M,
164
+ /* int N = */ implicit_N,
165
+ /* int K = */ implicit_K,
166
+ /* int batch_size_out = */ groups,
167
+ /* int lda = */ implicit_K * groups,
168
+ /* int ldb = */ implicit_K,
169
+ /* int ldd = */ implicit_N * groups,
170
+ /* bool transpose_a = */ false,
171
+ /* bool transpose_b = */ true,
172
+ /* std::vector<array>& copies = */ copies,
173
+ /* Shape batch_shape = */ {1},
174
+ /* Strides batch_strides = */ {0},
175
+ /* int64_t A_batch_strides = */ int64_t(implicit_K),
176
+ /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
177
+ /* int64_t matrix_stride_out = */ int64_t(implicit_N));
178
+ }
179
+
180
+ void implicit_gemm_conv_2D_gpu(
181
+ const Stream& s,
182
+ metal::Device& d,
183
+ const array& in,
184
+ const array& wt,
185
+ array out,
186
+ const MLXConvParams<2>& conv_params) {
187
+ const int groups = conv_params.groups;
188
+ const int C_per_group = conv_params.C / conv_params.groups;
189
+ const int O_per_group = conv_params.O / conv_params.groups;
190
+
191
+ // Deduce implicit gemm size
192
+ const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
193
+ const int implicit_N = O_per_group;
194
+ const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
195
+
196
+ // Determine block and warp tiles
197
+ int wm = 2, wn = 2;
198
+
199
+ int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
200
+ int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
201
+ int bk = 16;
202
+
203
+ if (implicit_N <= 16) {
204
+ bn = 8;
205
+ wm = 4;
206
+ wn = 1;
207
+ }
208
+
209
+ int tn = (implicit_N + bn - 1) / bn;
210
+ int tm = (implicit_M + bm - 1) / bm;
211
+ int swizzle_log = 0;
212
+
213
+ // Fix small channel specialization
214
+ int n_channel_specialization = 0;
215
+ int channel_k_iters = ((C_per_group + bk - 1) / bk);
216
+ int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
217
+
218
+ if (C_per_group <= 2) {
219
+ gemm_k_iters = (implicit_K + bk - 1) / bk;
220
+ n_channel_specialization = C_per_group;
221
+ } else if (C_per_group <= 4) {
222
+ gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
223
+ n_channel_specialization = C_per_group;
224
+ }
225
+
226
+ bool small_filter = (!n_channel_specialization) &&
227
+ (conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
228
+
229
+ // Fix host side helper params
230
+ int sign = (conv_params.flip ? -1 : 1);
231
+ int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
232
+ int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
233
+
234
+ int inp_jump_w = sign * ijw;
235
+ int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
236
+ int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
237
+ sign * (conv_params.wS[1] - 1) * ijw;
238
+
239
+ // Build implicit gemm params
240
+ ImplicitGemmConv2DParams gemm_params{
241
+ /* const int M = */ implicit_M,
242
+ /* const int N = */ implicit_N,
243
+ /* const int K = */ implicit_K,
244
+
245
+ /* const int gemm_k_iterations = */ gemm_k_iters,
246
+
247
+ /* const int inp_jump_w = */ inp_jump_w,
248
+ /* const int inp_jump_h = */ inp_jump_h,
249
+ /* const int inp_jump_c = */ inp_jump_c,
250
+
251
+ /* const int tiles_n = */ tn,
252
+ /* const int tiles_m = */ tm,
253
+ /* const int swizzle_log = */ swizzle_log};
254
+
255
+ // Determine kernel
256
+ std::string kname;
257
+ kname.reserve(64);
258
+ concatenate(
259
+ kname,
260
+ "implicit_gemm_conv_2d_",
261
+ type_to_name(out),
262
+ "_bm",
263
+ bm,
264
+ "_bn",
265
+ bn,
266
+ "_bk",
267
+ bk,
268
+ "_wm",
269
+ wm,
270
+ "_wn",
271
+ wn,
272
+ "_channel_",
273
+ n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
274
+ "_filter_",
275
+ small_filter ? 's' : 'l');
276
+
277
+ // Encode and dispatch kernel
278
+ auto& compute_encoder = d.get_command_encoder(s.index);
279
+ auto kernel = get_steel_conv_kernel(
280
+ d,
281
+ kname,
282
+ out,
283
+ bm,
284
+ bn,
285
+ bk,
286
+ wm,
287
+ wn,
288
+ n_channel_specialization,
289
+ small_filter);
290
+ compute_encoder.set_compute_pipeline_state(kernel);
291
+
292
+ // Deduce grid launch dimensions
293
+ int tile = 1 << swizzle_log;
294
+ size_t grid_dim_y = (tm + tile - 1) / tile;
295
+ size_t grid_dim_x = tn * tile;
296
+
297
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
298
+ MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
299
+
300
+ // Encode arrays
301
+ compute_encoder.set_input_array(in, 0);
302
+ compute_encoder.set_input_array(wt, 1);
303
+ compute_encoder.set_output_array(out, 2);
304
+
305
+ // Encode params
306
+ compute_encoder.set_bytes(conv_params, 3);
307
+ compute_encoder.set_bytes(gemm_params, 4);
308
+
309
+ // Launch kernel
310
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
311
+ }
312
+
313
+ void implicit_gemm_conv_2D_general_gpu(
314
+ const Stream& s,
315
+ metal::Device& d,
316
+ const array& in,
317
+ const array& wt,
318
+ array out,
319
+ const MLXConvParams<2>& conv_params) {
320
+ // Deduce implicit gemm size
321
+ int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
322
+ int implicit_N = conv_params.O;
323
+ int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
324
+
325
+ // Determine block and warp tiles
326
+ int wm = 2, wn = 2;
327
+
328
+ // Make jump params
329
+ int f_wgt_jump_h =
330
+ std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
331
+ int f_wgt_jump_w =
332
+ std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
333
+
334
+ int f_out_jump_h =
335
+ std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
336
+ int f_out_jump_w =
337
+ std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
338
+
339
+ int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
340
+ int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
341
+ int adj_out_hw = adj_out_h * adj_out_w;
342
+ int adj_implicit_m = conv_params.N * adj_out_hw;
343
+
344
+ Conv2DGeneralJumpParams jump_params{
345
+ /* const int f_wgt_jump_h = */ f_wgt_jump_h,
346
+ /* const int f_wgt_jump_w = */ f_wgt_jump_w,
347
+
348
+ /* const int f_out_jump_h = */ f_out_jump_h,
349
+ /* const int f_out_jump_w = */ f_out_jump_w,
350
+
351
+ /* const int adj_out_h = */ adj_out_h,
352
+ /* const int adj_out_w = */ adj_out_w,
353
+ /* const int adj_out_hw = */ adj_out_hw,
354
+ /* const int adj_implicit_m = */ adj_implicit_m};
355
+
356
+ // Make base info
357
+ std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
358
+ std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
359
+
360
+ int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
361
+ int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
362
+
363
+ int init_h =
364
+ (conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
365
+ int init_w =
366
+ (conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
367
+
368
+ for (int i = 0; i < f_out_jump_h; ++i) {
369
+ int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
370
+
371
+ int wh_base = 0;
372
+ while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
373
+ wh_base++;
374
+ ih_loop += jump_h;
375
+ }
376
+
377
+ int wh_size =
378
+ ((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
379
+ base_h[i] = {wh_base, wh_size};
380
+ }
381
+
382
+ for (int j = 0; j < f_out_jump_w; ++j) {
383
+ int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
384
+
385
+ int ww_base = 0;
386
+ while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
387
+ ww_base++;
388
+ iw_loop += jump_w;
389
+ }
390
+
391
+ int ww_size =
392
+ ((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
393
+ base_w[j] = {ww_base, ww_size};
394
+ }
395
+
396
+ // Collect block sizes
397
+ int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
398
+ int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
399
+ int bk = 16;
400
+
401
+ int tn = (implicit_N + bn - 1) / bn;
402
+ int tm = (adj_implicit_m + bm - 1) / bm;
403
+ int swizzle_log = 0;
404
+
405
+ // Get channel iteration info
406
+ int channel_k_iters = ((conv_params.C + bk - 1) / bk);
407
+ int gemm_k_iters = channel_k_iters;
408
+ bool align_C = conv_params.C % bk == 0;
409
+
410
+ // Fix host side helper params
411
+ int sign = (conv_params.flip ? -1 : 1);
412
+ int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
413
+ int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
414
+
415
+ int inp_jump_w = sign * ijw;
416
+ int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
417
+ int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
418
+ sign * (conv_params.wS[1] - 1) * ijw;
419
+
420
+ // Build implicit gemm params
421
+ ImplicitGemmConv2DParams gemm_params{
422
+ /* const int M = */ implicit_M,
423
+ /* const int N = */ implicit_N,
424
+ /* const int K = */ implicit_K,
425
+
426
+ /* const int gemm_k_iterations = */ gemm_k_iters,
427
+
428
+ /* const int inp_jump_w = */ inp_jump_w,
429
+ /* const int inp_jump_h = */ inp_jump_h,
430
+ /* const int inp_jump_c = */ inp_jump_c,
431
+
432
+ /* const int tiles_n = */ tn,
433
+ /* const int tiles_m = */ tm,
434
+ /* const int swizzle_log = */ swizzle_log};
435
+
436
+ // Determine kernel
437
+ std::string kname;
438
+ kname.reserve(64);
439
+ concatenate(
440
+ kname,
441
+ "implicit_gemm_conv_2d_general_",
442
+ type_to_name(out),
443
+ "_bm",
444
+ bm,
445
+ "_bn",
446
+ bn,
447
+ "_bk",
448
+ bk,
449
+ "_wm",
450
+ wm,
451
+ "_wn",
452
+ wn);
453
+ std::string hash_name;
454
+ hash_name.reserve(64);
455
+ concatenate(hash_name, kname, "_alC_", align_C);
456
+ metal::MTLFCList func_consts = {
457
+ {&align_C, MTL::DataType::DataTypeBool, 200},
458
+ };
459
+
460
+ // Encode and dispatch kernel
461
+ auto& compute_encoder = d.get_command_encoder(s.index);
462
+ auto kernel = get_steel_conv_general_kernel(
463
+ d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
464
+ compute_encoder.set_compute_pipeline_state(kernel);
465
+
466
+ // Deduce grid launch dimensions
467
+ int tile = 1 << swizzle_log;
468
+ size_t grid_dim_y = (tm + tile - 1) / tile;
469
+ size_t grid_dim_x = tn * tile;
470
+ size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
471
+
472
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
473
+ MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
474
+
475
+ // Encode arrays
476
+ compute_encoder.set_input_array(in, 0);
477
+ compute_encoder.set_input_array(wt, 1);
478
+ compute_encoder.set_output_array(out, 2);
479
+
480
+ // Encode params
481
+ compute_encoder.set_bytes(conv_params, 3);
482
+ compute_encoder.set_bytes(gemm_params, 4);
483
+ compute_encoder.set_bytes(jump_params, 5);
484
+
485
+ compute_encoder.set_vector_bytes(base_h, 6);
486
+ compute_encoder.set_vector_bytes(base_w, 7);
487
+
488
+ // Launch kernel
489
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
490
+ }
491
+
492
+ void winograd_conv_2D_gpu(
493
+ const Stream& s,
494
+ metal::Device& d,
495
+ const array& in,
496
+ const array& wt,
497
+ array out,
498
+ const MLXConvParams<2>& conv_params,
499
+ std::vector<array>& copies_w) {
500
+ Shape padded_shape = {
501
+ conv_params.N,
502
+ conv_params.iS[0] + 2 * conv_params.pad[0],
503
+ conv_params.iS[1] + 2 * conv_params.pad[1],
504
+ conv_params.C};
505
+
506
+ padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
507
+ padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
508
+
509
+ array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
510
+
511
+ // Fill with zeros
512
+ array zero_arr = array(0, in.dtype());
513
+ fill_gpu(zero_arr, in_padded, s);
514
+ copies_w.push_back(zero_arr);
515
+
516
+ // Pick input slice from padded
517
+ size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
518
+ conv_params.pad[1] * in_padded.strides()[2];
519
+ array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
520
+ in_padded_slice.copy_shared_buffer(
521
+ in_padded,
522
+ in_padded.strides(),
523
+ in_padded.flags(),
524
+ in_padded_slice.size(),
525
+ data_offset);
526
+
527
+ // Copy input values into the slice
528
+ copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
529
+
530
+ copies_w.push_back(in_padded_slice);
531
+ copies_w.push_back(in_padded);
532
+
533
+ MLXConvParams<2> conv_params_updated{
534
+ /* const int N = */ static_cast<int>(in_padded.shape(0)),
535
+ /* const int C = */ static_cast<int>(in_padded.shape(3)),
536
+ /* const int O = */ static_cast<int>(wt.shape(0)),
537
+ /* const int iS[NDIM] = */
538
+ {static_cast<int>(in_padded.shape(1)),
539
+ static_cast<int>(in_padded.shape(2))},
540
+ /* const int wS[NDIM] = */
541
+ {static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
542
+ /* const int oS[NDIM] = */
543
+ {static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
544
+ /* const int str[NDIM] = */ {1, 1},
545
+ /* const int pad[NDIM] = */ {0, 0},
546
+ /* const int kdil[NDIM] = */ {1, 1},
547
+ /* const int idil[NDIM] = */ {1, 1},
548
+ /* const size_t in_strides[NDIM + 2] = */
549
+ {in_padded.strides()[0],
550
+ in_padded.strides()[1],
551
+ in_padded.strides()[2],
552
+ in_padded.strides()[3]},
553
+ /* const size_t wt_strides[NDIM + 2] = */
554
+ {wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
555
+ /* const size_t out_strides[NDIM + 2] = */
556
+ {out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
557
+ /* const int groups = */ 1,
558
+ /* const bool flip = */ false,
559
+ };
560
+
561
+ int O_c = conv_params.O;
562
+ int C_c = conv_params.C;
563
+
564
+ int N_tiles_n = conv_params.N;
565
+ int N_tiles_h = (conv_params.oS[0] + 5) / 6;
566
+ int N_tiles_w = (conv_params.oS[1] + 5) / 6;
567
+ int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
568
+
569
+ // Do filter transform
570
+ Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
571
+ array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
572
+ filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
573
+ copies_w.push_back(filt_wg);
574
+ {
575
+ int bc = 32;
576
+ int bo = 4;
577
+ std::string kname;
578
+ kname.reserve(32);
579
+ concatenate(
580
+ kname,
581
+ "winograd_conv_2d_weight_transform_",
582
+ type_to_name(out),
583
+ "_bc",
584
+ bc);
585
+ auto& compute_encoder = d.get_command_encoder(s.index);
586
+ auto kernel = d.get_kernel(kname);
587
+ compute_encoder.set_compute_pipeline_state(kernel);
588
+
589
+ compute_encoder.set_input_array(wt, 0);
590
+ compute_encoder.set_output_array(filt_wg, 1);
591
+
592
+ compute_encoder.set_bytes(C_c, 2);
593
+ compute_encoder.set_bytes(O_c, 3);
594
+
595
+ MTL::Size group_dims = MTL::Size(32, bo, 1);
596
+ MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
597
+
598
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
599
+ }
600
+
601
+ // Do input transform
602
+ Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
603
+ array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
604
+ inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
605
+ copies_w.push_back(inp_wg);
606
+ {
607
+ int bc = 32;
608
+ int wm = 2;
609
+ int wn = 2;
610
+ std::string kname;
611
+ kname.reserve(32);
612
+ concatenate(
613
+ kname,
614
+ "winograd_conv_2d_input_transform_",
615
+ type_to_name(out),
616
+ "_bc",
617
+ bc);
618
+ auto& compute_encoder = d.get_command_encoder(s.index);
619
+ auto kernel = d.get_kernel(kname);
620
+ compute_encoder.set_compute_pipeline_state(kernel);
621
+
622
+ compute_encoder.set_input_array(in_padded, 0);
623
+ compute_encoder.set_output_array(inp_wg, 1);
624
+
625
+ compute_encoder.set_bytes(conv_params_updated, 2);
626
+
627
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
628
+ MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
629
+
630
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
631
+ }
632
+
633
+ // Do batched gemm
634
+ Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
635
+ array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
636
+ out_wg.set_data(allocator::malloc(out_wg.nbytes()));
637
+ copies_w.push_back(out_wg);
638
+ {
639
+ std::vector<array> empty_copies;
640
+ steel_matmul(
641
+ s,
642
+ d,
643
+ /*a = */ inp_wg,
644
+ /*b = */ filt_wg,
645
+ /*c = */ out_wg,
646
+ /*M = */ N_tiles,
647
+ /*N = */ conv_params.O,
648
+ /*K = */ conv_params.C,
649
+ /*batch_size_out = */ 8 * 8,
650
+ /*a_cols = */ conv_params.C,
651
+ /*b_cols = */ conv_params.O,
652
+ /*a_transposed = */ false,
653
+ /*b_transposed = */ false,
654
+ /*copies = */ empty_copies);
655
+ }
656
+
657
+ // Do output transform
658
+ {
659
+ int bc = 32;
660
+ int wm = 2;
661
+ int wn = 2;
662
+ std::string kname;
663
+ kname.reserve(32);
664
+ concatenate(
665
+ kname,
666
+ "winograd_conv_2d_output_transform_",
667
+ type_to_name(out),
668
+ "_bo",
669
+ bc);
670
+ auto& compute_encoder = d.get_command_encoder(s.index);
671
+ auto kernel = d.get_kernel(kname);
672
+ compute_encoder.set_compute_pipeline_state(kernel);
673
+
674
+ compute_encoder.set_input_array(out_wg, 0);
675
+ compute_encoder.set_output_array(out, 1);
676
+
677
+ compute_encoder.set_bytes(conv_params_updated, 2);
678
+
679
+ MTL::Size group_dims = MTL::Size(32, wn, wm);
680
+ MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
681
+
682
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
683
+ }
684
+ }
685
+
686
+ void depthwise_conv_2D_gpu(
687
+ const Stream& s,
688
+ metal::Device& d,
689
+ const array& in,
690
+ const array& wt,
691
+ array out,
692
+ const MLXConvParams<2>& conv_params) {
693
+ std::string base_name;
694
+ base_name.reserve(32);
695
+ concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
696
+
697
+ const int N = conv_params.N;
698
+ const int ker_h = conv_params.wS[0];
699
+ const int ker_w = conv_params.wS[1];
700
+ const int str_h = conv_params.str[0];
701
+ const int str_w = conv_params.str[1];
702
+ const int tc = 8;
703
+ const int tw = 8;
704
+ const int th = 4;
705
+ const bool do_flip = conv_params.flip;
706
+
707
+ metal::MTLFCList func_consts = {
708
+ {&ker_h, MTL::DataType::DataTypeInt, 00},
709
+ {&ker_w, MTL::DataType::DataTypeInt, 01},
710
+ {&str_h, MTL::DataType::DataTypeInt, 10},
711
+ {&str_w, MTL::DataType::DataTypeInt, 11},
712
+ {&th, MTL::DataType::DataTypeInt, 100},
713
+ {&tw, MTL::DataType::DataTypeInt, 101},
714
+ {&do_flip, MTL::DataType::DataTypeBool, 200},
715
+ };
716
+
717
+ // clang-format off
718
+ std::string hash_name;
719
+ hash_name.reserve(64);
720
+ concatenate(
721
+ hash_name,
722
+ base_name,
723
+ "_ker_h_", ker_h,
724
+ "_ker_w_", ker_w,
725
+ "_str_h_", str_h,
726
+ "_str_w_", str_w,
727
+ "_tgp_h_", th,
728
+ "_tgp_w_", tw,
729
+ "_do_flip_", do_flip ? 't' : 'n'); // clang-format on
730
+
731
+ auto& compute_encoder = d.get_command_encoder(s.index);
732
+ auto kernel = d.get_kernel(base_name, hash_name, func_consts);
733
+ compute_encoder.set_compute_pipeline_state(kernel);
734
+
735
+ compute_encoder.set_input_array(in, 0);
736
+ compute_encoder.set_input_array(wt, 1);
737
+ compute_encoder.set_output_array(out, 2);
738
+
739
+ compute_encoder.set_bytes(conv_params, 3);
740
+
741
+ MTL::Size group_dims = MTL::Size(tc, tw, th);
742
+ MTL::Size grid_dims = MTL::Size(
743
+ conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
744
+
745
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
746
+ }
747
+
748
+ void dispatch_conv_2D_gpu(
749
+ const Stream& s,
750
+ metal::Device& d,
751
+ const array& in,
752
+ const array& wt,
753
+ array out,
754
+ const MLXConvParams<2>& conv_params,
755
+ std::vector<array>& copies) {
756
+ bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
757
+ bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
758
+ bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
759
+
760
+ if (is_idil_one && conv_params.groups > 1) {
761
+ const int C_per_group = conv_params.C / conv_params.groups;
762
+ const int O_per_group = conv_params.O / conv_params.groups;
763
+
764
+ if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
765
+ conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
766
+ conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
767
+ conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
768
+ conv_params.wt_strides[1] == conv_params.wS[1] &&
769
+ conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
770
+ return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
771
+ }
772
+
773
+ if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
774
+ (O_per_group <= 16 || O_per_group % 16 == 0)) {
775
+ return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
776
+ } else {
777
+ return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
778
+ }
779
+ }
780
+
781
+ // Direct to winograd conv
782
+ bool inp_large =
783
+ (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
784
+ bool channels_large = (conv_params.C + conv_params.O) >= 256;
785
+ bool out_large =
786
+ (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
787
+ if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
788
+ conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
789
+ conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
790
+ channels_large) {
791
+ return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
792
+ }
793
+
794
+ // Direct to implicit gemm conv
795
+ if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
796
+ (conv_params.O <= 16 || conv_params.O % 16 == 0)) {
797
+ return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
798
+ }
799
+
800
+ else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
801
+ return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
802
+ }
803
+
804
+ // Direct to explicit gemm conv
805
+ else {
806
+ return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
807
+ }
808
+ }
809
+
810
+ void depthwise_conv_1D_gpu(
811
+ const Stream& s,
812
+ metal::Device& d,
813
+ const array& in,
814
+ array wt,
815
+ array out) {
816
+ bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
817
+ std::string base_name;
818
+ base_name.reserve(32);
819
+ concatenate(
820
+ base_name,
821
+ "depthwise_conv_1d_",
822
+ large ? "_large" : "",
823
+ type_to_name(out));
824
+
825
+ if (!wt.flags().row_contiguous) {
826
+ wt = contiguous_copy_gpu(wt, s);
827
+ d.add_temporary(wt, s.index);
828
+ }
829
+ auto& compute_encoder = d.get_command_encoder(s.index);
830
+ auto kernel = d.get_kernel(base_name);
831
+ compute_encoder.set_compute_pipeline_state(kernel);
832
+
833
+ auto B = in.shape(0);
834
+ auto Tout = out.shape(1);
835
+ auto D = in.shape(2);
836
+ auto K = wt.shape(1);
837
+
838
+ compute_encoder.set_input_array(in, 0);
839
+ compute_encoder.set_input_array(wt, 1);
840
+ compute_encoder.set_output_array(out, 2);
841
+ if (large) {
842
+ int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
843
+ compute_encoder.set_bytes(strides, 3, 3);
844
+
845
+ } else {
846
+ int strides[3] = {
847
+ static_cast<int>(in.strides(0)),
848
+ static_cast<int>(in.strides(1)),
849
+ static_cast<int>(in.strides(2))};
850
+ compute_encoder.set_bytes(strides, 3, 3);
851
+ }
852
+
853
+ compute_encoder.set_bytes(K, 4);
854
+ auto group_dims = get_block_dims(D, Tout, B);
855
+ MTL::Size grid_dims = MTL::Size(D, Tout, B);
856
+
857
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
858
+ }
859
+
860
+ void conv_1D_gpu(
861
+ const Stream& s,
862
+ metal::Device& d,
863
+ const array& in,
864
+ const array& wt,
865
+ array out,
866
+ const std::vector<int>& padding,
867
+ const std::vector<int>& wt_strides,
868
+ const std::vector<int>& wt_dilation,
869
+ const std::vector<int>& in_dilation,
870
+ int groups,
871
+ bool flip,
872
+ std::vector<array>& copies) {
873
+ bool is_idil_one = in_dilation[0] == 1;
874
+ int C = in.shape(2);
875
+ int O = wt.shape(0);
876
+ // Fast path for fully separable 1D convolution
877
+ if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
878
+ wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
879
+ depthwise_conv_1D_gpu(s, d, in, wt, out);
880
+ return;
881
+ }
882
+
883
+ const int C_per_group = C / groups;
884
+ const int O_per_group = O / groups;
885
+
886
+ // Direct to implicit gemm conv
887
+ if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
888
+ (O_per_group <= 16 || O_per_group % 16 == 0)) {
889
+ MLXConvParams<2> conv_params{
890
+ /* const int N = */ static_cast<int>(in.shape(0)),
891
+ /* const int C = */ C,
892
+ /* const int O = */ O,
893
+ /* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
894
+ /* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
895
+ /* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
896
+ /* const int str[NDIM] = */ {wt_strides[0], 1},
897
+ /* const int pad[NDIM] = */ {padding[0], 0},
898
+ /* const int kdil[NDIM] = */ {wt_dilation[0], 1},
899
+ /* const int idil[NDIM] = */ {in_dilation[0], 1},
900
+ /* const size_t in_strides[NDIM + 2] = */
901
+ {in.strides()[0], in.strides()[1], 0, in.strides()[2]},
902
+ /* const size_t wt_strides[NDIM + 2] = */
903
+ {wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
904
+ /* const size_t out_strides[NDIM + 2] = */
905
+ {out.strides()[0], out.strides()[1], 0, out.strides()[2]},
906
+ /* const int groups = */ groups,
907
+ /* const bool flip = */ flip};
908
+
909
+ dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
910
+ return;
911
+ }
912
+
913
+ // Make conv params
914
+ MLXConvParams<1> conv_params{
915
+ /* const int N = */ static_cast<int>(in.shape(0)),
916
+ /* const int C = */ static_cast<int>(in.shape(2)),
917
+ /* const int O = */ static_cast<int>(wt.shape(0)),
918
+ /* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
919
+ /* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
920
+ /* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
921
+ /* const int str[NDIM] = */ {wt_strides[0]},
922
+ /* const int pad[NDIM] = */ {padding[0]},
923
+ /* const int kdil[NDIM] = */ {wt_dilation[0]},
924
+ /* const int idil[NDIM] = */ {in_dilation[0]},
925
+ /* const size_t in_strides[NDIM + 2] = */
926
+ {in.strides()[0], in.strides()[1], in.strides()[2]},
927
+ /* const size_t wt_strides[NDIM + 2] = */
928
+ {wt.strides()[0], wt.strides()[1], wt.strides()[2]},
929
+ /* const size_t out_strides[NDIM + 2] = */
930
+ {out.strides()[0], out.strides()[1], out.strides()[2]},
931
+ /* const int groups = */ groups,
932
+ /* const bool flip = */ flip};
933
+
934
+ // Direct to explicit gemm conv
935
+ if (groups > 1) {
936
+ return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
937
+ } else {
938
+ return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
939
+ }
940
+ }
941
+
942
+ void conv_2D_gpu(
943
+ const Stream& s,
944
+ metal::Device& d,
945
+ const array& in,
946
+ const array& wt,
947
+ array out,
948
+ const std::vector<int>& padding,
949
+ const std::vector<int>& wt_strides,
950
+ const std::vector<int>& wt_dilation,
951
+ const std::vector<int>& in_dilation,
952
+ const int groups,
953
+ bool flip,
954
+ std::vector<array>& copies) {
955
+ // Make conv params
956
+ MLXConvParams<2> conv_params{
957
+ /* const int N = */ static_cast<int>(in.shape(0)),
958
+ /* const int C = */ static_cast<int>(in.shape(3)),
959
+ /* const int O = */ static_cast<int>(wt.shape(0)),
960
+ /* const int iS[NDIM] = */
961
+ {static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
962
+ /* const int wS[NDIM] = */
963
+ {static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
964
+ /* const int oS[NDIM] = */
965
+ {static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
966
+ /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
967
+ /* const int pad[NDIM] = */ {padding[0], padding[1]},
968
+ /* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
969
+ /* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
970
+ /* const size_t in_strides[NDIM + 2] = */
971
+ {in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
972
+ /* const size_t wt_strides[NDIM + 2] = */
973
+ {wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
974
+ /* const size_t out_strides[NDIM + 2] = */
975
+ {out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
976
+ /* const int groups = */ groups,
977
+ /* const bool flip = */ flip,
978
+ };
979
+ dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
980
+ }
981
+
982
+ void conv_3D_gpu(
983
+ const Stream& s,
984
+ metal::Device& d,
985
+ const array& in,
986
+ const array& wt,
987
+ array out,
988
+ const std::vector<int>& padding,
989
+ const std::vector<int>& wt_strides,
990
+ const std::vector<int>& wt_dilation,
991
+ const std::vector<int>& in_dilation,
992
+ bool flip,
993
+ std::vector<array>& copies) {
994
+ // Make conv params
995
+ MLXConvParams<3> conv_params{
996
+ /* const int N = */ static_cast<int>(in.shape(0)),
997
+ /* const int C = */ static_cast<int>(in.shape(4)),
998
+ /* const int O = */ static_cast<int>(wt.shape(0)),
999
+ /* const int iS[NDIM] = */
1000
+ {static_cast<int>(in.shape(1)),
1001
+ static_cast<int>(in.shape(2)),
1002
+ static_cast<int>(in.shape(3))},
1003
+ /* const int wS[NDIM] = */
1004
+ {static_cast<int>(wt.shape(1)),
1005
+ static_cast<int>(wt.shape(2)),
1006
+ static_cast<int>(wt.shape(3))},
1007
+ /* const int oS[NDIM] = */
1008
+ {static_cast<int>(out.shape(1)),
1009
+ static_cast<int>(out.shape(2)),
1010
+ static_cast<int>(out.shape(3))},
1011
+ /* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
1012
+ /* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
1013
+ /* const int kdil[NDIM] = */
1014
+ {wt_dilation[0], wt_dilation[1], wt_dilation[2]},
1015
+ /* const int idil[NDIM] = */
1016
+ {in_dilation[0], in_dilation[1], in_dilation[2]},
1017
+ /* const size_t in_strides[NDIM + 2] = */
1018
+ {in.strides()[0],
1019
+ in.strides()[1],
1020
+ in.strides()[2],
1021
+ in.strides()[3],
1022
+ in.strides()[4]},
1023
+ /* const size_t wt_strides[NDIM + 2] = */
1024
+ {wt.strides()[0],
1025
+ wt.strides()[1],
1026
+ wt.strides()[2],
1027
+ wt.strides()[3],
1028
+ wt.strides()[4]},
1029
+ /* const size_t out_strides[NDIM + 2] = */
1030
+ {out.strides()[0],
1031
+ out.strides()[1],
1032
+ out.strides()[2],
1033
+ out.strides()[3],
1034
+ out.strides()[4]},
1035
+ /* const int groups = */ 1,
1036
+ /* const bool flip = */ flip,
1037
+ };
1038
+ return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
1039
+ }
1040
+
1041
+ } // namespace
1042
+
1043
+ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
1044
+ out.set_data(allocator::malloc(out.nbytes()));
1045
+ auto& s = stream();
1046
+ auto& d = metal::device(s.device);
1047
+
1048
+ // Ensure contiguity
1049
+ std::vector<array> copies;
1050
+ auto in = inputs[0];
1051
+ auto wt = inputs[1];
1052
+ if (!in.flags().row_contiguous) {
1053
+ in = contiguous_copy_gpu(in, s);
1054
+ copies.push_back(in);
1055
+ }
1056
+ if (!wt.flags().row_contiguous) {
1057
+ wt = contiguous_copy_gpu(wt, s);
1058
+ copies.push_back(wt);
1059
+ }
1060
+
1061
+ // 3D conv
1062
+ if (out.ndim() == 5) {
1063
+ conv_3D_gpu(
1064
+ s,
1065
+ d,
1066
+ in,
1067
+ wt,
1068
+ out,
1069
+ padding_lo_,
1070
+ kernel_strides_,
1071
+ kernel_dilation_,
1072
+ input_dilation_,
1073
+ flip_,
1074
+ copies);
1075
+ }
1076
+ // 2D conv
1077
+ else if (out.ndim() == 4) {
1078
+ conv_2D_gpu(
1079
+ s,
1080
+ d,
1081
+ in,
1082
+ wt,
1083
+ out,
1084
+ padding_lo_,
1085
+ kernel_strides_,
1086
+ kernel_dilation_,
1087
+ input_dilation_,
1088
+ groups_,
1089
+ flip_,
1090
+ copies);
1091
+ }
1092
+ // 1D conv
1093
+ else if (out.ndim() == 3) {
1094
+ conv_1D_gpu(
1095
+ s,
1096
+ d,
1097
+ in,
1098
+ wt,
1099
+ out,
1100
+ padding_lo_,
1101
+ kernel_strides_,
1102
+ kernel_dilation_,
1103
+ input_dilation_,
1104
+ groups_,
1105
+ flip_,
1106
+ copies);
1107
+ }
1108
+ // Throw error
1109
+ else {
1110
+ throw std::invalid_argument(
1111
+ "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
1112
+ }
1113
+
1114
+ // Record copies
1115
+ d.add_temporaries(std::move(copies), s.index);
1116
+ }
1117
+
1118
+ } // namespace mlx::core