mlx 0.30.7.2 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (605) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +112 -58
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/distributed_utils/common.rb +1 -1
  8. data/lib/mlx/distributed_utils/config.rb +7 -4
  9. data/lib/mlx/distributed_utils/launch.rb +2 -0
  10. data/lib/mlx/dsl/attention.rb +132 -0
  11. data/lib/mlx/dsl/builder.rb +8 -0
  12. data/lib/mlx/dsl/config_schema.rb +133 -0
  13. data/lib/mlx/dsl/generate.rb +193 -0
  14. data/lib/mlx/dsl/kv_cache.rb +96 -0
  15. data/lib/mlx/dsl/masks.rb +32 -0
  16. data/lib/mlx/dsl/positions.rb +35 -0
  17. data/lib/mlx/dsl/run_stack.rb +68 -0
  18. data/lib/mlx/dsl/tensor.rb +126 -0
  19. data/lib/mlx/dsl/transformer_block.rb +113 -0
  20. data/lib/mlx/dsl/weight_map.rb +140 -0
  21. data/lib/mlx/dsl.rb +10 -0
  22. data/lib/mlx/nn/base.rb +4 -0
  23. data/lib/mlx/nn/layers/linear.rb +2 -3
  24. data/lib/mlx/onnx.rb +250 -0
  25. data/lib/mlx/version.rb +1 -1
  26. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  28. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  29. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  30. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  31. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  32. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  33. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  34. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  35. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  36. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  37. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  38. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  39. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  40. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  41. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  42. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  43. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  44. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  45. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  46. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  47. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  48. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  49. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  50. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  51. data/submodules/mlx-onnx/LICENSE +21 -0
  52. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  53. data/submodules/mlx-onnx/src/api.cpp +81 -0
  54. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  55. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  56. data/submodules/mlx-onnx/src/export.cpp +653 -0
  57. data/submodules/mlx-onnx/src/io.cpp +61 -0
  58. data/submodules/mlx-onnx/src/json.hpp +25 -0
  59. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  60. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  61. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  62. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  63. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  64. metadata +665 -567
  65. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  66. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  67. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  68. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  69. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  70. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  71. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  590. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  591. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  592. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  593. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  594. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  595. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  596. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  597. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  598. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  599. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  600. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  601. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  602. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  603. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  604. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  605. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -0,0 +1,653 @@
