mlx 0.30.7.2 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (605) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +112 -58
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/distributed_utils/common.rb +1 -1
  8. data/lib/mlx/distributed_utils/config.rb +7 -4
  9. data/lib/mlx/distributed_utils/launch.rb +2 -0
  10. data/lib/mlx/dsl/attention.rb +132 -0
  11. data/lib/mlx/dsl/builder.rb +8 -0
  12. data/lib/mlx/dsl/config_schema.rb +133 -0
  13. data/lib/mlx/dsl/generate.rb +193 -0
  14. data/lib/mlx/dsl/kv_cache.rb +96 -0
  15. data/lib/mlx/dsl/masks.rb +32 -0
  16. data/lib/mlx/dsl/positions.rb +35 -0
  17. data/lib/mlx/dsl/run_stack.rb +68 -0
  18. data/lib/mlx/dsl/tensor.rb +126 -0
  19. data/lib/mlx/dsl/transformer_block.rb +113 -0
  20. data/lib/mlx/dsl/weight_map.rb +140 -0
  21. data/lib/mlx/dsl.rb +10 -0
  22. data/lib/mlx/nn/base.rb +4 -0
  23. data/lib/mlx/nn/layers/linear.rb +2 -3
  24. data/lib/mlx/onnx.rb +250 -0
  25. data/lib/mlx/version.rb +1 -1
  26. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  28. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  29. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  30. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  31. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  32. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  33. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  34. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  35. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  36. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  37. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  38. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  39. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  40. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  41. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  42. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  43. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  44. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  45. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  46. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  47. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  48. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  49. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  50. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  51. data/submodules/mlx-onnx/LICENSE +21 -0
  52. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  53. data/submodules/mlx-onnx/src/api.cpp +81 -0
  54. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  55. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  56. data/submodules/mlx-onnx/src/export.cpp +653 -0
  57. data/submodules/mlx-onnx/src/io.cpp +61 -0
  58. data/submodules/mlx-onnx/src/json.hpp +25 -0
  59. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  60. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  61. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  62. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  63. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  64. metadata +665 -567
  65. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  66. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  67. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  68. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  69. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  70. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  71. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  590. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  591. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  592. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  593. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  594. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  595. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  596. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  597. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  598. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  599. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  600. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  601. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  602. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  603. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  604. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  605. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 215a912d2353fd5edaa60e320a5b857aa13009e55f3a190b21b2ffe5735f37af
4
- data.tar.gz: 1c9b4279f8077e3cd067354ea692e1b24248afe7262d88d4569c566c52f5a158
3
+ metadata.gz: d8190b13b012fe0693ef46cc3f20b01b78f6d13cde44674fd01999434e56eda9
4
+ data.tar.gz: '059a7b993b17cae7bd448567ef6a9058d0a575cdc143a3784b411953e1309e7d'
5
5
  SHA512:
6
- metadata.gz: 66abcbd58ccfc04186df11b0d2b6445c7d1e0ab4a36451742755f6fcf41022363403536c3d27640935364c58231c4c4a03e39ceba97617a59f2ad69acf23dc16
7
- data.tar.gz: ba7ad07ccd31e94bdf3fdee73117f069c6ee22c11cdfc3f2470eedee9ee0bc9e976970bab21f156c003b658723a0e3fb2f63f62d0e11d33de3a61fc1d9121711
6
+ metadata.gz: 1a90a276ac6b68760bdbe348931e6cfc48ec01662a7870522cf344748a8e5df167f3d4f6b9d28bd768abc3c5dc9a5e020abe2a5fc81edf1ef4056e54abc269b7
7
+ data.tar.gz: 8b17e8c878f32659cc90b365943cf7fce132c03fd7680959c03b615b28819d6a002e6e05236209b0282fdec7b03f06872556f7ac292ac85e00d531fa0bd1c11a
data/ext/mlx/extconf.rb CHANGED
@@ -3,6 +3,9 @@
3
3
  require "etc"
4
4
  require "fileutils"
5
5
  require "mkmf"
6
+ require "open3"
7
+ require "pathname"
8
+ require "rbconfig"
6
9
 
7
10
  def run_or_abort(*cmd, chdir:)
8
11
  puts ">> #{cmd.join(' ')}"
@@ -17,6 +20,76 @@ def run_with_status(*cmd, chdir:)
17
20
  system(*cmd, chdir: chdir)
18
21
  end
19
22
 
