mlx 0.30.7.3 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (590) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +104 -56
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/nn/base.rb +4 -0
  8. data/lib/mlx/nn/layers/linear.rb +2 -3
  9. data/lib/mlx/onnx.rb +250 -0
  10. data/lib/mlx/version.rb +1 -1
  11. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  12. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  13. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  14. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  15. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  16. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  17. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  18. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  19. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  20. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  21. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  22. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  23. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  24. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  25. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  26. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  28. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  29. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  30. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  31. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  32. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  33. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  34. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  35. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  36. data/submodules/mlx-onnx/LICENSE +21 -0
  37. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  38. data/submodules/mlx-onnx/src/api.cpp +81 -0
  39. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  40. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  41. data/submodules/mlx-onnx/src/export.cpp +653 -0
  42. data/submodules/mlx-onnx/src/io.cpp +61 -0
  43. data/submodules/mlx-onnx/src/json.hpp +25 -0
  44. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  45. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  46. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  47. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  48. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  49. metadata +609 -563
  50. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  51. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  52. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  53. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  54. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  55. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  56. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  57. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  58. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  59. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  60. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  61. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  62. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  63. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  64. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  65. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  66. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  67. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  68. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  69. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  70. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  71. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  590. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 25d582e4816d69b27713a4027534b75cd00ca72557e69681daf07146d3e79ef2
4
- data.tar.gz: c010252aa355370a531fa4f3b9bf8cc729876d2f7fb9ae8b8e0d6a1eb6cb57c4
3
+ metadata.gz: d8190b13b012fe0693ef46cc3f20b01b78f6d13cde44674fd01999434e56eda9
4
+ data.tar.gz: '059a7b993b17cae7bd448567ef6a9058d0a575cdc143a3784b411953e1309e7d'
5
5
  SHA512:
6
- metadata.gz: 53e629e845342f173c04c7c6d9d976a29dd5492ae945239897d3168a586288ec958ba58753345317f301220ac5f4b91a22f97731ab799fcea5d59f3d19e48214
7
- data.tar.gz: 5b04f2e63e3dcdb6a0282184a310600f4fb72e606b45e8e7a27a7b9461abef3a598afe2eb5525e66457da45b8a84a6c2f07c871c955e9cddd4308971496c7fd1
6
+ metadata.gz: 1a90a276ac6b68760bdbe348931e6cfc48ec01662a7870522cf344748a8e5df167f3d4f6b9d28bd768abc3c5dc9a5e020abe2a5fc81edf1ef4056e54abc269b7
7
+ data.tar.gz: 8b17e8c878f32659cc90b365943cf7fce132c03fd7680959c03b615b28819d6a002e6e05236209b0282fdec7b03f06872556f7ac292ac85e00d531fa0bd1c11a
data/ext/mlx/extconf.rb CHANGED
@@ -3,6 +3,9 @@
3
3
  require "etc"
4
4
  require "fileutils"
5
5
  require "mkmf"
6
+ require "open3"
7
+ require "pathname"
8
+ require "rbconfig"
6
9
 
7
10
  def run_or_abort(*cmd, chdir:)
8
11
  puts ">> #{cmd.join(' ')}"
@@ -17,6 +20,76 @@ def run_with_status(*cmd, chdir:)
17
20
  system(*cmd, chdir: chdir)
18
21
  end
19
22
 
