mlx 0.30.7.3 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (590) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +104 -56
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/nn/base.rb +4 -0
  8. data/lib/mlx/nn/layers/linear.rb +2 -3
  9. data/lib/mlx/onnx.rb +250 -0
  10. data/lib/mlx/version.rb +1 -1
  11. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  12. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  13. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  14. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  15. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  16. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  17. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  18. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  19. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  20. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  21. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  22. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  23. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  24. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  25. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  26. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  28. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  29. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  30. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  31. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  32. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  33. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  34. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  35. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  36. data/submodules/mlx-onnx/LICENSE +21 -0
  37. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  38. data/submodules/mlx-onnx/src/api.cpp +81 -0
  39. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  40. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  41. data/submodules/mlx-onnx/src/export.cpp +653 -0
  42. data/submodules/mlx-onnx/src/io.cpp +61 -0
  43. data/submodules/mlx-onnx/src/json.hpp +25 -0
  44. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  45. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  46. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  47. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  48. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  49. metadata +609 -563
  50. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  51. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  52. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  53. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  54. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  55. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  56. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  57. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  58. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  59. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  60. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  61. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  62. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  63. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  64. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  65. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  66. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  67. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  68. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  69. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  70. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  71. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  590. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
data/lib/mlx/onnx.rb ADDED
@@ -0,0 +1,250 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require_relative "../mlx-onnx/webgpu_harness"
5
+
6
+ module MLX
7
+ module ONNX
8
+ ONNX_VERSION = 1
9
+
10
+ module_function
11
+
12
+ def export_onnx(
13
+ target_path,
14
+ fun,
15
+ *extras,
16
+ shapeless: false,
17
+ opset: 18,
18
+ model_name: "mlx_graph",
19
+ external_data: false,
20
+ external_data_file: nil,
21
+ external_data_size_threshold: 1024,
22
+ **trace_kwargs
23
+ )
24
+ ensure_native_graph_ir!
25
+ assert_boolean!(shapeless, "shapeless")
26
+ assert_boolean!(external_data, "external_data")
27
+ native_target_path = normalize_binary_target_path!(target_path, "export_onnx")
28
+ native_fun, native_extras, _positional_count, _keyword_names =
29
+ prepare_trace_invocation_without_native_kwargs(fun, extras, trace_kwargs)
30
+ translate_native_unsupported do
31
+ written = MLX::ONNX::Native.export_onnx(
32
+ native_target_path,
33
+ native_fun,
34
+ native_extras,
35
+ {},
36
+ shapeless,
37
+ opset,
38
+ model_name,
39
+ external_data,
40
+ external_data_file,
41
+ external_data_size_threshold
42
+ )
43
+ unless written.is_a?(String)
44
+ raise TypeError, "MLX::ONNX::Native.export_onnx must return String path"
45
+ end
46
+ written
47
+ end
48
+ end
49
+
50
+ def export_onnx_json(
51
+ fun,
52
+ *extras,
53
+ shapeless: false,
54
+ opset: 18,
55
+ model_name: "mlx_graph",
56
+ **trace_kwargs
57
+ )
58
+ ensure_native_graph_ir!
59
+ assert_boolean!(shapeless, "shapeless")
60
+ native_fun, native_extras, _positional_count, _keyword_names =
61
+ prepare_trace_invocation_without_native_kwargs(fun, extras, trace_kwargs)
62
+ translate_native_unsupported do
63
+ onnx_json = MLX::ONNX::Native.export_onnx_json(
64
+ native_fun,
65
+ native_extras,
66
+ {},
67
+ shapeless,
68
+ opset,
69
+ model_name
70
+ )
71
+ unless onnx_json.is_a?(String)
72
+ raise TypeError, "MLX::ONNX::Native.export_onnx_json must return String JSON"
73
+ end
74
+ onnx_json
75
+ end
76
+ end
77
+
78
+ def export_onnx_compatibility_report(fun, *extras, shapeless: false, **trace_kwargs)
79
+ ensure_native_graph_ir!
80
+ assert_boolean!(shapeless, "shapeless")
81
+ native_fun, native_extras, _positional_count, _keyword_names =
82
+ prepare_trace_invocation_without_native_kwargs(fun, extras, trace_kwargs)
83
+ report = MLX::ONNX::Native.export_onnx_compatibility_report(
84
+ native_fun,
85
+ native_extras,
86
+ {},
87
+ shapeless
88
+ )
89
+ unless report.is_a?(Hash)
90
+ raise TypeError, "MLX::ONNX::Native.export_onnx_compatibility_report must return Hash payload"
91
+ end
92
+
93
+ report
94
+ end
95
+
96
+ def export_graph_ir(fun, *extras, shapeless: false, **trace_kwargs)
97
+ ensure_native_graph_ir!
98
+ assert_boolean!(shapeless, "shapeless")
99
+ native_fun, native_extras, positional_count, keyword_names =
100
+ prepare_trace_invocation_without_native_kwargs(fun, extras, trace_kwargs)
101
+ payload = MLX::ONNX::Native.export_graph_ir(native_fun, native_extras, {}, shapeless)
102
+ unless payload.is_a?(Hash)
103
+ raise TypeError, "MLX::ONNX::Native.export_graph_ir must return Hash payload"
104
+ end
105
+ inject_keyword_inputs!(payload, positional_count, keyword_names)
106
+ payload
107
+ end
108
+
109
+ def export_graph_ir_json(fun, *extras, shapeless: false, **trace_kwargs)
110
+ ensure_native_graph_ir!
111
+ assert_boolean!(shapeless, "shapeless")
112
+ native_fun, native_extras, positional_count, keyword_names =
113
+ prepare_trace_invocation_without_native_kwargs(fun, extras, trace_kwargs)
114
+ content = MLX::ONNX::Native.export_graph_ir_json(native_fun, native_extras, {}, shapeless)
115
+ unless content.is_a?(String)
116
+ raise TypeError, "MLX::ONNX::Native.export_graph_ir_json must return String JSON"
117
+ end
118
+ return content if keyword_names.empty?
119
+
120
+ payload = JSON.parse(content)
121
+ inject_keyword_inputs!(payload, positional_count, keyword_names)
122
+ content = JSON.generate(payload)
123
+ content
124
+ end
125
+
126
+ def graph_ir_to_onnx(
127
+ target_path,
128
+ ir_source,
129
+ opset: 18,
130
+ model_name: "mlx_graph",
131
+ external_data: false,
132
+ external_data_file: nil,
133
+ external_data_size_threshold: 1024
134
+ )
135
+ ensure_native_graph_ir!
136
+ assert_boolean!(external_data, "external_data")
137
+ native_target_path = normalize_binary_target_path!(target_path, "graph_ir_to_onnx")
138
+ translate_native_unsupported do
139
+ written = MLX::ONNX::Native.graph_ir_to_onnx(
140
+ native_target_path,
141
+ ir_source,
142
+ opset,
143
+ model_name,
144
+ external_data,
145
+ external_data_file,
146
+ external_data_size_threshold
147
+ )
148
+ unless written.is_a?(String)
149
+ raise TypeError, "MLX::ONNX::Native.graph_ir_to_onnx must return String path"
150
+ end
151
+ written
152
+ end
153
+ end
154
+
155
+ def graph_ir_to_onnx_json(ir_source, opset: 18, model_name: "mlx_graph")
156
+ ensure_native_graph_ir!
157
+ translate_native_unsupported do
158
+ onnx_json = MLX::ONNX::Native.graph_ir_to_onnx_json(ir_source, opset, model_name)
159
+ unless onnx_json.is_a?(String)
160
+ raise TypeError, "MLX::ONNX::Native.graph_ir_to_onnx_json must return String JSON"
161
+ end
162
+ onnx_json
163
+ end
164
+ end
165
+
166
+ def ensure_native_graph_ir!
167
+ MLX::Core.ensure_native!
168
+ unless defined?(MLX::ONNX::Native)
169
+ raise RuntimeError, "MLX::ONNX::Native is unavailable"
170
+ end
171
+ end
172
+ private_class_method :ensure_native_graph_ir!
173
+
174
+ def assert_boolean!(value, label)
175
+ unless value == true || value == false
176
+ raise TypeError, "#{label} must be true or false"
177
+ end
178
+ end
179
+ private_class_method :assert_boolean!
180
+
181
+ def translate_native_unsupported
182
+ yield
183
+ rescue MLX::ONNX::Native::UnsupportedError => e
184
+ raise NotImplementedError, e.message
185
+ end
186
+ private_class_method :translate_native_unsupported
187
+
188
+ def normalize_binary_target_path!(target_path, method_name)
189
+ if target_path.respond_to?(:write) && !target_path.respond_to?(:to_path)
190
+ raise ArgumentError, "#{method_name} requires a path-like target, not an IO-like target"
191
+ end
192
+
193
+ path = if target_path.respond_to?(:to_path)
194
+ target_path.to_path
195
+ else
196
+ target_path
197
+ end
198
+
199
+ unless path.is_a?(String)
200
+ raise TypeError, "target_path must be a String or respond to #to_path"
201
+ end
202
+
203
+ path
204
+ end
205
+ private_class_method :normalize_binary_target_path!
206
+
207
+ def prepare_trace_invocation_without_native_kwargs(fun, extras, trace_kwargs)
208
+ positional_inputs = extras.dup
209
+ return [fun, positional_inputs, positional_inputs.length, []] if trace_kwargs.empty?
210
+
211
+ keyword_entries = trace_kwargs.to_a
212
+ positional_count = positional_inputs.length
213
+ native_fun = lambda do |*all_args|
214
+ positional_args = all_args.first(positional_count)
215
+ keyword_args = {}
216
+ keyword_entries.each_with_index do |(name, _value), idx|
217
+ keyword_args[name] = all_args[positional_count + idx]
218
+ end
219
+ fun.call(*positional_args, **keyword_args)
220
+ end
221
+
222
+ keyword_values = keyword_entries.map { |(_name, value)| value }
223
+ keyword_names = keyword_entries.map { |(name, _value)| name.to_s }
224
+ [native_fun, positional_inputs + keyword_values, positional_count, keyword_names]
225
+ end
226
+ private_class_method :prepare_trace_invocation_without_native_kwargs
227
+
228
+ def inject_keyword_inputs!(payload, positional_count, keyword_names)
229
+ return payload if keyword_names.empty?
230
+ return payload unless payload.is_a?(Hash)
231
+
232
+ inputs = payload["inputs"]
233
+ return payload unless inputs.is_a?(Array)
234
+
235
+ keyword_inputs = []
236
+ keyword_names.each_with_index do |name, idx|
237
+ input_entry = inputs[positional_count + idx]
238
+ next unless input_entry.is_a?(Hash)
239
+
240
+ tensor = input_entry["name"]
241
+ next unless tensor.is_a?(String) && !tensor.empty?
242
+
243
+ keyword_inputs << { "name" => name, "tensor" => tensor }
244
+ end
245
+ payload["keyword_inputs"] = keyword_inputs
246
+ payload
247
+ end
248
+ private_class_method :inject_keyword_inputs!
249
+ end
250
+ end
data/lib/mlx/version.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module MLX
4
- VERSION = "0.30.7.3"
4
+ VERSION = "0.30.7.6"
5
5
  end