23
+ def git_revision(path)
24
+ return nil unless Dir.exist?(path)
25
+
26
+ stdout, status = Open3.capture2("git", "-C", path, "rev-parse", "HEAD")
27
+ return nil unless status.success?
28
+
29
+ rev = stdout.strip
30
+ return rev if rev.match?(/\A[0-9a-f]{40}\z/)
31
+
32
+ nil
33
+ end
34
+
35
+ def mlx_revision_pinned_by_mlx_onnx(mlx_onnx_root)
36
+ return nil unless Dir.exist?(mlx_onnx_root)
37
+
38
+ stdout, status = Open3.capture2("git", "-C", mlx_onnx_root, "submodule", "status", "--", "mlx")
39
+ return nil unless status.success?
40
+
41
+ line = stdout.lines.first.to_s.strip
42
+ return nil if line.empty?
43
+
44
+ token = line.split(/\s+/).first.to_s
45
+ token = token.delete_prefix("-").delete_prefix("+").delete_prefix("U")
46
+ return token if token.match?(/\A[0-9a-f]{40}\z/)
47
+
48
+ nil
49
+ end
50
+
51
+ def enforce_mlx_onnx_compatibility!(mlx_root:, mlx_onnx_root:)
52
+ workspace_mlx_revision =
53
+ ENV.fetch("MLX_EXTCONF_TEST_MLX_REVISION", "").strip
54
+ pinned_mlx_revision =
55
+ ENV.fetch("MLX_EXTCONF_TEST_MLX_ONNX_PINNED_MLX_REVISION", "").strip
56
+
57
+ workspace_mlx_revision = git_revision(mlx_root) if workspace_mlx_revision.empty?
58
+ pinned_mlx_revision = mlx_revision_pinned_by_mlx_onnx(mlx_onnx_root) if pinned_mlx_revision.empty?
59
+
60
+ return if workspace_mlx_revision.nil? || pinned_mlx_revision.nil?
61
+ return if workspace_mlx_revision == pinned_mlx_revision
62
+
63
+ abort(<<~MSG)
64
+ mlx/mlx-onnx revision mismatch detected.
65
+ workspace mlx revision: #{workspace_mlx_revision}
66
+ mlx-onnx pinned mlx revision: #{pinned_mlx_revision}
67
+ Run:
68
+ git submodule update --init --recursive submodules/mlx submodules/mlx-onnx
69
+ MSG
70
+ end
71
+
72
+ def patch_mlx_onnx_gcc_optional_shape_initlist!(mlx_onnx_root)
73
+ lowering_cpp = File.join(mlx_onnx_root, "src", "lowering.cpp")
74
+ return unless File.file?(lowering_cpp)
75
+
76
+ source = File.read(lowering_cpp)
77
+ patched = source
78
+ .gsub(
79
+ "std::optional<Shape>({work_shape[0], 1, seq_len})",
80
+ "std::optional<Shape>(Shape{work_shape[0], 1, seq_len})"
81
+ )
82
+ .gsub(
83
+ "std::optional<Shape>({seq_len})",
84
+ "std::optional<Shape>(Shape{seq_len})"
85
+ )
86
+
87
+ return if patched == source
88
+
89
+ File.write(lowering_cpp, patched)
90
+ puts "patched mlx-onnx lowering.cpp optional Shape initlists for GCC compatibility"
91
+ end
92
+
20
93
  def rpath_flag(path)
21
94
  case RUBY_PLATFORM
22
95
  when /darwin/
@@ -28,15 +101,120 @@ def rpath_flag(path)
28
101
  end
29
102
  end
30
103
 