1
+ #include "mlx/ir.hpp"
2
+
3
+ #include "detail.hpp"
4
+ #include "json.hpp"
5
+
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <optional>
9
+ #include <sstream>
10
+ #include <stdexcept>
11
+ #include <string>
12
+ #include <tuple>
13
+ #include <utility>
14
+ #include <variant>
15
+ #include <vector>
16
+
17
+ #include "mlx/ops.h"
18
+
19
+ namespace mx = mlx::core;
20
+
21
+ namespace mlx::onnx {
22
+ namespace {
23
+
24
+ using GraphTensorInfo = std::tuple<std::string, mx::Shape, mx::Dtype>;
25
+
26
+ std::string dtype_to_string(mx::Dtype dtype) {
27
+ std::ostringstream out;
28
+ out << dtype;
29
+ return out.str();
30
+ }
31
+
32
+ template <typename ValueAt>
33
+ OrderedJson capture_build_nested_json_array(
34
+ const mx::Shape& shape,
35
+ size_t dim,
36
+ size_t& flat_index,
37
+ ValueAt value_at) {
38
+ if (dim == shape.size()) {
39
+ return value_at(flat_index++);
40
+ }
41
+
42
+ OrderedJson out = OrderedJson::array();
43
+ for (size_t i = 0; i < shape[dim]; ++i) {
44
+ out.push_back(
45
+ capture_build_nested_json_array(shape, dim + 1, flat_index, value_at));
46
+ }
47
+ return out;
48
+ }
49
+
50
+ template <typename ValueAt>
51
+ OrderedJson capture_build_flat_json_array(size_t size, ValueAt value_at) {
52
+ OrderedJson out = OrderedJson::array();
53
+ for (size_t i = 0; i < size; ++i) {
54
+ out.push_back(value_at(i));
55
+ }
56
+ return out;
57
+ }
58
+
59
+ OrderedJson capture_json_shape_from_mx_shape(const mx::Shape& shape) {
60
+ OrderedJson out = OrderedJson::array();
61
+ for (size_t dim : shape) {
62
+ out.push_back(dim);
63
+ }
64
+ return out;
65
+ }
66
+
67
+ OrderedJson capture_json_scalar_from_array(const mx::array& array) {
68
+ switch (array.dtype()) {
69
+ case mx::bool_:
70
+ return OrderedJson(array.item<bool>());
71
+ case mx::uint8:
72
+ return OrderedJson(array.item<uint8_t>());
73
+ case mx::uint16:
74
+ return OrderedJson(array.item<uint16_t>());
75
+ case mx::uint32:
76
+ return OrderedJson(array.item<uint32_t>());
77
+ case mx::uint64:
78
+ return OrderedJson(array.item<uint64_t>());
79
+ case mx::int8:
80
+ return OrderedJson(array.item<int8_t>());
81
+ case mx::int16:
82
+ return OrderedJson(array.item<int16_t>());
83
+ case mx::int32:
84
+ return OrderedJson(array.item<int32_t>());
85
+ case mx::int64:
86
+ return OrderedJson(array.item<int64_t>());
87
+ case mx::float16:
88
+ return OrderedJson(static_cast<double>(array.item<mx::float16_t>()));
89
+ case mx::bfloat16:
90
+ return OrderedJson(static_cast<double>(array.item<mx::bfloat16_t>()));
91
+ case mx::float32:
92
+ return OrderedJson(static_cast<double>(array.item<float>()));
93
+ case mx::float64:
94
+ return OrderedJson(array.item<double>());
95
+ default:
96
+ throw std::runtime_error("unsupported dtype for graph ir constant conversion");
97
+ }
98
+ }
99
+
100
+ OrderedJson capture_json_values_from_array(const mx::array& source) {
101
+ mx::array array = source;
102
+ if (array.ndim() == 0) {
103
+ array.eval();
104
+ return capture_json_scalar_from_array(array);
105
+ }
106
+
107
+ if (array.ndim() == 1) {
108
+ array.eval();
109
+ const size_t size = array.size();
110
+ switch (array.dtype()) {
111
+ case mx::bool_: {
112
+ const bool* data = array.data<bool>();
113
+ return capture_build_flat_json_array(
114
+ size, [&](size_t i) { return OrderedJson(data[i]); });
115
+ }
116
+ case mx::uint8: {
117
+ const uint8_t* data = array.data<uint8_t>();
118
+ return capture_build_flat_json_array(
119
+ size, [&](size_t i) { return OrderedJson(data[i]); });
120
+ }
121
+ case mx::uint16: {
122
+ const uint16_t* data = array.data<uint16_t>();
123
+ return capture_build_flat_json_array(
124
+ size, [&](size_t i) { return OrderedJson(data[i]); });
125
+ }
126
+ case mx::uint32: {
127
+ const uint32_t* data = array.data<uint32_t>();
128
+ return capture_build_flat_json_array(
129
+ size, [&](size_t i) { return OrderedJson(data[i]); });
130
+ }
131
+ case mx::uint64: {
132
+ const uint64_t* data = array.data<uint64_t>();
133
+ return capture_build_flat_json_array(
134
+ size, [&](size_t i) { return OrderedJson(data[i]); });
135
+ }
136
+ case mx::int8: {
137
+ const int8_t* data = array.data<int8_t>();
138
+ return capture_build_flat_json_array(
139
+ size, [&](size_t i) { return OrderedJson(data[i]); });
140
+ }
141
+ case mx::int16: {
142
+ const int16_t* data = array.data<int16_t>();
143
+ return capture_build_flat_json_array(
144
+ size, [&](size_t i) { return OrderedJson(data[i]); });
145
+ }
146
+ case mx::int32: {
147
+ const int32_t* data = array.data<int32_t>();
148
+ return capture_build_flat_json_array(
149
+ size, [&](size_t i) { return OrderedJson(data[i]); });
150
+ }
151
+ case mx::int64: {
152
+ const int64_t* data = array.data<int64_t>();
153
+ return capture_build_flat_json_array(
154
+ size, [&](size_t i) { return OrderedJson(data[i]); });
155
+ }
156
+ case mx::float16: {
157
+ const mx::float16_t* data = array.data<mx::float16_t>();
158
+ return capture_build_flat_json_array(
159
+ size,
160
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
161
+ }
162
+ case mx::bfloat16: {
163
+ const mx::bfloat16_t* data = array.data<mx::bfloat16_t>();
164
+ return capture_build_flat_json_array(
165
+ size,
166
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
167
+ }
168
+ case mx::float32: {
169
+ const float* data = array.data<float>();
170
+ return capture_build_flat_json_array(
171
+ size,
172
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
173
+ }
174
+ case mx::float64: {
175
+ const double* data = array.data<double>();
176
+ return capture_build_flat_json_array(
177
+ size, [&](size_t i) { return OrderedJson(data[i]); });
178
+ }
179
+ default:
180
+ throw std::runtime_error(
181
+ "unsupported dtype for graph ir constant conversion");
182
+ }
183
+ }
184
+
185
+ const mx::Shape shape = array.shape();
186
+ mx::array flat =
187
+ mx::reshape(array, mx::Shape{static_cast<mx::ShapeElem>(array.size())});
188
+ flat.eval();
189
+
190
+ size_t idx = 0;
191
+ switch (flat.dtype()) {
192
+ case mx::bool_: {
193
+ const bool* data = flat.data<bool>();
194
+ return capture_build_nested_json_array(
195
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
196
+ }
197
+ case mx::uint8: {
198
+ const uint8_t* data = flat.data<uint8_t>();
199
+ return capture_build_nested_json_array(
200
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
201
+ }
202
+ case mx::uint16: {
203
+ const uint16_t* data = flat.data<uint16_t>();
204
+ return capture_build_nested_json_array(
205
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
206
+ }
207
+ case mx::uint32: {
208
+ const uint32_t* data = flat.data<uint32_t>();
209
+ return capture_build_nested_json_array(
210
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
211
+ }
212
+ case mx::uint64: {
213
+ const uint64_t* data = flat.data<uint64_t>();
214
+ return capture_build_nested_json_array(
215
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
216
+ }
217
+ case mx::int8: {
218
+ const int8_t* data = flat.data<int8_t>();
219
+ return capture_build_nested_json_array(
220
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
221
+ }
222
+ case mx::int16: {
223
+ const int16_t* data = flat.data<int16_t>();
224
+ return capture_build_nested_json_array(
225
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
226
+ }
227
+ case mx::int32: {
228
+ const int32_t* data = flat.data<int32_t>();
229
+ return capture_build_nested_json_array(
230
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
231
+ }
232
+ case mx::int64: {
233
+ const int64_t* data = flat.data<int64_t>();
234
+ return capture_build_nested_json_array(
235
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
236
+ }
237
+ case mx::float16: {
238
+ const mx::float16_t* data = flat.data<mx::float16_t>();
239
+ return capture_build_nested_json_array(
240
+ shape,
241
+ 0,
242
+ idx,
243
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
244
+ }
245
+ case mx::bfloat16: {
246
+ const mx::bfloat16_t* data = flat.data<mx::bfloat16_t>();
247
+ return capture_build_nested_json_array(
248
+ shape,
249
+ 0,
250
+ idx,
251
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
252
+ }
253
+ case mx::float32: {
254
+ const float* data = flat.data<float>();
255
+ return capture_build_nested_json_array(
256
+ shape,
257
+ 0,
258
+ idx,
259
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
260
+ }
261
+ case mx::float64: {
262
+ const double* data = flat.data<double>();
263
+ return capture_build_nested_json_array(
264
+ shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
265
+ }
266
+ default:
267
+ throw std::runtime_error(
268
+ "unsupported dtype for graph ir constant conversion");
269
+ }
270
+ }
271
+
272
+ OrderedJson capture_json_tensor_info_from_graph_tensor(const GraphTensorInfo& info) {
273
+ OrderedJson out = OrderedJson::object();
274
+ out["name"] = std::get<0>(info);
275
+ out["shape"] = capture_json_shape_from_mx_shape(std::get<1>(info));
276
+ out["dtype"] = dtype_to_string(std::get<2>(info));
277
+ return out;
278
+ }
279
+
280
+ OrderedJson capture_json_tensor_infos_from_graph_tensors(
281
+ const std::vector<GraphTensorInfo>& infos) {
282
+ OrderedJson out = OrderedJson::array();
283
+ for (const auto& info : infos) {
284
+ out.push_back(capture_json_tensor_info_from_graph_tensor(info));
285
+ }
286
+ return out;
287
+ }
288
+
289
+ OrderedJson capture_json_tensor_names_from_graph_tensors(
290
+ const std::vector<GraphTensorInfo>& infos) {
291
+ OrderedJson out = OrderedJson::array();
292
+ for (const auto& info : infos) {
293
+ out.push_back(std::get<0>(info));
294
+ }
295
+ return out;
296
+ }
297
+
298
+ OrderedJson capture_json_state_value_from_mx_state(const mx::StateT& value) {
299
+ if (std::holds_alternative<bool>(value)) {
300
+ return OrderedJson(std::get<bool>(value));
301
+ }
302
+ if (std::holds_alternative<int>(value)) {
303
+ return OrderedJson(std::get<int>(value));
304
+ }
305
+ if (std::holds_alternative<size_t>(value)) {
306
+ return OrderedJson(std::get<size_t>(value));
307
+ }
308
+ if (std::holds_alternative<float>(value)) {
309
+ return OrderedJson(static_cast<double>(std::get<float>(value)));
310
+ }
311
+ if (std::holds_alternative<double>(value)) {
312
+ return OrderedJson(std::get<double>(value));
313
+ }
314
+ if (std::holds_alternative<mx::Dtype>(value)) {
315
+ return OrderedJson(dtype_to_string(std::get<mx::Dtype>(value)));
316
+ }
317
+ if (std::holds_alternative<mx::Shape>(value)) {
318
+ return capture_json_shape_from_mx_shape(std::get<mx::Shape>(value));
319
+ }
320
+ if (std::holds_alternative<mx::Strides>(value)) {
321
+ OrderedJson out = OrderedJson::array();
322
+ const auto& strides = std::get<mx::Strides>(value);
323
+ for (auto stride : strides) {
324
+ out.push_back(static_cast<long long>(stride));
325
+ }
326
+ return out;
327
+ }
328
+ if (std::holds_alternative<std::vector<int>>(value)) {
329
+ OrderedJson out = OrderedJson::array();
330
+ const auto& values = std::get<std::vector<int>>(value);
331
+ for (int item : values) {
332
+ out.push_back(item);
333
+ }
334
+ return out;
335
+ }
336
+ if (std::holds_alternative<std::vector<size_t>>(value)) {
337
+ OrderedJson out = OrderedJson::array();
338
+ const auto& values = std::get<std::vector<size_t>>(value);
339
+ for (size_t item : values) {
340
+ out.push_back(item);
341
+ }
342
+ return out;
343
+ }
344
+ if (std::holds_alternative<std::vector<std::tuple<bool, bool, bool>>>(value)) {
345
+ OrderedJson out = OrderedJson::array();
346
+ const auto& tuples = std::get<std::vector<std::tuple<bool, bool, bool>>>(value);
347
+ for (const auto& item : tuples) {
348
+ out.push_back(
349
+ OrderedJson::array({std::get<0>(item), std::get<1>(item), std::get<2>(item)}));
350
+ }
351
+ return out;
352
+ }
353
+ if (std::holds_alternative<std::vector<std::variant<bool, int, float>>>(value)) {
354
+ OrderedJson out = OrderedJson::array();
355
+ const auto& vars = std::get<std::vector<std::variant<bool, int, float>>>(value);
356
+ for (const auto& item : vars) {
357
+ if (std::holds_alternative<bool>(item)) {
358
+ out.push_back(std::get<bool>(item));
359
+ } else if (std::holds_alternative<int>(item)) {
360
+ out.push_back(std::get<int>(item));
361
+ } else {
362
+ out.push_back(static_cast<double>(std::get<float>(item)));
363
+ }
364
+ }
365
+ return out;
366
+ }
367
+ if (std::holds_alternative<std::optional<float>>(value)) {
368
+ const auto& opt = std::get<std::optional<float>>(value);
369
+ if (!opt.has_value()) {
370
+ return nullptr;
371
+ }
372
+ return OrderedJson(static_cast<double>(opt.value()));
373
+ }
374
+ return OrderedJson(std::get<std::string>(value));
375
+ }
376
+
377
+ OrderedJson capture_json_state_values_from_mx_states(
378
+ const std::vector<mx::StateT>& values) {
379
+ OrderedJson out = OrderedJson::array();
380
+ for (const auto& value : values) {
381
+ out.push_back(capture_json_state_value_from_mx_state(value));
382
+ }
383
+ return out;
384
+ }
385
+
386
+ template <typename T>
387
+ const T* export_callback_field(
388
+ const mx::ExportCallbackInput& data,
389
+ const std::string& key) {
390
+ for (const auto& [candidate_key, candidate_value] : data) {
391
+ if (candidate_key == key && std::holds_alternative<T>(candidate_value)) {
392
+ return &std::get<T>(candidate_value);
393
+ }
394
+ }
395
+ return nullptr;
396
+ }
397
+
398
+ OrderedJson export_ir_payload(
399
+ const IrCaptureFunction& fun,
400
+ const mx::Args& args,
401
+ const mx::Kwargs& kwargs,
402
+ bool shapeless) {
403
+ OrderedJson graph_inputs = OrderedJson::array();
404
+ OrderedJson keyword_inputs = OrderedJson::array();
405
+ OrderedJson graph_outputs = OrderedJson::array();
406
+ OrderedJson graph_constants = OrderedJson::array();
407
+ OrderedJson graph_nodes = OrderedJson::array();
408
+
409
+ mx::export_function(
410
+ [&](const mx::ExportCallbackInput& data) {
411
+ const auto* record_type = export_callback_field<std::string>(data, "type");
412
+ if (record_type == nullptr) {
413
+ return;
414
+ }
415
+
416
+ if (*record_type == "inputs") {
417
+ const auto* inputs =
418
+ export_callback_field<std::vector<GraphTensorInfo>>(data, "inputs");
419
+ if (inputs != nullptr) {
420
+ graph_inputs = capture_json_tensor_infos_from_graph_tensors(*inputs);
421
+ }
422
+ return;
423
+ }
424
+
425
+ if (*record_type == "keyword_inputs") {
426
+ const auto* keywords =
427
+ export_callback_field<std::vector<std::pair<std::string, std::string>>>(
428
+ data,
429
+ "keywords");
430
+ if (keywords != nullptr) {
431
+ keyword_inputs = OrderedJson::array();
432
+ for (const auto& [name, tensor] : *keywords) {
433
+ OrderedJson entry = OrderedJson::object();
434
+ entry["name"] = name;
435
+ entry["tensor"] = tensor;
436
+ keyword_inputs.push_back(std::move(entry));
437
+ }
438
+ }
439
+ return;
440
+ }
441
+
442
+ if (*record_type == "outputs") {
443
+ const auto* outputs =
444
+ export_callback_field<std::vector<GraphTensorInfo>>(data, "outputs");
445
+ if (outputs != nullptr) {
446
+ graph_outputs = capture_json_tensor_infos_from_graph_tensors(*outputs);
447
+ }
448
+ return;
449
+ }
450
+
451
+ if (*record_type == "constants") {
452
+ const auto* constants =
453
+ export_callback_field<std::vector<std::pair<std::string, mx::array>>>(
454
+ data,
455
+ "constants");
456
+ if (constants != nullptr) {
457
+ graph_constants = OrderedJson::array();
458
+ for (const auto& [name, arr] : *constants) {
459
+ OrderedJson entry = OrderedJson::object();
460
+ entry["name"] = name;
461
+ entry["shape"] = capture_json_shape_from_mx_shape(arr.shape());
462
+ entry["dtype"] = dtype_to_string(arr.dtype());
463
+ entry["values"] = capture_json_values_from_array(arr);
464
+ graph_constants.push_back(std::move(entry));
465
+ }
466
+ }
467
+ return;
468
+ }
469
+
470
+ if (*record_type != "primitive") {
471
+ return;
472
+ }
473
+
474
+ const auto* op_name = export_callback_field<std::string>(data, "name");
475
+ if (op_name == nullptr) {
476
+ return;
477
+ }
478
+
479
+ OrderedJson node = OrderedJson::object();
480
+ node["op"] = *op_name;
481
+
482
+ OrderedJson node_inputs = OrderedJson::array();
483
+ const auto* node_input_infos =
484
+ export_callback_field<std::vector<GraphTensorInfo>>(data, "inputs");
485
+ if (node_input_infos != nullptr) {
486
+ node_inputs = capture_json_tensor_names_from_graph_tensors(*node_input_infos);
487
+ }
488
+ node["inputs"] = std::move(node_inputs);
489
+
490
+ OrderedJson node_outputs = OrderedJson::array();
491
+ const auto* node_output_infos =
492
+ export_callback_field<std::vector<GraphTensorInfo>>(data, "outputs");
493
+ if (node_output_infos != nullptr) {
494
+ node_outputs = capture_json_tensor_names_from_graph_tensors(*node_output_infos);
495
+ }
496
+ node["outputs"] = std::move(node_outputs);
497
+
498
+ OrderedJson node_arguments = OrderedJson::array();
499
+ const auto* arguments =
500
+ export_callback_field<std::vector<mx::StateT>>(data, "arguments");
501
+ if (arguments != nullptr) {
502
+ node_arguments = capture_json_state_values_from_mx_states(*arguments);
503
+ }
504
+ node["arguments"] = std::move(node_arguments);
505
+
506
+ graph_nodes.push_back(std::move(node));
507
+ },
508
+ fun,
509
+ args,
510
+ kwargs,
511
+ shapeless);
512
+
513
+ OrderedJson payload = OrderedJson::object();
514
+ payload["ir_version"] = kGraphIrVersion;
515
+ payload["shapeless"] = shapeless;
516
+ payload["inputs"] = std::move(graph_inputs);
517
+ payload["keyword_inputs"] = std::move(keyword_inputs);
518
+ payload["outputs"] = std::move(graph_outputs);
519
+ payload["constants"] = std::move(graph_constants);
520
+ payload["nodes"] = std::move(graph_nodes);
521
+ return payload;
522
+ }
523
+
524
+ int64_t normalize_positive_integer(int64_t value, const char* label) {
525
+ if (value <= 0) {
526
+ std::ostringstream out;
527
+ out << label << " must be a positive integer";
528
+ throw std::invalid_argument(detail::tagged_error_message("ir.api", out.str()));
529
+ }
530
+ return value;
531
+ }
532
+
533
+ std::string non_empty_model_name(std::string value) {
534
+ if (value.empty()) {
535
+ throw std::invalid_argument(
536
+ detail::tagged_error_message("ir.api", "model_name must not be empty"));
537
+ }
538
+ return value;
539
+ }
540
+
541
+ void validate_onnx_binary_write_options(const OnnxBinaryWriteOptions& options) {
542
+ if (!options.external_data) {
543
+ return;
544
+ }
545
+
546
+ if (options.external_data_size_threshold < 0) {
547
+ throw std::invalid_argument(
548
+ detail::tagged_error_message(
549
+ "ir.api", "external_data_size_threshold must be non-negative"));
550
+ }
551
+
552
+ if (options.external_data_file.empty()) {
553
+ throw std::invalid_argument(
554
+ detail::tagged_error_message(
555
+ "ir.api", "external_data_file must not be empty when provided"));
556
+ }
557
+ }
558
+
559
+ template <typename Result, typename Callable>
560
+ Result with_ir_api_error_tag(Callable&& callable) {
561
+ try {
562
+ return callable();
563
+ } catch (const std::exception& error) {
564
+ if (ir_is_unsupported_error_message(error.what())) {
565
+ throw;
566
+ }
567
+ throw std::runtime_error(
568
+ detail::tagged_error_message("ir.api", error.what()));
569
+ }
570
+ }
571
+
572
+ } // namespace
573
+
574
+ std::string export_ir_json(
575
+ const IrCaptureFunction& fun,
576
+ const mx::Args& args,
577
+ const mx::Kwargs& kwargs,
578
+ bool shapeless) {
579
+ return with_ir_api_error_tag<std::string>(
580
+ [&]() { return export_ir_payload(fun, args, kwargs, shapeless).dump(); });
581
+ }
582
+
583
+ std::string export_onnx_compatibility_report_json(
584
+ const IrCaptureFunction& fun,
585
+ const mx::Args& args,
586
+ const mx::Kwargs& kwargs,
587
+ bool shapeless) {
588
+ return with_ir_api_error_tag<std::string>([&]() {
589
+ const auto payload = export_ir_payload(fun, args, kwargs, shapeless);
590
+ return ir_compatibility_report_payload(payload).dump();
591
+ });
592
+ }
593
+
594
+ std::string export_onnx_json(
595
+ const IrCaptureFunction& fun,
596
+ const mx::Args& args,
597
+ const mx::Kwargs& kwargs,
598
+ bool shapeless,
599
+ int64_t opset,
600
+ const std::string& model_name) {
601
+ return with_ir_api_error_tag<std::string>([&]() {
602
+ const auto payload = export_ir_payload(fun, args, kwargs, shapeless);
603
+ const auto onnx = ir_to_onnx_json_payload(
604
+ payload,
605
+ normalize_positive_integer(opset, "opset"),
606
+ non_empty_model_name(model_name));
607
+ return onnx.dump();
608
+ });
609
+ }
610
+
611
+ std::string export_onnx(
612
+ const std::string& target_path,
613
+ const IrCaptureFunction& fun,
614
+ const mx::Args& args,
615
+ const mx::Kwargs& kwargs,
616
+ bool shapeless,
617
+ int64_t opset,
618
+ const std::string& model_name,
619
+ const OnnxBinaryWriteOptions& options) {
620
+ return with_ir_api_error_tag<std::string>([&]() {
621
+ validate_onnx_binary_write_options(options);
622
+ const auto onnx_json = export_onnx_json(
623
+ fun,
624
+ args,
625
+ kwargs,
626
+ shapeless,
627
+ normalize_positive_integer(opset, "opset"),
628
+ non_empty_model_name(model_name));
629
+ const auto artifact =
630
+ build_onnx_binary_artifact_from_onnx_json(onnx_json, options);
631
+ return write_onnx_binary_artifact_to_path(target_path, artifact, options);
632
+ });
633
+ }
634
+
635
+ std::string ir_to_onnx(
636
+ const std::string& target_path,
637
+ const std::string& ir_json,
638
+ int64_t opset,
639
+ const std::string& model_name,
640
+ const OnnxBinaryWriteOptions& options) {
641
+ return with_ir_api_error_tag<std::string>([&]() {
642
+ validate_onnx_binary_write_options(options);
643
+ const auto onnx_json = ir_to_onnx_json(
644
+ ir_json,
645
+ normalize_positive_integer(opset, "opset"),
646
+ non_empty_model_name(model_name));
647
+ const auto artifact =
648
+ build_onnx_binary_artifact_from_onnx_json(onnx_json, options);
649
+ return write_onnx_binary_artifact_to_path(target_path, artifact, options);
650
+ });
651
+ }
652
+
653
+ } // namespace mlx::onnx
@@ -0,0 +1,61 @@
1
+ #include "mlx/ir.hpp"
2
+ #include "detail.hpp"
3
+
4
+ #include <filesystem>
5
+ #include <sstream>
6
+ #include <stdexcept>
7
+
8
+ #include "mlx/io/load.h"
9
+
10
+ namespace mlx::onnx {
11
+ namespace {
12
+
13
+ constexpr const char* kGraphIrIoTag = "ir.io";
14
+ constexpr const char* kGraphIrLoweringUnsupportedPrefix =
15
+ "[ir.lowering] unsupported";
16
+ constexpr const char* kGraphIrLegacyUnsupportedPrefix =
17
+ "[ir_to_onnx_stub] unsupported";
18
+
19
+ void write_binary_file(
20
+ const std::filesystem::path& path,
21
+ const std::string& bytes) {
22
+ mlx::core::io::FileWriter output(path.string());
23
+ if (!output.good()) {
24
+ std::ostringstream out;
25
+ out << "failed to open file for write: " << path.string();
26
+ throw std::runtime_error(
27
+ detail::tagged_error_message(kGraphIrIoTag, out.str()));
28
+ }
29
+ output.write(bytes.data(), bytes.size());
30
+ }
31
+
32
+ } // namespace
33
+
34
+ std::string write_onnx_binary_artifact_to_path(
35
+ const std::string& target_path,
36
+ const OnnxBinaryArtifact& artifact,
37
+ const OnnxBinaryWriteOptions& options) {
38
+ std::filesystem::path path(target_path);
39
+ if (!path.has_parent_path()) {
40
+ path = std::filesystem::absolute(path);
41
+ }
42
+ const auto parent = path.parent_path();
43
+ if (!parent.empty()) {
44
+ std::filesystem::create_directories(parent);
45
+ }
46
+
47
+ write_binary_file(path, artifact.model_bytes);
48
+ if (options.external_data && artifact.has_external_data) {
49
+ write_binary_file(
50
+ parent / options.external_data_file, artifact.external_data_bytes);
51
+ }
52
+
53
+ return path.string();
54
+ }
55
+
56
+ bool ir_is_unsupported_error_message(const std::string& message) {
57
+ return message.find(kGraphIrLoweringUnsupportedPrefix) != std::string::npos ||
58
+ message.find(kGraphIrLegacyUnsupportedPrefix) != std::string::npos;
59
+ }
60
+
61
+ } // namespace mlx::onnx