@@ -0,0 +1,289 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "open3"
5
+ require "fileutils"
6
+
7
+ module MLX
8
+ module ONNX
9
+ module WebGPUHarness
10
+ module_function
11
+
12
+ def export_onnx_webgpu_harness(
13
+ target_dir,
14
+ payload_or_source,
15
+ opset: 18,
16
+ model_name: "mlx_graph",
17
+ execution_providers: %w[webgpu wasm],
18
+ benchmark_warmup_runs: 2,
19
+ benchmark_measure_runs: 10,
20
+ external_data: false,
21
+ external_data_size_threshold: 1024,
22
+ external_data_file: nil
23
+ )
24
+ output_dir = file_path(target_dir)
25
+ raise ArgumentError, "target_dir must not be empty" if output_dir.empty?
26
+
27
+ providers = normalize_web_execution_providers(execution_providers)
28
+ warmup_runs = normalize_non_negative_integer(
29
+ benchmark_warmup_runs,
30
+ "benchmark_warmup_runs"
31
+ )
32
+ measure_runs = normalize_positive_integer(
33
+ benchmark_measure_runs,
34
+ "benchmark_measure_runs"
35
+ )
36
+
37
+ FileUtils.mkdir_p(output_dir)
38
+ model_filename = "model.onnx"
39
+ model_path = File.join(output_dir, model_filename)
40
+ onnx_json = MLX::ONNX.graph_ir_to_onnx_json(
41
+ payload_or_source,
42
+ opset: opset,
43
+ model_name: model_name
44
+ )
45
+ MLX::ONNX.graph_ir_to_onnx(
46
+ model_path,
47
+ payload_or_source,
48
+ opset: opset,
49
+ model_name: model_name,
50
+ external_data: external_data,
51
+ external_data_size_threshold: external_data_size_threshold,
52
+ external_data_file: external_data_file
53
+ )
54
+
55
+ stub = JSON.parse(onnx_json)
56
+ input_specs = stub.fetch("graph").fetch("inputs")
57
+ input_examples = build_input_examples(input_specs)
58
+
59
+ manifest = {
60
+ "format" => "onnx_webgpu_harness_v1",
61
+ "model" => model_filename,
62
+ "execution_providers" => providers,
63
+ "benchmark" => {
64
+ "warmup_runs" => warmup_runs,
65
+ "measure_runs" => measure_runs
66
+ },
67
+ "inputs" => input_specs.map do |spec|
68
+ {
69
+ "name" => spec.fetch("name"),
70
+ "shape" => spec.fetch("shape"),
71
+ "dtype" => spec.fetch("dtype")
72
+ }
73
+ end
74
+ }
75
+ if external_data
76
+ manifest["external_data"] = [
77
+ external_data_file.nil? ? "model.data" : external_data_file.to_s
78
+ ]
79
+ end
80
+
81
+ File.binwrite(
82
+ File.join(output_dir, "harness.manifest.json"),
83
+ JSON.pretty_generate(manifest)
84
+ )
85
+ File.binwrite(
86
+ File.join(output_dir, "inputs.example.json"),
87
+ JSON.pretty_generate(input_examples)
88
+ )
89
+ copy_assets!(output_dir)
90
+
91
+ manifest
92
+ end
93
+
94
+ def smoke_test_onnx_webgpu_harness(
95
+ harness_dir,
96
+ timeout_seconds: 30,
97
+ mock_ort: false,
98
+ local_ort: true,
99
+ node_bin: ENV.fetch("NODE", "node")
100
+ )
101
+ directory = file_path(harness_dir)
102
+ raise ArgumentError, "harness_dir must not be empty" if directory.empty?
103
+
104
+ directory = File.expand_path(directory)
105
+ unless Dir.exist?(directory)
106
+ raise ArgumentError, "harness_dir does not exist: #{directory}"
107
+ end
108
+
109
+ timeout = normalize_positive_integer(timeout_seconds, "timeout_seconds")
110
+ mock = normalize_boolean(mock_ort, "mock_ort")
111
+ local = normalize_boolean(local_ort, "local_ort")
112
+ node = node_bin.to_s
113
+ raise ArgumentError, "node_bin must not be empty" if node.empty?
114
+
115
+ smoke_script = web_harness_smoke_script_path
116
+ unless File.file?(smoke_script)
117
+ raise RuntimeError, "missing web harness smoke script: #{smoke_script}"
118
+ end
119
+
120
+ argv = [
121
+ node,
122
+ smoke_script,
123
+ "--harness-dir",
124
+ directory,
125
+ "--timeout-seconds",
126
+ timeout.to_s
127
+ ]
128
+ argv << "--mock-ort" if mock
129
+ argv << (local ? "--local-ort" : "--no-local-ort")
130
+
131
+ stdout, stderr, status = Open3.capture3(*argv, chdir: web_root_dir)
132
+ unless status.success?
133
+ raise RuntimeError, <<~MSG
134
+ web harness smoke test failed: #{argv.join(" ")}
135
+ stdout:
136
+ #{stdout}
137
+ stderr:
138
+ #{stderr}
139
+ MSG
140
+ end
141
+
142
+ telemetry = begin
143
+ JSON.parse(stdout)
144
+ rescue JSON::ParserError => e
145
+ raise RuntimeError, <<~MSG
146
+ web harness smoke test produced invalid JSON: #{e.message}
147
+ stdout:
148
+ #{stdout}
149
+ stderr:
150
+ #{stderr}
151
+ MSG
152
+ end
153
+
154
+ unless telemetry.is_a?(Hash)
155
+ raise RuntimeError, "web harness smoke test produced non-object telemetry"
156
+ end
157
+ unless telemetry.fetch("format", nil) == "onnx_webgpu_telemetry_v1"
158
+ raise RuntimeError, "unexpected web harness telemetry format: #{telemetry.fetch('format', nil).inspect}"
159
+ end
160
+
161
+ telemetry
162
+ end
163
+
164
+ def file_path(file)
165
+ if file.respond_to?(:to_path)
166
+ file.to_path.to_s
167
+ else
168
+ file.to_s
169
+ end
170
+ end
171
+ private_class_method :file_path
172
+
173
+ def normalize_web_execution_providers(value)
174
+ providers = if value.is_a?(::Array)
175
+ value
176
+ else
177
+ [value]
178
+ end
179
+ providers = providers.map(&:to_s)
180
+ raise ArgumentError, "execution_providers must contain at least one provider" if providers.empty?
181
+
182
+ allowed = %w[webgpu wasm]
183
+ providers.each do |provider|
184
+ unless allowed.include?(provider)
185
+ raise ArgumentError, "execution_providers contains unsupported provider #{provider.inspect}"
186
+ end
187
+ end
188
+ providers.uniq
189
+ end
190
+ private_class_method :normalize_web_execution_providers
191
+
192
+ def normalize_non_negative_integer(value, label)
193
+ integer = begin
194
+ Integer(value)
195
+ rescue ArgumentError, TypeError
196
+ raise ArgumentError, "#{label} must be a non-negative Integer"
197
+ end
198
+ raise ArgumentError, "#{label} must be a non-negative Integer" if integer.negative?
199
+
200
+ integer
201
+ end
202
+ private_class_method :normalize_non_negative_integer
203
+
204
+ def normalize_positive_integer(value, label)
205
+ integer = begin
206
+ Integer(value)
207
+ rescue ArgumentError, TypeError
208
+ raise ArgumentError, "#{label} must be a positive Integer"
209
+ end
210
+ raise ArgumentError, "#{label} must be a positive Integer" unless integer.positive?
211
+
212
+ integer
213
+ end
214
+ private_class_method :normalize_positive_integer
215
+
216
+ def normalize_boolean(value, label)
217
+ unless value == true || value == false
218
+ raise ArgumentError, "#{label} must be true or false"
219
+ end
220
+
221
+ value
222
+ end
223
+ private_class_method :normalize_boolean
224
+
225
+ def build_input_examples(input_specs)
226
+ input_specs.each_with_object({}) do |spec, out|
227
+ out[spec.fetch("name")] = build_zero_tensor_values(
228
+ spec.fetch("shape"),
229
+ spec.fetch("dtype")
230
+ )
231
+ end
232
+ end
233
+ private_class_method :build_input_examples
234
+
235
+ def build_zero_tensor_values(shape, dtype)
236
+ if shape.empty?
237
+ zero_leaf_value_for_dtype(dtype)
238
+ else
239
+ ::Array.new(shape.first) { build_zero_tensor_values(shape[1..], dtype) }
240
+ end
241
+ end
242
+ private_class_method :build_zero_tensor_values
243
+
244
+ def zero_leaf_value_for_dtype(dtype)
245
+ if dtype == "bool" || dtype == "bool_"
246
+ false
247
+ elsif dtype == "complex64"
248
+ { "__mlx_complex__" => [0.0, 0.0] }
249
+ elsif dtype.start_with?("float") || dtype == "bfloat16"
250
+ 0.0
251
+ else
252
+ 0
253
+ end
254
+ end
255
+ private_class_method :zero_leaf_value_for_dtype
256
+
257
+ def copy_assets!(output_dir)
258
+ template_dir = web_harness_template_dir
259
+ unless Dir.exist?(template_dir)
260
+ raise RuntimeError, "missing web harness template directory: #{template_dir}"
261
+ end
262
+
263
+ %w[index.html harness.js].each do |file_name|
264
+ source = File.join(template_dir, file_name)
265
+ unless File.file?(source)
266
+ raise RuntimeError, "missing web harness template file: #{source}"
267
+ end
268
+ FileUtils.cp(source, File.join(output_dir, file_name))
269
+ end
270
+ end
271
+ private_class_method :copy_assets!
272
+
273
+ def web_harness_template_dir
274
+ File.expand_path("../../web/onnx_webgpu_harness", __dir__)
275
+ end
276
+ private_class_method :web_harness_template_dir
277
+
278
+ def web_harness_smoke_script_path
279
+ File.join(web_harness_template_dir, "browser_smoke.mjs")
280
+ end
281
+ private_class_method :web_harness_smoke_script_path
282
+
283
+ def web_root_dir
284
+ File.expand_path("../../web", __dir__)
285
+ end
286
+ private_class_method :web_root_dir
287
+ end
288
+ end
289
+ end
@@ -105,13 +105,6 @@ void CublasMatmulBase::init_base(
105
105
  CHECK_CUBLAS_ERROR(
106
106
  cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type));
