mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,807 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #include <cassert>
3
+ #include <complex>
4
+ #include <map>
5
+ #include <numeric>
6
+ #include <set>
7
+
8
+ #include "mlx/3rdparty/pocketfft.h"
9
+ #include "mlx/backend/common/utils.h"
10
+ #include "mlx/backend/gpu/copy.h"
11
+ #include "mlx/backend/gpu/slicing.h"
12
+ #include "mlx/backend/metal/binary.h"
13
+ #include "mlx/backend/metal/kernels.h"
14
+ #include "mlx/backend/metal/unary.h"
15
+ #include "mlx/backend/metal/utils.h"
16
+ #include "mlx/utils.h"
17
+
18
+ namespace mlx::core {
19
+
20
+ using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
21
+
22
+ #define MAX_STOCKHAM_FFT_SIZE 4096
23
+ #define MAX_RADER_FFT_SIZE 2048
24
+ #define MAX_BLUESTEIN_FFT_SIZE 2048
25
+ // Threadgroup memory batching improves throughput for small n
26
+ #define MIN_THREADGROUP_MEM_SIZE 256
27
+ // For strided reads/writes, coalesce at least this many complex64s
28
+ #define MIN_COALESCE_WIDTH 4
29
+
30
+ inline const std::vector<int> supported_radices() {
31
+ // Ordered by preference in decomposition.
32
+ return {13, 11, 8, 7, 6, 5, 4, 3, 2};
33
+ }
34
+
35
+ std::vector<int> prime_factors(int n) {
36
+ int z = 2;
37
+ std::vector<int> factors;
38
+ while (z * z <= n) {
39
+ if (n % z == 0) {
40
+ factors.push_back(z);
41
+ n /= z;
42
+ } else {
43
+ z++;
44
+ }
45
+ }
46
+ if (n > 1) {
47
+ factors.push_back(n);
48
+ }
49
+ return factors;
50
+ }
51
+
52
+ struct FourStepParams {
53
+ bool required = false;
54
+ bool first_step = true;
55
+ int n1 = 0;
56
+ int n2 = 0;
57
+ };
58
+
59
+ // Forward Declaration
60
+ void fft_op(
61
+ const array& in,
62
+ array& out,
63
+ size_t axis,
64
+ bool inverse,
65
+ bool real,
66
+ const FourStepParams four_step_params,
67
+ bool inplace,
68
+ const Stream& s);
69
+
70
+ struct FFTPlan {
71
+ int n = 0;
72
+ // Number of steps for each radix in the Stockham decomposition
73
+ std::vector<int> stockham;
74
+ // Number of steps for each radix in the Rader decomposition
75
+ std::vector<int> rader;
76
+ // Rader factor, 1 if no rader factors
77
+ int rader_n = 1;
78
+ int bluestein_n = -1;
79
+ // Four step FFT
80
+ bool four_step = false;
81
+ int n1 = 0;
82
+ int n2 = 0;
83
+ };
84
+
85
+ int next_fast_n(int n) {
86
+ return next_power_of_2(n);
87
+ }
88
+
89
+ std::vector<int> plan_stockham_fft(int n) {
90
+ auto radices = supported_radices();
91
+ std::vector<int> plan(radices.size(), 0);
92
+ int orig_n = n;
93
+ if (n == 1) {
94
+ return plan;
95
+ }
96
+ for (int i = 0; i < radices.size(); i++) {
97
+ int radix = radices[i];
98
+ // Manually tuned radices for powers of 2
99
+ if (is_power_of_2(orig_n) && orig_n < 512 && radix > 4) {
100
+ continue;
101
+ }
102
+ while (n % radix == 0) {
103
+ plan[i] += 1;
104
+ n /= radix;
105
+ if (n == 1) {
106
+ return plan;
107
+ }
108
+ }
109
+ }
110
+ throw std::runtime_error("Unplannable");
111
+ }
112
+
113
+ FFTPlan plan_fft(int n) {
114
+ auto radices = supported_radices();
115
+ std::set<int> radices_set(radices.begin(), radices.end());
116
+
117
+ FFTPlan plan;
118
+ plan.n = n;
119
+ plan.rader = std::vector<int>(radices.size(), 0);
120
+ auto factors = prime_factors(n);
121
+ int remaining_n = n;
122
+
123
+ // Four Step FFT when N is too large for shared mem.
124
+ if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
125
+ // For power's of two we have a fast, no transpose four step implementation.
126
+ plan.four_step = true;
127
+ // Rough heuristic for choosing faster powers of two when we can
128
+ plan.n2 = n > 65536 ? 1024 : 64;
129
+ plan.n1 = n / plan.n2;
130
+ return plan;
131
+ } else if (n > MAX_STOCKHAM_FFT_SIZE) {
132
+ // Otherwise we use a multi-upload Bluestein's
133
+ plan.four_step = true;
134
+ plan.bluestein_n = next_fast_n(2 * n - 1);
135
+ return plan;
136
+ }
137
+
138
+ for (int factor : factors) {
139
+ // Make sure the factor is a supported radix
140
+ if (radices_set.find(factor) == radices_set.end()) {
141
+ // We only support a single Rader factor currently
142
+ // TODO(alexbarron) investigate weirdness with large
143
+ // Rader sizes -- possibly a compiler issue?
144
+ if (plan.rader_n > 1 || n > MAX_RADER_FFT_SIZE) {
145
+ plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
146
+ plan.bluestein_n = next_fast_n(2 * n - 1);
147
+ plan.stockham = plan_stockham_fft(plan.bluestein_n);
148
+ plan.rader = std::vector<int>(radices.size(), 0);
149
+ return plan;
150
+ }
151
+ // See if we can use Rader's algorithm to Stockham decompose n - 1
152
+ auto rader_factors = prime_factors(factor - 1);
153
+ for (int rf : rader_factors) {
154
+ // We don't nest Rader's algorithm so if `factor - 1`
155
+ // isn't Stockham decomposable we give up and do Bluestein's.
156
+ if (radices_set.find(rf) == radices_set.end()) {
157
+ plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
158
+ plan.bluestein_n = next_fast_n(2 * n - 1);
159
+ plan.stockham = plan_stockham_fft(plan.bluestein_n);
160
+ plan.rader = std::vector<int>(radices.size(), 0);
161
+ return plan;
162
+ }
163
+ }
164
+ plan.rader = plan_stockham_fft(factor - 1);
165
+ plan.rader_n = factor;
166
+ remaining_n /= factor;
167
+ }
168
+ }
169
+
170
+ plan.stockham = plan_stockham_fft(remaining_n);
171
+ return plan;
172
+ }
173
+
174
+ int compute_elems_per_thread(FFTPlan plan) {
175
+ // Heuristics for selecting an efficient number
176
+ // of threads to use for a particular mixed-radix FFT.
177
+ auto n = plan.n;
178
+
179
+ std::vector<int> steps;
180
+ auto radices = supported_radices();
181
+ steps.insert(steps.end(), plan.stockham.begin(), plan.stockham.end());
182
+ steps.insert(steps.end(), plan.rader.begin(), plan.rader.end());
183
+ std::set<int> used_radices;
184
+ for (int i = 0; i < steps.size(); i++) {
185
+ int radix = radices[i % radices.size()];
186
+ if (steps[i] > 0) {
187
+ used_radices.insert(radix);
188
+ }
189
+ }
190
+
191
+ // Manual tuning for 7/11/13
192
+ if (used_radices.find(7) != used_radices.end() &&
193
+ (used_radices.find(11) != used_radices.end() ||
194
+ used_radices.find(13) != used_radices.end())) {
195
+ return 7;
196
+ } else if (
197
+ used_radices.find(11) != used_radices.end() &&
198
+ used_radices.find(13) != used_radices.end()) {
199
+ return 11;
200
+ }
201
+
202
+ // TODO(alexbarron) Some really weird stuff is going on
203
+ // for certain `elems_per_thread` on large composite n.
204
+ // Possibly a compiler issue?
205
+ if (n == 3159)
206
+ return 13;
207
+ if (n == 3645)
208
+ return 5;
209
+ if (n == 3969)
210
+ return 7;
211
+ if (n == 1982)
212
+ return 5;
213
+
214
+ if (used_radices.size() == 1) {
215
+ return *(used_radices.begin());
216
+ }
217
+ if (used_radices.size() == 2) {
218
+ if (used_radices.find(11) != used_radices.end() ||
219
+ used_radices.find(13) != used_radices.end()) {
220
+ return std::accumulate(used_radices.begin(), used_radices.end(), 0) / 2;
221
+ }
222
+ std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
223
+ return radix_vec[1];
224
+ }
225
+ // In all other cases use the second smallest radix.
226
+ std::vector<int> radix_vec(used_radices.begin(), used_radices.end());
227
+ return radix_vec[1];
228
+ }
229
+
230
+ // Rader
231
+ int mod_exp(int x, int y, int n) {
232
+ int out = 1;
233
+ while (y) {
234
+ if (y & 1) {
235
+ out = out * x % n;
236
+ }
237
+ y >>= 1;
238
+ x = x * x % n;
239
+ }
240
+ return out;
241
+ }
242
+
243
+ int primitive_root(int n) {
244
+ auto factors = prime_factors(n - 1);
245
+
246
+ for (int r = 2; r < n - 1; r++) {
247
+ bool found = true;
248
+ for (int factor : factors) {
249
+ if (mod_exp(r, (n - 1) / factor, n) == 1) {
250
+ found = false;
251
+ break;
252
+ }
253
+ }
254
+ if (found) {
255
+ return r;
256
+ }
257
+ }
258
+ return -1;
259
+ }
260
+
261
+ std::tuple<array, array, array> compute_raders_constants(
262
+ int rader_n,
263
+ const Stream& s) {
264
+ int proot = primitive_root(rader_n);
265
+ // Fermat's little theorem
266
+ int inv = mod_exp(proot, rader_n - 2, rader_n);
267
+ std::vector<short> g_q(rader_n - 1);
268
+ std::vector<short> g_minus_q(rader_n - 1);
269
+ for (int i = 0; i < rader_n - 1; i++) {
270
+ g_q[i] = mod_exp(proot, i, rader_n);
271
+ g_minus_q[i] = mod_exp(inv, i, rader_n);
272
+ }
273
+ array g_q_arr(g_q.begin(), {rader_n - 1});
274
+ array g_minus_q_arr(g_minus_q.begin(), {rader_n - 1});
275
+
276
+ std::vector<std::complex<float>> b_q(rader_n - 1);
277
+ for (int i = 0; i < rader_n - 1; i++) {
278
+ float pi_i = (float)g_minus_q[i] * -2.0 * M_PI / rader_n;
279
+ b_q[i] = std::exp(std::complex<float>(0, pi_i));
280
+ }
281
+
282
+ array b_q_fft({rader_n - 1}, complex64, nullptr, {});
283
+ b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes()));
284
+ auto b_q_fft_ptr =
285
+ reinterpret_cast<std::complex<float>*>(b_q_fft.data<complex64_t>());
286
+ std::ptrdiff_t item_size = b_q_fft.itemsize();
287
+ size_t fft_size = rader_n - 1;
288
+ // This FFT is always small (<4096, batch 1) so save some overhead
289
+ // and do it on the CPU
290
+ pocketfft::c2c(
291
+ /* shape= */ {fft_size},
292
+ /* stride_in= */ {item_size},
293
+ /* stride_out= */ {item_size},
294
+ /* axes= */ {0},
295
+ /* forward= */ true,
296
+ /* data_in= */ b_q.data(),
297
+ /* data_out= */ b_q_fft_ptr,
298
+ /* scale= */ 1.0f);
299
+ return std::make_tuple(b_q_fft, g_q_arr, g_minus_q_arr);
300
+ }
301
+
302
+ // Bluestein
303
+ std::pair<array, array> compute_bluestein_constants(int n, int bluestein_n) {
304
+ // We need to calculate the Bluestein twiddle factors
305
+ // in double precision for the overall numerical stability
306
+ // of Bluestein's FFT algorithm to be acceptable.
307
+ //
308
+ // Metal doesn't support float64, so instead we
309
+ // manually implement the required operations on cpu.
310
+ //
311
+ // In numpy:
312
+ // w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2))
313
+ // w_q = np.fft.fft(1/w_k)
314
+ // return w_k, w_q
315
+ std::vector<std::complex<float>> w_k_vec(n);
316
+ std::vector<std::complex<float>> w_q_vec(bluestein_n, 0);
317
+
318
+ for (int i = -n + 1; i < n; i++) {
319
+ double theta = pow(i, 2) * M_PI / (double)n;
320
+ w_q_vec[i + n - 1] = std::exp(std::complex<double>(0, theta));
321
+ if (i >= 0) {
322
+ w_k_vec[i] = std::exp(std::complex<double>(0, -theta));
323
+ }
324
+ }
325
+
326
+ array w_k({n}, complex64, nullptr, {});
327
+ w_k.set_data(allocator::malloc(w_k.nbytes()));
328
+ std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data<complex64_t>());
329
+
330
+ array w_q({bluestein_n}, complex64, nullptr, {});
331
+ w_q.set_data(allocator::malloc(w_q.nbytes()));
332
+ auto w_q_ptr =
333
+ reinterpret_cast<std::complex<float>*>(w_q.data<complex64_t>());
334
+
335
+ std::ptrdiff_t item_size = w_q.itemsize();
336
+ size_t fft_size = bluestein_n;
337
+ pocketfft::c2c(
338
+ /* shape= */ {fft_size},
339
+ /* stride_in= */ {item_size},
340
+ /* stride_out= */ {item_size},
341
+ /* axes= */ {0},
342
+ /* forward= */ true,
343
+ /* data_in= */ w_q_vec.data(),
344
+ /* data_out= */ w_q_ptr,
345
+ /* scale= */ 1.0f);
346
+ return std::make_tuple(w_k, w_q);
347
+ }
348
+
349
+ void multi_upload_bluestein_fft(
350
+ const array& in,
351
+ array& out,
352
+ size_t axis,
353
+ bool inverse,
354
+ bool real,
355
+ FFTPlan& plan,
356
+ std::vector<array>& copies,
357
+ const Stream& s) {
358
+ // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
359
+ // algorithm
360
+ int n = inverse ? out.shape(axis) : in.shape(axis);
361
+ auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
362
+ copies.push_back(w_k);
363
+ copies.push_back(w_q);
364
+
365
+ auto temp_shape = inverse ? out.shape() : in.shape();
366
+ array temp(temp_shape, complex64, nullptr, {});
367
+ array temp1(temp_shape, complex64, nullptr, {});
368
+
369
+ if (real && !inverse) {
370
+ // Convert float32->complex64
371
+ copy_gpu(in, temp, CopyType::General, s);
372
+ copies.push_back(temp);
373
+ } else if (real && inverse) {
374
+ int back_offset = n % 2 == 0 ? 2 : 1;
375
+ auto slice_shape = in.shape();
376
+ slice_shape[axis] -= back_offset;
377
+ array slice_temp(slice_shape, complex64, nullptr, {});
378
+ array conj_temp(in.shape(), complex64, nullptr, {});
379
+ copies.push_back(conj_temp);
380
+
381
+ Shape rstarts(in.ndim(), 0);
382
+ Shape rstrides(in.ndim(), 1);
383
+ rstarts[axis] = in.shape(axis) - back_offset;
384
+ rstrides[axis] = -1;
385
+ unary_op_gpu({in}, conj_temp, "Conjugate", s);
386
+ slice_gpu(in, slice_temp, rstarts, rstrides, s);
387
+ concatenate_gpu({conj_temp, slice_temp}, temp, (int)axis, s);
388
+ copies.push_back(temp);
389
+ } else if (inverse) {
390
+ unary_op_gpu({in}, temp, "Conjugate", s);
391
+ copies.push_back(temp);
392
+ } else {
393
+ temp.copy_shared_buffer(in);
394
+ }
395
+
396
+ Strides b_strides(in.ndim(), 0);
397
+ b_strides[axis] = 1;
398
+ array w_k_broadcast(temp.shape(), complex64, nullptr, {});
399
+ w_k_broadcast.copy_shared_buffer(w_k, b_strides, {}, w_k.data_size());
400
+ binary_op_gpu({temp, w_k_broadcast}, temp1, "Multiply", s);
401
+
402
+ std::vector<std::pair<int, int>> pads;
403
+ auto padded_shape = out.shape();
404
+ padded_shape[axis] = plan.bluestein_n;
405
+ array pad_temp(padded_shape, complex64, nullptr, {});
406
+ auto zero = array(complex64_t{0.0f, 0.0f});
407
+ copies.push_back(zero);
408
+ pad_gpu(temp1, zero, pad_temp, {(int)axis}, {0}, s);
409
+ copies.push_back(pad_temp);
410
+
411
+ array pad_temp1(padded_shape, complex64, nullptr, {});
412
+ fft_op(
413
+ pad_temp,
414
+ pad_temp1,
415
+ axis,
416
+ /*inverse=*/false,
417
+ /*real=*/false,
418
+ FourStepParams(),
419
+ /*inplace=*/false,
420
+ s);
421
+ copies.push_back(pad_temp1);
422
+
423
+ array w_q_broadcast(pad_temp1.shape(), complex64, nullptr, {});
424
+ w_q_broadcast.copy_shared_buffer(w_q, b_strides, {}, w_q.data_size());
425
+ binary_op_gpu_inplace({pad_temp1, w_q_broadcast}, pad_temp, "Multiply", s);
426
+
427
+ fft_op(
428
+ pad_temp,
429
+ pad_temp1,
430
+ axis,
431
+ /* inverse= */ true,
432
+ /* real= */ false,
433
+ FourStepParams(),
434
+ /*inplace=*/true,
435
+ s);
436
+
437
+ int offset = plan.bluestein_n - (2 * n - 1);
438
+ Shape starts(in.ndim(), 0);
439
+ Shape strides(in.ndim(), 1);
440
+ starts[axis] = plan.bluestein_n - offset - n;
441
+
442
+ array temp2(temp_shape, complex64, nullptr, {});
443
+ slice_gpu(pad_temp1, temp2, starts, strides, s);
444
+
445
+ binary_op_gpu_inplace({temp2, w_k_broadcast}, temp1, "Multiply", s);
446
+
447
+ if (real && !inverse) {
448
+ Shape rstarts(in.ndim(), 0);
449
+ Shape rstrides(in.ndim(), 1);
450
+ slice_gpu(temp1, out, rstarts, strides, s);
451
+ } else if (real && inverse) {
452
+ Strides b_strides(in.ndim(), 0);
453
+ auto inv_n = array({1.0f / n}, {1}, float32);
454
+ array temp_float(out.shape(), out.dtype(), nullptr, {});
455
+ copies.push_back(temp_float);
456
+ copies.push_back(inv_n);
457
+ copies.push_back(temp1);
458
+
459
+ copy_gpu(temp1, temp_float, CopyType::General, s);
460
+ binary_op_gpu({temp_float, inv_n}, out, "Multiply", s);
461
+ } else if (inverse) {
462
+ auto inv_n = array({1.0f / n}, {1}, complex64);
463
+ array temp3(temp_shape, complex64, nullptr, {});
464
+ unary_op_gpu({temp1}, temp3, "Conjugate", s);
465
+ binary_op_gpu({temp3, inv_n}, out, "Multiply", s);
466
+ copies.push_back(inv_n);
467
+ copies.push_back(temp1);
468
+ copies.push_back(temp3);
469
+ } else {
470
+ out.copy_shared_buffer(temp1);
471
+ }
472
+ }
473
+
474
+ void four_step_fft(
475
+ const array& in,
476
+ array& out,
477
+ size_t axis,
478
+ bool inverse,
479
+ bool real,
480
+ FFTPlan& plan,
481
+ std::vector<array>& copies,
482
+ const Stream& s,
483
+ bool in_place) {
484
+ if (plan.bluestein_n == -1) {
485
+ // Fast no transpose implementation for powers of 2.
486
+ FourStepParams four_step_params = {
487
+ /* required= */ true, /* first_step= */ true, plan.n1, plan.n2};
488
+ auto temp_shape = (real && inverse) ? out.shape() : in.shape();
489
+ array temp(temp_shape, complex64, nullptr, {});
490
+ fft_op(
491
+ in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
492
+ four_step_params.first_step = false;
493
+ fft_op(
494
+ temp,
495
+ out,
496
+ axis,
497
+ inverse,
498
+ real,
499
+ four_step_params,
500
+ /*inplace=*/in_place,
501
+ s);
502
+ copies.push_back(temp);
503
+ } else {
504
+ multi_upload_bluestein_fft(in, out, axis, inverse, real, plan, copies, s);
505
+ }
506
+ }
507
+
508
+ void fft_op(
509
+ const array& in,
510
+ array& out,
511
+ size_t axis,
512
+ bool inverse,
513
+ bool real,
514
+ const FourStepParams four_step_params,
515
+ bool inplace,
516
+ const Stream& s) {
517
+ auto& d = metal::device(s.device);
518
+
519
+ size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
520
+ if (n == 1) {
521
+ out.copy_shared_buffer(in);
522
+ return;
523
+ }
524
+
525
+ if (four_step_params.required) {
526
+ // Four Step FFT decomposes into two FFTs: n1 on columns, n2 on rows
527
+ n = four_step_params.first_step ? four_step_params.n1 : four_step_params.n2;
528
+ }
529
+
530
+ // Make sure that the array is contiguous and has stride 1 in the FFT dim
531
+ std::vector<array> copies;
532
+ auto check_input = [&axis, &copies, &s](const array& x) {
533
+ // TODO: Pass the strides to the kernel so
534
+ // we can avoid the copy when x is not contiguous.
535
+ bool no_copy = x.strides()[axis] == 1 &&
536
+ (x.flags().row_contiguous || x.flags().col_contiguous);
537
+ if (no_copy) {
538
+ return x;
539
+ } else {
540
+ array x_copy(x.shape(), x.dtype(), nullptr, {});
541
+ Strides strides;
542
+ int64_t cur_stride = x.shape(axis);
543
+ for (int a = 0; a < x.ndim(); a++) {
544
+ if (a == axis) {
545
+ strides.push_back(1);
546
+ } else {
547
+ strides.push_back(cur_stride);
548
+ cur_stride *= x.shape(a);
549
+ }
550
+ }
551
+
552
+ auto flags = x.flags();
553
+ auto [data_size, is_row_contiguous, is_col_contiguous] =
554
+ check_contiguity(x.shape(), strides);
555
+
556
+ flags.col_contiguous = is_col_contiguous;
557
+ flags.row_contiguous = is_row_contiguous;
558
+ flags.contiguous = data_size == x_copy.size();
559
+
560
+ x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags);
561
+ copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s);
562
+ copies.push_back(x_copy);
563
+ return x_copy;
564
+ }
565
+ };
566
+ const array& in_contiguous = check_input(in);
567
+
568
+ // real to complex: n -> (n/2)+1
569
+ // complex to real: (n/2)+1 -> n
570
+ auto out_strides = in_contiguous.strides();
571
+ size_t out_data_size = in_contiguous.data_size();
572
+ if (in.shape(axis) != out.shape(axis)) {
573
+ for (int i = 0; i < out_strides.size(); i++) {
574
+ if (out_strides[i] != 1) {
575
+ out_strides[i] = out_strides[i] / in.shape(axis) * out.shape(axis);
576
+ }
577
+ }
578
+ out_data_size = out_data_size / in.shape(axis) * out.shape(axis);
579
+ }
580
+
581
+ auto plan = plan_fft(n);
582
+ if (plan.four_step) {
583
+ four_step_fft(in, out, axis, inverse, real, plan, copies, s, inplace);
584
+ d.add_temporaries(std::move(copies), s.index);
585
+ return;
586
+ }
587
+
588
+ // TODO: allow donation here
589
+ if (!inplace) {
590
+ out.set_data(
591
+ allocator::malloc(out.nbytes()),
592
+ out_data_size,
593
+ out_strides,
594
+ in_contiguous.flags());
595
+ }
596
+
597
+ auto radices = supported_radices();
598
+ int fft_size = plan.bluestein_n > 0 ? plan.bluestein_n : n;
599
+
600
+ // Setup function constants
601
+ bool power_of_2 = is_power_of_2(fft_size);
602
+
603
+ auto make_int = [](int* a, int i) {
604
+ return std::make_tuple(a, MTL::DataType::DataTypeInt, i);
605
+ };
606
+ auto make_bool = [](bool* a, int i) {
607
+ return std::make_tuple(a, MTL::DataType::DataTypeBool, i);
608
+ };
609
+
610
+ std::vector<MTLFC> func_consts = {
611
+ make_bool(&inverse, 0), make_bool(&power_of_2, 1)};
612
+
613
+ // Start of radix/rader step constants
614
+ int index = 4;
615
+ for (int i = 0; i < plan.stockham.size(); i++) {
616
+ func_consts.push_back(make_int(&plan.stockham[i], index));
617
+ index += 1;
618
+ }
619
+ for (int i = 0; i < plan.rader.size(); i++) {
620
+ func_consts.push_back(make_int(&plan.rader[i], index));
621
+ index += 1;
622
+ }
623
+ int elems_per_thread = compute_elems_per_thread(plan);
624
+ func_consts.push_back(make_int(&elems_per_thread, 2));
625
+
626
+ int rader_m = n / plan.rader_n;
627
+ func_consts.push_back(make_int(&rader_m, 3));
628
+
629
+ // The overall number of FFTs we're going to compute for this input
630
+ size_t size = out.dtype() == float32 ? out.size() : in.size();
631
+ if (real && inverse && four_step_params.required) {
632
+ size = out.size();
633
+ }
634
+ int total_batch_size = size / n;
635
+ int threads_per_fft = (fft_size + elems_per_thread - 1) / elems_per_thread;
636
+
637
+ // We batch among threadgroups for improved efficiency when n is small
638
+ int threadgroup_batch_size = std::max(MIN_THREADGROUP_MEM_SIZE / fft_size, 1);
639
+ if (four_step_params.required) {
640
+ // Require a threadgroup batch size of at least 4 for four step FFT
641
+ // so we can coalesce the memory accesses.
642
+ threadgroup_batch_size =
643
+ std::max(threadgroup_batch_size, MIN_COALESCE_WIDTH);
644
+ }
645
+ int threadgroup_mem_size = next_power_of_2(threadgroup_batch_size * fft_size);
646
+ // FFTs up to 2^20 are currently supported
647
+ assert(threadgroup_mem_size <= MAX_STOCKHAM_FFT_SIZE);
648
+
649
+ // ceil divide
650
+ int batch_size =
651
+ (total_batch_size + threadgroup_batch_size - 1) / threadgroup_batch_size;
652
+
653
+ if (real && !four_step_params.required) {
654
+ // We can perform 2 RFFTs at once so the batch size is halved.
655
+ batch_size = (batch_size + 2 - 1) / 2;
656
+ }
657
+ auto& compute_encoder = d.get_command_encoder(s.index);
658
+ auto in_type_str = in.dtype() == float32 ? "float" : "float2";
659
+ auto out_type_str = out.dtype() == float32 ? "float" : "float2";
660
+ // Only required by four step
661
+ int step = -1;
662
+ {
663
+ std::ostringstream kname;
664
+ std::string inv_string = inverse ? "true" : "false";
665
+ std::string real_string = real ? "true" : "false";
666
+ std::string func_name;
667
+ if (plan.bluestein_n > 0) {
668
+ kname << "bluestein_fft_mem_" << threadgroup_mem_size << "_"
669
+ << in_type_str << "_" << out_type_str;
670
+ func_name = "bluestein_fft";
671
+ } else if (plan.rader_n > 1) {
672
+ kname << "rader_fft_mem_" << threadgroup_mem_size << "_" << in_type_str
673
+ << "_" << out_type_str;
674
+ func_name = "rader_fft";
675
+ } else if (four_step_params.required) {
676
+ step = four_step_params.first_step ? 0 : 1;
677
+ kname << "four_step_mem_" << threadgroup_mem_size << "_" << in_type_str
678
+ << "_" << out_type_str << "_" << step << "_" << real_string;
679
+ func_name = "four_step_fft";
680
+ } else {
681
+ kname << "fft_mem_" << threadgroup_mem_size << "_" << in_type_str << "_"
682
+ << out_type_str;
683
+ func_name = "fft";
684
+ }
685
+ std::string base_name = kname.str();
686
+ // We use a specialized kernel for each FFT size
687
+ kname << "_n" << fft_size << "_inv_" << inverse;
688
+ std::string hash_name = kname.str();
689
+ auto template_def = func_name == "four_step_fft" ? get_template_definition(
690
+ base_name,
691
+ func_name,
692
+ threadgroup_mem_size,
693
+ in_type_str,
694
+ out_type_str,
695
+ step,
696
+ real)
697
+ : get_template_definition(
698
+ base_name,
699
+ func_name,
700
+ threadgroup_mem_size,
701
+ in_type_str,
702
+ out_type_str);
703
+ auto kernel =
704
+ get_fft_kernel(d, base_name, hash_name, func_consts, template_def);
705
+
706
+ compute_encoder.set_compute_pipeline_state(kernel);
707
+ compute_encoder.set_input_array(in_contiguous, 0);
708
+ compute_encoder.set_output_array(out, 1);
709
+
710
+ if (plan.bluestein_n > 0) {
711
+ // Precomputed twiddle factors for Bluestein's
712
+ auto [w_k, w_q] = compute_bluestein_constants(n, plan.bluestein_n);
713
+ copies.push_back(w_q);
714
+ copies.push_back(w_k);
715
+
716
+ compute_encoder.set_input_array(w_q, 2); // w_q
717
+ compute_encoder.set_input_array(w_k, 3); // w_k
718
+ compute_encoder.set_bytes(n, 4);
719
+ compute_encoder.set_bytes(plan.bluestein_n, 5);
720
+ compute_encoder.set_bytes(total_batch_size, 6);
721
+ } else if (plan.rader_n > 1) {
722
+ auto [b_q, g_q, g_minus_q] = compute_raders_constants(plan.rader_n, s);
723
+ copies.push_back(b_q);
724
+ copies.push_back(g_q);
725
+ copies.push_back(g_minus_q);
726
+
727
+ compute_encoder.set_input_array(b_q, 2);
728
+ compute_encoder.set_input_array(g_q, 3);
729
+ compute_encoder.set_input_array(g_minus_q, 4);
730
+ compute_encoder.set_bytes(n, 5);
731
+ compute_encoder.set_bytes(total_batch_size, 6);
732
+ compute_encoder.set_bytes(plan.rader_n, 7);
733
+ } else if (four_step_params.required) {
734
+ compute_encoder.set_bytes(four_step_params.n1, 2);
735
+ compute_encoder.set_bytes(four_step_params.n2, 3);
736
+ compute_encoder.set_bytes(total_batch_size, 4);
737
+ } else {
738
+ compute_encoder.set_bytes(n, 2);
739
+ compute_encoder.set_bytes(total_batch_size, 3);
740
+ }
741
+
742
+ auto group_dims = MTL::Size(1, threadgroup_batch_size, threads_per_fft);
743
+ auto grid_dims =
744
+ MTL::Size(batch_size, threadgroup_batch_size, threads_per_fft);
745
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
746
+ }
747
+
748
+ d.add_temporaries(std::move(copies), s.index);
749
+ }
750
+
751
+ void fft_op(
752
+ const array& in,
753
+ array& out,
754
+ size_t axis,
755
+ bool inverse,
756
+ bool real,
757
+ bool inplace,
758
+ const Stream& s) {
759
+ fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
760
+ }
761
+
762
+ void nd_fft_op(
763
+ const array& in,
764
+ array& out,
765
+ const std::vector<size_t>& axes,
766
+ bool inverse,
767
+ bool real,
768
+ const Stream& s) {
769
+ // Perform ND FFT on GPU as a series of 1D FFTs
770
+ auto temp_shape = inverse ? in.shape() : out.shape();
771
+ std::vector<array> temp_arrs;
772
+ temp_arrs.emplace_back(temp_shape, complex64, nullptr, std::vector<array>{});
773
+ if (axes.size() > 2) {
774
+ temp_arrs.emplace_back(
775
+ temp_shape, complex64, nullptr, std::vector<array>{});
776
+ }
777
+ for (int i = axes.size() - 1; i >= 0; i--) {
778
+ int reverse_index = axes.size() - i - 1;
779
+ // For 5D and above, we don't want to reallocate our two temporary arrays
780
+ bool inplace = reverse_index >= 3 && i != 0;
781
+ // Opposite order for fft vs ifft
782
+ int index = inverse ? reverse_index : i;
783
+ size_t axis = axes[index];
784
+ // Mirror np.fft.(i)rfftn and perform a real transform
785
+ // only on the final axis.
786
+ bool step_real = (real && index == axes.size() - 1);
787
+ const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[i % 2];
788
+ array& out_arr = i == 0 ? out : temp_arrs[1 - i % 2];
789
+ fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
790
+ }
791
+
792
+ auto& d = metal::device(s.device);
793
+ d.add_temporaries(std::move(temp_arrs), s.index);
794
+ }
795
+
796
+ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
797
+ auto& s = stream();
798
+ auto& in = inputs[0];
799
+
800
+ if (axes_.size() > 1) {
801
+ nd_fft_op(in, out, axes_, inverse_, real_, s);
802
+ } else {
803
+ fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
804
+ }
805
+ }
806
+
807
+ } // namespace mlx::core