mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,816 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <cstdlib>
4
+ #include <sstream>
5
+
6
+ #define NS_PRIVATE_IMPLEMENTATION
7
+ #define CA_PRIVATE_IMPLEMENTATION
8
+ #define MTL_PRIVATE_IMPLEMENTATION
9
+
10
+ #include "mlx/backend/common/utils.h"
11
+ #include "mlx/backend/metal/device.h"
12
+ #include "mlx/backend/metal/metal.h"
13
+ #include "mlx/backend/metal/utils.h"
14
+ #include "mlx/utils.h"
15
+
16
+ namespace mlx::core::metal {
17
+
18
+ namespace {
19
+
20
+ constexpr const char* default_mtllib_path = METAL_PATH;
21
+
22
+ auto get_metal_version() {
23
+ auto get_metal_version_ = []() {
24
+ if (__builtin_available(macOS 26, iOS 26, tvOS 26, visionOS 26, *)) {
25
+ return MTL::LanguageVersion4_0;
26
+ } else if (__builtin_available(macOS 15, iOS 18, tvOS 18, visionOS 2, *)) {
27
+ return MTL::LanguageVersion3_2;
28
+ } else {
29
+ return MTL::LanguageVersion3_1;
30
+ }
31
+ };
32
+ static auto metal_version_ = get_metal_version_();
33
+ return metal_version_;
34
+ }
35
+
36
+ auto load_device() {
37
+ auto devices = MTL::CopyAllDevices();
38
+ auto device = static_cast<MTL::Device*>(devices->object(0))
39
+ ?: MTL::CreateSystemDefaultDevice();
40
+ if (!device) {
41
+ throw std::runtime_error("Failed to load device");
42
+ }
43
+ return device;
44
+ }
45
+ std::pair<MTL::Library*, NS::Error*> load_library_from_path(
46
+ MTL::Device* device,
47
+ const char* path) {
48
+ auto library = NS::String::string(path, NS::UTF8StringEncoding);
49
+ NS::Error* error;
50
+ auto lib = device->newLibrary(library, &error);
51
+
52
+ return std::make_pair(lib, error);
53
+ }
54
+
55
+ #ifdef SWIFTPM_BUNDLE
56
+ MTL::Library* try_load_bundle(
57
+ MTL::Device* device,
58
+ NS::URL* url,
59
+ const std::string& lib_name) {
60
+ std::string bundle_path = std::string(url->fileSystemRepresentation()) + "/" +
61
+ SWIFTPM_BUNDLE + ".bundle";
62
+ auto bundle = NS::Bundle::alloc()->init(
63
+ NS::String::string(bundle_path.c_str(), NS::UTF8StringEncoding));
64
+ if (bundle != nullptr) {
65
+ std::string resource_path =
66
+ std::string(bundle->resourceURL()->fileSystemRepresentation()) + "/" +
67
+ lib_name + ".metallib";
68
+ auto [lib, error] = load_library_from_path(device, resource_path.c_str());
69
+ if (lib) {
70
+ return lib;
71
+ }
72
+ }
73
+ return nullptr;
74
+ }
75
+
76
+ MTL::Library* try_load_framework(
77
+ MTL::Device* device,
78
+ NS::URL* url,
79
+ const std::string& lib_name) {
80
+ std::string resource_path = std::string(url->fileSystemRepresentation()) +
81
+ "/" + lib_name + ".metallib";
82
+ auto [lib, error] = load_library_from_path(device, resource_path.c_str());
83
+ if (lib) {
84
+ return lib;
85
+ }
86
+ return nullptr;
87
+ }
88
+ #endif
89
+
90
+ // Firstly, search for the metallib in the same path as this binary
91
+ std::pair<MTL::Library*, NS::Error*> load_colocated_library(
92
+ MTL::Device* device,
93
+ const std::string& relative_path) {
94
+ auto path = current_binary_dir() / relative_path;
95
+ if (!path.has_extension()) {
96
+ path.replace_extension(".metallib");
97
+ }
98
+
99
+ return load_library_from_path(device, path.c_str());
100
+ }
101
+
102
+ std::pair<MTL::Library*, NS::Error*> load_swiftpm_library(
103
+ MTL::Device* device,
104
+ const std::string& lib_name) {
105
+ #ifdef SWIFTPM_BUNDLE
106
+ MTL::Library* library =
107
+ try_load_bundle(device, NS::Bundle::mainBundle()->bundleURL(), lib_name);
108
+ if (library != nullptr) {
109
+ return {library, nullptr};
110
+ }
111
+ auto bundles = NS::Bundle::allBundles();
112
+ for (int i = 0, c = (int)bundles->count(); i < c; i++) {
113
+ auto bundle = reinterpret_cast<NS::Bundle*>(bundles->object(i));
114
+ library = try_load_bundle(device, bundle->resourceURL(), lib_name);
115
+ if (library != nullptr) {
116
+ return {library, nullptr};
117
+ }
118
+ }
119
+ // if SWIFTPM_BUNDLE is a framework identifier, try loading from that
120
+ auto frameworks = NS::Bundle::allFrameworks();
121
+ for (int i = 0, c = (int)frameworks->count(); i < c; i++) {
122
+ const auto bundle = reinterpret_cast<NS::Bundle*>(frameworks->object(i));
123
+ const auto identifier = bundle->bundleIdentifier();
124
+ if (identifier != nullptr &&
125
+ !strcmp(identifier->utf8String(), SWIFTPM_BUNDLE)) {
126
+ library = try_load_framework(device, bundle->resourceURL(), lib_name);
127
+ if (library != nullptr) {
128
+ return {library, nullptr};
129
+ }
130
+ }
131
+ }
132
+ #endif
133
+ return {nullptr, nullptr};
134
+ }
135
+
136
+ MTL::Library* load_default_library(MTL::Device* device) {
137
+ NS::Error* error[5];
138
+ MTL::Library* lib;
139
+ // First try the colocated mlx.metallib
140
+ std::tie(lib, error[0]) = load_colocated_library(device, "mlx");
141
+ if (lib) {
142
+ return lib;
143
+ }
144
+
145
+ std::tie(lib, error[1]) = load_colocated_library(device, "Resources/mlx");
146
+ if (lib) {
147
+ return lib;
148
+ }
149
+
150
+ // Then try default.metallib in a SwiftPM bundle if we have one
151
+ std::tie(lib, error[2]) = load_swiftpm_library(device, "default");
152
+ if (lib) {
153
+ return lib;
154
+ }
155
+
156
+ // Try lo load resources from Framework resources if SwiftPM wrapped as a
157
+ // dynamic framework.
158
+ std::tie(lib, error[3]) = load_colocated_library(device, "Resources/default");
159
+ if (lib) {
160
+ return lib;
161
+ }
162
+
163
+ // Finally try default_mtllib_path
164
+ std::tie(lib, error[4]) = load_library_from_path(device, default_mtllib_path);
165
+ if (!lib) {
166
+ std::ostringstream msg;
167
+ msg << "Failed to load the default metallib. ";
168
+ for (int i = 0; i < 5; i++) {
169
+ if (error[i] != nullptr) {
170
+ msg << error[i]->localizedDescription()->utf8String() << " ";
171
+ }
172
+ }
173
+ throw std::runtime_error(msg.str());
174
+ }
175
+ return lib;
176
+ }
177
+
178
+ MTL::Library* load_library(
179
+ MTL::Device* device,
180
+ const std::string& lib_name,
181
+ const std::string& lib_path) {
182
+ // We have been given a path that ends in metallib so try to load it
183
+ if (lib_path.size() > 9 &&
184
+ std::equal(lib_path.end() - 9, lib_path.end(), ".metallib")) {
185
+ auto [lib, error] = load_library_from_path(device, lib_path.c_str());
186
+ if (!lib) {
187
+ std::ostringstream msg;
188
+ msg << "Failed to load the metallib from <" << lib_path << "> with error "
189
+ << error->localizedDescription()->utf8String();
190
+ throw std::runtime_error(msg.str());
191
+ }
192
+ return lib;
193
+ }
194
+
195
+ // We have been given a path so try to load from lib_path / lib_name.metallib
196
+ if (lib_path.size() > 0) {
197
+ std::string full_path = lib_path + "/" + lib_name + ".metallib";
198
+ auto [lib, error] = load_library_from_path(device, full_path.c_str());
199
+ if (!lib) {
200
+ std::ostringstream msg;
201
+ msg << "Failed to load the metallib from <" << full_path
202
+ << "> with error " << error->localizedDescription()->utf8String();
203
+ throw std::runtime_error(msg.str());
204
+ }
205
+ return lib;
206
+ }
207
+
208
+ // Try to load the colocated library
209
+ {
210
+ auto [lib, error] = load_colocated_library(device, lib_name);
211
+ if (lib) {
212
+ return lib;
213
+ }
214
+ }
215
+
216
+ // Try to load the library from swiftpm
217
+ {
218
+ auto [lib, error] = load_swiftpm_library(device, lib_name);
219
+ if (lib) {
220
+ return lib;
221
+ }
222
+ }
223
+
224
+ std::ostringstream msg;
225
+ msg << "Failed to load the metallib " << lib_name << ".metallib. "
226
+ << "We attempted to load it from <" << current_binary_dir() << "/"
227
+ << lib_name << ".metallib>";
228
+ #ifdef SWIFTPM_BUNDLE
229
+ msg << " and from the Swift PM bundle.";
230
+ #endif
231
+ throw std::runtime_error(msg.str());
232
+ }
233
+
234
+ } // namespace
235
+
236
+ CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
237
+ enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
238
+ enc_->retain();
239
+ }
240
+
241
+ CommandEncoder::~CommandEncoder() {
242
+ enc_->endEncoding();
243
+ enc_->release();
244
+ }
245
+
246
+ void CommandEncoder::set_buffer(
247
+ const MTL::Buffer* buf,
248
+ int idx,
249
+ int64_t offset /* = 0 */) {
250
+ enc_->setBuffer(buf, offset, idx);
251
+ }
252
+
253
+ void CommandEncoder::set_input_array(
254
+ const array& a,
255
+ int idx,
256
+ int64_t offset /* = 0 */) {
257
+ if (all_inputs_.insert(a.buffer().ptr()).second) {
258
+ stream_.buffer_sizes += a.data_size();
259
+ }
260
+ auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
261
+ needs_barrier_ =
262
+ needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
263
+ auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
264
+ enc_->setBuffer(a_buf, a.offset() + offset, idx);
265
+ }
266
+
267
+ void CommandEncoder::set_output_array(
268
+ array& a,
269
+ int idx,
270
+ int64_t offset /* = 0 */) {
271
+ // Add barriers before adding the output to the output set
272
+ set_input_array(a, idx, offset);
273
+ register_output_array(a);
274
+ }
275
+
276
+ void CommandEncoder::register_output_array(const array& a) {
277
+ all_outputs_.insert(a.buffer().ptr());
278
+
279
+ auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
280
+ if (concurrent_) {
281
+ concurrent_outputs_.insert(buf);
282
+ } else {
283
+ next_outputs_.insert(buf);
284
+ }
285
+ }
286
+
287
+ void CommandEncoder::maybeInsertBarrier() {
288
+ if (needs_barrier_) {
289
+ enc_->memoryBarrier(MTL::BarrierScopeBuffers);
290
+ needs_barrier_ = false;
291
+ prev_outputs_ = std::move(next_outputs_);
292
+ } else {
293
+ prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
294
+ }
295
+ next_outputs_.clear();
296
+ }
297
+
298
+ void CommandEncoder::dispatch_threadgroups(
299
+ MTL::Size grid_dims,
300
+ MTL::Size group_dims) {
301
+ maybeInsertBarrier();
302
+ stream_.buffer_ops++;
303
+ enc_->dispatchThreadgroups(grid_dims, group_dims);
304
+ }
305
+
306
+ void CommandEncoder::dispatch_threads(
307
+ MTL::Size grid_dims,
308
+ MTL::Size group_dims) {
309
+ maybeInsertBarrier();
310
+ stream_.buffer_ops++;
311
+ enc_->dispatchThreads(grid_dims, group_dims);
312
+ }
313
+
314
+ void CommandEncoder::barrier() {
315
+ enc_->memoryBarrier(MTL::BarrierScopeBuffers);
316
+ }
317
+
318
+ Device::Device() {
319
+ auto pool = new_scoped_memory_pool();
320
+ device_ = load_device();
321
+ default_library_ = load_default_library(device_);
322
+ arch_ = std::string(device_->architecture()->name()->utf8String());
323
+ int ag_tens = arch_[arch_.size() - 3] - '0';
324
+ int ag_ones = arch_[arch_.size() - 2] - '0';
325
+ arch_gen_ = ag_tens * 10 + ag_ones;
326
+ auto arch = arch_.back();
327
+ switch (arch) {
328
+ case 'p': // phone
329
+ max_ops_per_buffer_ = 20;
330
+ max_mb_per_buffer_ = 40;
331
+ break;
332
+ case 'g': // base, pro
333
+ max_ops_per_buffer_ = 40;
334
+ max_mb_per_buffer_ = 40;
335
+ break;
336
+ case 's': // max
337
+ max_ops_per_buffer_ = 50;
338
+ max_mb_per_buffer_ = 50;
339
+ break;
340
+ case 'd': // ultra
341
+ max_ops_per_buffer_ = 50;
342
+ max_mb_per_buffer_ = 50;
343
+ break;
344
+ default: // default to medium
345
+ max_ops_per_buffer_ = 40;
346
+ max_mb_per_buffer_ = 40;
347
+ break;
348
+ }
349
+ max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);
350
+ max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);
351
+ }
352
+
353
+ Device::~Device() {
354
+ auto pool = new_scoped_memory_pool();
355
+ for (auto& [l, kernel_map] : library_kernels_) {
356
+ l->release();
357
+ for (auto& [_, k] : kernel_map) {
358
+ k->release();
359
+ }
360
+ }
361
+ stream_map_.clear();
362
+ device_->release();
363
+ }
364
+
365
+ void Device::new_queue(int index) {
366
+ auto thread_pool = metal::new_scoped_memory_pool();
367
+ auto q = device_->newCommandQueue();
368
+ debug_set_stream_queue_label(q, index);
369
+ if (!q) {
370
+ throw std::runtime_error(
371
+ "[metal::Device] Failed to make new command queue.");
372
+ }
373
+ stream_map_.emplace(index, q);
374
+ if (residency_set_ != nullptr) {
375
+ q->addResidencySet(residency_set_);
376
+ }
377
+ }
378
+
379
+ MTL::CommandQueue* Device::get_queue(Stream stream) {
380
+ return get_stream_(stream.index).queue;
381
+ }
382
+
383
+ bool Device::command_buffer_needs_commit(int index) {
384
+ auto& stream = get_stream_(index);
385
+ return (stream.buffer_ops > max_ops_per_buffer_) ||
386
+ ((stream.buffer_sizes >> 20) > max_mb_per_buffer_);
387
+ }
388
+
389
+ MTL::CommandBuffer* Device::get_command_buffer(int index) {
390
+ auto& stream = get_stream_(index);
391
+ if (stream.buffer == nullptr) {
392
+ stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
393
+ if (!stream.buffer) {
394
+ throw std::runtime_error(
395
+ "[metal::Device] Unable to create new command buffer");
396
+ }
397
+ // Increment ref count so the buffer is not garbage collected
398
+ stream.buffer->retain();
399
+ }
400
+ return stream.buffer;
401
+ }
402
+
403
+ void Device::commit_command_buffer(int index) {
404
+ auto& stream = get_stream_(index);
405
+ stream.buffer->commit();
406
+ stream.buffer->release();
407
+ stream.buffer = nullptr;
408
+ stream.buffer_ops = 0;
409
+ stream.buffer_sizes = 0;
410
+ }
411
+
412
+ void Device::add_temporary(array arr, int index) {
413
+ get_stream_(index).temporaries.push_back(std::move(arr));
414
+ }
415
+
416
+ void Device::add_temporaries(std::vector<array> arrays, int index) {
417
+ if (arrays.empty()) {
418
+ return;
419
+ }
420
+ auto& stream = get_stream_(index);
421
+ stream.temporaries.insert(
422
+ stream.temporaries.end(),
423
+ std::make_move_iterator(arrays.begin()),
424
+ std::make_move_iterator(arrays.end()));
425
+ }
426
+
427
+ void Device::end_encoding(int index) {
428
+ auto& stream = get_stream_(index);
429
+ if (stream.encoder != nullptr) {
430
+ // Each command encoder has a unique fence. We also store a map of
431
+ // all previous outputs of command encoders to their corresponding fence.
432
+ // - The command encoder records its inputs and outputs.
433
+ // - Wait on a fence if any inputs in the encoder are outputs of a previous
434
+ // encoder.
435
+ // - Update the map of outputs to include this command encoder's outputs.
436
+ // - Always signal this command encoders fence.
437
+ // - Add a completion handler for this command encoder that removes outputs
438
+ // from the map to limit the growth of the map and avoid unnecessary waits
439
+ // - Temporaries are a special case as they do not cross command encoder
440
+ // boundaries. These can be removed early from the encoders inputs and
441
+ // outputs since they don't need synchronization.
442
+ auto& enc = *stream.encoder;
443
+ // Remove temporaries from inputs and outputs
444
+ for (auto& t : stream.temporaries) {
445
+ enc.outputs().erase(t.buffer().ptr());
446
+ enc.inputs().erase(t.buffer().ptr());
447
+ }
448
+
449
+ // Keep references to the fences we waited on and put them
450
+ // in the completion handler so they are not prematurely released
451
+ std::unordered_set<std::shared_ptr<Fence>> waiting_on;
452
+ {
453
+ std::lock_guard<std::mutex> lk(stream.fence_mtx);
454
+ for (auto in : enc.inputs()) {
455
+ if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
456
+ // If we've already waited on a fence, don't wait on it again.
457
+ if (waiting_on.find(it->second) == waiting_on.end()) {
458
+ enc.wait_for_fence(it->second->fence);
459
+ waiting_on.insert(it->second);
460
+ }
461
+ }
462
+ }
463
+ for (auto out : enc.outputs()) {
464
+ stream.outputs[out] = stream.fence;
465
+ }
466
+ }
467
+ enc.update_fence(stream.fence->fence);
468
+ stream.buffer->addCompletedHandler(
469
+ [&stream,
470
+ waiting_on = std::move(waiting_on),
471
+ fence = std::move(stream.fence),
472
+ outputs = std::move(enc.outputs()),
473
+ temporaries =
474
+ std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
475
+ temporaries.clear();
476
+ std::lock_guard<std::mutex> lk(stream.fence_mtx);
477
+ for (auto o : outputs) {
478
+ if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
479
+ if (it->second == fence) {
480
+ stream.outputs.erase(it);
481
+ }
482
+ }
483
+ }
484
+ });
485
+ }
486
+ stream.encoder = nullptr;
487
+ }
488
+
489
+ CommandEncoder& Device::get_command_encoder(int index) {
490
+ auto& stream = get_stream_(index);
491
+ if (stream.encoder == nullptr) {
492
+ // Ensure there is an active command buffer
493
+ if (stream.buffer == nullptr) {
494
+ get_command_buffer(index);
495
+ }
496
+ stream.encoder = std::make_unique<CommandEncoder>(stream);
497
+ stream.fence = std::make_shared<Fence>(device_->newFence());
498
+ }
499
+ return *stream.encoder;
500
+ }
501
+
502
+ MTL::Library* Device::get_library(
503
+ const std::string& name,
504
+ const std::string& path /* = "" */) {
505
+ {
506
+ std::shared_lock rlock(library_mtx_);
507
+ if (auto it = library_map_.find(name); it != library_map_.end()) {
508
+ return it->second;
509
+ }
510
+ }
511
+
512
+ std::unique_lock wlock(library_mtx_);
513
+ if (auto it = library_map_.find(name); it != library_map_.end()) {
514
+ return it->second;
515
+ }
516
+
517
+ auto new_lib = load_library(device_, name, path.c_str());
518
+ library_map_.insert({name, new_lib});
519
+ return new_lib;
520
+ }
521
+
522
+ MTL::Library* Device::build_library_(const std::string& source_string) {
523
+ auto pool = new_scoped_memory_pool();
524
+
525
+ auto ns_code =
526
+ NS::String::string(source_string.c_str(), NS::ASCIIStringEncoding);
527
+
528
+ NS::Error* error = nullptr;
529
+ auto options = MTL::CompileOptions::alloc()->init();
530
+ options->setFastMathEnabled(false);
531
+ options->setLanguageVersion(get_metal_version());
532
+ #ifndef NDEBUG
533
+ if (options->languageVersion() >= MTL::LanguageVersion3_2) {
534
+ options->setEnableLogging(true);
535
+ }
536
+ #endif
537
+ auto mtl_lib = device_->newLibrary(ns_code, options, &error);
538
+ options->release();
539
+
540
+ // Throw error if unable to compile library
541
+ if (!mtl_lib) {
542
+ std::ostringstream msg;
543
+ msg << "[metal::Device] Unable to build metal library from source\n";
544
+ if (error) {
545
+ msg << error->localizedDescription()->utf8String() << "\n";
546
+ }
547
+ throw std::runtime_error(msg.str());
548
+ }
549
+
550
+ return mtl_lib;
551
+ }
552
+
553
+ MTL::Function* Device::get_function_(
554
+ const std::string& name,
555
+ MTL::Library* mtl_lib) {
556
+ // Pull kernel from library
557
+ auto ns_name = NS::String::string(name.c_str(), NS::ASCIIStringEncoding);
558
+ auto mtl_function = mtl_lib->newFunction(ns_name);
559
+
560
+ return mtl_function;
561
+ }
562
+
563
+ MTL::Function* Device::get_function_(
564
+ const std::string& name,
565
+ const std::string& specialized_name,
566
+ const MTLFCList& func_consts,
567
+ MTL::Library* mtl_lib) {
568
+ if (func_consts.empty() && (specialized_name == name)) {
569
+ return get_function_(name, mtl_lib);
570
+ }
571
+
572
+ // Prepare function constants
573
+ auto mtl_func_consts = MTL::FunctionConstantValues::alloc()->init();
574
+
575
+ for (auto [value, type, index] : func_consts) {
576
+ mtl_func_consts->setConstantValue(value, type, index);
577
+ }
578
+
579
+ // Prepare function desc
580
+ auto desc = MTL::FunctionDescriptor::functionDescriptor();
581
+ desc->setName(NS::String::string(name.c_str(), NS::ASCIIStringEncoding));
582
+ desc->setSpecializedName(
583
+ NS::String::string(specialized_name.c_str(), NS::ASCIIStringEncoding));
584
+ desc->setConstantValues(mtl_func_consts);
585
+
586
+ // Pull kernel from library
587
+ NS::Error* error = nullptr;
588
+ auto mtl_function = mtl_lib->newFunction(desc, &error);
589
+
590
+ // Throw error if unable to build metal function
591
+ if (!mtl_function) {
592
+ std::ostringstream msg;
593
+ msg << "[metal::Device] Unable to load function " << name << "\n";
594
+ if (error) {
595
+ msg << error->localizedDescription()->utf8String() << "\n";
596
+ }
597
+ throw std::runtime_error(msg.str());
598
+ }
599
+
600
+ mtl_func_consts->release();
601
+
602
+ return mtl_function;
603
+ }
604
+
605
+ MTL::ComputePipelineState* Device::get_kernel_(
606
+ const std::string& name,
607
+ const MTL::Function* mtl_function) {
608
+ // Compile kernel to compute pipeline
609
+ NS::Error* error = nullptr;
610
+ MTL::ComputePipelineState* kernel;
611
+
612
+ if (mtl_function) {
613
+ kernel = device_->newComputePipelineState(mtl_function, &error);
614
+ }
615
+
616
+ // Throw error if unable to compile metal function
617
+ if (!mtl_function || !kernel) {
618
+ std::ostringstream msg;
619
+ msg << "[metal::Device] Unable to load kernel " << name << "\n";
620
+ if (error) {
621
+ msg << error->localizedDescription()->utf8String() << "\n";
622
+ }
623
+ throw std::runtime_error(msg.str());
624
+ }
625
+
626
+ return kernel;
627
+ }
628
+
629
+ MTL::ComputePipelineState* Device::get_kernel_(
630
+ const std::string& name,
631
+ const MTL::Function* mtl_function,
632
+ const MTL::LinkedFunctions* linked_functions) {
633
+ // Check inputs
634
+ if (!linked_functions) {
635
+ return get_kernel_(name, mtl_function);
636
+ }
637
+
638
+ if (!mtl_function) {
639
+ std::ostringstream msg;
640
+ msg << "[metal::Device] Unable to load kernel " << name << "\n";
641
+ throw std::runtime_error(msg.str());
642
+ }
643
+
644
+ // Prepare compute pipeline state descriptor
645
+ auto desc = MTL::ComputePipelineDescriptor::alloc()->init();
646
+ desc->setComputeFunction(mtl_function);
647
+ desc->setLinkedFunctions(linked_functions);
648
+
649
+ // Compile kernel to compute pipeline
650
+ NS::Error* error = nullptr;
651
+ auto kernel = device_->newComputePipelineState(
652
+ desc, MTL::PipelineOptionNone, nullptr, &error);
653
+
654
+ // Throw error if unable to compile metal function
655
+ if (!kernel) {
656
+ std::ostringstream msg;
657
+ msg << "[metal::Device] Unable to load kernel " << name << "\n";
658
+ if (error) {
659
+ msg << error->localizedDescription()->utf8String() << "\n";
660
+ }
661
+ throw std::runtime_error(msg.str());
662
+ }
663
+
664
+ return kernel;
665
+ }
666
+
667
+ MTL::Library* Device::get_library_(const std::string& name) {
668
+ std::shared_lock lock(library_mtx_);
669
+ auto it = library_map_.find(name);
670
+ return (it != library_map_.end()) ? it->second : nullptr;
671
+ }
672
+
673
+ MTL::Library* Device::get_library(
674
+ const std::string& name,
675
+ const std::function<std::string(void)>& builder) {
676
+ {
677
+ std::shared_lock rlock(library_mtx_);
678
+ if (auto it = library_map_.find(name); it != library_map_.end()) {
679
+ return it->second;
680
+ }
681
+ }
682
+
683
+ std::unique_lock wlock(library_mtx_);
684
+ if (auto it = library_map_.find(name); it != library_map_.end()) {
685
+ return it->second;
686
+ }
687
+
688
+ auto mtl_lib = build_library_(builder());
689
+ library_map_.insert({name, mtl_lib});
690
+ return mtl_lib;
691
+ }
692
+
693
+ void Device::clear_library(const std::string& name) {
694
+ std::unique_lock wlock(library_mtx_);
695
+ if (auto it = library_map_.find(name); it != library_map_.end()) {
696
+ auto kernel_map_it = library_kernels_.find(it->second);
697
+ for (auto& [_, kernel] : kernel_map_it->second) {
698
+ kernel->release();
699
+ }
700
+ library_kernels_.erase(kernel_map_it);
701
+ it->second->release();
702
+ library_map_.erase(it);
703
+ }
704
+ }
705
+
706
+ MTL::LinkedFunctions* Device::get_linked_functions_(
707
+ const std::vector<MTL::Function*>& funcs) {
708
+ if (funcs.empty()) {
709
+ return nullptr;
710
+ }
711
+
712
+ auto lfuncs = MTL::LinkedFunctions::linkedFunctions();
713
+
714
+ std::vector<NS::Object*> objs(funcs.size());
715
+ for (int i = 0; i < funcs.size(); i++) {
716
+ objs[i] = funcs[i];
717
+ }
718
+
719
+ NS::Array* funcs_arr = NS::Array::array(objs.data(), funcs.size());
720
+
721
+ lfuncs->setPrivateFunctions(funcs_arr);
722
+
723
+ return lfuncs;
724
+ }
725
+
726
+ MTL::ComputePipelineState* Device::get_kernel_(
727
+ const std::string& base_name,
728
+ MTL::Library* mtl_lib,
729
+ const std::string& hash_name,
730
+ const MTLFCList& func_consts /* = {} */,
731
+ const std::vector<MTL::Function*>& linked_functions /* = {} */) {
732
+ // Single writer allowed
733
+ std::unique_lock wlock(kernel_mtx_);
734
+
735
+ // Try loading again to avoid loading twice
736
+ auto& kernel_map_ = library_kernels_[mtl_lib];
737
+ if (auto it = kernel_map_.find(hash_name); it != kernel_map_.end()) {
738
+ return it->second;
739
+ }
740
+
741
+ auto pool = new_scoped_memory_pool();
742
+
743
+ // Pull kernel from library
744
+ auto mtl_function = get_function_(base_name, hash_name, func_consts, mtl_lib);
745
+
746
+ // Compile kernel to compute pipeline
747
+ auto mtl_linked_funcs = get_linked_functions_(linked_functions);
748
+ auto kernel = get_kernel_(hash_name, mtl_function, mtl_linked_funcs);
749
+
750
+ mtl_function->release();
751
+ mtl_linked_funcs->release();
752
+
753
+ // Add kernel to cache
754
+ kernel_map_.insert({hash_name, kernel});
755
+
756
+ return kernel;
757
+ }
758
+
759
+ MTL::ComputePipelineState* Device::get_kernel(
760
+ const std::string& base_name,
761
+ MTL::Library* mtl_lib,
762
+ const std::string& hash_name /* = "" */,
763
+ const MTLFCList& func_consts /* = {} */,
764
+ const std::vector<MTL::Function*>& linked_functions /* = {} */) {
765
+ const auto& kname = hash_name.empty() ? base_name : hash_name;
766
+ {
767
+ // Multiple readers allowed
768
+ std::shared_lock lock(kernel_mtx_);
769
+
770
+ // Look for cached kernel
771
+ auto& kernel_map_ = library_kernels_[mtl_lib];
772
+ if (auto it = kernel_map_.find(kname); it != kernel_map_.end()) {
773
+ return it->second;
774
+ }
775
+ }
776
+ return get_kernel_(base_name, mtl_lib, kname, func_consts, linked_functions);
777
+ }
778
+
779
+ MTL::ComputePipelineState* Device::get_kernel(
780
+ const std::string& base_name,
781
+ const std::string& hash_name /* = "" */,
782
+ const MTLFCList& func_consts /* = {} */,
783
+ const std::vector<MTL::Function*>& linked_functions /* = {} */) {
784
+ return get_kernel(
785
+ base_name, default_library_, hash_name, func_consts, linked_functions);
786
+ }
787
+
788
+ void Device::set_residency_set(const MTL::ResidencySet* residency_set) {
789
+ if (residency_set_ != nullptr) {
790
+ throw std::runtime_error(
791
+ "[Device::set_residency_set] Can only be set once.");
792
+ }
793
+ if (residency_set == nullptr) {
794
+ return;
795
+ }
796
+ residency_set_ = residency_set;
797
+ // Attach residency set to existing command queues
798
+ for (auto& [_, stream] : stream_map_) {
799
+ stream.queue->addResidencySet(residency_set_);
800
+ }
801
+ }
802
+
803
+ Device& device(mlx::core::Device) {
804
+ static Device metal_device;
805
+ return metal_device;
806
+ }
807
+
808
+ std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool() {
809
+ auto dtor = [](void* ptr) {
810
+ static_cast<NS::AutoreleasePool*>(ptr)->release();
811
+ };
812
+ return std::unique_ptr<void, std::function<void(void*)>>(
813
+ NS::AutoreleasePool::alloc()->init(), dtor);
814
+ }
815
+
816
+ } // namespace mlx::core::metal