104
+ def cmake_compilers_from_cache(cache_path)
105
+ return [nil, nil] unless File.file?(cache_path)
106
+
107
+ cc = nil
108
+ cxx = nil
109
+ File.foreach(cache_path) do |line|
110
+ cc = Regexp.last_match(1).strip if line =~ /^CMAKE_C_COMPILER:(?:STRING|FILEPATH)=(.+)$/
111
+ cxx = Regexp.last_match(1).strip if line =~ /^CMAKE_CXX_COMPILER:(?:STRING|FILEPATH)=(.+)$/
112
+ end
113
+ [cc, cxx]
114
+ end
115
+
116
+ def normalize_compiler_for_mkmf(path, kind)
117
+ return path if path.nil? || path.empty?
118
+
119
+ case kind
120
+ when :cc
121
+ return "/usr/bin/clang" if path.end_with?("/usr/bin/cc")
122
+ when :cxx
123
+ return "/usr/bin/clang++" if path.end_with?("/usr/bin/c++")
124
+ end
125
+ path
126
+ end
127
+
128
+ def force_mkmf_compilers!(cc, cxx)
129
+ return if cc.nil? || cc.empty? || cxx.nil? || cxx.empty?
130
+
131
+ cc = normalize_compiler_for_mkmf(cc, :cc)
132
+ cxx = normalize_compiler_for_mkmf(cxx, :cxx)
133
+
134
+ [RbConfig::CONFIG, RbConfig::MAKEFILE_CONFIG].each do |cfg|
135
+ cfg["CC"] = cc if cfg.key?("CC")
136
+ cfg["CXX"] = cxx if cfg.key?("CXX")
137
+
138
+ if cfg.key?("LDSHARED") && cfg["LDSHARED"]
139
+ cfg["LDSHARED"] = cfg["LDSHARED"].sub(/\A\S+/, cc)
140
+ end
141
+ if cfg.key?("LDSHAREDXX") && cfg["LDSHAREDXX"]
142
+ cfg["LDSHAREDXX"] = cfg["LDSHAREDXX"].sub(/\A\S+/, cxx)
143
+ end
144
+ end
145
+
146
+ $CC = cc if defined?($CC)
147
+ $CXX = cxx if defined?($CXX)
148
+ $LDSHARED = $LDSHARED.sub(/\A\S+/, cc) if defined?($LDSHARED) && $LDSHARED
149
+ $LDSHAREDXX = $LDSHAREDXX.sub(/\A\S+/, cxx) if defined?($LDSHAREDXX) && $LDSHAREDXX
150
+ end
151
+
152
+ def patch_makefile_compilers!(makefile_path, cc, cxx)
153
+ return if cc.nil? || cc.empty? || cxx.nil? || cxx.empty?
154
+ return unless File.file?(makefile_path)
155
+
156
+ cc = normalize_compiler_for_mkmf(cc, :cc)
157
+ cxx = normalize_compiler_for_mkmf(cxx, :cxx)
158
+
159
+ text = File.read(makefile_path)
160
+ text = text.gsub(/^CC = .+$/, "CC = #{cc}")
161
+ text = text.gsub(/^CXX = .+$/, "CXX = #{cxx}")
162
+ File.write(makefile_path, text)
163
+ end
164
+
165
+ def patch_makefile_sources!(makefile_path, source_files, header_files)
166
+ return unless File.file?(makefile_path)
167
+
168
+ objects = source_files.map do |path|
169
+ stem = path.sub(/\.[^.]+\z/, "")
170
+ "#{stem}.o"
171
+ end
172
+ text = File.read(makefile_path)
173
+ text = text.gsub(/^ORIG_SRCS = .+$/, "ORIG_SRCS = #{source_files.join(' ')}")
174
+ text = text.gsub(/^OBJS = .+$/, "OBJS = #{objects.join(' ')}")
175
+ text = text.gsub(/^HDRS = .+$/, "HDRS = #{header_files.map { |path| '$(srcdir)/' + path }.join(' ')}")
176
+ File.write(makefile_path, text)
177
+ end
178
+
179
+ def patch_makefile_include_dirs!(makefile_path, include_dirs)
180
+ return unless File.file?(makefile_path)
181
+
182
+ text = File.read(makefile_path)
183
+ match = text.match(/^CPPFLAGS = (.+)$/)
184
+ return if match.nil?
185
+
186
+ cppflags = match[1]
187
+ include_dirs.each do |path|
188
+ flag = "-I#{path}"
189
+ cppflags = "#{cppflags} #{flag}" unless cppflags.include?(flag)
190
+ end
191
+ text = text.sub(/^CPPFLAGS = .+$/, "CPPFLAGS = #{cppflags}")
192
+ File.write(makefile_path, text)
193
+ end
194
+
31
195
  repo_root = File.expand_path("../..", __dir__)
32
- mlx_root = File.join(repo_root, "mlx")
33
- mlx_include_dir = mlx_root
196
+ submodules_root = File.join(repo_root, "submodules")
197
+ mlx_root = File.join(submodules_root, "mlx")
198
+ mlx_onnx_root = File.join(submodules_root, "mlx-onnx")
199
+ mlx_public_include_dir = mlx_root
200
+ mlx_internal_include_dir = File.join(mlx_root, "mlx")
201
+ mlx_onnx_public_include_dir = File.join(mlx_onnx_root, "include")
202
+ mlx_onnx_internal_include_dir = File.join(mlx_onnx_root, "src")
34
203
  ext_root = File.expand_path(__dir__)
204
+ ext_onnx_root = File.join(repo_root, "ext", "mlx-onnx")
35
205
  build_root = File.join(ext_root, "build")
36
206
  mlx_build_dir = File.join(build_root, "mlx")
207
+ mlx_onnx_build_dir = File.join(build_root, "mlx-onnx")
37
208
  mlx_install_dir = File.join(build_root, "install")
38
209
  jobs = [Etc.nprocessors, 1].max
39
210
 