23
+ def git_revision(path)
24
+ return nil unless Dir.exist?(path)
25
+
26
+ stdout, status = Open3.capture2("git", "-C", path, "rev-parse", "HEAD")
27
+ return nil unless status.success?
28
+
29
+ rev = stdout.strip
30
+ return rev if rev.match?(/\A[0-9a-f]{40}\z/)
31
+
32
+ nil
33
+ end
34
+
35
+ def mlx_revision_pinned_by_mlx_onnx(mlx_onnx_root)
36
+ return nil unless Dir.exist?(mlx_onnx_root)
37
+
38
+ stdout, status = Open3.capture2("git", "-C", mlx_onnx_root, "submodule", "status", "--", "mlx")
39
+ return nil unless status.success?
40
+
41
+ line = stdout.lines.first.to_s.strip
42
+ return nil if line.empty?
43
+
44
+ token = line.split(/\s+/).first.to_s
45
+ token = token.delete_prefix("-").delete_prefix("+").delete_prefix("U")
46
+ return token if token.match?(/\A[0-9a-f]{40}\z/)
47
+
48
+ nil
49
+ end
50
+
51
+ def enforce_mlx_onnx_compatibility!(mlx_root:, mlx_onnx_root:)
52
+ workspace_mlx_revision =
53
+ ENV.fetch("MLX_EXTCONF_TEST_MLX_REVISION", "").strip
54
+ pinned_mlx_revision =
55
+ ENV.fetch("MLX_EXTCONF_TEST_MLX_ONNX_PINNED_MLX_REVISION", "").strip
56
+
57
+ workspace_mlx_revision = git_revision(mlx_root) if workspace_mlx_revision.empty?
58
+ pinned_mlx_revision = mlx_revision_pinned_by_mlx_onnx(mlx_onnx_root) if pinned_mlx_revision.empty?
59
+
60
+ return if workspace_mlx_revision.nil? || pinned_mlx_revision.nil?
61
+ return if workspace_mlx_revision == pinned_mlx_revision
62
+
63
+ abort(<<~MSG)
64
+ mlx/mlx-onnx revision mismatch detected.
65
+ workspace mlx revision: #{workspace_mlx_revision}
66
+ mlx-onnx pinned mlx revision: #{pinned_mlx_revision}
67
+ Run:
68
+ git submodule update --init --recursive submodules/mlx submodules/mlx-onnx
69
+ MSG
70
+ end
71
+
72
+ def patch_mlx_onnx_gcc_optional_shape_initlist!(mlx_onnx_root)
73
+ lowering_cpp = File.join(mlx_onnx_root, "src", "lowering.cpp")
74
+ return unless File.file?(lowering_cpp)
75
+
76
+ source = File.read(lowering_cpp)
77
+ patched = source
78
+ .gsub(
79
+ "std::optional<Shape>({work_shape[0], 1, seq_len})",
80
+ "std::optional<Shape>(Shape{work_shape[0], 1, seq_len})"
81
+ )
82
+ .gsub(
83
+ "std::optional<Shape>({seq_len})",
84
+ "std::optional<Shape>(Shape{seq_len})"
85
+ )
86
+
87
+ return if patched == source
88
+
89
+ File.write(lowering_cpp, patched)
90
+ puts "patched mlx-onnx lowering.cpp optional Shape initlists for GCC compatibility"
91
+ end
92
+
20
93
  def rpath_flag(path)
21
94
  case RUBY_PLATFORM
22
95
  when /darwin/
@@ -28,15 +101,120 @@ def rpath_flag(path)
28
101
  end
29
102
  end
30
103
 