107
107
 
108
- int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
109
- CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
110
- matmul_desc_,
111
- CUBLASLT_MATMUL_DESC_POINTER_MODE,
112
- &pointer_mode,
113
- sizeof(int32_t)));
114
-
115
108
  // In cublasLt matrices use column-major layout, while it is possible to use
116
109
  // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
117
110
  // epilogue does not work with the option. So instead we swap A and B to make
@@ -73,6 +73,14 @@ CublasGemm::CublasGemm(
73
73
  batch_count,
74
74
  a_batch_stride,
75
75
  b_batch_stride);
76
+
77
+ // alpha and beta are both host pointers
78
+ cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
79
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
80
+ matmul_desc_,
81
+ CUBLASLT_MATMUL_DESC_POINTER_MODE,
82
+ &pointer_mode,
83
+ sizeof(pointer_mode)));
76
84
  }
77
85
 
78
86
  CublasGemm::CublasGemm(
@@ -215,8 +223,8 @@ void CublasGemm::execute(
215
223
  const void* a,
216
224
  const void* b,
217
225
  const void* c,
218
- float alpha /* = 1 */,
219
- float beta /* = 0 */) {
226
+ const float alpha /* = 1 */,
227
+ const float beta /* = 0 */) {
220
228
  const void* alpha_ptr = &alpha;
221
229
  const void* beta_ptr = &beta;
222
230
  complex64_t alpha_c, beta_c;