211
+ enforce_mlx_onnx_compatibility!(mlx_root: mlx_root, mlx_onnx_root: mlx_onnx_root)
212
+ patch_mlx_onnx_gcc_optional_shape_initlist!(mlx_onnx_root)
213
+ if ENV["MLX_EXTCONF_VALIDATE_ONLY"] == "1"
214
+ puts "mlx-onnx compatibility check passed"
215
+ exit 0
216
+ end
217
+
40
218
  FileUtils.mkdir_p(mlx_build_dir)
41
219
 
42
220
  cmake_configure = [
@@ -54,7 +232,7 @@ cmake_configure = [
54
232
  "-DMLX_BUILD_PYTHON_STUBS=OFF",
55
233
  "-DMLX_BUILD_METAL=ON",
56
234
  "-DMLX_BUILD_GGUF=OFF",
57
- "-DMLX_BUILD_SAFETENSORS=OFF",
235
+ "-DMLX_BUILD_SAFETENSORS=ON",
58
236
  "-DBUILD_SHARED_LIBS=ON"
59
237
  ]
60
238
 
@@ -78,17 +256,98 @@ unless configured
78
256
  end
79
257
  run_or_abort(*cmake_build, chdir: ext_root)
80
258
 
81
- include_dir = mlx_include_dir
259
+ cmake_cache_path = File.join(mlx_build_dir, "CMakeCache.txt")
260
+ cmake_cc, cmake_cxx = cmake_compilers_from_cache(cmake_cache_path)
261
+ force_mkmf_compilers!(cmake_cc, cmake_cxx)
262
+
263
+ abort("missing MLX include dir: #{mlx_public_include_dir}") unless Dir.exist?(mlx_public_include_dir)
264
+ abort("missing MLX internal include dir: #{mlx_internal_include_dir}") unless Dir.exist?(mlx_internal_include_dir)
265
+ abort("missing mlx-onnx include dir: #{mlx_onnx_public_include_dir}") unless Dir.exist?(mlx_onnx_public_include_dir)
266
+ abort("missing mlx-onnx internal include dir: #{mlx_onnx_internal_include_dir}") unless Dir.exist?(mlx_onnx_internal_include_dir)
267
+ include_dirs = [
268
+ mlx_public_include_dir,
269
+ mlx_internal_include_dir,
270
+ mlx_onnx_public_include_dir,
271
+ mlx_onnx_internal_include_dir
272
+ ]
82
273
  lib_dir = File.join(mlx_install_dir, "lib")
83
274
 
84
- abort("missing MLX include dir: #{include_dir}") unless Dir.exist?(include_dir)
275
+ cmake_onnx_configure = [
276
+ "cmake",
277
+ "-S",
278
+ mlx_onnx_root,
279
+ "-B",
280
+ mlx_onnx_build_dir,
281
+ "-DCMAKE_BUILD_TYPE=Release",
282
+ "-DCMAKE_INSTALL_PREFIX=#{mlx_install_dir}",
283
+ "-DMLX_ONNX_USE_EXTERNAL_MLX=ON",
284
+ "-DMLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR=#{mlx_public_include_dir}",
285
+ "-DMLX_ONNX_EXTERNAL_MLX_LIB_DIR=#{lib_dir}",
286
+ "-DMLX_ONNX_BUILD_PYTHON_BINDINGS=OFF"
287
+ ]
288
+
289
+ cmake_onnx_build = [
290
+ "cmake",
291
+ "--build",
292
+ mlx_onnx_build_dir,
293
+ "--target",
294
+ "install",
295
+ "--config",
296
+ "Release",
297
+ "-j#{jobs}"
298
+ ]
299
+
300
+ run_or_abort(*cmake_onnx_configure, chdir: ext_root)
301
+ run_or_abort(*cmake_onnx_build, chdir: ext_root)
302
+
303
+ json_include_dirs = [
304
+ File.join(mlx_build_dir, "_deps", "json-src", "include"),
305
+ File.join(mlx_build_dir, "_deps", "json-src", "single_include"),
306
+ File.join(mlx_build_dir, "_deps", "json-src", "single_include", "nlohmann"),
307
+ File.join(mlx_onnx_build_dir, "_deps", "nlohmann_json-src", "include")
308
+ ].select { |path| Dir.exist?(path) }
309
+
85
310
  abort("missing MLX lib dir: #{lib_dir}") unless Dir.exist?(lib_dir)
86
311
 
87
- dir_config("mlx", include_dir, lib_dir)
312
+ dir_config("mlx", include_dirs.first, lib_dir)
88
313
 
89
314
  $CXXFLAGS = "#{$CXXFLAGS} -std=c++20"
90
- $CPPFLAGS = "#{$CPPFLAGS} -I#{include_dir}"
315
+ include_dirs.each do |path|
316
+ $CPPFLAGS = "#{$CPPFLAGS} -I#{path}"
317
+ end
318
+ json_include_dirs.each do |path|
319
+ $CPPFLAGS = "#{$CPPFLAGS} -I#{path}"
320
+ end
91
321
  $LDFLAGS = "#{$LDFLAGS} -L#{lib_dir} #{rpath_flag(lib_dir)}"
92
- $libs = "-lmlx #{$libs}"
322
+ $libs = "-lmlx_onnx -lmlx #{$libs}"
323
+
324
+ source_files = (
325
+ Dir.glob(File.join(ext_root, "**", "*.{c,cc,cpp,cxx}")) +
326
+ Dir.glob(File.join(ext_onnx_root, "**", "*.{c,cc,cpp,cxx}"))
327
+ ).reject { |path| path.start_with?(File.join(build_root, "")) }
328
+ .map { |path| Pathname(path).relative_path_from(Pathname(ext_root)).to_s }
329
+ .sort
330
+ header_files = (
331
+ Dir.glob(File.join(ext_root, "**", "*.{h,hpp,hh}")) +
332
+ Dir.glob(File.join(ext_onnx_root, "**", "*.{h,hpp,hh}"))
333
+ ).reject { |path| path.start_with?(File.join(build_root, "")) }
334
+ .map { |path| Pathname(path).relative_path_from(Pathname(ext_root)).to_s }
335
+ .sort
336
+ $srcs = source_files
337
+ $objs = source_files.map { |path| "#{path.sub(/\.[^.]+\z/, "")}.o" }
338
+ $hdrs = header_files
93
339
 
94
340
  create_makefile("mlx/native")
341
+ makefile_path = File.join(ext_root, "Makefile")
342
+ patch_makefile_compilers!(makefile_path, cmake_cc, cmake_cxx)
343
+ patch_makefile_sources!(makefile_path, source_files, header_files)
344
+ patch_makefile_include_dirs!(
345
+ makefile_path,
346
+ [
347
+ mlx_internal_include_dir,
348
+ mlx_onnx_public_include_dir,
349
+ mlx_onnx_internal_include_dir,
350
+ File.join(mlx_build_dir, "_deps", "json-src", "single_include", "nlohmann"),
351
+ File.join(mlx_onnx_build_dir, "_deps", "nlohmann_json-src", "include")
352
+ ]
353
+ )
data/ext/mlx/native.cpp CHANGED
@@ -10,12 +10,15 @@
10
10
  #include <limits>
11
11
  #include <optional>
12
12
  #include <sstream>
13
+ #include <stdexcept>
13
14
  #include <string>
14
15
  #include <type_traits>
15
16
  #include <unordered_map>
16
17
  #include <variant>
17
18
  #include <vector>
18
19
 
20
+ #include <nlohmann/json.hpp>
21
+
19
22
  #include "mlx/array.h"
20
23
  #include "mlx/backend/metal/metal.h"
21
24
  #include "mlx/compile.h"
@@ -37,6 +40,7 @@
37
40
  #include "mlx/transforms.h"
38
41
  #include "mlx/utils.h"
39
42
  #include "mlx/version.h"
43
+ #include "../mlx-onnx/native.hpp"
40
44
 
41
45
  namespace mx = mlx::core;
42
46
  namespace mxfft = mlx::core::fft;
@@ -44,6 +48,7 @@ namespace mxfast = mlx::core::fast;
44
48
  namespace mxlinalg = mlx::core::linalg;
45
49
  namespace mxmetal = mlx::core::metal;
46
50
  namespace mxdist = mlx::core::distributed;
51
+ using OrderedJson = nlohmann::ordered_json;
47
52
 
48
53
  static VALUE mMLX;
49
54
  static VALUE mNative;
@@ -339,14 +344,7 @@ static VALUE id_to_symbol(ID id) {
339
344
  }
340
345
 
341
346
  static ID cached_intern_id(const char* name) {
342
- static std::unordered_map<std::string, ID> cache;
343
- auto it = cache.find(name);
344
- if (it != cache.end()) {
345
- return it->second;
346
- }
347
- const ID id = rb_intern(name);
348
- cache.emplace(name, id);
349
- return id;
347
+ return rb_intern(name);
350
348
  }
351
349
 
352
350
  static mx::Device::DeviceType device_type_from_value(VALUE value) {
@@ -928,12 +926,6 @@ struct ArrayCollector {
928
926
 
929
927
  static void collect_arrays_from_tree(VALUE value, std::vector<mx::array>& arrays);
930
928
 
931
- static int hash_collect_arrays_iter(VALUE, VALUE value, VALUE arg) {
932
- auto* collector = reinterpret_cast<ArrayCollector*>(arg);
933
- collect_arrays_from_tree(value, *collector->arrays);
934
- return ST_CONTINUE;
935
- }
936
-
937
929
  static void collect_arrays_from_tree(VALUE value, std::vector<mx::array>& arrays) {
938
930
  if (rb_obj_is_kind_of(value, cArray)) {
939
931
  arrays.push_back(array_unwrap(value));
@@ -947,28 +939,37 @@ static void collect_arrays_from_tree(VALUE value, std::vector<mx::array>& arrays
947
939
  return;
948
940
  }
949
941
  if (RB_TYPE_P(value, T_HASH)) {
950
- ArrayCollector collector{&arrays};
951
- rb_hash_foreach(value, hash_collect_arrays_iter, reinterpret_cast<VALUE>(&collector));
942
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
943
+ const long len = RARRAY_LEN(keys);
944
+ for (long i = 0; i < len; ++i) {
945
+ VALUE key = rb_ary_entry(keys, i);
946
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
947
+ if (hash_value == Qundef) {
948
+ continue;
949
+ }
950
+ collect_arrays_from_tree(hash_value, arrays);
951
+ }
952
952
  }
953
953
  }
954
954
 
955
- struct ArrayMapBuilder {
956
- std::unordered_map<std::string, mx::array> map;
957
- };
958
-
959
- static int hash_to_array_map_iter(VALUE key, VALUE value, VALUE arg) {
960
- auto* builder = reinterpret_cast<ArrayMapBuilder*>(arg);
961
- builder->map.insert_or_assign(string_from_ruby(key), array_unwrap(value));
962
- return ST_CONTINUE;
963
- }
964
-
965
955
  static std::unordered_map<std::string, mx::array> array_map_from_ruby_hash(VALUE value) {
966
956
  if (!RB_TYPE_P(value, T_HASH)) {
967
957
  rb_raise(rb_eTypeError, "expected Hash mapping String/Symbol keys to MLX::Core::Array");
968
958
  }
969
- ArrayMapBuilder builder;
970
- rb_hash_foreach(value, hash_to_array_map_iter, reinterpret_cast<VALUE>(&builder));
971
- return builder.map;
959
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
960
+ const long len = RARRAY_LEN(keys);
961
+ std::unordered_map<std::string, mx::array> out;
962
+ out.reserve(static_cast<size_t>(len));
963
+ for (long i = 0; i < len; ++i) {
964
+ VALUE key = rb_ary_entry(keys, i);
965
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
966
+ if (hash_value == Qundef) {
967
+ continue;
968
+ }
969
+ std::string ruby_key = string_from_ruby(key);
970
+ out.insert_or_assign(ruby_key, array_unwrap(hash_value));
971
+ }
972
+ return out;
972
973
  }
973
974
 
974
975
  static VALUE ruby_hash_of_arrays(const std::unordered_map<std::string, mx::array>& map) {
@@ -990,16 +991,6 @@ static VALUE ruby_hash_of_strings(const std::unordered_map<std::string, std::str
990
991
  return out;
991
992
  }
992
993
 
993
- struct StringMapBuilder {
994
- std::unordered_map<std::string, std::string> map;
995
- };
996
-
997
- static int hash_to_string_map_iter(VALUE key, VALUE value, VALUE arg) {
998
- auto* builder = reinterpret_cast<StringMapBuilder*>(arg);
999
- builder->map.insert_or_assign(string_from_ruby(key), string_from_ruby(value));
1000
- return ST_CONTINUE;
1001
- }
1002
-
1003
994
  static std::unordered_map<std::string, std::string> string_map_from_ruby_hash(VALUE value) {
1004
995
  if (NIL_P(value)) {
1005
996
  return {};
@@ -1007,9 +998,20 @@ static std::unordered_map<std::string, std::string> string_map_from_ruby_hash(VA
1007
998
  if (!RB_TYPE_P(value, T_HASH)) {
1008
999
  rb_raise(rb_eTypeError, "expected Hash mapping String/Symbol keys to String values");
1009
1000
  }
1010
- StringMapBuilder builder;
1011
- rb_hash_foreach(value, hash_to_string_map_iter, reinterpret_cast<VALUE>(&builder));
1012
- return builder.map;
1001
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
1002
+ const long len = RARRAY_LEN(keys);
1003
+ std::unordered_map<std::string, std::string> out;
1004
+ out.reserve(static_cast<size_t>(len));
1005
+ for (long i = 0; i < len; ++i) {
1006
+ VALUE key = rb_ary_entry(keys, i);
1007
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
1008
+ if (hash_value == Qundef) {
1009
+ continue;
1010
+ }
1011
+ std::string ruby_key = string_from_ruby(key);
1012
+ out.insert_or_assign(ruby_key, string_from_ruby(hash_value));
1013
+ }
1014
+ return out;
1013
1015
  }
1014
1016
 
1015
1017
  static mx::GGUFMetaData gguf_metadata_from_ruby(VALUE value) {
@@ -1038,16 +1040,6 @@ static mx::GGUFMetaData gguf_metadata_from_ruby(VALUE value) {
1038
1040
  return std::monostate{};
1039
1041
  }
1040
1042
 
1041
- struct GGUFMetaMapBuilder {
1042
- std::unordered_map<std::string, mx::GGUFMetaData> map;
1043
- };
1044
-
1045
- static int hash_to_gguf_meta_map_iter(VALUE key, VALUE value, VALUE arg) {
1046
- auto* builder = reinterpret_cast<GGUFMetaMapBuilder*>(arg);
1047
- builder->map.insert_or_assign(string_from_ruby(key), gguf_metadata_from_ruby(value));
1048
- return ST_CONTINUE;
1049
- }
1050
-
1051
1043
  static std::unordered_map<std::string, mx::GGUFMetaData> gguf_meta_map_from_ruby_hash(VALUE value) {
1052
1044
  if (NIL_P(value)) {
1053
1045
  return {};
@@ -1055,9 +1047,20 @@ static std::unordered_map<std::string, mx::GGUFMetaData> gguf_meta_map_from_ruby
1055
1047
  if (!RB_TYPE_P(value, T_HASH)) {
1056
1048
  rb_raise(rb_eTypeError, "expected Hash for GGUF metadata");
1057
1049
  }
1058
- GGUFMetaMapBuilder builder;
1059
- rb_hash_foreach(value, hash_to_gguf_meta_map_iter, reinterpret_cast<VALUE>(&builder));
1060
- return builder.map;
1050
+ VALUE keys = rb_funcall(value, cached_intern_id("keys"), 0);
1051
+ const long len = RARRAY_LEN(keys);
1052
+ std::unordered_map<std::string, mx::GGUFMetaData> out;
1053
+ out.reserve(static_cast<size_t>(len));
1054
+ for (long i = 0; i < len; ++i) {
1055
+ VALUE key = rb_ary_entry(keys, i);
1056
+ VALUE hash_value = rb_hash_lookup2(value, key, Qundef);
1057
+ if (hash_value == Qundef) {
1058
+ continue;
1059
+ }
1060
+ std::string ruby_key = string_from_ruby(key);
1061
+ out.insert_or_assign(ruby_key, gguf_metadata_from_ruby(hash_value));
1062
+ }
1063
+ return out;
1061
1064
  }
1062
1065
 
1063
1066
  static VALUE gguf_metadata_to_ruby(const mx::GGUFMetaData& value) {
@@ -1501,6 +1504,23 @@ args_kwargs_function_from_callable(VALUE callable) {
1501
1504
  };
1502
1505
  }
1503
1506
 
1507
+ mx::array onnx_array_from_ruby(VALUE value) {
1508
+ return array_from_ruby(value, std::nullopt);
1509
+ }
1510
+
1511
+ std::vector<mx::array> onnx_array_vector_from_ruby(VALUE value) {
1512
+ return array_vector_from_ruby(value);
1513
+ }
1514
+
1515
+ std::unordered_map<std::string, mx::array> onnx_array_map_from_ruby_hash(VALUE value) {
1516
+ return array_map_from_ruby_hash(value);
1517
+ }
1518
+
1519
+ std::function<std::vector<mx::array>(const mx::Args&, const mx::Kwargs&)>
1520
+ onnx_args_kwargs_function_from_callable(VALUE callable) {
1521
+ return args_kwargs_function_from_callable(callable);
1522
+ }
1523
+
1504
1524
  static std::vector<int> argnums_from_value(VALUE value) {
1505
1525
  if (NIL_P(value)) {
1506
1526
  return {0};
@@ -3495,7 +3515,7 @@ static VALUE core_dequantize(int argc, VALUE* argv, VALUE) {
3495
3515
  dtype = optional_dtype_from_value(argv[6]);
3496
3516
  }
3497
3517
 
3498
- return array_wrap(mx::dequantize(w, scales, biases, group_size, bits, mode, dtype));
3518
+ return array_wrap(mx::dequantize(w, scales, biases, group_size, bits, mode, std::nullopt, dtype));
3499
3519
  } catch (const std::exception& error) {
3500
3520
  raise_std_exception(error);
3501
3521
  return Qnil;
@@ -6459,6 +6479,30 @@ static VALUE core_identity(int argc, VALUE* argv, VALUE) {
6459
6479
  }
6460
6480
  }
6461
6481
 
6482
+ static VALUE core_hanning(int argc, VALUE* argv, VALUE) {
6483
+ try {
6484
+ VALUE m;
6485
+ VALUE stream;
6486
+ rb_scan_args(argc, argv, "11", &m, &stream);
6487
+ return array_wrap(mx::hanning(NUM2INT(m), stream_or_device_from_value(stream)));
6488
+ } catch (const std::exception& error) {
6489
+ raise_std_exception(error);
6490
+ return Qnil;
6491
+ }
6492
+ }
6493
+
6494
+ static VALUE core_hamming(int argc, VALUE* argv, VALUE) {
6495
+ try {
6496
+ VALUE m;
6497
+ VALUE stream;
6498
+ rb_scan_args(argc, argv, "11", &m, &stream);
6499
+ return array_wrap(mx::hamming(NUM2INT(m), stream_or_device_from_value(stream)));
6500
+ } catch (const std::exception& error) {
6501
+ raise_std_exception(error);
6502
+ return Qnil;
6503
+ }
6504
+ }
6505
+
6462
6506
  static VALUE core_tri(int argc, VALUE* argv, VALUE) {
6463
6507
  try {
6464
6508
  if (argc < 1 || argc > 4) {
@@ -6625,7 +6669,8 @@ static VALUE core_clear_cache(VALUE) {
6625
6669
 
6626
6670
  static VALUE core_metal_is_available(VALUE) {
6627
6671
  try {
6628
- return mxmetal::is_available() ? Qtrue : Qfalse;
6672
+ const mx::Device gpu_device(mx::Device::gpu, 0);
6673
+ return mx::is_available(gpu_device) ? Qtrue : Qfalse;
6629
6674
  } catch (const std::exception& error) {
6630
6675
  raise_std_exception(error);
6631
6676
  return Qnil;
@@ -6654,7 +6699,12 @@ static VALUE core_metal_stop_capture(VALUE) {
6654
6699
 
6655
6700
  static VALUE core_metal_device_info(VALUE) {
6656
6701
  try {
6657
- const auto& info = mxmetal::device_info();
6702
+ const mx::Device gpu_device(mx::Device::gpu, 0);
6703
+ if (!mx::is_available(gpu_device)) {
6704
+ rb_raise(rb_eRuntimeError, "[metal_device_info] Metal GPU device is not available");
6705
+ }
6706
+
6707
+ const auto& info = mx::device_info(gpu_device);
6658
6708
  VALUE hash = rb_hash_new();
6659
6709
  for (const auto& [key, value] : info) {
6660
6710
  VALUE ruby_key = rb_utf8_str_new(key.c_str(), static_cast<long>(key.size()));
@@ -7616,6 +7666,7 @@ extern "C" void Init_native(void) {
7616
7666
  rb_define_singleton_method(mNative, "loaded?", RUBY_METHOD_FUNC(native_loaded_p), 0);
7617
7667
 
7618
7668
  mCore = rb_define_module_under(mMLX, "Core");
7669
+ init_onnx_native_bindings(mMLX);
7619
7670
  rb_define_singleton_method(mCore, "version", RUBY_METHOD_FUNC(core_version), 0);
7620
7671
 
7621
7672
  rb_define_singleton_method(mCore, "get_active_memory", RUBY_METHOD_FUNC(core_get_active_memory), 0);
@@ -7893,6 +7944,8 @@ extern "C" void Init_native(void) {
7893
7944
  rb_define_singleton_method(mCore, "ones_like", RUBY_METHOD_FUNC(core_ones_like), 1);
7894
7945
  rb_define_singleton_method(mCore, "eye", RUBY_METHOD_FUNC(core_eye), -1);
7895
7946
  rb_define_singleton_method(mCore, "identity", RUBY_METHOD_FUNC(core_identity), -1);
7947
+ rb_define_singleton_method(mCore, "hanning", RUBY_METHOD_FUNC(core_hanning), -1);
7948
+ rb_define_singleton_method(mCore, "hamming", RUBY_METHOD_FUNC(core_hamming), -1);
7896
7949
  rb_define_singleton_method(mCore, "tri", RUBY_METHOD_FUNC(core_tri), -1);
7897
7950
  rb_define_singleton_method(mCore, "tril", RUBY_METHOD_FUNC(core_tril), -1);
7898
7951
  rb_define_singleton_method(mCore, "triu", RUBY_METHOD_FUNC(core_triu), -1);
@@ -8020,4 +8073,5 @@ extern "C" void Init_native(void) {
8020
8073
  "precompiled_cuda_kernel",
8021
8074
  RUBY_METHOD_FUNC(core_precompiled_cuda_kernel),
8022
8075
  -1);
8076
+
8023
8077
  }