104
+ def cmake_compilers_from_cache(cache_path)
105
+ return [nil, nil] unless File.file?(cache_path)
106
+
107
+ cc = nil
108
+ cxx = nil
109
+ File.foreach(cache_path) do |line|
110
+ cc = Regexp.last_match(1).strip if line =~ /^CMAKE_C_COMPILER:(?:STRING|FILEPATH)=(.+)$/
111
+ cxx = Regexp.last_match(1).strip if line =~ /^CMAKE_CXX_COMPILER:(?:STRING|FILEPATH)=(.+)$/
112
+ end
113
+ [cc, cxx]
114
+ end
115
+
116
+ def normalize_compiler_for_mkmf(path, kind)
117
+ return path if path.nil? || path.empty?
118
+
119
+ case kind
120
+ when :cc
121
+ return "/usr/bin/clang" if path.end_with?("/usr/bin/cc")
122
+ when :cxx
123
+ return "/usr/bin/clang++" if path.end_with?("/usr/bin/c++")
124
+ end
125
+ path
126
+ end
127
+
128
+ def force_mkmf_compilers!(cc, cxx)
129
+ return if cc.nil? || cc.empty? || cxx.nil? || cxx.empty?
130
+
131
+ cc = normalize_compiler_for_mkmf(cc, :cc)
132
+ cxx = normalize_compiler_for_mkmf(cxx, :cxx)
133
+
134
+ [RbConfig::CONFIG, RbConfig::MAKEFILE_CONFIG].each do |cfg|
135
+ cfg["CC"] = cc if cfg.key?("CC")
136
+ cfg["CXX"] = cxx if cfg.key?("CXX")
137
+
138
+ if cfg.key?("LDSHARED") && cfg["LDSHARED"]
139
+ cfg["LDSHARED"] = cfg["LDSHARED"].sub(/\A\S+/, cc)
140
+ end
141
+ if cfg.key?("LDSHAREDXX") && cfg["LDSHAREDXX"]
142
+ cfg["LDSHAREDXX"] = cfg["LDSHAREDXX"].sub(/\A\S+/, cxx)
143
+ end
144
+ end
145
+
146
+ $CC = cc if defined?($CC)
147
+ $CXX = cxx if defined?($CXX)
148
+ $LDSHARED = $LDSHARED.sub(/\A\S+/, cc) if defined?($LDSHARED) && $LDSHARED
149
+ $LDSHAREDXX = $LDSHAREDXX.sub(/\A\S+/, cxx) if defined?($LDSHAREDXX) && $LDSHAREDXX
150
+ end
151
+
152
+ def patch_makefile_compilers!(makefile_path, cc, cxx)
153
+ return if cc.nil? || cc.empty? || cxx.nil? || cxx.empty?
154
+ return unless File.file?(makefile_path)
155
+
156
+ cc = normalize_compiler_for_mkmf(cc, :cc)
157
+ cxx = normalize_compiler_for_mkmf(cxx, :cxx)
158
+
159
+ text = File.read(makefile_path)
160
+ text = text.gsub(/^CC = .+$/, "CC = #{cc}")
161
+ text = text.gsub(/^CXX = .+$/, "CXX = #{cxx}")
162
+ File.write(makefile_path, text)
163
+ end
164
+
165
+ def patch_makefile_sources!(makefile_path, source_files, header_files)
166
+ return unless File.file?(makefile_path)
167
+
168
+ objects = source_files.map do |path|
169
+ stem = path.sub(/\.[^.]+\z/, "")
170
+ "#{stem}.o"
171
+ end
172
+ text = File.read(makefile_path)
173
+ text = text.gsub(/^ORIG_SRCS = .+$/, "ORIG_SRCS = #{source_files.join(' ')}")
174
+ text = text.gsub(/^OBJS = .+$/, "OBJS = #{objects.join(' ')}")
175
+ text = text.gsub(/^HDRS = .+$/, "HDRS = #{header_files.map { |path| '$(srcdir)/' + path }.join(' ')}")
176
+ File.write(makefile_path, text)
177
+ end
178
+
179
+ def patch_makefile_include_dirs!(makefile_path, include_dirs)
180
+ return unless File.file?(makefile_path)
181
+
182
+ text = File.read(makefile_path)
183
+ match = text.match(/^CPPFLAGS = (.+)$/)
184
+ return if match.nil?
185
+
186
+ cppflags = match[1]
187
+ include_dirs.each do |path|
188
+ flag = "-I#{path}"
189
+ cppflags = "#{cppflags} #{flag}" unless cppflags.include?(flag)
190
+ end
191
+ text = text.sub(/^CPPFLAGS = .+$/, "CPPFLAGS = #{cppflags}")
192
+ File.write(makefile_path, text)
193
+ end
194
+
31
195
  repo_root = File.expand_path("../..", __dir__)
32
- mlx_root = File.join(repo_root, "mlx")
33
- mlx_include_dir = mlx_root
196
+ submodules_root = File.join(repo_root, "submodules")
197
+ mlx_root = File.join(submodules_root, "mlx")
198
+ mlx_onnx_root = File.join(submodules_root, "mlx-onnx")
199
+ mlx_public_include_dir = mlx_root
200
+ mlx_internal_include_dir = File.join(mlx_root, "mlx")
201
+ mlx_onnx_public_include_dir = File.join(mlx_onnx_root, "include")
202
+ mlx_onnx_internal_include_dir = File.join(mlx_onnx_root, "src")
34
203
  ext_root = File.expand_path(__dir__)
204
+ ext_onnx_root = File.join(repo_root, "ext", "mlx-onnx")
35
205
  build_root = File.join(ext_root, "build")
36
206
  mlx_build_dir = File.join(build_root, "mlx")
207
+ mlx_onnx_build_dir = File.join(build_root, "mlx-onnx")
37
208
  mlx_install_dir = File.join(build_root, "install")
38
209
  jobs = [Etc.nprocessors, 1].max
39
210
 
