mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,1351 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <cassert>
4
+ #include <numeric>
5
+
6
+ #include "mlx/backend/cpu/copy.h"
7
+ #include "mlx/backend/cpu/encoder.h"
8
+ #include "mlx/backend/cpu/lapack.h"
9
+ #include "mlx/primitives.h"
10
+ #include "mlx/utils.h"
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace {
15
+
16
+ ///////////////////////////////////////////////////////////////////////////////
17
+ // Naive reference conv
18
+ ///////////////////////////////////////////////////////////////////////////////
19
+
20
+ template <typename T>
21
+ void slow_conv_1D(
22
+ const array& in,
23
+ const array& wt,
24
+ array out,
25
+ const std::vector<int>& padding_lo,
26
+ const std::vector<int>& padding_hi,
27
+ const std::vector<int>& wt_strides,
28
+ const std::vector<int>& wt_dilation,
29
+ const std::vector<int>& in_dilation,
30
+ bool flip,
31
+ Stream stream) {
32
+ auto& encoder = cpu::get_command_encoder(stream);
33
+ encoder.set_input_array(in);
34
+ encoder.set_input_array(wt);
35
+ encoder.set_output_array(out);
36
+
37
+ encoder.dispatch([start_wt_ptr = wt.data<T>(),
38
+ in_ptr = in.data<T>(),
39
+ out_ptr = out.data<T>(),
40
+
41
+ N = in.shape(
42
+ 0), // Batch size, should be the same as out.shape(0)
43
+ iH = 1 +
44
+ in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
45
+ oH = out.shape(1), // Output spatial dim
46
+ wH = wt.shape(1), // Weight spatial dim
47
+ groups = in.shape(2) / wt.shape(2),
48
+ O = wt.shape(0), // Out channels
49
+ C_per_group = wt.shape(2),
50
+
51
+ in_stride_N = in.strides()[0],
52
+ in_stride_H = in.strides()[1],
53
+ in_stride_C = in.strides()[2],
54
+
55
+ wt_stride_O = wt.strides()[0],
56
+ wt_stride_H = wt.strides()[1],
57
+ wt_stride_C = wt.strides()[2],
58
+
59
+ out_stride_N = out.strides()[0],
60
+ out_stride_H = out.strides()[1],
61
+ out_stride_O = out.strides()[2],
62
+
63
+ flip,
64
+ padding_lo = padding_lo[0],
65
+ padding_hi = padding_hi[0],
66
+ wt_stride = wt_strides[0],
67
+ wt_dilation = wt_dilation[0],
68
+ in_dilation = in_dilation[0]]() mutable {
69
+ auto O_per_group = O / groups;
70
+
71
+ for (int n = 0; n < N; ++n) {
72
+ for (int oh = 0; oh < oH; ++oh) {
73
+ for (int g = 0; g < groups; ++g) {
74
+ for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
75
+ const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
76
+ float r = 0.;
77
+
78
+ for (int wh = 0; wh < wH; ++wh) {
79
+ const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;
80
+
81
+ int wh_flip = flip ? (wH - wh - 1) : wh;
82
+ int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;
83
+
84
+ auto ih_div = std::div(ih, in_dilation);
85
+
86
+ if (ih >= 0 && ih < iH && ih_div.rem == 0) {
87
+ for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
88
+ r +=
89
+ static_cast<float>(
90
+ in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
91
+ static_cast<float>(
92
+ wt_ptr[(c % C_per_group) * wt_stride_C]);
93
+ } // c
94
+
95
+ } // ih check
96
+ } // wh
97
+
98
+ out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
99
+ } // o
100
+ } // g
101
+ } // oh
102
+
103
+ in_ptr += in_stride_N;
104
+ out_ptr += out_stride_N;
105
+ } // n
106
+ });
107
+ }
108
+
109
+ template <typename T>
110
+ void slow_conv_2D(
111
+ const array& in,
112
+ const array& wt,
113
+ array out,
114
+ const std::vector<int>& padding_lo,
115
+ const std::vector<int>& padding_hi,
116
+ const std::vector<int>& wt_strides,
117
+ const std::vector<int>& wt_dilation,
118
+ const std::vector<int>& in_dilation,
119
+ bool flip,
120
+ Stream stream) {
121
+ auto& encoder = cpu::get_command_encoder(stream);
122
+ encoder.set_input_array(in);
123
+ encoder.set_input_array(wt);
124
+ encoder.set_output_array(out);
125
+
126
+ encoder.dispatch(
127
+ [st_wt_ptr = wt.data<T>(),
128
+ st_in_ptr = in.data<T>(),
129
+ st_out_ptr = out.data<T>(),
130
+
131
+ N = in.shape(0), // Batch size, should be the same as out.shape(0)
132
+ iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
133
+ iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
134
+ C = in.shape(3), // In channels
135
+ oH = out.shape(1), // Output spatial dim
136
+ oW = out.shape(2), // Output spatial dim
137
+ O = wt.shape(0), // Out channels
138
+ wH = wt.shape(1), // Weight spatial dim
139
+ wW = wt.shape(2), // Weight spatial dim
140
+
141
+ groups = in.shape(3) / wt.shape(3),
142
+ C_per_group = wt.shape(3),
143
+
144
+ in_stride_N = in.strides()[0],
145
+ in_stride_H = in.strides()[1],
146
+ in_stride_W = in.strides()[2],
147
+ in_stride_C = in.strides()[3],
148
+
149
+ wt_stride_O = wt.strides()[0],
150
+ wt_stride_H = wt.strides()[1],
151
+ wt_stride_W = wt.strides()[2],
152
+ wt_stride_C = wt.strides()[3],
153
+
154
+ out_stride_N = out.strides()[0],
155
+ out_stride_H = out.strides()[1],
156
+ out_stride_W = out.strides()[2],
157
+ out_stride_O = out.strides()[3],
158
+
159
+ padding_lo,
160
+ padding_hi,
161
+ wt_strides,
162
+ wt_dilation,
163
+ in_dilation,
164
+ flip]() mutable {
165
+ bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;
166
+
167
+ const int O_per_group = O / groups;
168
+ auto pt_conv_no_checks =
169
+ [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
170
+ out_ptr += oh * out_stride_H + ow * out_stride_W;
171
+ int ih_base = oh * wt_strides[0] - padding_lo[0];
172
+ int iw_base = ow * wt_strides[1] - padding_lo[1];
173
+
174
+ for (int g = 0; g < groups; ++g) {
175
+ for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
176
+ float r = 0.;
177
+
178
+ for (int wh = 0; wh < wH; ++wh) {
179
+ for (int ww = 0; ww < wW; ++ww) {
180
+ int wh_flip = flip ? wH - wh - 1 : wh;
181
+ int ww_flip = flip ? wW - ww - 1 : ww;
182
+ int ih = ih_base + wh_flip * wt_dilation[0];
183
+ int iw = iw_base + ww_flip * wt_dilation[1];
184
+
185
+ const T* wt_ptr_pt =
186
+ wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
187
+ const T* in_ptr_pt =
188
+ in_ptr + ih * in_stride_H + iw * in_stride_W;
189
+
190
+ for (int c = g * C_per_group; c < (g + 1) * C_per_group;
191
+ ++c) {
192
+ r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
193
+ static_cast<float>(
194
+ wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
195
+ } // c
196
+ } // ww
197
+ } // wh
198
+
199
+ out_ptr[0] = static_cast<T>(r);
200
+ out_ptr += out_stride_O;
201
+ wt_ptr += wt_stride_O;
202
+ } // o
203
+ } // g
204
+ };
205
+
206
+ int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
207
+ int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];
208
+
209
+ int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
210
+ int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);
211
+
212
+ int f_wgt_jump_h =
213
+ std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
214
+ int f_wgt_jump_w =
215
+ std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
216
+
217
+ int f_out_jump_h =
218
+ std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
219
+ int f_out_jump_w =
220
+ std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
221
+
222
+ std::vector<int> base_h(f_out_jump_h);
223
+ std::vector<int> base_w(f_out_jump_w);
224
+
225
+ for (int i = 0; i < f_out_jump_h; ++i) {
226
+ int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;
227
+
228
+ int wh_base = 0;
229
+ while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
230
+ wh_base++;
231
+ ih_loop += jump_h;
232
+ }
233
+
234
+ base_h[i] = wh_base;
235
+ }
236
+
237
+ for (int j = 0; j < f_out_jump_w; ++j) {
238
+ int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;
239
+
240
+ int ww_base = 0;
241
+ while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
242
+ ww_base++;
243
+ iw_loop += jump_w;
244
+ }
245
+
246
+ base_w[j] = ww_base;
247
+ }
248
+
249
+ auto pt_conv_all_checks =
250
+ [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
251
+ out_ptr += oh * out_stride_H + ow * out_stride_W;
252
+
253
+ int ih_base = oh * wt_strides[0] - padding_lo[0];
254
+ int iw_base = ow * wt_strides[1] - padding_lo[1];
255
+
256
+ int wh_base = base_h[oh % f_out_jump_h];
257
+ int ww_base = base_w[ow % f_out_jump_w];
258
+
259
+ for (int g = 0; g < groups; ++g) {
260
+ for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
261
+ float r = 0.;
262
+
263
+ for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
264
+ for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
265
+ int wh_flip = flip ? wH - wh - 1 : wh;
266
+ int ww_flip = flip ? wW - ww - 1 : ww;
267
+ int ih = ih_base + wh_flip * wt_dilation[0];
268
+ int iw = iw_base + ww_flip * wt_dilation[1];
269
+
270
+ if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
271
+ const T* wt_ptr_pt =
272
+ wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
273
+
274
+ int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
275
+ int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;
276
+
277
+ const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
278
+ iw_dil * in_stride_W;
279
+
280
+ for (int c = g * C_per_group; c < (g + 1) * C_per_group;
281
+ ++c) {
282
+ r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
283
+ static_cast<float>(
284
+ wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
285
+ } // c
286
+
287
+ } // ih, iw check
288
+ } // ww
289
+ } // wh
290
+
291
+ out_ptr[0] = static_cast<T>(r);
292
+ out_ptr += out_stride_O;
293
+ wt_ptr += wt_stride_O;
294
+ } // o
295
+ } // g
296
+ };
297
+
298
+ int oH_border_0 = 0;
299
+ int oH_border_1 = is_idil_one
300
+ ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
301
+ : oH;
302
+ int oH_border_2 = std::max(
303
+ oH_border_1,
304
+ (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
305
+ int oH_border_3 = oH;
306
+
307
+ int oW_border_0 = 0;
308
+ int oW_border_1 = is_idil_one
309
+ ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
310
+ : oW;
311
+ int oW_border_2 = std::max(
312
+ oW_border_1,
313
+ (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
314
+ int oW_border_3 = oW;
315
+
316
+ for (int n = 0; n < N; ++n) {
317
+ // Case 1: oh might put us out of bounds
318
+ for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
319
+ for (int ow = 0; ow < oW; ++ow) {
320
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
321
+ } // ow
322
+ } // oh
323
+
324
+ // Case 2: oh in bounds
325
+ for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
326
+ // Case a: ow might put us out of bounds
327
+ for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
328
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
329
+ } // ow
330
+
331
+ // Case b: ow in bounds
332
+ for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
333
+ pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
334
+ } // ow
335
+
336
+ // Case c: ow might put us out of bounds
337
+ for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
338
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
339
+ } // ow
340
+
341
+ } // oh
342
+
343
+ // Case 3: oh might put us out of bounds
344
+ for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
345
+ for (int ow = 0; ow < oW; ++ow) {
346
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
347
+ } // ow
348
+ } // oh
349
+
350
+ st_in_ptr += in_stride_N;
351
+ st_out_ptr += out_stride_N;
352
+
353
+ } // n
354
+ });
355
+ }
356
+
357
+ template <typename T>
358
+ void slow_conv_3D(
359
+ const array& in,
360
+ const array& wt,
361
+ array out,
362
+ const std::vector<int>& padding_lo,
363
+ const std::vector<int>& padding_hi,
364
+ const std::vector<int>& wt_strides,
365
+ const std::vector<int>& wt_dilation,
366
+ const std::vector<int>& in_dilation,
367
+ bool flip,
368
+ Stream stream) {
369
+ auto& encoder = cpu::get_command_encoder(stream);
370
+ encoder.set_input_array(in);
371
+ encoder.set_input_array(wt);
372
+ encoder.set_output_array(out);
373
+
374
+ encoder.dispatch([st_wt_ptr = wt.data<T>(),
375
+ st_in_ptr = in.data<T>(),
376
+ st_out_ptr = out.data<T>(),
377
+
378
+ N = in.shape(
379
+ 0), // Batch size, should be the same as out.shape(0)
380
+ iD = 1 +
381
+ in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
382
+ iH = 1 +
383
+ in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
384
+ iW = 1 +
385
+ in_dilation[2] * (in.shape(3) - 1), // Input spatial dim
386
+ oD = out.shape(1), // Output spatial dim
387
+ oH = out.shape(2), // Output spatial dim
388
+ oW = out.shape(3), // Output spatial dim
389
+ O = wt.shape(0), // Out channels
390
+ C = wt.shape(4), // In channels
391
+ wD = wt.shape(1), // Weight spatial dim
392
+ wH = wt.shape(2), // Weight spatial dim
393
+ wW = wt.shape(3), // Weight spatial dim
394
+
395
+ in_stride_N = in.strides()[0],
396
+ in_stride_D = in.strides()[1],
397
+ in_stride_H = in.strides()[2],
398
+ in_stride_W = in.strides()[3],
399
+ in_stride_C = in.strides()[4],
400
+
401
+ wt_stride_O = wt.strides()[0],
402
+ wt_stride_D = wt.strides()[1],
403
+ wt_stride_H = wt.strides()[2],
404
+ wt_stride_W = wt.strides()[3],
405
+ wt_stride_C = wt.strides()[4],
406
+
407
+ out_stride_N = out.strides()[0],
408
+ out_stride_D = out.strides()[1],
409
+ out_stride_H = out.strides()[2],
410
+ out_stride_W = out.strides()[3],
411
+ out_stride_O = out.strides()[4],
412
+ padding_lo,
413
+ padding_hi,
414
+ wt_strides,
415
+ wt_dilation,
416
+ in_dilation,
417
+ flip]() mutable {
418
+ bool is_idil_one =
419
+ in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;
420
+
421
+ auto pt_conv_no_checks = [&](const T* in_ptr,
422
+ const T* wt_ptr,
423
+ T* out_ptr,
424
+ int od,
425
+ int oh,
426
+ int ow) {
427
+ out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
428
+ int id_base = od * wt_strides[0] - padding_lo[0];
429
+ int ih_base = oh * wt_strides[1] - padding_lo[1];
430
+ int iw_base = ow * wt_strides[2] - padding_lo[2];
431
+
432
+ for (int o = 0; o < O; ++o) {
433
+ float r = 0.;
434
+
435
+ for (int wd = 0; wd < wD; ++wd) {
436
+ for (int wh = 0; wh < wH; ++wh) {
437
+ for (int ww = 0; ww < wW; ++ww) {
438
+ int wd_flip = flip ? wD - wd - 1 : wd;
439
+ int wh_flip = flip ? wH - wh - 1 : wh;
440
+ int ww_flip = flip ? wW - ww - 1 : ww;
441
+ int id = id_base + wd_flip * wt_dilation[0];
442
+ int ih = ih_base + wh_flip * wt_dilation[1];
443
+ int iw = iw_base + ww_flip * wt_dilation[2];
444
+
445
+ const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
446
+ wh * wt_stride_H + ww * wt_stride_W;
447
+ const T* in_ptr_pt = in_ptr + id * in_stride_D +
448
+ ih * in_stride_H + iw * in_stride_W;
449
+
450
+ for (int c = 0; c < C; ++c) {
451
+ r += static_cast<float>(in_ptr_pt[0]) *
452
+ static_cast<float>(wt_ptr_pt[0]);
453
+ in_ptr_pt += in_stride_C;
454
+ wt_ptr_pt += wt_stride_C;
455
+ } // c
456
+
457
+ } // ww
458
+ } // wh
459
+ } // wd
460
+
461
+ out_ptr[0] = static_cast<T>(r);
462
+ out_ptr += out_stride_O;
463
+ wt_ptr += wt_stride_O;
464
+ } // o
465
+ };
466
+
467
+ int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
468
+ int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
469
+ int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];
470
+
471
+ int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
472
+ int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
473
+ int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);
474
+
475
+ int f_wgt_jump_d =
476
+ std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
477
+ int f_wgt_jump_h =
478
+ std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
479
+ int f_wgt_jump_w =
480
+ std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];
481
+
482
+ int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
483
+ int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
484
+ int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];
485
+
486
+ std::vector<int> base_d(f_out_jump_d);
487
+ std::vector<int> base_h(f_out_jump_h);
488
+ std::vector<int> base_w(f_out_jump_w);
489
+
490
+ for (int i = 0; i < f_out_jump_d; ++i) {
491
+ int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;
492
+
493
+ int wd_base = 0;
494
+ while (wd_base < wD && id_loop % in_dilation[0] != 0) {
495
+ wd_base++;
496
+ id_loop += jump_d;
497
+ }
498
+
499
+ base_d[i] = wd_base;
500
+ }
501
+
502
+ for (int i = 0; i < f_out_jump_h; ++i) {
503
+ int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;
504
+
505
+ int wh_base = 0;
506
+ while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
507
+ wh_base++;
508
+ ih_loop += jump_h;
509
+ }
510
+
511
+ base_h[i] = wh_base;
512
+ }
513
+
514
+ for (int j = 0; j < f_out_jump_w; ++j) {
515
+ int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;
516
+
517
+ int ww_base = 0;
518
+ while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
519
+ ww_base++;
520
+ iw_loop += jump_w;
521
+ }
522
+
523
+ base_w[j] = ww_base;
524
+ }
525
+
526
+ auto pt_conv_all_checks = [&](const T* in_ptr,
527
+ const T* wt_ptr,
528
+ T* out_ptr,
529
+ int od,
530
+ int oh,
531
+ int ow) {
532
+ out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
533
+
534
+ int id_base = od * wt_strides[0] - padding_lo[0];
535
+ int ih_base = oh * wt_strides[1] - padding_lo[1];
536
+ int iw_base = ow * wt_strides[2] - padding_lo[2];
537
+
538
+ int wd_base = base_d[od % f_out_jump_d];
539
+ int wh_base = base_h[oh % f_out_jump_h];
540
+ int ww_base = base_w[ow % f_out_jump_w];
541
+
542
+ for (int o = 0; o < O; ++o) {
543
+ float r = 0.;
544
+
545
+ for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
546
+ for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
547
+ for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
548
+ int wd_flip = flip ? wD - wd - 1 : wd;
549
+ int wh_flip = flip ? wH - wh - 1 : wh;
550
+ int ww_flip = flip ? wW - ww - 1 : ww;
551
+ int id = id_base + wd_flip * wt_dilation[0];
552
+ int ih = ih_base + wh_flip * wt_dilation[1];
553
+ int iw = iw_base + ww_flip * wt_dilation[2];
554
+
555
+ if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
556
+ iw < iW) {
557
+ const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
558
+ wh * wt_stride_H + ww * wt_stride_W;
559
+
560
+ int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
561
+ int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
562
+ int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;
563
+
564
+ const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
565
+ ih_dil * in_stride_H + iw_dil * in_stride_W;
566
+
567
+ for (int c = 0; c < C; ++c) {
568
+ r += static_cast<float>(in_ptr_pt[0]) *
569
+ static_cast<float>(wt_ptr_pt[0]);
570
+ in_ptr_pt += in_stride_C;
571
+ wt_ptr_pt += wt_stride_C;
572
+ } // c
573
+
574
+ } // iD, ih, iw check
575
+ } // ww
576
+ } // wh
577
+ } // wd
578
+
579
+ out_ptr[0] = static_cast<T>(r);
580
+ out_ptr += out_stride_O;
581
+ wt_ptr += wt_stride_O;
582
+ } // o
583
+ };
584
+
585
+ int oD_border_0 = 0;
586
+ int oD_border_1 = is_idil_one
587
+ ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
588
+ : oD;
589
+ int oD_border_2 = std::max(
590
+ oD_border_1,
591
+ (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
592
+ int oD_border_3 = oD;
593
+
594
+ int oH_border_0 = 0;
595
+ int oH_border_1 = is_idil_one
596
+ ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
597
+ : oH;
598
+ int oH_border_2 = std::max(
599
+ oH_border_1,
600
+ (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
601
+ int oH_border_3 = oH;
602
+
603
+ int oW_border_0 = 0;
604
+ int oW_border_1 = is_idil_one
605
+ ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
606
+ : oW;
607
+ int oW_border_2 = std::max(
608
+ oW_border_1,
609
+ (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
610
+ int oW_border_3 = oW;
611
+
612
+ for (int n = 0; n < N; ++n) {
613
+ // Case 1: od might put us out of bounds
614
+ for (int od = oD_border_0; od < oD_border_1; ++od) {
615
+ for (int oh = 0; oh < oH; ++oh) {
616
+ for (int ow = 0; ow < oW; ++ow) {
617
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
618
+ } // ow
619
+ } // oh
620
+ } // od
621
+
622
+ // Case 2: od in bounds
623
+ for (int od = oD_border_1; od < oD_border_2; ++od) {
624
+ // Case 2.1: oh might put us out of bounds
625
+ for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
626
+ for (int ow = 0; ow < oW; ++ow) {
627
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
628
+ } // ow
629
+ } // oh
630
+
631
+ // Case 2.2: oh in bounds
632
+ for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
633
+ // Case 2.2.1: ow might put us out of bounds
634
+ for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
635
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
636
+ } // ow
637
+
638
+ // Case 2.2.2: ow in bounds
639
+ for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
640
+ pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
641
+ } // ow
642
+
643
+ // Case 2.2.3: ow might put us out of bounds
644
+ for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
645
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
646
+ } // ow
647
+ } // oh
648
+
649
+ // Case 2.3: oh might put us out of bounds
650
+ for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
651
+ for (int ow = 0; ow < oW; ++ow) {
652
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
653
+ } // ow
654
+ } // oh
655
+ } // od
656
+
657
+ // Case 3: od might put us out of bounds
658
+ for (int od = oD_border_2; od < oD_border_3; ++od) {
659
+ for (int oh = 0; oh < oH; ++oh) {
660
+ for (int ow = 0; ow < oW; ++ow) {
661
+ pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
662
+ } // ow
663
+ } // oh
664
+ } // od
665
+
666
+ st_in_ptr += in_stride_N;
667
+ st_out_ptr += out_stride_N;
668
+
669
+ } // n
670
+ });
671
+ }
672
+
673
+ void dispatch_slow_conv_1D(
674
+ const array& in,
675
+ const array& wt,
676
+ array out,
677
+ const std::vector<int>& padding_lo,
678
+ const std::vector<int>& padding_hi,
679
+ const std::vector<int>& wt_strides,
680
+ const std::vector<int>& wt_dilation,
681
+ const std::vector<int>& in_dilation,
682
+ bool flip,
683
+ Stream stream) {
684
+ if (in.dtype() == float32) {
685
+ return slow_conv_1D<float>(
686
+ in,
687
+ wt,
688
+ out,
689
+ padding_lo,
690
+ padding_hi,
691
+ wt_strides,
692
+ wt_dilation,
693
+ in_dilation,
694
+ flip,
695
+ stream);
696
+ } else if (in.dtype() == float16) {
697
+ return slow_conv_1D<float16_t>(
698
+ in,
699
+ wt,
700
+ out,
701
+ padding_lo,
702
+ padding_hi,
703
+ wt_strides,
704
+ wt_dilation,
705
+ in_dilation,
706
+ flip,
707
+ stream);
708
+ } else if (in.dtype() == bfloat16) {
709
+ return slow_conv_1D<bfloat16_t>(
710
+ in,
711
+ wt,
712
+ out,
713
+ padding_lo,
714
+ padding_hi,
715
+ wt_strides,
716
+ wt_dilation,
717
+ in_dilation,
718
+ flip,
719
+ stream);
720
+ } else {
721
+ throw std::invalid_argument(
722
+ "[Convolution::eval] got unsupported data type.");
723
+ }
724
+ }
725
+
726
+ void dispatch_slow_conv_2D(
727
+ const array& in,
728
+ const array& wt,
729
+ array out,
730
+ const std::vector<int>& padding_lo,
731
+ const std::vector<int>& padding_hi,
732
+ const std::vector<int>& wt_strides,
733
+ const std::vector<int>& wt_dilation,
734
+ const std::vector<int>& in_dilation,
735
+ bool flip,
736
+ Stream stream) {
737
+ if (in.dtype() == float32) {
738
+ return slow_conv_2D<float>(
739
+ in,
740
+ wt,
741
+ out,
742
+ padding_lo,
743
+ padding_hi,
744
+ wt_strides,
745
+ wt_dilation,
746
+ in_dilation,
747
+ flip,
748
+ stream);
749
+ } else if (in.dtype() == float16) {
750
+ return slow_conv_2D<float16_t>(
751
+ in,
752
+ wt,
753
+ out,
754
+ padding_lo,
755
+ padding_hi,
756
+ wt_strides,
757
+ wt_dilation,
758
+ in_dilation,
759
+ flip,
760
+ stream);
761
+ } else if (in.dtype() == bfloat16) {
762
+ return slow_conv_2D<bfloat16_t>(
763
+ in,
764
+ wt,
765
+ out,
766
+ padding_lo,
767
+ padding_hi,
768
+ wt_strides,
769
+ wt_dilation,
770
+ in_dilation,
771
+ flip,
772
+ stream);
773
+ } else {
774
+ throw std::invalid_argument(
775
+ "[Convolution::eval] got unsupported data type.");
776
+ }
777
+ }
778
+
779
+ void dispatch_slow_conv_3D(
780
+ const array& in,
781
+ const array& wt,
782
+ array out,
783
+ const std::vector<int>& padding_lo,
784
+ const std::vector<int>& padding_hi,
785
+ const std::vector<int>& wt_strides,
786
+ const std::vector<int>& wt_dilation,
787
+ const std::vector<int>& in_dilation,
788
+ bool flip,
789
+ Stream stream) {
790
+ if (in.dtype() == float32) {
791
+ return slow_conv_3D<float>(
792
+ in,
793
+ wt,
794
+ out,
795
+ padding_lo,
796
+ padding_hi,
797
+ wt_strides,
798
+ wt_dilation,
799
+ in_dilation,
800
+ flip,
801
+ stream);
802
+ } else if (in.dtype() == float16) {
803
+ return slow_conv_3D<float16_t>(
804
+ in,
805
+ wt,
806
+ out,
807
+ padding_lo,
808
+ padding_hi,
809
+ wt_strides,
810
+ wt_dilation,
811
+ in_dilation,
812
+ flip,
813
+ stream);
814
+ } else if (in.dtype() == bfloat16) {
815
+ return slow_conv_3D<bfloat16_t>(
816
+ in,
817
+ wt,
818
+ out,
819
+ padding_lo,
820
+ padding_hi,
821
+ wt_strides,
822
+ wt_dilation,
823
+ in_dilation,
824
+ flip,
825
+ stream);
826
+ } else {
827
+ throw std::invalid_argument(
828
+ "[Convolution::eval] got unsupported data type.");
829
+ }
830
+ }
831
+
832
+ ///////////////////////////////////////////////////////////////////////////////
833
+ // Explicit gemm conv
834
+ ///////////////////////////////////////////////////////////////////////////////
835
+
836
+ template <typename T>
837
+ void flip_spatial_dims_inplace(
838
+ T* x,
839
+ size_t in_channels,
840
+ size_t out_channels,
841
+ size_t spatial_size) {
842
+ for (size_t i = 0; i < out_channels; i++) {
843
+ T* top = x + i * spatial_size * in_channels;
844
+ T* bottom =
845
+ x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
846
+ for (size_t j = 0; j < spatial_size / 2; j++) {
847
+ for (size_t k = 0; k < in_channels; k++) {
848
+ std::swap(top[k], bottom[k]);
849
+ }
850
+ top += in_channels;
851
+ bottom -= in_channels;
852
+ }
853
+ }
854
+ }
855
+
856
+ void explicit_gemm_conv_1D_cpu(
857
+ const array& in,
858
+ const array& wt,
859
+ array out,
860
+ const std::vector<int>& padding_lo,
861
+ const std::vector<int>& padding_hi,
862
+ const std::vector<int>& wt_strides,
863
+ const std::vector<int>& wt_dilation,
864
+ Stream stream) {
865
+ const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
866
+ const int iH = in.shape(1); // Input spatial dim
867
+ const int C = in.shape(2); // Input channels
868
+ const int oH = out.shape(1); // Output spatial dim
869
+ const int O = wt.shape(0); // Out channels
870
+ const int wH = wt.shape(1); // Weight spatial dim
871
+
872
+ const int groups = C / wt.shape(2);
873
+ const int C_per_group = wt.shape(2);
874
+ const int O_per_group = O / groups;
875
+
876
+ auto conv_dtype = float32;
877
+ auto& encoder = cpu::get_command_encoder(stream);
878
+
879
+ // Pad input
880
+ Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
881
+ array in_padded(padded_shape, conv_dtype, nullptr, {});
882
+
883
+ // Fill with zeros
884
+ std::vector<array> temps;
885
+ temps.push_back(array(0, conv_dtype));
886
+ copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
887
+
888
+ // Pick input slice from padded
889
+ size_t data_offset = padding_lo[0] * in_padded.strides()[1];
890
+ array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
891
+ in_padded_slice.copy_shared_buffer(
892
+ in_padded,
893
+ in_padded.strides(),
894
+ in_padded.flags(),
895
+ in_padded_slice.size(),
896
+ data_offset);
897
+ // Copy input values into the slice
898
+ copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
899
+ temps.push_back(in_padded_slice);
900
+
901
+ // Make strided view
902
+ Shape strided_shape = {N, oH, wH, C};
903
+
904
+ Strides strided_strides = {
905
+ in_padded.strides()[0],
906
+ in_padded.strides()[1] * wt_strides[0],
907
+ in_padded.strides()[1],
908
+ in_padded.strides()[2]};
909
+ auto flags = in_padded.flags();
910
+ if (groups > 1) {
911
+ // Transpose the last two dimensions for grouped convolutions
912
+ std::swap(strided_shape[2], strided_shape[3]);
913
+ std::swap(strided_strides[2], strided_strides[3]);
914
+ }
915
+
916
+ array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
917
+ in_strided_view.copy_shared_buffer(
918
+ in_padded, strided_strides, flags, in_strided_view.size(), 0);
919
+
920
+ // Materialize strided view
921
+ Shape strided_reshape = {N * oH, wH * C};
922
+ array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
923
+ copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
924
+ temps.push_back(in_strided);
925
+
926
+ // Check wt dtype and prepare
927
+ auto gemm_wt = wt;
928
+ auto gemm_out = out;
929
+
930
+ if (groups > 1) {
931
+ // Transpose the last two dimensions for grouped convolutions
932
+ array wt_transpose(
933
+ {wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
934
+ wt_transpose.copy_shared_buffer(
935
+ wt,
936
+ {wt.strides(0), wt.strides(2), wt.strides(1)},
937
+ wt.flags(),
938
+ wt.size(),
939
+ 0);
940
+ gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
941
+ copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
942
+ temps.push_back(gemm_wt);
943
+ } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
944
+ auto ctype =
945
+ wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
946
+ gemm_wt = array(wt.shape(), float32, nullptr, {});
947
+ copy_cpu(wt, gemm_wt, ctype, stream);
948
+ temps.push_back(gemm_wt);
949
+ }
950
+
951
+ if (out.dtype() != float32) {
952
+ gemm_out = array(out.shape(), float32, nullptr, {});
953
+ gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
954
+ temps.push_back(gemm_out);
955
+ }
956
+
957
+ encoder.set_input_array(in_strided);
958
+ encoder.set_input_array(gemm_wt);
959
+ encoder.set_output_array(gemm_out);
960
+
961
+ encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
962
+ gemm_wt_ptr = gemm_wt.data<float>(),
963
+ gemm_out_ptr = gemm_out.data<float>(),
964
+ groups,
965
+ strided_reshape = strided_reshape[0],
966
+ O,
967
+ C,
968
+ wH,
969
+ O_per_group,
970
+ C_per_group]() {
971
+ for (int g = 0; g < groups; ++g) {
972
+ // Perform gemm
973
+ cblas_sgemm(
974
+ CblasRowMajor,
975
+ CblasNoTrans, // no trans A
976
+ CblasTrans, // transB
977
+ strided_reshape, // M
978
+ O_per_group, // N
979
+ C_per_group * wH, // K
980
+ 1.0f, // alpha
981
+ in_strided_ptr + g * C_per_group * wH, // A
982
+ wH * C, // lda
983
+ gemm_wt_ptr + g * O_per_group * C_per_group * wH, // B
984
+ wH * C_per_group, // ldb
985
+ 0.0f, // beta
986
+ gemm_out_ptr + g * O_per_group, // C
987
+ O // ldc
988
+ );
989
+ }
990
+ });
991
+
992
+ // Copy results if needed
993
+ if (out.dtype() != float32) {
994
+ copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
995
+ }
996
+ encoder.add_temporaries(std::move(temps));
997
+ }
998
+
999
+ void explicit_gemm_conv_ND_cpu(
1000
+ const array& in,
1001
+ const array& wt,
1002
+ array out,
1003
+ const std::vector<int>& padding_lo,
1004
+ const std::vector<int>& padding_hi,
1005
+ const std::vector<int>& wt_strides,
1006
+ const std::vector<int>& wt_dilation,
1007
+ const bool flip,
1008
+ Stream stream) {
1009
+ const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
1010
+ const auto iDim =
1011
+ Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
1012
+ const auto oDim = Shape(
1013
+ out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
1014
+ const int O = wt.shape(0); // Out channels
1015
+ const int C = wt.shape(-1); // In channels
1016
+ const auto wDim =
1017
+ Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim
1018
+
1019
+ auto conv_dtype = float32;
1020
+
1021
+ auto& encoder = cpu::get_command_encoder(stream);
1022
+
1023
+ // Pad input
1024
+ Shape padded_shape(in.shape().size());
1025
+ padded_shape.front() = N;
1026
+ for (size_t i = 0; i < iDim.size(); i++) {
1027
+ padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
1028
+ }
1029
+ padded_shape.back() = C;
1030
+ array in_padded(padded_shape, conv_dtype, nullptr, {});
1031
+
1032
+ // Fill with zeros
1033
+ std::vector<array> temps = {array(0, conv_dtype)};
1034
+ copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);
1035
+
1036
+ // Pick input slice from padded
1037
+ size_t data_offset = 0;
1038
+ for (size_t i = 0; i < padding_lo.size(); i++) {
1039
+ data_offset += padding_lo[i] * in_padded.strides()[i + 1];
1040
+ }
1041
+
1042
+ array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
1043
+ in_padded_slice.copy_shared_buffer(
1044
+ in_padded,
1045
+ in_padded.strides(),
1046
+ in_padded.flags(),
1047
+ in_padded_slice.size(),
1048
+ data_offset);
1049
+
1050
+ // Copy input values into the slice
1051
+ copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
1052
+ temps.push_back(in_padded_slice);
1053
+
1054
+ // Make strided view
1055
+ Shape strided_shape(oDim.size() + wDim.size() + 2);
1056
+ strided_shape.front() = N;
1057
+ for (size_t i = 0; i < oDim.size(); i++) {
1058
+ strided_shape[i + 1] = oDim[i];
1059
+ }
1060
+ for (size_t i = 0; i < wDim.size(); i++) {
1061
+ strided_shape[i + 1 + oDim.size()] = wDim[i];
1062
+ }
1063
+ strided_shape.back() = C;
1064
+
1065
+ Strides strided_strides(in.shape().size() * 2 - 2);
1066
+ strided_strides[0] = in_padded.strides()[0];
1067
+ for (size_t i = 0; i < wt_strides.size(); i++) {
1068
+ strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
1069
+ }
1070
+ for (size_t i = 1; i < in_padded.strides().size(); i++) {
1071
+ strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
1072
+ }
1073
+
1074
+ auto flags = in_padded.flags();
1075
+
1076
+ array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
1077
+ in_strided_view.copy_shared_buffer(
1078
+ in_padded, strided_strides, flags, in_strided_view.size(), 0);
1079
+
1080
+ // Materialize strided view
1081
+ Shape strided_reshape = {N, C};
1082
+ for (const auto& o : oDim) {
1083
+ strided_reshape[0] *= o;
1084
+ }
1085
+ for (const auto& w : wDim) {
1086
+ strided_reshape[1] *= w;
1087
+ }
1088
+
1089
+ array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
1090
+ copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
1091
+ temps.push_back(in_strided);
1092
+
1093
+ // Check wt dtype and prepare
1094
+ auto gemm_wt = wt;
1095
+ auto gemm_out = out;
1096
+
1097
+ if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
1098
+ auto ctype =
1099
+ wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
1100
+ gemm_wt = array(wt.shape(), float32, nullptr, {});
1101
+ copy_cpu(wt, gemm_wt, ctype, stream);
1102
+ temps.push_back(gemm_wt);
1103
+ }
1104
+
1105
+ if (flip) {
1106
+ auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
1107
+ copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
1108
+ temps.push_back(gemm_wt_);
1109
+
1110
+ // Calculate the total size of the spatial dimensions
1111
+ int spatial_size = 1;
1112
+ for (int d = 1; d < gemm_wt.ndim() - 1; ++d) {
1113
+ spatial_size *= gemm_wt.shape(d);
1114
+ }
1115
+ encoder.set_output_array(gemm_wt_);
1116
+ encoder.dispatch([gemm_wt_ptr = gemm_wt_.data<float>(),
1117
+ out_channels = gemm_wt.shape(0),
1118
+ in_channels = gemm_wt.shape(-1),
1119
+ spatial_size]() {
1120
+ flip_spatial_dims_inplace<float>(
1121
+ gemm_wt_ptr, in_channels, out_channels, spatial_size);
1122
+ });
1123
+ gemm_wt = gemm_wt_;
1124
+ }
1125
+
1126
+ if (out.dtype() != float32) {
1127
+ gemm_out = array(out.shape(), float32, nullptr, {});
1128
+ gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
1129
+ temps.push_back(gemm_out);
1130
+ }
1131
+
1132
+ encoder.set_input_array(in_strided);
1133
+ encoder.set_input_array(gemm_wt);
1134
+ encoder.set_output_array(gemm_out);
1135
+
1136
+ encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
1137
+ gemm_wt_ptr = gemm_wt.data<float>(),
1138
+ gemm_out_ptr = gemm_out.data<float>(),
1139
+ strided_reshape = std::move(strided_reshape),
1140
+ O]() {
1141
+ // Perform gemm
1142
+ cblas_sgemm(
1143
+ CblasRowMajor,
1144
+ CblasNoTrans, // no trans A
1145
+ CblasTrans, // transB
1146
+ strided_reshape[0], // M
1147
+ O, // N
1148
+ strided_reshape[1], // K
1149
+ 1.0f, // alpha
1150
+ in_strided_ptr,
1151
+ strided_reshape[1], // lda
1152
+ gemm_wt_ptr,
1153
+ strided_reshape[1], // ldb
1154
+ 0.0f, // beta
1155
+ gemm_out_ptr,
1156
+ O // ldc
1157
+ );
1158
+ });
1159
+
1160
+ // Copy results if needed
1161
+ if (out.dtype() != float32) {
1162
+ copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
1163
+ }
1164
+ encoder.add_temporaries(std::move(temps));
1165
+ }
1166
+
1167
+ ///////////////////////////////////////////////////////////////////////////////
1168
+ // Conv routing
1169
+ ///////////////////////////////////////////////////////////////////////////////
1170
+
1171
+ void conv_1D_cpu(
1172
+ const array& in,
1173
+ const array& wt,
1174
+ array out,
1175
+ const std::vector<int>& padding_lo,
1176
+ const std::vector<int>& padding_hi,
1177
+ const std::vector<int>& wt_strides,
1178
+ const std::vector<int>& wt_dilation,
1179
+ const std::vector<int>& in_dilation,
1180
+ bool flip,
1181
+ Stream stream) {
1182
+ const int groups = in.shape().back() / wt.shape().back();
1183
+ if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
1184
+ return explicit_gemm_conv_1D_cpu(
1185
+ in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
1186
+ }
1187
+ if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
1188
+ return explicit_gemm_conv_ND_cpu(
1189
+ in,
1190
+ wt,
1191
+ out,
1192
+ padding_lo,
1193
+ padding_hi,
1194
+ wt_strides,
1195
+ wt_dilation,
1196
+ flip,
1197
+ stream);
1198
+ }
1199
+
1200
+ return dispatch_slow_conv_1D(
1201
+ in,
1202
+ wt,
1203
+ out,
1204
+ padding_lo,
1205
+ padding_hi,
1206
+ wt_strides,
1207
+ wt_dilation,
1208
+ in_dilation,
1209
+ flip,
1210
+ stream);
1211
+ }
1212
+
1213
+ void conv_2D_cpu(
1214
+ const array& in,
1215
+ const array& wt,
1216
+ array out,
1217
+ const std::vector<int>& padding_lo,
1218
+ const std::vector<int>& padding_hi,
1219
+ const std::vector<int>& wt_strides,
1220
+ const std::vector<int>& wt_dilation,
1221
+ const std::vector<int>& in_dilation,
1222
+ bool flip,
1223
+ Stream stream) {
1224
+ const int groups = in.shape().back() / wt.shape().back();
1225
+ if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
1226
+ in_dilation[1] == 1 && groups == 1) {
1227
+ return explicit_gemm_conv_ND_cpu(
1228
+ in,
1229
+ wt,
1230
+ out,
1231
+ padding_lo,
1232
+ padding_hi,
1233
+ wt_strides,
1234
+ wt_dilation,
1235
+ flip,
1236
+ stream);
1237
+ }
1238
+ return dispatch_slow_conv_2D(
1239
+ in,
1240
+ wt,
1241
+ out,
1242
+ padding_lo,
1243
+ padding_hi,
1244
+ wt_strides,
1245
+ wt_dilation,
1246
+ in_dilation,
1247
+ flip,
1248
+ stream);
1249
+ }
1250
+
1251
+ void conv_3D_cpu(
1252
+ const array& in,
1253
+ const array& wt,
1254
+ array out,
1255
+ const std::vector<int>& padding_lo,
1256
+ const std::vector<int>& padding_hi,
1257
+ const std::vector<int>& wt_strides,
1258
+ const std::vector<int>& wt_dilation,
1259
+ const std::vector<int>& in_dilation,
1260
+ bool flip,
1261
+ Stream stream) {
1262
+ const int groups = in.shape().back() / wt.shape().back();
1263
+ if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
1264
+ in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
1265
+ groups == 1) {
1266
+ return explicit_gemm_conv_ND_cpu(
1267
+ in,
1268
+ wt,
1269
+ out,
1270
+ padding_lo,
1271
+ padding_hi,
1272
+ wt_strides,
1273
+ wt_dilation,
1274
+ flip,
1275
+ stream);
1276
+ }
1277
+
1278
+ return dispatch_slow_conv_3D(
1279
+ in,
1280
+ wt,
1281
+ out,
1282
+ padding_lo,
1283
+ padding_hi,
1284
+ wt_strides,
1285
+ wt_dilation,
1286
+ in_dilation,
1287
+ flip,
1288
+ stream);
1289
+ }
1290
+
1291
+ } // namespace
1292
+
1293
+ void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
1294
+ out.set_data(allocator::malloc(out.nbytes()));
1295
+
1296
+ auto& in = inputs[0];
1297
+ auto& wt = inputs[1];
1298
+
1299
+ // 3D convolution
1300
+ if (in.ndim() == (3 + 2)) {
1301
+ return conv_3D_cpu(
1302
+ in,
1303
+ wt,
1304
+ out,
1305
+ padding_lo_,
1306
+ padding_hi_,
1307
+ kernel_strides_,
1308
+ kernel_dilation_,
1309
+ input_dilation_,
1310
+ flip_,
1311
+ stream());
1312
+ }
1313
+ // 2D convolution
1314
+ else if (in.ndim() == (2 + 2)) {
1315
+ return conv_2D_cpu(
1316
+ in,
1317
+ wt,
1318
+ out,
1319
+ padding_lo_,
1320
+ padding_hi_,
1321
+ kernel_strides_,
1322
+ kernel_dilation_,
1323
+ input_dilation_,
1324
+ flip_,
1325
+ stream());
1326
+ }
1327
+ // 1D convolution
1328
+ else if (in.ndim() == (1 + 2)) {
1329
+ return conv_1D_cpu(
1330
+ in,
1331
+ wt,
1332
+ out,
1333
+ padding_lo_,
1334
+ padding_hi_,
1335
+ kernel_strides_,
1336
+ kernel_dilation_,
1337
+ input_dilation_,
1338
+ flip_,
1339
+ stream());
1340
+ }
1341
+ // Throw error
1342
+ else {
1343
+ std::ostringstream msg;
1344
+ msg << "[Convolution::eval] Convolution currently only supports"
1345
+ << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
1346
+ << " spatial dimensions";
1347
+ throw std::invalid_argument(msg.str());
1348
+ }
1349
+ }
1350
+
1351
+ } // namespace mlx::core