mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,449 @@
1
+ cmake_minimum_required(VERSION 3.25)
2
+
3
+ if(NOT MLX_VERSION)
4
+ file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
5
+ string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
6
+ set(_major ${CMAKE_MATCH_1})
7
+ string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
8
+ set(_minor ${CMAKE_MATCH_1})
9
+ string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
10
+ set(_patch ${CMAKE_MATCH_1})
11
+ set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
12
+ set(MLX_VERSION ${MLX_PROJECT_VERSION})
13
+ else()
14
+ string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
15
+ ${MLX_VERSION})
16
+ endif()
17
+
18
+ project(
19
+ mlx
20
+ LANGUAGES C CXX
21
+ VERSION ${MLX_PROJECT_VERSION})
22
+
23
+ # ----------------------------- Setup -----------------------------
24
+ set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
25
+ set(CMAKE_CXX_STANDARD 20)
26
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
27
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
28
+ set(CMAKE_INSTALL_MESSAGE NEVER)
29
+ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
30
+
31
+ # ----------------------------- Configuration -----------------------------
32
+ option(MLX_BUILD_TESTS "Build tests for mlx" ON)
33
+ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
34
+ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
35
+ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
36
+ option(MLX_BUILD_METAL "Build metal backend" ON)
37
+ option(MLX_BUILD_CPU "Build cpu backend" ON)
38
+ option(MLX_BUILD_CUDA "Build cuda backend" OFF)
39
+ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
40
+ option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
41
+ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
42
+ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
43
+ option(MLX_BUILD_PYTHON_STUBS "Build stub files for python bindings" ON)
44
+ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
45
+ option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
46
+ option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
47
+ option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
48
+ option(USE_ASAN "Enable AddressSanitizer (ASan)" OFF)
49
+ option(USE_UBSAN "Enable UndefinedBehaviorSanitizer (UBSan)" OFF)
50
+ option(USE_TSAN "Enable ThreadSanitizer (TSan)" OFF)
51
+
52
+ # --------------------- Processor tests -------------------------
53
+ message(
54
+ STATUS
55
+ "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
56
+ )
57
+
58
+ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
59
+ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
60
+ if(NOT MLX_ENABLE_X64_MAC)
61
+ message(
62
+ FATAL_ERROR
63
+ "Building for x86_64 on macOS is not supported."
64
+ " If you are on an Apple silicon system, check the build"
65
+ " documentation for possible fixes: "
66
+ "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
67
+ )
68
+ else()
69
+ set(MLX_BUILD_METAL OFF)
70
+ message(WARNING "Building for x86_64 arch is not officially supported.")
71
+ endif()
72
+ endif()
73
+ else()
74
+ set(MLX_BUILD_METAL OFF)
75
+ endif()
76
+
77
+ if(MLX_USE_CCACHE)
78
+ find_program(CCACHE_PROGRAM ccache)
79
+ if(CCACHE_PROGRAM)
80
+ message(STATUS "Found CCache: ${CCACHE_PROGRAM}")
81
+ set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
82
+ set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
83
+ set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
84
+ endif()
85
+ endif()
86
+
87
+ if(USE_ASAN AND USE_TSAN)
88
+ message(
89
+ FATAL_ERROR
90
+ "AddressSanitizer (ASan) and ThreadSanitizer (TSan) are mutually exclusive and cannot be enabled at the same time."
91
+ )
92
+ endif()
93
+
94
+ set(SANITIZER_COMPILE_FLAGS "")
95
+ set(SANITIZER_LINK_FLAGS "")
96
+
97
+ if(USE_ASAN)
98
+ if(WIN32 AND MSVC)
99
+ list(APPEND SANITIZER_COMPILE_FLAGS /fsanitize=address)
100
+ list(APPEND SANITIZER_LINK_FLAGS /fsanitize=address)
101
+ else()
102
+ list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=address)
103
+ list(APPEND SANITIZER_LINK_FLAGS -fsanitize=address)
104
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
105
+ list(APPEND SANITIZER_LINK_FLAGS -lpthread)
106
+ endif()
107
+ endif()
108
+ endif()
109
+
110
+ if(USE_UBSAN)
111
+ if(WIN32 AND MSVC)
112
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
113
+ list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
114
+ list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
115
+ else()
116
+ message(
117
+ WARNING
118
+ "UndefinedBehaviorSanitizer (UBSan) is not directly supported via a simple flag in MSVC."
119
+ )
120
+ endif()
121
+ else()
122
+ list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=undefined)
123
+ list(APPEND SANITIZER_LINK_FLAGS -fsanitize=undefined)
124
+ endif()
125
+ endif()
126
+
127
+ if(USE_TSAN)
128
+ if(WIN32 AND MSVC)
129
+ message(
130
+ FATAL_ERROR
131
+ "ThreadSanitizer (TSan) is not supported by the MSVC compiler. Please use Clang or GCC."
132
+ )
133
+ elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
134
+ message(FATAL_ERROR "ThreadSanitizer (TSan) is not supported on macOS.")
135
+ else()
136
+ list(APPEND SANITIZER_COMPILE_FLAGS -fsanitize=thread)
137
+ list(APPEND SANITIZER_LINK_FLAGS -fsanitize=thread)
138
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
139
+ list(APPEND SANITIZER_LINK_FLAGS -lpthread)
140
+ endif()
141
+ endif()
142
+ endif()
143
+
144
+ # ----------------------------- Lib -----------------------------
145
+
146
+ include(FetchContent)
147
+ # Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
148
+ cmake_policy(SET CMP0135 NEW)
149
+
150
+ add_library(mlx)
151
+
152
+ target_compile_options(mlx PUBLIC ${SANITIZER_COMPILE_FLAGS})
153
+ target_link_options(mlx PUBLIC ${SANITIZER_LINK_FLAGS})
154
+
155
+ if(MLX_BUILD_CUDA)
156
+ enable_language(CUDA)
157
+ find_package(CUDAToolkit REQUIRED)
158
+ find_package(CUDNN REQUIRED)
159
+ endif()
160
+
161
+ if(MLX_BUILD_METAL)
162
+ find_library(METAL_LIB Metal)
163
+ find_library(FOUNDATION_LIB Foundation)
164
+ find_library(QUARTZ_LIB QuartzCore)
165
+ if(METAL_LIB)
166
+ message(STATUS "Metal found ${METAL_LIB}")
167
+ else()
168
+ message(
169
+ FATAL_ERROR
170
+ "Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
171
+ endif()
172
+
173
+ if(MLX_METAL_DEBUG)
174
+ add_compile_definitions(MLX_METAL_DEBUG)
175
+ endif()
176
+
177
+ # Throw an error if xcrun not found
178
+ execute_process(
179
+ COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
180
+ OUTPUT_VARIABLE MACOS_SDK_VERSION
181
+ OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
182
+
183
+ if(${MACOS_SDK_VERSION} LESS 14.0)
184
+ message(
185
+ FATAL_ERROR
186
+ "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
187
+ endif()
188
+ message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
189
+
190
+ set(METAL_CPP_URL
191
+ https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip)
192
+
193
+ if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
194
+ if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0)
195
+ message(FATAL_ERROR "MLX requires macOS >= 14.0")
196
+ endif()
197
+ set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
198
+ endif()
199
+ execute_process(
200
+ COMMAND
201
+ zsh "-c"
202
+ "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
203
+ OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
204
+ FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
205
+ FetchContent_MakeAvailable(metal_cpp)
206
+ target_include_directories(
207
+ mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
208
+ $<INSTALL_INTERFACE:include/metal_cpp>)
209
+ target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
210
+ endif()
211
+
212
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
213
+ # With newer clang/gcc versions following libs are implicitly linked, but when
214
+ # building on old distributions they need to be explicitly listed.
215
+ target_link_libraries(mlx PRIVATE dl pthread)
216
+ endif()
217
+
218
+ if(WIN32)
219
+ if(MSVC)
220
+ # GGUF does not build with MSVC.
221
+ set(MLX_BUILD_GGUF OFF)
222
+ endif()
223
+ # Generate DLL and EXE in the same dir, otherwise EXE will not be able to run.
224
+ # This is only done when MLX is built as the top project.
225
+ if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
226
+ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
227
+ endif()
228
+ # Windows implementation of dlfcn.h APIs.
229
+ FetchContent_Declare(
230
+ dlfcn-win32
231
+ GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
232
+ GIT_TAG v1.4.2
233
+ EXCLUDE_FROM_ALL)
234
+ block()
235
+ set(BUILD_SHARED_LIBS OFF)
236
+ FetchContent_MakeAvailable(dlfcn-win32)
237
+ endblock()
238
+ target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
239
+ target_link_libraries(mlx PRIVATE dl)
240
+ endif()
241
+
242
+ if(MLX_BUILD_CPU)
243
+ find_library(ACCELERATE_LIBRARY Accelerate)
244
+ if(ACCELERATE_LIBRARY)
245
+ message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
246
+ set(MLX_BUILD_ACCELERATE ON)
247
+ else()
248
+ message(STATUS "Accelerate not found, using default backend.")
249
+ set(MLX_BUILD_ACCELERATE OFF)
250
+ endif()
251
+
252
+ if(MLX_BUILD_ACCELERATE)
253
+ target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
254
+ add_compile_definitions(MLX_USE_ACCELERATE)
255
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
256
+ elseif(WIN32)
257
+ # Download and link prebuilt binaries of OpenBLAS. Note that we can only
258
+ # link with the dynamic library, the prebuilt binaries were built with MinGW
259
+ # so static-linking would require linking with MinGW's runtime.
260
+ FetchContent_Declare(
261
+ openblas
262
+ URL "https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.31/OpenBLAS-0.3.31-x64.zip"
263
+ )
264
+ FetchContent_MakeAvailable(openblas)
265
+ target_link_libraries(mlx
266
+ PRIVATE "${openblas_SOURCE_DIR}/lib/libopenblas.lib")
267
+ target_include_directories(mlx PRIVATE "${openblas_SOURCE_DIR}/include")
268
+ # Make sure the DLL file is placed in the same dir with executables.
269
+ set(OPENBLAS_DLL_FILE "${openblas_SOURCE_DIR}/bin/libopenblas.dll")
270
+ add_custom_command(
271
+ TARGET mlx
272
+ POST_BUILD
273
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${OPENBLAS_DLL_FILE}
274
+ ${CMAKE_BINARY_DIR})
275
+ else()
276
+ if(${CMAKE_HOST_APPLE})
277
+ # The blas shipped in macOS SDK is not supported, search homebrew for
278
+ # openblas instead.
279
+ set(BLA_VENDOR OpenBLAS)
280
+ set(LAPACK_ROOT
281
+ "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
282
+ endif()
283
+ # Search and link with lapack.
284
+ find_package(LAPACK REQUIRED)
285
+ if(NOT LAPACK_FOUND)
286
+ message(FATAL_ERROR "Must have LAPACK installed")
287
+ endif()
288
+ find_path(
289
+ LAPACK_INCLUDE_DIRS lapacke.h
290
+ /usr/include
291
+ /usr/include/lapacke
292
+ /usr/include/x86_64-linux-gnu
293
+ /usr/include/x86_64-linux-gnu/lapacke
294
+ /usr/local/include
295
+ /usr/local/include/lapacke
296
+ /usr/local/opt/openblas/include)
297
+ message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
298
+ message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
299
+ target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
300
+ target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
301
+ # List blas after lapack otherwise we may accidentally incldue an old
302
+ # version of lapack.h from the include dirs of blas.
303
+ find_package(BLAS REQUIRED)
304
+ if(NOT BLAS_FOUND)
305
+ message(FATAL_ERROR "Must have BLAS installed")
306
+ endif()
307
+ # TODO find a cleaner way to do this
308
+ find_path(
309
+ BLAS_INCLUDE_DIRS cblas.h
310
+ /usr/include
311
+ /usr/include/x86_64-linux-gnu
312
+ /usr/local/include
313
+ /usr/local/include/openblas
314
+ $ENV{BLAS_HOME}/include)
315
+ message(STATUS "Blas lib " ${BLAS_LIBRARIES})
316
+ message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
317
+ target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
318
+ target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
319
+ endif()
320
+ else()
321
+ set(MLX_BUILD_ACCELERATE OFF)
322
+ endif()
323
+
324
+ message(STATUS "Downloading json")
325
+ FetchContent_Declare(
326
+ json
327
+ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
328
+ FetchContent_MakeAvailable(json)
329
+ target_include_directories(
330
+ mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
331
+
332
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
333
+
334
+ target_include_directories(
335
+ mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
336
+ $<INSTALL_INTERFACE:include>)
337
+
338
+ if(USE_SYSTEM_FMT)
339
+ find_package(fmt REQUIRED)
340
+ else()
341
+ FetchContent_Declare(
342
+ fmt
343
+ GIT_REPOSITORY https://github.com/fmtlib/fmt.git
344
+ GIT_TAG 12.1.0
345
+ EXCLUDE_FROM_ALL)
346
+ FetchContent_MakeAvailable(fmt)
347
+ endif()
348
+ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
349
+
350
+ if(MLX_BUILD_PYTHON_BINDINGS)
351
+ message(STATUS "Building Python bindings.")
352
+ find_package(
353
+ Python 3.10
354
+ COMPONENTS Interpreter Development.Module
355
+ REQUIRED)
356
+ FetchContent_Declare(
357
+ nanobind
358
+ GIT_REPOSITORY https://github.com/wjakob/nanobind.git
359
+ GIT_TAG v2.10.2
360
+ GIT_SHALLOW TRUE
361
+ EXCLUDE_FROM_ALL)
362
+ FetchContent_MakeAvailable(nanobind)
363
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
364
+ endif()
365
+
366
+ if(MLX_BUILD_TESTS)
367
+ include(CTest)
368
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
369
+ endif()
370
+
371
+ if(MLX_BUILD_EXAMPLES)
372
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
373
+ endif()
374
+
375
+ if(MLX_BUILD_BENCHMARKS)
376
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
377
+ endif()
378
+
379
+ # ----------------------------- Installation -----------------------------
380
+ include(GNUInstallDirs)
381
+
382
+ if(WIN32)
383
+ # Install DLLs to the same dir with extension file (core.pyd) on Windows.
384
+ set(CMAKE_INSTALL_BINDIR ".")
385
+ if(MLX_BUILD_CPU)
386
+ # Install OpenBLAS.
387
+ install(FILES ${OPENBLAS_DLL_FILE} TYPE BIN)
388
+ endif()
389
+ endif()
390
+
391
+ # Install library
392
+ install(
393
+ TARGETS mlx
394
+ EXPORT MLXTargets
395
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
396
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
397
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
398
+ INCLUDES
399
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
400
+
401
+ # Install headers
402
+ install(
403
+ DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
404
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
405
+ COMPONENT headers
406
+ FILES_MATCHING
407
+ PATTERN "*.h"
408
+ PATTERN "backend/metal/kernels.h" EXCLUDE)
409
+
410
+ # Install metal dependencies
411
+ if(MLX_BUILD_METAL)
412
+
413
+ # Install metal cpp
414
+ install(
415
+ DIRECTORY ${metal_cpp_SOURCE_DIR}/
416
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
417
+ COMPONENT metal_cpp_source)
418
+
419
+ endif()
420
+
421
+ # Install cmake config
422
+ set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
423
+ set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
424
+ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
425
+
426
+ install(
427
+ EXPORT MLXTargets
428
+ FILE MLXTargets.cmake
429
+ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
430
+
431
+ include(CMakePackageConfigHelpers)
432
+
433
+ write_basic_package_version_file(
434
+ ${MLX_CMAKE_BUILD_VERSION_CONFIG}
435
+ COMPATIBILITY SameMajorVersion
436
+ VERSION ${MLX_VERSION})
437
+
438
+ configure_package_config_file(
439
+ ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
440
+ INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
441
+ NO_CHECK_REQUIRED_COMPONENTS_MACRO
442
+ PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
443
+ MLX_CMAKE_INSTALL_MODULE_DIR)
444
+
445
+ install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
446
+ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
447
+
448
+ install(DIRECTORY ${CMAKE_MODULE_PATH}/
449
+ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
@@ -0,0 +1,177 @@
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ # of this software and associated documentation files (the "Software"), to deal
5
+ # in the Software without restriction, including without limitation the rights
6
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ # copies of the Software, and to permit persons to whom the Software is
8
+ # furnished to do so, subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ # SOFTWARE.
20
+
21
+ # Modified from
22
+ # https://github.com/NVIDIA/cudnn-frontend/blob/main/cmake/cuDNN.cmake
23
+
24
+ # Return the last file matching the pattern.
25
+ function(find_file_glob VAR PATTERN)
26
+ file(GLOB _RESULT "${PATTERN}")
27
+ if(_RESULT)
28
+ list(LENGTH ${_RESULT} _RESULT_LENGTH)
29
+ if(_RESULT_LENGTH GREATER 0)
30
+ list(GET ${_RESULT} -1 _RESULT)
31
+ endif()
32
+ set(${VAR}
33
+ "${_RESULT}"
34
+ PARENT_SCOPE)
35
+ endif()
36
+ endfunction()
37
+
38
+ # Find the dir including the "cudnn.h" file.
39
+ find_path(
40
+ CUDNN_INCLUDE_DIR cudnn.h
41
+ HINTS ${CUDNN_INCLUDE_PATH} ${CUDAToolkit_INCLUDE_DIRS}
42
+ PATH_SUFFIXES include OPTIONAL)
43
+
44
+ # Glob searching "cudnn.h" for Windows.
45
+ if(WIN32 AND NOT CUDNN_INCLUDE_DIR)
46
+ find_file_glob(
47
+ CUDNN_H_PATH
48
+ "C:/Program Files/NVIDIA/CUDNN/*/include/${CUDAToolkit_VERSION_MAJOR}.*/cudnn.h"
49
+ )
50
+ if(CUDNN_H_PATH)
51
+ get_filename_component(CUDNN_INCLUDE_DIR "${CUDNN_H_PATH}" DIRECTORY)
52
+ endif()
53
+ endif()
54
+
55
+ if(NOT CUDNN_INCLUDE_DIR)
56
+ message(
57
+ FATAL_ERROR
58
+ "Unable to find cudnn.h, please make sure cuDNN is installed and pass CUDNN_INCLUDE_PATH to cmake."
59
+ )
60
+ endif()
61
+
62
+ # Get cudnn version.
63
+ file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
64
+ string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef
65
+ "${cudnn_version_header}")
66
+ string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
67
+
68
+ # Function for searching library files.
69
+ function(find_cudnn_library NAME)
70
+ if(NOT "${ARGV1}" STREQUAL "OPTIONAL")
71
+ set(_CUDNN_REQUIRED TRUE)
72
+ else()
73
+ set(_CUDNN_REQUIRED FALSE)
74
+ endif()
75
+
76
+ find_library(
77
+ ${NAME}_LIBRARY
78
+ NAMES ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" NAMES_PER_DIR
79
+ HINTS ${CUDNN_LIBRARY_PATH} ${CUDAToolkit_LIBRARY_DIR}
80
+ PATH_SUFFIXES lib64 lib/x64 lib OPTIONAL)
81
+
82
+ if(WIN32 AND NOT ${NAME}_LIBRARY)
83
+ find_file_glob(
84
+ ${NAME}_LIBRARY
85
+ "C:/Program Files/NVIDIA/CUDNN/*/lib/${CUDAToolkit_VERSION_MAJOR}.*/x64/${NAME}.lib"
86
+ )
87
+ endif()
88
+
89
+ if(NOT ${NAME}_LIBRARY AND ${_CUDNN_REQUIRED})
90
+ message(
91
+ FATAL_ERROR
92
+ "Unable to find ${NAME}, please make sure cuDNN is installed and pass CUDNN_LIBRARY_PATH to cmake."
93
+ )
94
+ endif()
95
+
96
+ if(${NAME}_LIBRARY)
97
+ add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
98
+ set_target_properties(
99
+ CUDNN::${NAME}
100
+ PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
101
+ IMPORTED_LOCATION ${${NAME}_LIBRARY})
102
+ set(${NAME}_LIBRARY
103
+ "${${NAME}_LIBRARY}"
104
+ PARENT_SCOPE)
105
+ else()
106
+ message(STATUS "${NAME} not found.")
107
+ endif()
108
+ endfunction()
109
+
110
+ # Search for the main cudnn library.
111
+ find_cudnn_library(cudnn)
112
+
113
+ include(FindPackageHandleStandardArgs)
114
+ find_package_handle_standard_args(CUDNN REQUIRED_VARS CUDNN_INCLUDE_DIR
115
+ cudnn_LIBRARY)
116
+
117
+ if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)
118
+ set(CUDNN_FOUND
119
+ ON
120
+ CACHE INTERNAL "cuDNN Library Found")
121
+ else()
122
+ set(CUDNN_FOUND
123
+ OFF
124
+ CACHE INTERNAL "cuDNN Library Not Found")
125
+ endif()
126
+
127
+ # Find out all the DLL files for Windows.
128
+ if(WIN32 AND cudnn_LIBRARY)
129
+ get_filename_component(CUDNN_BIN_DIR "${cudnn_LIBRARY}" DIRECTORY)
130
+ string(REPLACE "/lib/" "/bin/" CUDNN_BIN_DIR "${CUDNN_BIN_DIR}")
131
+ file(
132
+ GLOB CUDNN_DLL_NAMES
133
+ RELATIVE "${CUDNN_BIN_DIR}"
134
+ "${CUDNN_BIN_DIR}/*.dll")
135
+ endif()
136
+
137
+ # Create an interface library that users can link with.
138
+ add_library(CUDNN::cudnn_all INTERFACE IMPORTED)
139
+ target_link_libraries(CUDNN::cudnn_all INTERFACE CUDNN::cudnn)
140
+ target_include_directories(
141
+ CUDNN::cudnn_all INTERFACE $<INSTALL_INTERFACE:include>
142
+ $<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>)
143
+
144
+ # Add other components of cudnn.
145
+ if(CUDNN_MAJOR_VERSION EQUAL 8)
146
+ find_cudnn_library(cudnn_adv_infer)
147
+ find_cudnn_library(cudnn_adv_train)
148
+ find_cudnn_library(cudnn_cnn_infer)
149
+ find_cudnn_library(cudnn_cnn_train)
150
+ find_cudnn_library(cudnn_ops_infer)
151
+ find_cudnn_library(cudnn_ops_train)
152
+
153
+ target_link_libraries(
154
+ CUDNN::cudnn_all
155
+ INTERFACE CUDNN::cudnn_adv_train CUDNN::cudnn_ops_train
156
+ CUDNN::cudnn_cnn_train CUDNN::cudnn_adv_infer
157
+ CUDNN::cudnn_cnn_infer CUDNN::cudnn_ops_infer)
158
+
159
+ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
160
+ find_cudnn_library(cudnn_graph)
161
+ find_cudnn_library(cudnn_engines_runtime_compiled)
162
+ find_cudnn_library(cudnn_ops OPTIONAL)
163
+ find_cudnn_library(cudnn_cnn OPTIONAL)
164
+ find_cudnn_library(cudnn_adv OPTIONAL)
165
+ find_cudnn_library(cudnn_engines_precompiled OPTIONAL)
166
+ find_cudnn_library(cudnn_heuristic OPTIONAL)
167
+
168
+ target_link_libraries(
169
+ CUDNN::cudnn_all
170
+ INTERFACE CUDNN::cudnn_graph
171
+ CUDNN::cudnn_engines_runtime_compiled
172
+ CUDNN::cudnn_ops
173
+ CUDNN::cudnn_cnn
174
+ CUDNN::cudnn_adv
175
+ CUDNN::cudnn_engines_precompiled
176
+ CUDNN::cudnn_heuristic)
177
+ endif()
@@ -0,0 +1,54 @@
1
+ # FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
2
+ # directories.
3
+
4
+ set(NCCL_ROOT_DIR
5
+ $ENV{NCCL_ROOT_DIR}
6
+ CACHE PATH "Folder contains NVIDIA NCCL")
7
+
8
+ find_path(
9
+ NCCL_INCLUDE_DIRS
10
+ NAMES nccl.h
11
+ HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
12
+ ${CUDA_TOOLKIT_ROOT_DIR}/include)
13
+
14
+ if($ENV{USE_STATIC_NCCL})
15
+ message(
16
+ STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
17
+ set(NCCL_LIBNAME "libnccl_static.a")
18
+ else()
19
+ set(NCCL_LIBNAME "nccl")
20
+ endif()
21
+
22
+ find_library(
23
+ NCCL_LIBRARIES
24
+ NAMES ${NCCL_LIBNAME}
25
+ HINTS ${NCCL_LIB_DIR}
26
+ ${NCCL_ROOT_DIR}
27
+ ${NCCL_ROOT_DIR}/lib
28
+ ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
29
+ ${NCCL_ROOT_DIR}/lib64
30
+ ${CUDA_TOOLKIT_ROOT_DIR}/lib
31
+ ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
32
+
33
+ include(FindPackageHandleStandardArgs)
34
+ find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
35
+ NCCL_LIBRARIES)
36
+
37
+ if(NCCL_FOUND)
38
+ set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
39
+ message(
40
+ STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
41
+ file(
42
+ STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
43
+ REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
44
+ LIMIT_COUNT 1)
45
+ if(NCCL_MAJOR_VERSION_DEFINED)
46
+ string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
47
+ NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
48
+ message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
49
+ endif()
50
+ message(
51
+ STATUS
52
+ "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
53
+ mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
54
+ endif()
@@ -0,0 +1,3 @@
1
+ # This file does nothing but to suppress the cmake warning: "By not providing
2
+ # Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
3
+ # find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.