211
+ enforce_mlx_onnx_compatibility!(mlx_root: mlx_root, mlx_onnx_root: mlx_onnx_root)
212
+ patch_mlx_onnx_gcc_optional_shape_initlist!(mlx_onnx_root)
213
+ if ENV["MLX_EXTCONF_VALIDATE_ONLY"] == "1"
214
+ puts "mlx-onnx compatibility check passed"
215
+ exit 0
216
+ end
217
+
40
218
  FileUtils.mkdir_p(mlx_build_dir)
41
219
 
42
220
  cmake_configure = [
@@ -54,7 +232,7 @@ cmake_configure = [
54
232
  "-DMLX_BUILD_PYTHON_STUBS=OFF",
55
233
  "-DMLX_BUILD_METAL=ON",
56
234
  "-DMLX_BUILD_GGUF=OFF",
57
- "-DMLX_BUILD_SAFETENSORS=OFF",
235
+ "-DMLX_BUILD_SAFETENSORS=ON",
58
236
  "-DBUILD_SHARED_LIBS=ON"
59
237
  ]
60
238
 
@@ -78,17 +256,98 @@ unless configured
78
256
  end
79
257
  run_or_abort(*cmake_build, chdir: ext_root)
80
258
 
81
- include_dir = mlx_include_dir
259
+ cmake_cache_path = File.join(mlx_build_dir, "CMakeCache.txt")
260
+ cmake_cc, cmake_cxx = cmake_compilers_from_cache(cmake_cache_path)
261
+ force_mkmf_compilers!(cmake_cc, cmake_cxx)
262
+
263
+ abort("missing MLX include dir: #{mlx_public_include_dir}") unless Dir.exist?(mlx_public_include_dir)
264
+ abort("missing MLX internal include dir: #{mlx_internal_include_dir}") unless Dir.exist?(mlx_internal_include_dir)
265
+ abort("missing mlx-onnx include dir: #{mlx_onnx_public_include_dir}") unless Dir.exist?(mlx_onnx_public_include_dir)
266
+ abort("missing mlx-onnx internal include dir: #{mlx_onnx_internal_include_dir}") unless Dir.exist?(mlx_onnx_internal_include_dir)
267
+ include_dirs = [
268
+ mlx_public_include_dir,
269
+ mlx_internal_include_dir,
270
+ mlx_onnx_public_include_dir,
271
+ mlx_onnx_internal_include_dir
272
+ ]
82
273
  lib_dir = File.join(mlx_install_dir, "lib")
83
274
 
84
- abort("missing MLX include dir: #{include_dir}") unless Dir.exist?(include_dir)
275
+ cmake_onnx_configure = [
276
+ "cmake",
277
+ "-S",
278
+ mlx_onnx_root,
279
+ "-B",
280
+ mlx_onnx_build_dir,
281
+ "-DCMAKE_BUILD_TYPE=Release",
282
+ "-DCMAKE_INSTALL_PREFIX=#{mlx_install_dir}",
283
+ "-DMLX_ONNX_USE_EXTERNAL_MLX=ON",
284
+ "-DMLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR=#{mlx_public_include_dir}",
285
+ "-DMLX_ONNX_EXTERNAL_MLX_LIB_DIR=#{lib_dir}",
286
+ "-DMLX_ONNX_BUILD_PYTHON_BINDINGS=OFF"
287
+ ]
288
+
289
+ cmake_onnx_build = [
290
+ "cmake",
291
+ "--build",
292
+ mlx_onnx_build_dir,
293
+ "--target",
294
+ "install",
295
+ "--config",
296
+ "Release",
297
+ "-j#{jobs}"
298
+ ]
299
+
300
+ run_or_abort(*cmake_onnx_configure, chdir: ext_root)
301
+ run_or_abort(*cmake_onnx_build, chdir: ext_root)
302
+
303
+ json_include_dirs = [
304
+ File.join(mlx_build_dir, "_deps", "json-src", "include"),
305
+ File.join(mlx_build_dir, "_deps", "json-src", "single_include"),
306
+ File.join(mlx_build_dir, "_deps", "json-src", "single_include", "nlohmann"),
307
+ File.join(mlx_onnx_build_dir, "_deps", "nlohmann_json-src", "include")
308
+ ].select { |path| Dir.exist?(path) }
309
+
85
310
  abort("missing MLX lib dir: #{lib_dir}") unless Dir.exist?(lib_dir)
86
311
 
87
- dir_config("mlx", include_dir, lib_dir)
312
+ dir_config("mlx", include_dirs.first, lib_dir)
88
313
 
89
314
  $CXXFLAGS = "#{$CXXFLAGS} -std=c++20"
90
- $CPPFLAGS = "#{$CPPFLAGS} -I#{include_dir}"
315
+ include_dirs.each do |path|
316
+ $CPPFLAGS = "#{$CPPFLAGS} -I#{path}"
317
+ end
318
+ json_include_dirs.each do |path|
319
+ $CPPFLAGS = "#{$CPPFLAGS} -I#{path}"
320
+ end
91
321
  $LDFLAGS = "#{$LDFLAGS} -L#{lib_dir} #{rpath_flag(lib_dir)}"
92
- $libs = "-lmlx #{$libs}"
322
+ $libs = "-lmlx_onnx -lmlx #{$libs}"
323
+
324
+ source_files = (
325
+ Dir.glob(File.join(ext_root, "**", "*.{c,cc,cpp,cxx}")) +
326
+ Dir.glob(File.join(ext_onnx_root, "**", "*.{c,cc,cpp,cxx}"))
327
+ ).reject { |path| path.start_with?(File.join(build_root, "")) }
328
+ .map { |path| Pathname(path).relative_path_from(Pathname(ext_root)).to_s }
329
+ .sort
330
+ header_files = (
331
+ Dir.glob(File.join(ext_root, "**", "*.{h,hpp,hh}")) +
332
+ Dir.glob(File.join(ext_onnx_root, "**", "*.{h,hpp,hh}"))
333
+ ).reject { |path| path.start_with?(File.join(build_root, "")) }
334
+ .map { |path| Pathname(path).relative_path_from(Pathname(ext_root)).to_s }
335
+ .sort
336
+ $srcs = source_files
337
+ $objs = source_files.map { |path| "#{path.sub(/\.[^.]+\z/, "")}.o" }
338
+ $hdrs = header_files
93
339
 
94
340
  create_makefile("mlx/native")
341
+ makefile_path = File.join(ext_root, "Makefile")
342
+ patch_makefile_compilers!(makefile_path, cmake_cc, cmake_cxx)
343
+ patch_makefile_sources!(makefile_path, source_files, header_files)
344
+ patch_makefile_include_dirs!(
345
+ makefile_path,
346
+ [
347
+ mlx_internal_include_dir,
348
+ mlx_onnx_public_include_dir,
349
+ mlx_onnx_internal_include_dir,
350
+ File.join(mlx_build_dir, "_deps", "json-src", "single_include", "nlohmann"),
351
+ File.join(mlx_onnx_build_dir, "_deps", "nlohmann_json-src", "include")
352
+ ]
353
+ )
data/ext/mlx/native.cpp CHANGED
@@ -10,12 +10,15 @@
10
10
  #include <limits>
11
11
  #include <optional>
12
12
  #include <sstream>
13
+ #include <stdexcept>
13
14
  #include <string>
14
15
  #include <type_traits>
15
16
  #include <unordered_map>
16
17
  #include <variant>
17
18
  #include <vector>
18
19
 
20
+ #include <nlohmann/json.hpp>
21
+
19
22
  #include "mlx/array.h"
20
23
  #include "mlx/backend/metal/metal.h"
21
24
  #include "mlx/compile.h"
@@ -37,6 +40,7 @@
37
40
  #include "mlx/transforms.h"
38
41
  #include "mlx/utils.h"
39
42
  #include "mlx/version.h"
43
+ #include "../mlx-onnx/native.hpp"
40
44
 
41
45
  namespace mx = mlx::core;
42
46
  namespace mxfft = mlx::core::fft;
@@ -44,6 +48,7 @@ namespace mxfast = mlx::core::fast;
44
48
  namespace mxlinalg = mlx::core::linalg;
45
49
  namespace mxmetal = mlx::core::metal;
46
50
  namespace mxdist = mlx::core::distributed;
51
+ using OrderedJson = nlohmann::ordered_json;
47
52
 
48
53
  static VALUE mMLX;
49
54
  static VALUE mNative;
@@ -339,14 +344,7 @@ static VALUE id_to_symbol(ID id) {
339
344
  }
340
345
 
341
346
  static ID cached_intern_id(const char* name) {
342
- static std::unordered_map<std::string, ID> cache;
343
- auto it = cache.find(name);
344
- if (it != cache.end()) {
345
- return it->second;
346
- }
347
- const ID id = rb_intern(name);
348
- cache.emplace(name, id);
349
- return id;
347
+ return rb_intern(name);
350
348
  }
351
349
 
352
350
  static mx::Device::DeviceType device_type_from_value(VALUE value) {
@@ -928,12 +926,6 @@ struct ArrayCollector {
928
926
 
929
927
  static void collect_arrays_from_tree(VALUE value, std::vector<mx::array>& arrays);
930
928
 
931
- static int hash_collect_arrays_iter(VALUE, VALUE value, VALUE arg) {
932
- auto* collector = reinterpret_cast<ArrayCollector*>(arg);
933
- collect_arrays_from_tree(value, *collector->arrays);
934
- return ST_CONTINUE;
935
- }
936
-
937
929
  static void collect_arrays_from_tree(VALUE value, std::vector<mx::array>& arrays) {
938
930
  if (rb_obj_is_kind_of(value, cArray)) {
939
931
  arrays.push_back(array_unwrap(value));
@@ -947,28 +939,37 @@ static void collect_arrays_from_tree(VALUE value, std::vector<mx::array>& arrays
947
939
  return;
948
940
  }
949
941
  if (RB_TYPE_P(value, T_HASH)) {
950
- ArrayCollector collector{&arrays};
951
- rb_hash_foreach(value, hash_collect_arrays_iter, reinterpret_cast<VALUE>(&collector));
942
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
943
+ const long len = RARRAY_LEN(keys);
944
+ for (long i = 0; i < len; ++i) {
945
+ VALUE key = rb_ary_entry(keys, i);
946
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
947
+ if (hash_value == Qundef) {
948
+ continue;
949
+ }
950
+ collect_arrays_from_tree(hash_value, arrays);
951
+ }
952
952
  }
953
953
  }
954
954
 
955
- struct ArrayMapBuilder {
956
- std::unordered_map<std::string, mx::array> map;
957
- };
958
-
959
- static int hash_to_array_map_iter(VALUE key, VALUE value, VALUE arg) {
960
- auto* builder = reinterpret_cast<ArrayMapBuilder*>(arg);
961
- builder->map.insert_or_assign(string_from_ruby(key), array_unwrap(value));
962
- return ST_CONTINUE;
963
- }
964
-
965
955
  static std::unordered_map<std::string, mx::array> array_map_from_ruby_hash(VALUE value) {
966
956
  if (!RB_TYPE_P(value, T_HASH)) {
967
957
  rb_raise(rb_eTypeError, "expected Hash mapping String/Symbol keys to MLX::Core::Array");
968
958
  }
969
- ArrayMapBuilder builder;
970
- rb_hash_foreach(value, hash_to_array_map_iter, reinterpret_cast<VALUE>(&builder));
971
- return builder.map;
959
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
960
+ const long len = RARRAY_LEN(keys);
961
+ std::unordered_map<std::string, mx::array> out;
962
+ out.reserve(static_cast<size_t>(len));
963
+ for (long i = 0; i < len; ++i) {
964
+ VALUE key = rb_ary_entry(keys, i);
965
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
966
+ if (hash_value == Qundef) {
967
+ continue;
968
+ }
969
+ std::string ruby_key = string_from_ruby(key);
970
+ out.insert_or_assign(ruby_key, array_unwrap(hash_value));
971
+ }
972
+ return out;
972
973
  }
973
974
 
974
975
  static VALUE ruby_hash_of_arrays(const std::unordered_map<std::string, mx::array>& map) {
@@ -990,16 +991,6 @@ static VALUE ruby_hash_of_strings(const std::unordered_map<std::string, std::str
990
991
  return out;
991
992
  }
992
993
 
993
- struct StringMapBuilder {
994
- std::unordered_map<std::string, std::string> map;
995
- };
996
-
997
- static int hash_to_string_map_iter(VALUE key, VALUE value, VALUE arg) {
998
- auto* builder = reinterpret_cast<StringMapBuilder*>(arg);
999
- builder->map.insert_or_assign(string_from_ruby(key), string_from_ruby(value));
1000
- return ST_CONTINUE;
1001
- }
1002
-
1003
994
  static std::unordered_map<std::string, std::string> string_map_from_ruby_hash(VALUE value) {
1004
995
  if (NIL_P(value)) {
1005
996
  return {};
@@ -1007,9 +998,20 @@ static std::unordered_map<std::string, std::string> string_map_from_ruby_hash(VA
1007
998
  if (!RB_TYPE_P(value, T_HASH)) {
1008
999
  rb_raise(rb_eTypeError, "expected Hash mapping String/Symbol keys to String values");
1009
1000
  }
1010
- StringMapBuilder builder;
1011
- rb_hash_foreach(value, hash_to_string_map_iter, reinterpret_cast<VALUE>(&builder));
1012
- return builder.map;
1001
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
1002
+ const long len = RARRAY_LEN(keys);
1003
+ std::unordered_map<std::string, std::string> out;
1004
+ out.reserve(static_cast<size_t>(len));
1005
+ for (long i = 0; i < len; ++i) {
1006
+ VALUE key = rb_ary_entry(keys, i);
1007
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
1008
+ if (hash_value == Qundef) {
1009
+ continue;
1010
+ }
1011
+ std::string ruby_key = string_from_ruby(key);
1012
+ out.insert_or_assign(ruby_key, string_from_ruby(hash_value));
1013
+ }
1014
+ return out;
1013
1015
  }
1014
1016
 
1015
1017
  static mx::GGUFMetaData gguf_metadata_from_ruby(VALUE value) {
@@ -1038,16 +1040,6 @@ static mx::GGUFMetaData gguf_metadata_from_ruby(VALUE value) {
1038
1040
  return std::monostate{};
1039
1041
  }
1040
1042
 
1041
- struct GGUFMetaMapBuilder {
1042
- std::unordered_map<std::string, mx::GGUFMetaData> map;
1043
- };
1044
-
1045
- static int hash_to_gguf_meta_map_iter(VALUE key, VALUE value, VALUE arg) {
1046
- auto* builder = reinterpret_cast<GGUFMetaMapBuilder*>(arg);
1047
- builder->map.insert_or_assign(string_from_ruby(key), gguf_metadata_from_ruby(value));
1048
- return ST_CONTINUE;
1049
- }
1050
-
1051
1043
  static std::unordered_map<std::string, mx::GGUFMetaData> gguf_meta_map_from_ruby_hash(VALUE value) {
1052
1044
  if (NIL_P(value)) {
1053
1045
  return {};
@@ -1055,9 +1047,20 @@ static std::unordered_map<std::string, mx::GGUFMetaData> gguf_meta_map_from_ruby
1055
1047
  if (!RB_TYPE_P(value, T_HASH)) {
1056
1048
  rb_raise(rb_eTypeError, "expected Hash for GGUF metadata");
1057
1049
  }
1058
- GGUFMetaMapBuilder builder;
1059
- rb_hash_foreach(value, hash_to_gguf_meta_map_iter, reinterpret_cast<VALUE>(&builder));
1060
- return builder.map;
1050
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
1051
+ const long len = RARRAY_LEN(keys);
1052
+ std::unordered_map<std::string, mx::GGUFMetaData> out;
1053
+ out.reserve(static_cast<size_t>(len));
1054
+ for (long i = 0; i < len; ++i) {
1055
+ VALUE key = rb_ary_entry(keys, i);
1056
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
1057
+ if (hash_value == Qundef) {
1058
+ continue;
1059
+ }
1060
+ std::string ruby_key = string_from_ruby(key);
1061
+ out.insert_or_assign(ruby_key, gguf_metadata_from_ruby(hash_value));
1062
+ }
1063
+ return out;
1061
1064
  }
1062
1065
 
1063
1066
  static VALUE gguf_metadata_to_ruby(const mx::GGUFMetaData& value) {
@@ -1501,6 +1504,23 @@ args_kwargs_function_from_callable(VALUE callable) {
1501
1504
  };
1502
1505
  }
1503
1506
 
1507
+ mx::array onnx_array_from_ruby(VALUE value) {
1508
+ return array_from_ruby(value, std::nullopt);
1509
+ }
1510
+
1511
+ std::vector<mx::array> onnx_array_vector_from_ruby(VALUE value) {
1512
+ return array_vector_from_ruby(value);
1513
+ }
1514
+
1515
+ std::unordered_map<std::string, mx::array> onnx_array_map_from_ruby_hash(VALUE value) {
1516
+ return array_map_from_ruby_hash(value);
1517
+ }
1518
+
1519
+ std::function<std::vector<mx::array>(const mx::Args&, const mx::Kwargs&)>
1520
+ onnx_args_kwargs_function_from_callable(VALUE callable) {
1521
+ return args_kwargs_function_from_callable(callable);
1522
+ }
1523
+
1504
1524
  static std::vector<int> argnums_from_value(VALUE value) {
1505
1525
  if (NIL_P(value)) {
1506
1526
  return {0};
@@ -3495,7 +3515,7 @@ static VALUE core_dequantize(int argc, VALUE* argv, VALUE) {
3495
3515
  dtype = optional_dtype_from_value(argv[6]);
3496
3516
  }
3497
3517
 
3498
- return array_wrap(mx::dequantize(w, scales, biases, group_size, bits, mode, dtype));
3518
+ return array_wrap(mx::dequantize(w, scales, biases, group_size, bits, mode, std::nullopt, dtype));
3499
3519
  } catch (const std::exception& error) {
3500
3520
  raise_std_exception(error);
3501
3521
  return Qnil;
@@ -6459,6 +6479,30 @@ static VALUE core_identity(int argc, VALUE* argv, VALUE) {
6459
6479
  }
6460
6480
  }
6461
6481
 
6482
+ static VALUE core_hanning(int argc, VALUE* argv, VALUE) {
6483
+ try {
6484
+ VALUE m;
6485
+ VALUE stream;
6486
+ rb_scan_args(argc, argv, "11", &m, &stream);
6487
+ return array_wrap(mx::hanning(NUM2INT(m), stream_or_device_from_value(stream)));
6488
+ } catch (const std::exception& error) {
6489
+ raise_std_exception(error);
6490
+ return Qnil;
6491
+ }
6492
+ }
6493
+
6494
+ static VALUE core_hamming(int argc, VALUE* argv, VALUE) {
6495
+ try {
6496
+ VALUE m;
6497
+ VALUE stream;
6498
+ rb_scan_args(argc, argv, "11", &m, &stream);
6499
+ return array_wrap(mx::hamming(NUM2INT(m), stream_or_device_from_value(stream)));
6500
+ } catch (const std::exception& error) {
6501
+ raise_std_exception(error);
6502
+ return Qnil;
6503
+ }
6504
+ }
6505
+
6462
6506
  static VALUE core_tri(int argc, VALUE* argv, VALUE) {
6463
6507
  try {
6464
6508
  if (argc < 1 || argc > 4) {
@@ -7622,6 +7666,7 @@ extern "C" void Init_native(void) {
7622
7666
  rb_define_singleton_method(mNative, "loaded?", RUBY_METHOD_FUNC(native_loaded_p), 0);
7623
7667
 
7624
7668
  mCore = rb_define_module_under(mMLX, "Core");
7669
+ init_onnx_native_bindings(mMLX);
7625
7670
  rb_define_singleton_method(mCore, "version", RUBY_METHOD_FUNC(core_version), 0);
7626
7671
 
7627
7672
  rb_define_singleton_method(mCore, "get_active_memory", RUBY_METHOD_FUNC(core_get_active_memory), 0);
@@ -7899,6 +7944,8 @@ extern "C" void Init_native(void) {
7899
7944
  rb_define_singleton_method(mCore, "ones_like", RUBY_METHOD_FUNC(core_ones_like), 1);
7900
7945
  rb_define_singleton_method(mCore, "eye", RUBY_METHOD_FUNC(core_eye), -1);
7901
7946
  rb_define_singleton_method(mCore, "identity", RUBY_METHOD_FUNC(core_identity), -1);
7947
+ rb_define_singleton_method(mCore, "hanning", RUBY_METHOD_FUNC(core_hanning), -1);
7948
+ rb_define_singleton_method(mCore, "hamming", RUBY_METHOD_FUNC(core_hamming), -1);
7902
7949
  rb_define_singleton_method(mCore, "tri", RUBY_METHOD_FUNC(core_tri), -1);
7903
7950
  rb_define_singleton_method(mCore, "tril", RUBY_METHOD_FUNC(core_tril), -1);
7904
7951
  rb_define_singleton_method(mCore, "triu", RUBY_METHOD_FUNC(core_triu), -1);
@@ -8026,4 +8073,5 @@ extern "C" void Init_native(void) {
8026
8073
  "precompiled_cuda_kernel",
8027
8074
  RUBY_METHOD_FUNC(core_precompiled_cuda_kernel),
8028
8075
  -1);
8076
+
8029
8077
  }