mlx 0.30.7.2 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (605) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +112 -58
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/distributed_utils/common.rb +1 -1
  8. data/lib/mlx/distributed_utils/config.rb +7 -4
  9. data/lib/mlx/distributed_utils/launch.rb +2 -0
  10. data/lib/mlx/dsl/attention.rb +132 -0
  11. data/lib/mlx/dsl/builder.rb +8 -0
  12. data/lib/mlx/dsl/config_schema.rb +133 -0
  13. data/lib/mlx/dsl/generate.rb +193 -0
  14. data/lib/mlx/dsl/kv_cache.rb +96 -0
  15. data/lib/mlx/dsl/masks.rb +32 -0
  16. data/lib/mlx/dsl/positions.rb +35 -0
  17. data/lib/mlx/dsl/run_stack.rb +68 -0
  18. data/lib/mlx/dsl/tensor.rb +126 -0
  19. data/lib/mlx/dsl/transformer_block.rb +113 -0
  20. data/lib/mlx/dsl/weight_map.rb +140 -0
  21. data/lib/mlx/dsl.rb +10 -0
  22. data/lib/mlx/nn/base.rb +4 -0
  23. data/lib/mlx/nn/layers/linear.rb +2 -3
  24. data/lib/mlx/onnx.rb +250 -0
  25. data/lib/mlx/version.rb +1 -1
  26. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  28. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  29. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  30. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  31. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  32. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  33. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  34. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  35. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  36. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  37. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  38. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  39. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  40. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  41. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  42. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  43. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  44. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  45. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  46. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  47. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  48. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  49. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  50. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  51. data/submodules/mlx-onnx/LICENSE +21 -0
  52. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  53. data/submodules/mlx-onnx/src/api.cpp +81 -0
  54. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  55. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  56. data/submodules/mlx-onnx/src/export.cpp +653 -0
  57. data/submodules/mlx-onnx/src/io.cpp +61 -0
  58. data/submodules/mlx-onnx/src/json.hpp +25 -0
  59. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  60. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  61. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  62. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  63. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  64. metadata +665 -567
  65. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  66. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  67. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  68. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  69. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  70. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  71. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  590. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  591. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  592. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  593. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  594. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  595. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  596. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  597. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  598. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  599. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  600. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  601. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  602. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  603. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  604. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  605. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -0,0 +1,1402 @@
1
+ #include "native.hpp"
2
+
3
+ #include <algorithm>
4
+ #include <array>
5
+ #include <bit>
6
+ #include <chrono>
7
+ #include <cctype>
8
+ #include <cmath>
9
+ #include <cstring>
10
+ #include <cstdio>
11
+ #include <cstdint>
12
+ #include <cerrno>
13
+ #include <cstdlib>
14
+ #include <exception>
15
+ #include <filesystem>
16
+ #include <functional>
17
+ #include <fstream>
18
+ #include <iomanip>
19
+ #include <limits>
20
+ #include <map>
21
+ #include <numeric>
22
+ #include <optional>
23
+ #include <set>
24
+ #include <sstream>
25
+ #include <stdexcept>
26
+ #include <string>
27
+ #include <string_view>
28
+ #include <tuple>
29
+ #include <unordered_map>
30
+ #include <unordered_set>
31
+ #include <utility>
32
+ #include <variant>
33
+ #include <vector>
34
+
35
+ #ifdef snprintf
36
+ #undef snprintf
37
+ #endif
38
+
39
+ #include <nlohmann/json.hpp>
40
+
41
+ #include "json.hpp"
42
+ #include "mlx/export.h"
43
+ #include "mlx/ops.h"
44
+
45
+ namespace mx = mlx::core;
46
+
47
+ using OrderedJson = nlohmann::ordered_json;
48
+
49
+ namespace {
50
+
51
+ // Ruby binding/front-end for IR capture and argument/source normalization.
52
+ //
53
+ // Heavyweight ONNX lowering, compatibility probing, and protobuf encoding live
54
+ // in ir_core.{hpp,cpp}. This file keeps Ruby VALUE conversion, tracing
55
+ // invocation decoding, and exception translation.
56
+
57
+ // ============================================================================
58
+ // Section: Binding State and Capture Types
59
+ // ============================================================================
60
+
61
+ static VALUE mONNX;
62
+ static VALUE mONNXNative;
63
+ static VALUE eOnnxNativeUnsupportedError = Qnil;
64
+
65
+ constexpr int64_t kGraphIrVersion = mlx::onnx::kGraphIrVersion;
66
+
67
+ using GraphTensorInfo = std::tuple<std::string, mx::Shape, mx::Dtype>;
68
+
69
+ // Decoded Ruby invocation used by export entry points.
70
+ struct GraphIrExportInvocation {
71
+ VALUE fun;
72
+ mx::Args args;
73
+ mx::Kwargs kwargs;
74
+ bool shapeless;
75
+ };
76
+
77
+ struct GraphIrExportTimingStats {
78
+ double export_function_ms = 0.0;
79
+ double constants_capture_ms = 0.0;
80
+ size_t constants_count = 0;
81
+ size_t constant_elements = 0;
82
+ };
83
+
84
+ // ============================================================================
85
+ // Section: Timing and Diagnostics Helpers
86
+ // ============================================================================
87
+
88
+ static bool onnx_native_timing_enabled() {
89
+ // Accept explicit false-like values; everything else enables timing.
90
+ const char* raw = std::getenv("MLX_IR_NATIVE_TIMING");
91
+ if (raw == nullptr) {
92
+ return false;
93
+ }
94
+
95
+ std::string value(raw);
96
+ if (value.empty()) {
97
+ return false;
98
+ }
99
+
100
+ std::transform(
101
+ value.begin(),
102
+ value.end(),
103
+ value.begin(),
104
+ [](unsigned char ch) { return static_cast<char>(std::tolower(ch)); });
105
+ return !(value == "0" || value == "false" || value == "off" || value == "no");
106
+ }
107
+
108
+ static double elapsed_millis(std::chrono::steady_clock::time_point started_at) {
109
+ const auto finished_at = std::chrono::steady_clock::now();
110
+ return std::chrono::duration<double, std::milli>(finished_at - started_at).count();
111
+ }
112
+
113
+ static void emit_onnx_native_timing_line(const std::string& line) {
114
+ std::fprintf(stderr, "%s\n", line.c_str());
115
+ std::fflush(stderr);
116
+ }
117
+
118
+ static void emit_export_onnx_json_timing_line(
119
+ const GraphIrExportInvocation& invocation,
120
+ int64_t opset,
121
+ const std::string& model_name,
122
+ const GraphIrExportTimingStats& export_stats,
123
+ double args_decode_ms,
124
+ double export_ir_ms,
125
+ double lower_onnx_ms,
126
+ double dump_json_ms,
127
+ double total_ms,
128
+ size_t onnx_json_bytes) {
129
+ // Key=value single-line logs are easier to parse in CI logs.
130
+ std::ostringstream out;
131
+ out << std::fixed << std::setprecision(3);
132
+ out << "[mlx.onnx.native.timing] export_onnx_json";
133
+ out << " total_ms=" << total_ms;
134
+ out << " args_decode_ms=" << args_decode_ms;
135
+ out << " export_ir_ms=" << export_ir_ms;
136
+ out << " trace_export_ms=" << export_stats.export_function_ms;
137
+ out << " constants_capture_ms=" << export_stats.constants_capture_ms;
138
+ out << " constants_count=" << export_stats.constants_count;
139
+ out << " constant_elements=" << export_stats.constant_elements;
140
+ out << " lower_onnx_ms=" << lower_onnx_ms;
141
+ out << " json_dump_ms=" << dump_json_ms;
142
+ out << " onnx_json_bytes=" << onnx_json_bytes;
143
+ out << " shapeless=" << (invocation.shapeless ? "true" : "false");
144
+ out << " opset=" << opset;
145
+ out << " model_name=" << model_name;
146
+ emit_onnx_native_timing_line(out.str());
147
+ }
148
+
149
+ static void emit_graph_ir_to_onnx_json_timing_line(
150
+ int64_t opset,
151
+ const std::string& model_name,
152
+ double parse_json_ms,
153
+ double lower_onnx_ms,
154
+ double dump_json_ms,
155
+ double total_ms,
156
+ size_t onnx_json_bytes) {
157
+ std::ostringstream out;
158
+ out << std::fixed << std::setprecision(3);
159
+ out << "[mlx.onnx.native.timing] graph_ir_to_onnx_json";
160
+ out << " total_ms=" << total_ms;
161
+ out << " parse_json_ms=" << parse_json_ms;
162
+ out << " lower_onnx_ms=" << lower_onnx_ms;
163
+ out << " json_dump_ms=" << dump_json_ms;
164
+ out << " onnx_json_bytes=" << onnx_json_bytes;
165
+ out << " opset=" << opset;
166
+ out << " model_name=" << model_name;
167
+ emit_onnx_native_timing_line(out.str());
168
+ }
169
+
170
+ static std::string dtype_to_string(mx::Dtype dtype) {
171
+ std::ostringstream out;
172
+ out << dtype;
173
+ return out.str();
174
+ }
175
+
176
+ static void ruby_hash_set_cstr(VALUE hash, const char* key, VALUE value) {
177
+ rb_hash_aset(hash, rb_str_new_cstr(key), value);
178
+ }
179
+
180
+ // ============================================================================
181
+ // Section: IR Export Capture and Trace Conversion
182
+ // ============================================================================
183
+
184
+ [[noreturn]] static void raise_onnx_native_exception(const std::exception& error);
185
+
186
+ template <typename ValueAt>
187
+ static OrderedJson capture_build_nested_json_array(
188
+ const mx::Shape& shape,
189
+ size_t dim,
190
+ size_t& flat_index,
191
+ ValueAt value_at) {
192
+ // Rebuild N-D JSON nesting from a flat buffer cursor.
193
+ if (dim == shape.size()) {
194
+ return value_at(flat_index++);
195
+ }
196
+
197
+ OrderedJson out = OrderedJson::array();
198
+ for (size_t i = 0; i < shape[dim]; ++i) {
199
+ out.push_back(capture_build_nested_json_array(shape, dim + 1, flat_index, value_at));
200
+ }
201
+ return out;
202
+ }
203
+
204
+ template <typename ValueAt>
205
+ static OrderedJson capture_build_flat_json_array(size_t size, ValueAt value_at) {
206
+ OrderedJson out = OrderedJson::array();
207
+ for (size_t i = 0; i < size; ++i) {
208
+ out.push_back(value_at(i));
209
+ }
210
+ return out;
211
+ }
212
+
213
+ static OrderedJson capture_json_shape_from_mx_shape(const mx::Shape& shape) {
214
+ OrderedJson out = OrderedJson::array();
215
+ for (size_t dim : shape) {
216
+ out.push_back(dim);
217
+ }
218
+ return out;
219
+ }
220
+
221
+ static OrderedJson capture_json_scalar_from_array(const mx::array& array) {
222
+ // Scalar conversion preserves integer types and normalizes floats to JSON
223
+ // numbers (double) for stable serialization.
224
+ switch (array.dtype()) {
225
+ case mx::bool_:
226
+ return OrderedJson(array.item<bool>());
227
+ case mx::uint8:
228
+ return OrderedJson(array.item<uint8_t>());
229
+ case mx::uint16:
230
+ return OrderedJson(array.item<uint16_t>());
231
+ case mx::uint32:
232
+ return OrderedJson(array.item<uint32_t>());
233
+ case mx::uint64:
234
+ return OrderedJson(array.item<uint64_t>());
235
+ case mx::int8:
236
+ return OrderedJson(array.item<int8_t>());
237
+ case mx::int16:
238
+ return OrderedJson(array.item<int16_t>());
239
+ case mx::int32:
240
+ return OrderedJson(array.item<int32_t>());
241
+ case mx::int64:
242
+ return OrderedJson(array.item<int64_t>());
243
+ case mx::float16:
244
+ return OrderedJson(static_cast<double>(array.item<mx::float16_t>()));
245
+ case mx::bfloat16:
246
+ return OrderedJson(static_cast<double>(array.item<mx::bfloat16_t>()));
247
+ case mx::float32:
248
+ return OrderedJson(static_cast<double>(array.item<float>()));
249
+ case mx::float64:
250
+ return OrderedJson(array.item<double>());
251
+ default:
252
+ throw std::runtime_error("unsupported dtype for graph ir constant conversion");
253
+ }
254
+ }
255
+
256
+ static OrderedJson capture_json_values_from_array(const mx::array& source) {
257
+ // Constants are eagerly materialized so later lowering/encoding has a single
258
+ // JSON representation independent of backend/device buffers.
259
+ mx::array array = source;
260
+ if (array.ndim() == 0) {
261
+ array.eval();
262
+ return capture_json_scalar_from_array(array);
263
+ }
264
+
265
+ if (array.ndim() == 1) {
266
+ array.eval();
267
+ const size_t size = array.size();
268
+ switch (array.dtype()) {
269
+ case mx::bool_: {
270
+ const bool* data = array.data<bool>();
271
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
272
+ }
273
+ case mx::uint8: {
274
+ const uint8_t* data = array.data<uint8_t>();
275
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
276
+ }
277
+ case mx::uint16: {
278
+ const uint16_t* data = array.data<uint16_t>();
279
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
280
+ }
281
+ case mx::uint32: {
282
+ const uint32_t* data = array.data<uint32_t>();
283
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
284
+ }
285
+ case mx::uint64: {
286
+ const uint64_t* data = array.data<uint64_t>();
287
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
288
+ }
289
+ case mx::int8: {
290
+ const int8_t* data = array.data<int8_t>();
291
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
292
+ }
293
+ case mx::int16: {
294
+ const int16_t* data = array.data<int16_t>();
295
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
296
+ }
297
+ case mx::int32: {
298
+ const int32_t* data = array.data<int32_t>();
299
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
300
+ }
301
+ case mx::int64: {
302
+ const int64_t* data = array.data<int64_t>();
303
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
304
+ }
305
+ case mx::float16: {
306
+ const mx::float16_t* data = array.data<mx::float16_t>();
307
+ return capture_build_flat_json_array(
308
+ size,
309
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
310
+ }
311
+ case mx::bfloat16: {
312
+ const mx::bfloat16_t* data = array.data<mx::bfloat16_t>();
313
+ return capture_build_flat_json_array(
314
+ size,
315
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
316
+ }
317
+ case mx::float32: {
318
+ const float* data = array.data<float>();
319
+ return capture_build_flat_json_array(
320
+ size,
321
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
322
+ }
323
+ case mx::float64: {
324
+ const double* data = array.data<double>();
325
+ return capture_build_flat_json_array(size, [&](size_t i) { return OrderedJson(data[i]); });
326
+ }
327
+ default:
328
+ throw std::runtime_error("unsupported dtype for graph ir constant conversion");
329
+ }
330
+ }
331
+
332
+ const mx::Shape shape = array.shape();
333
+ mx::array flat = mx::reshape(array, mx::Shape{static_cast<mx::ShapeElem>(array.size())});
334
+ flat.eval();
335
+
336
+ size_t idx = 0;
337
+ switch (flat.dtype()) {
338
+ case mx::bool_: {
339
+ const bool* data = flat.data<bool>();
340
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
341
+ }
342
+ case mx::uint8: {
343
+ const uint8_t* data = flat.data<uint8_t>();
344
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
345
+ }
346
+ case mx::uint16: {
347
+ const uint16_t* data = flat.data<uint16_t>();
348
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
349
+ }
350
+ case mx::uint32: {
351
+ const uint32_t* data = flat.data<uint32_t>();
352
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
353
+ }
354
+ case mx::uint64: {
355
+ const uint64_t* data = flat.data<uint64_t>();
356
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
357
+ }
358
+ case mx::int8: {
359
+ const int8_t* data = flat.data<int8_t>();
360
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
361
+ }
362
+ case mx::int16: {
363
+ const int16_t* data = flat.data<int16_t>();
364
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
365
+ }
366
+ case mx::int32: {
367
+ const int32_t* data = flat.data<int32_t>();
368
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
369
+ }
370
+ case mx::int64: {
371
+ const int64_t* data = flat.data<int64_t>();
372
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
373
+ }
374
+ case mx::float16: {
375
+ const mx::float16_t* data = flat.data<mx::float16_t>();
376
+ return capture_build_nested_json_array(
377
+ shape,
378
+ 0,
379
+ idx,
380
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
381
+ }
382
+ case mx::bfloat16: {
383
+ const mx::bfloat16_t* data = flat.data<mx::bfloat16_t>();
384
+ return capture_build_nested_json_array(
385
+ shape,
386
+ 0,
387
+ idx,
388
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
389
+ }
390
+ case mx::float32: {
391
+ const float* data = flat.data<float>();
392
+ return capture_build_nested_json_array(
393
+ shape,
394
+ 0,
395
+ idx,
396
+ [&](size_t i) { return OrderedJson(static_cast<double>(data[i])); });
397
+ }
398
+ case mx::float64: {
399
+ const double* data = flat.data<double>();
400
+ return capture_build_nested_json_array(shape, 0, idx, [&](size_t i) { return OrderedJson(data[i]); });
401
+ }
402
+ default:
403
+ throw std::runtime_error("unsupported dtype for graph ir constant conversion");
404
+ }
405
+ }
406
+
407
+ static OrderedJson capture_json_tensor_info_from_graph_tensor(const GraphTensorInfo& info) {
408
+ OrderedJson out = OrderedJson::object();
409
+ out["name"] = std::get<0>(info);
410
+ out["shape"] = capture_json_shape_from_mx_shape(std::get<1>(info));
411
+ out["dtype"] = dtype_to_string(std::get<2>(info));
412
+ return out;
413
+ }
414
+
415
+ static OrderedJson capture_json_tensor_infos_from_graph_tensors(const std::vector<GraphTensorInfo>& infos) {
416
+ OrderedJson out = OrderedJson::array();
417
+ for (const auto& info : infos) {
418
+ out.push_back(capture_json_tensor_info_from_graph_tensor(info));
419
+ }
420
+ return out;
421
+ }
422
+
423
+ static OrderedJson capture_json_tensor_names_from_graph_tensors(const std::vector<GraphTensorInfo>& infos) {
424
+ OrderedJson out = OrderedJson::array();
425
+ for (const auto& info : infos) {
426
+ out.push_back(std::get<0>(info));
427
+ }
428
+ return out;
429
+ }
430
+
431
+ static OrderedJson capture_json_state_value_from_mx_state(const mx::StateT& value) {
432
+ // Export callback state arguments use a tagged variant; serialize every
433
+ // supported alternative into JSON values consumed by lowering.
434
+ if (std::holds_alternative<bool>(value)) {
435
+ return OrderedJson(std::get<bool>(value));
436
+ }
437
+ if (std::holds_alternative<int>(value)) {
438
+ return OrderedJson(std::get<int>(value));
439
+ }
440
+ if (std::holds_alternative<size_t>(value)) {
441
+ return OrderedJson(std::get<size_t>(value));
442
+ }
443
+ if (std::holds_alternative<float>(value)) {
444
+ return OrderedJson(static_cast<double>(std::get<float>(value)));
445
+ }
446
+ if (std::holds_alternative<double>(value)) {
447
+ return OrderedJson(std::get<double>(value));
448
+ }
449
+ if (std::holds_alternative<mx::Dtype>(value)) {
450
+ return OrderedJson(dtype_to_string(std::get<mx::Dtype>(value)));
451
+ }
452
+ if (std::holds_alternative<mx::Shape>(value)) {
453
+ return capture_json_shape_from_mx_shape(std::get<mx::Shape>(value));
454
+ }
455
+ if (std::holds_alternative<mx::Strides>(value)) {
456
+ OrderedJson out = OrderedJson::array();
457
+ const auto& strides = std::get<mx::Strides>(value);
458
+ for (auto stride : strides) {
459
+ out.push_back(static_cast<long long>(stride));
460
+ }
461
+ return out;
462
+ }
463
+ if (std::holds_alternative<std::vector<int>>(value)) {
464
+ OrderedJson out = OrderedJson::array();
465
+ const auto& values = std::get<std::vector<int>>(value);
466
+ for (int item : values) {
467
+ out.push_back(item);
468
+ }
469
+ return out;
470
+ }
471
+ if (std::holds_alternative<std::vector<size_t>>(value)) {
472
+ OrderedJson out = OrderedJson::array();
473
+ const auto& values = std::get<std::vector<size_t>>(value);
474
+ for (size_t item : values) {
475
+ out.push_back(item);
476
+ }
477
+ return out;
478
+ }
479
+ if (std::holds_alternative<std::vector<std::tuple<bool, bool, bool>>>(value)) {
480
+ OrderedJson out = OrderedJson::array();
481
+ const auto& tuples = std::get<std::vector<std::tuple<bool, bool, bool>>>(value);
482
+ for (const auto& item : tuples) {
483
+ out.push_back(OrderedJson::array({std::get<0>(item), std::get<1>(item), std::get<2>(item)}));
484
+ }
485
+ return out;
486
+ }
487
+ if (std::holds_alternative<std::vector<std::variant<bool, int, float>>>(value)) {
488
+ OrderedJson out = OrderedJson::array();
489
+ const auto& vars = std::get<std::vector<std::variant<bool, int, float>>>(value);
490
+ for (const auto& item : vars) {
491
+ if (std::holds_alternative<bool>(item)) {
492
+ out.push_back(std::get<bool>(item));
493
+ } else if (std::holds_alternative<int>(item)) {
494
+ out.push_back(std::get<int>(item));
495
+ } else {
496
+ out.push_back(static_cast<double>(std::get<float>(item)));
497
+ }
498
+ }
499
+ return out;
500
+ }
501
+ if (std::holds_alternative<std::optional<float>>(value)) {
502
+ const auto& opt = std::get<std::optional<float>>(value);
503
+ if (!opt.has_value()) {
504
+ return nullptr;
505
+ }
506
+ return OrderedJson(static_cast<double>(opt.value()));
507
+ }
508
+ return OrderedJson(std::get<std::string>(value));
509
+ }
510
+
511
+ static OrderedJson capture_json_state_values_from_mx_states(const std::vector<mx::StateT>& values) {
512
+ OrderedJson out = OrderedJson::array();
513
+ for (const auto& value : values) {
514
+ out.push_back(capture_json_state_value_from_mx_state(value));
515
+ }
516
+ return out;
517
+ }
518
+
519
+ template <typename T>
520
+ static const T* export_callback_field(
521
+ const mx::ExportCallbackInput& data,
522
+ const std::string& key) {
523
+ // Callback records are heterogenous key/value pairs; this helper combines
524
+ // key lookup and variant type-check in one place.
525
+ for (const auto& [candidate_key, candidate_value] : data) {
526
+ if (candidate_key == key && std::holds_alternative<T>(candidate_value)) {
527
+ return &std::get<T>(candidate_value);
528
+ }
529
+ }
530
+ return nullptr;
531
+ }
532
+
533
+ static OrderedJson export_ir_payload(
534
+ const GraphIrExportInvocation& invocation,
535
+ GraphIrExportTimingStats* timing_stats = nullptr) {
536
+ // Single capture pass: collect graph metadata and primitive nodes from
537
+ // export_function callback records and normalize them into IR JSON.
538
+ OrderedJson graph_inputs = OrderedJson::array();
539
+ OrderedJson keyword_inputs = OrderedJson::array();
540
+ OrderedJson graph_outputs = OrderedJson::array();
541
+ OrderedJson graph_constants = OrderedJson::array();
542
+ OrderedJson graph_nodes = OrderedJson::array();
543
+
544
+ const auto trace_started_at = std::chrono::steady_clock::now();
545
+ mx::export_function(
546
+ [&graph_inputs, &keyword_inputs, &graph_outputs, &graph_constants, &graph_nodes, timing_stats](
547
+ const mx::ExportCallbackInput& data) {
548
+ // Record schema comes from mlx::export_function "type" discriminator.
549
+ const auto* record_type = export_callback_field<std::string>(data, "type");
550
+ if (record_type == nullptr) {
551
+ return;
552
+ }
553
+
554
+ if (*record_type == "inputs") {
555
+ const auto* inputs = export_callback_field<std::vector<GraphTensorInfo>>(data, "inputs");
556
+ if (inputs != nullptr) {
557
+ graph_inputs = capture_json_tensor_infos_from_graph_tensors(*inputs);
558
+ }
559
+ return;
560
+ }
561
+
562
+ if (*record_type == "keyword_inputs") {
563
+ const auto* keywords =
564
+ export_callback_field<std::vector<std::pair<std::string, std::string>>>(
565
+ data,
566
+ "keywords");
567
+ if (keywords != nullptr) {
568
+ keyword_inputs = OrderedJson::array();
569
+ for (const auto& [name, tensor] : *keywords) {
570
+ OrderedJson entry = OrderedJson::object();
571
+ entry["name"] = name;
572
+ entry["tensor"] = tensor;
573
+ keyword_inputs.push_back(std::move(entry));
574
+ }
575
+ }
576
+ return;
577
+ }
578
+
579
+ if (*record_type == "outputs") {
580
+ const auto* outputs = export_callback_field<std::vector<GraphTensorInfo>>(data, "outputs");
581
+ if (outputs != nullptr) {
582
+ graph_outputs = capture_json_tensor_infos_from_graph_tensors(*outputs);
583
+ }
584
+ return;
585
+ }
586
+
587
+ if (*record_type == "constants") {
588
+ const auto* constants =
589
+ export_callback_field<std::vector<std::pair<std::string, mx::array>>>(
590
+ data,
591
+ "constants");
592
+ if (constants != nullptr) {
593
+ graph_constants = OrderedJson::array();
594
+ if (timing_stats != nullptr) {
595
+ timing_stats->constants_count += constants->size();
596
+ }
597
+ for (const auto& [name, arr] : *constants) {
598
+ // Constant payloads are embedded by value so that ONNX export can
599
+ // be run later without re-tracing.
600
+ OrderedJson entry = OrderedJson::object();
601
+ entry["name"] = name;
602
+ entry["shape"] = capture_json_shape_from_mx_shape(arr.shape());
603
+ entry["dtype"] = dtype_to_string(arr.dtype());
604
+ if (timing_stats != nullptr) {
605
+ timing_stats->constant_elements += static_cast<size_t>(arr.size());
606
+ const auto capture_started_at = std::chrono::steady_clock::now();
607
+ entry["values"] = capture_json_values_from_array(arr);
608
+ timing_stats->constants_capture_ms += elapsed_millis(capture_started_at);
609
+ } else {
610
+ entry["values"] = capture_json_values_from_array(arr);
611
+ }
612
+ graph_constants.push_back(std::move(entry));
613
+ }
614
+ }
615
+ return;
616
+ }
617
+
618
+ if (*record_type != "primitive") {
619
+ return;
620
+ }
621
+
622
+ const auto* op_name = export_callback_field<std::string>(data, "name");
623
+ if (op_name == nullptr) {
624
+ return;
625
+ }
626
+
627
+ OrderedJson node = OrderedJson::object();
628
+ node["op"] = *op_name;
629
+
630
+ OrderedJson node_inputs = OrderedJson::array();
631
+ const auto* node_input_infos = export_callback_field<std::vector<GraphTensorInfo>>(data, "inputs");
632
+ if (node_input_infos != nullptr) {
633
+ node_inputs = capture_json_tensor_names_from_graph_tensors(*node_input_infos);
634
+ }
635
+ node["inputs"] = std::move(node_inputs);
636
+
637
+ OrderedJson node_outputs = OrderedJson::array();
638
+ const auto* node_output_infos =
639
+ export_callback_field<std::vector<GraphTensorInfo>>(data, "outputs");
640
+ if (node_output_infos != nullptr) {
641
+ node_outputs = capture_json_tensor_names_from_graph_tensors(*node_output_infos);
642
+ }
643
+ node["outputs"] = std::move(node_outputs);
644
+
645
+ OrderedJson node_arguments = OrderedJson::array();
646
+ const auto* arguments = export_callback_field<std::vector<mx::StateT>>(data, "arguments");
647
+ if (arguments != nullptr) {
648
+ node_arguments = capture_json_state_values_from_mx_states(*arguments);
649
+ }
650
+ node["arguments"] = std::move(node_arguments);
651
+
652
+ graph_nodes.push_back(std::move(node));
653
+ },
654
+ onnx_args_kwargs_function_from_callable(invocation.fun),
655
+ invocation.args,
656
+ invocation.kwargs,
657
+ invocation.shapeless);
658
+ if (timing_stats != nullptr) {
659
+ timing_stats->export_function_ms += elapsed_millis(trace_started_at);
660
+ }
661
+
662
+ OrderedJson payload = OrderedJson::object();
663
+ payload["ir_version"] = 1;
664
+ payload["shapeless"] = invocation.shapeless;
665
+ payload["inputs"] = std::move(graph_inputs);
666
+ payload["keyword_inputs"] = std::move(keyword_inputs);
667
+ payload["outputs"] = std::move(graph_outputs);
668
+ payload["constants"] = std::move(graph_constants);
669
+ payload["nodes"] = std::move(graph_nodes);
670
+ return payload;
671
+ }
672
+
673
+ // ============================================================================
674
+ // Section: Ruby <-> OrderedJson Conversion and Source Parsing
675
+ // ============================================================================
676
+
677
+ static std::string std_string_from_ruby(VALUE value) {
678
+ VALUE str = rb_obj_as_string(value);
679
+ return std::string(RSTRING_PTR(str), static_cast<size_t>(RSTRING_LEN(str)));
680
+ }
681
+
682
+ static VALUE ruby_string_from_std(const std::string& value) {
683
+ return rb_str_new(value.data(), static_cast<long>(value.size()));
684
+ }
685
+
686
+ static OrderedJson ordered_json_from_ruby(VALUE value);
687
+ static OrderedJson ordered_json_complex_from_ruby(VALUE value);
688
+
689
+ static OrderedJson ordered_json_integer_from_ruby(VALUE value) {
690
+ // Preserve Integer magnitude when possible (int64/uint64), then degrade to
691
+ // double only when outside 64-bit integer JSON range.
692
+ VALUE text_value = rb_funcall(value, rb_intern("to_s"), 0);
693
+ const std::string text =
694
+ std::string(RSTRING_PTR(text_value), static_cast<size_t>(RSTRING_LEN(text_value)));
695
+ if (text.empty()) {
696
+ throw std::invalid_argument("failed to convert Integer to JSON number");
697
+ }
698
+
699
+ const bool negative = text.front() == '-';
700
+ try {
701
+ if (negative) {
702
+ return static_cast<int64_t>(std::stoll(text));
703
+ }
704
+
705
+ const auto raw = std::stoull(text);
706
+ if (raw <= static_cast<unsigned long long>(std::numeric_limits<int64_t>::max())) {
707
+ return static_cast<int64_t>(raw);
708
+ }
709
+ return static_cast<uint64_t>(raw);
710
+ } catch (const std::out_of_range&) {
711
+ try {
712
+ return static_cast<double>(std::stold(text));
713
+ } catch (const std::exception&) {
714
+ throw std::invalid_argument("Integer is too large to convert into JSON numeric range");
715
+ }
716
+ } catch (const std::invalid_argument&) {
717
+ throw std::invalid_argument("failed to parse Integer while converting to JSON");
718
+ }
719
+ }
720
+
721
+ static OrderedJson ordered_json_object_from_ruby_hash(VALUE hash) {
722
+ OrderedJson out = OrderedJson::object();
723
+ VALUE keys = rb_funcall(hash, rb_intern("keys"), 0);
724
+ const long len = RARRAY_LEN(keys);
725
+ for (long i = 0; i < len; ++i) {
726
+ VALUE key = rb_ary_entry(keys, i);
727
+ VALUE item = rb_hash_aref(hash, key);
728
+ VALUE key_str = rb_obj_as_string(key);
729
+ out[std::string(RSTRING_PTR(key_str), static_cast<size_t>(RSTRING_LEN(key_str)))] =
730
+ ordered_json_from_ruby(item);
731
+ }
732
+ return out;
733
+ }
734
+
735
+ static OrderedJson ordered_json_complex_from_ruby(VALUE value) {
736
+ // Complex values round-trip via an explicit marker object.
737
+ VALUE real_value = rb_funcall(value, rb_intern("real"), 0);
738
+ VALUE imag_value = rb_funcall(value, rb_intern("imag"), 0);
739
+ double real = NUM2DBL(real_value);
740
+ double imag = NUM2DBL(imag_value);
741
+ OrderedJson pair = OrderedJson::array();
742
+ pair.push_back(real);
743
+ pair.push_back(imag);
744
+ OrderedJson out = OrderedJson::object();
745
+ out["__mlx_complex__"] = std::move(pair);
746
+ return out;
747
+ }
748
+
749
+ static OrderedJson ordered_json_from_ruby(VALUE value) {
750
+ // This conversion intentionally accepts more than strict JSON scalar classes
751
+ // to keep Ruby-side ergonomics predictable.
752
+ if (NIL_P(value)) {
753
+ return nullptr;
754
+ }
755
+ if (value == Qtrue || value == Qfalse) {
756
+ return value == Qtrue;
757
+ }
758
+ if (rb_obj_is_kind_of(value, rb_cComplex)) {
759
+ return ordered_json_complex_from_ruby(value);
760
+ }
761
+ if (RB_INTEGER_TYPE_P(value)) {
762
+ return ordered_json_integer_from_ruby(value);
763
+ }
764
+ if (RB_FLOAT_TYPE_P(value)) {
765
+ return NUM2DBL(value);
766
+ }
767
+ if (RB_TYPE_P(value, T_STRING)) {
768
+ return std::string(RSTRING_PTR(value), static_cast<size_t>(RSTRING_LEN(value)));
769
+ }
770
+ if (RB_TYPE_P(value, T_ARRAY)) {
771
+ OrderedJson out = OrderedJson::array();
772
+ const long len = RARRAY_LEN(value);
773
+ for (long i = 0; i < len; ++i) {
774
+ out.push_back(ordered_json_from_ruby(rb_ary_entry(value, i)));
775
+ }
776
+ return out;
777
+ }
778
+ if (RB_TYPE_P(value, T_HASH)) {
779
+ return ordered_json_object_from_ruby_hash(value);
780
+ }
781
+ if (SYMBOL_P(value)) {
782
+ VALUE symbol_str = rb_sym2str(value);
783
+ return std::string(RSTRING_PTR(symbol_str), static_cast<size_t>(RSTRING_LEN(symbol_str)));
784
+ }
785
+
786
+ VALUE converted = rb_obj_as_string(value);
787
+ return std::string(RSTRING_PTR(converted), static_cast<size_t>(RSTRING_LEN(converted)));
788
+ }
789
+
790
+ static VALUE ruby_value_from_ordered_json(const OrderedJson& value) {
791
+ // Reverse bridge for native methods that return structured JSON payloads.
792
+ if (value.is_null()) {
793
+ return Qnil;
794
+ }
795
+ if (value.is_boolean()) {
796
+ return value.get<bool>() ? Qtrue : Qfalse;
797
+ }
798
+ if (value.is_number_integer()) {
799
+ return LL2NUM(value.get<int64_t>());
800
+ }
801
+ if (value.is_number_unsigned()) {
802
+ return ULL2NUM(value.get<uint64_t>());
803
+ }
804
+ if (value.is_number_float()) {
805
+ return rb_float_new(value.get<double>());
806
+ }
807
+ if (value.is_string()) {
808
+ const auto& text = value.get_ref<const std::string&>();
809
+ return rb_str_new(text.data(), static_cast<long>(text.size()));
810
+ }
811
+ if (value.is_array()) {
812
+ VALUE out = rb_ary_new_capa(static_cast<long>(value.size()));
813
+ for (const auto& item : value) {
814
+ rb_ary_push(out, ruby_value_from_ordered_json(item));
815
+ }
816
+ return out;
817
+ }
818
+ if (value.is_object()) {
819
+ VALUE out = rb_hash_new();
820
+ for (auto it = value.begin(); it != value.end(); ++it) {
821
+ const auto& key = it.key();
822
+ rb_hash_aset(
823
+ out,
824
+ rb_str_new(key.data(), static_cast<long>(key.size())),
825
+ ruby_value_from_ordered_json(it.value()));
826
+ }
827
+ return out;
828
+ }
829
+
830
+ rb_raise(rb_eTypeError, "unsupported JSON value type");
831
+ return Qnil;
832
+ }
833
+
834
+ static std::string tagged_ir_api_error(const std::string& message) {
835
+ if (!message.empty() && message.front() == '[') {
836
+ return message;
837
+ }
838
+ return std::string("[ir.api] ") + message;
839
+ }
840
+
841
+ static std::string read_file_to_string(const std::string& path) {
842
+ std::ifstream input(path, std::ios::binary);
843
+ if (!input.good()) {
844
+ std::ostringstream out;
845
+ out << "failed to read file: " << path;
846
+ throw std::runtime_error(tagged_ir_api_error(out.str()));
847
+ }
848
+ std::ostringstream buffer;
849
+ buffer << input.rdbuf();
850
+ return buffer.str();
851
+ }
852
+
853
+ static OrderedJson parse_json_payload_from_string(const std::string& raw, const std::string& label) {
854
+ try {
855
+ return OrderedJson::parse(raw);
856
+ } catch (const std::exception& error) {
857
+ std::ostringstream out;
858
+ out << "failed to parse " << label << ": " << error.what();
859
+ throw std::invalid_argument(tagged_ir_api_error(out.str()));
860
+ }
861
+ }
862
+
863
+ static OrderedJson parse_ir_source_payload(VALUE source) {
864
+ // Unified source parser for Hash / JSON string / file path / IO-like object.
865
+ // Heuristic for String:
866
+ // - existing regular file path -> read and parse file contents
867
+ // - otherwise -> parse as JSON literal string
868
+ if (RB_TYPE_P(source, T_HASH)) {
869
+ return ordered_json_from_ruby(source);
870
+ }
871
+
872
+ if (RB_TYPE_P(source, T_STRING)) {
873
+ const std::string raw = std_string_from_ruby(source);
874
+ bool treat_as_file = false;
875
+ try {
876
+ treat_as_file = std::filesystem::is_regular_file(raw);
877
+ } catch (const std::filesystem::filesystem_error&) {
878
+ treat_as_file = false;
879
+ }
880
+ if (treat_as_file) {
881
+ return parse_json_payload_from_string(read_file_to_string(raw), "graph ir file");
882
+ }
883
+ return parse_json_payload_from_string(raw, "graph ir string");
884
+ }
885
+
886
+ if (rb_respond_to(source, rb_intern("to_path"))) {
887
+ VALUE path_value = rb_funcall(source, rb_intern("to_path"), 0);
888
+ const std::string path = std_string_from_ruby(path_value);
889
+ if (path.empty()) {
890
+ throw std::invalid_argument("[ir.api] graph ir path-like source must be non-empty");
891
+ }
892
+ if (!std::filesystem::is_regular_file(path)) {
893
+ std::ostringstream out;
894
+ out << "graph ir path does not exist: " << path;
895
+ throw std::invalid_argument(std::string("[ir.api] ") + out.str());
896
+ }
897
+ return parse_json_payload_from_string(read_file_to_string(path), "graph ir file");
898
+ }
899
+
900
+ if (rb_respond_to(source, rb_intern("read"))) {
901
+ VALUE io_raw = rb_funcall(source, rb_intern("read"), 0);
902
+ return parse_json_payload_from_string(std_string_from_ruby(io_raw), "graph ir IO");
903
+ }
904
+
905
+ throw std::invalid_argument(
906
+ "[ir.api] graph ir source must be a Hash, JSON String, file path, or IO-like object");
907
+ }
908
+
909
+ static std::string ruby_path_string(VALUE value, const char* label) {
910
+ // Binary export paths are path-like only; IO-like objects are rejected to
911
+ // keep external-data colocated path handling deterministic.
912
+ if (rb_respond_to(value, rb_intern("write"))) {
913
+ std::ostringstream out;
914
+ out << label << " requires a path-like target, not an IO-like target";
915
+ throw std::invalid_argument(out.str());
916
+ }
917
+
918
+ VALUE raw = value;
919
+ if (rb_respond_to(value, rb_intern("to_path"))) {
920
+ raw = rb_funcall(value, rb_intern("to_path"), 0);
921
+ }
922
+ const std::string path = std_string_from_ruby(raw);
923
+ if (path.empty()) {
924
+ std::ostringstream out;
925
+ out << label << " target must be a non-empty path-like value";
926
+ throw std::invalid_argument(out.str());
927
+ }
928
+ return path;
929
+ }
930
+
931
+
932
+ static int64_t normalize_positive_integer(VALUE value, const char* label) {
933
+ VALUE integer = rb_Integer(value);
934
+ const int64_t out = NUM2LL(integer);
935
+ if (out <= 0) {
936
+ std::ostringstream msg;
937
+ msg << label << " must be a positive Integer";
938
+ throw std::invalid_argument(msg.str());
939
+ }
940
+ return out;
941
+ }
942
+
943
+ static std::string non_empty_model_name(VALUE value) {
944
+ std::string out = std_string_from_ruby(value);
945
+ if (out.empty()) {
946
+ throw std::invalid_argument("model_name must not be empty");
947
+ }
948
+ return out;
949
+ }
950
+
951
+ static mlx::onnx::OnnxBinaryWriteOptions normalize_onnx_binary_write_options(
952
+ const std::string& target_path,
953
+ VALUE external_data,
954
+ VALUE external_data_file,
955
+ VALUE external_data_size_threshold) {
956
+ // Normalize Ruby kwargs into strict binary writer options.
957
+ if (!(external_data == Qtrue || external_data == Qfalse)) {
958
+ throw std::invalid_argument("external_data must be true or false");
959
+ }
960
+
961
+ mlx::onnx::OnnxBinaryWriteOptions options;
962
+ options.external_data = (external_data == Qtrue);
963
+ options.external_data_size_threshold = 1024;
964
+ options.external_data_file = "weights.bin";
965
+
966
+ if (!options.external_data) {
967
+ return options;
968
+ }
969
+
970
+ const int64_t threshold = NUM2LL(external_data_size_threshold);
971
+ if (threshold < 0) {
972
+ throw std::invalid_argument("external_data_size_threshold must be a non-negative Integer");
973
+ }
974
+ options.external_data_size_threshold = threshold;
975
+
976
+ std::string location;
977
+ if (NIL_P(external_data_file)) {
978
+ std::filesystem::path path(target_path);
979
+ std::string base = path.stem().string();
980
+ if (base.empty()) {
981
+ base = "weights";
982
+ }
983
+ location = base + ".data";
984
+ } else {
985
+ location = std_string_from_ruby(external_data_file);
986
+ }
987
+ if (location.empty()) {
988
+ throw std::invalid_argument("external_data_file must be a non-empty filename");
989
+ }
990
+ std::filesystem::path location_path(location);
991
+ if (location_path.has_parent_path() || location.find('/') != std::string::npos ||
992
+ location.find('\\') != std::string::npos) {
993
+ throw std::invalid_argument("external_data_file must be a filename without path separators");
994
+ }
995
+ options.external_data_file = location;
996
+ return options;
997
+ }
998
+
999
+ [[noreturn]] static void raise_onnx_native_exception(const std::exception& error) {
1000
+ // Promote lowering "unsupported" errors to typed Ruby exception so callers
1001
+ // can distinguish unsupported coverage from generic runtime failures.
1002
+ const std::string message(error.what());
1003
+ if (!NIL_P(eOnnxNativeUnsupportedError) &&
1004
+ mlx::onnx::ir_is_unsupported_error_message(message)) {
1005
+ VALUE exc = rb_exc_new_str(
1006
+ eOnnxNativeUnsupportedError,
1007
+ rb_str_new(message.data(), static_cast<long>(message.size())));
1008
+ rb_exc_raise(exc);
1009
+ }
1010
+
1011
+ rb_raise(rb_eRuntimeError, "%s", message.c_str());
1012
+ }
1013
+
1014
+ // ============================================================================
1015
+ // Section: Ruby-Callable Native Entry Helpers
1016
+ // ============================================================================
1017
+
1018
+ static VALUE graph_ir_to_onnx_json_from_source(VALUE ir_source, VALUE opset, VALUE model_name) {
1019
+ // Direct source->ONNX JSON entry used by Ruby API and tests.
1020
+ const bool timing_enabled = onnx_native_timing_enabled();
1021
+ const auto started_at = std::chrono::steady_clock::now();
1022
+ const auto opset_int = normalize_positive_integer(opset, "opset");
1023
+ const auto model_name_str = non_empty_model_name(model_name);
1024
+ const auto parse_started_at = std::chrono::steady_clock::now();
1025
+ const auto payload = parse_ir_source_payload(ir_source);
1026
+ const double parse_json_ms = elapsed_millis(parse_started_at);
1027
+
1028
+ const auto lower_started_at = std::chrono::steady_clock::now();
1029
+ const auto onnx_payload =
1030
+ mlx::onnx::ir_to_onnx_json_payload(payload, opset_int, model_name_str);
1031
+ const double lower_onnx_ms = elapsed_millis(lower_started_at);
1032
+ const auto dump_started_at = std::chrono::steady_clock::now();
1033
+ const auto content = onnx_payload.dump();
1034
+ const double dump_json_ms = elapsed_millis(dump_started_at);
1035
+ if (timing_enabled) {
1036
+ emit_graph_ir_to_onnx_json_timing_line(
1037
+ opset_int,
1038
+ model_name_str,
1039
+ parse_json_ms,
1040
+ lower_onnx_ms,
1041
+ dump_json_ms,
1042
+ elapsed_millis(started_at),
1043
+ content.size());
1044
+ }
1045
+ return ruby_string_from_std(content);
1046
+ }
1047
+
1048
+ static VALUE graph_ir_compatibility_report_json_from_source(VALUE ir_source) {
1049
+ const auto payload = parse_ir_source_payload(ir_source);
1050
+ const auto report = mlx::onnx::ir_compatibility_report_payload(payload);
1051
+ return ruby_string_from_std(report.dump());
1052
+ }
1053
+
1054
+ static GraphIrExportInvocation parse_ir_export_invocation_from_structured_args(
1055
+ VALUE fun,
1056
+ VALUE args_array,
1057
+ VALUE kwargs_hash,
1058
+ VALUE shapeless,
1059
+ const char* method_name) {
1060
+ // Structured parser used by public singleton methods that pass args/kwargs
1061
+ // explicitly instead of variadic flattening.
1062
+ if (!RB_TYPE_P(args_array, T_ARRAY)) {
1063
+ std::ostringstream out;
1064
+ out << method_name << " args_array must be an Array";
1065
+ throw std::invalid_argument(out.str());
1066
+ }
1067
+ if (!(NIL_P(kwargs_hash) || RB_TYPE_P(kwargs_hash, T_HASH))) {
1068
+ std::ostringstream out;
1069
+ out << method_name << " kwargs_hash must be a Hash or nil";
1070
+ throw std::invalid_argument(out.str());
1071
+ }
1072
+ if (!(shapeless == Qtrue || shapeless == Qfalse)) {
1073
+ std::ostringstream out;
1074
+ out << method_name << " shapeless must be true or false";
1075
+ throw std::invalid_argument(out.str());
1076
+ }
1077
+
1078
+ mx::Args args;
1079
+ const long args_len = RARRAY_LEN(args_array);
1080
+ args.reserve(static_cast<size_t>(args_len));
1081
+ for (long i = 0; i < args_len; ++i) {
1082
+ args.push_back(onnx_array_from_ruby(rb_ary_entry(args_array, i)));
1083
+ }
1084
+
1085
+ mx::Kwargs kwargs = NIL_P(kwargs_hash) ? mx::Kwargs{} : onnx_array_map_from_ruby_hash(kwargs_hash);
1086
+ if (args.empty() && kwargs.empty()) {
1087
+ std::ostringstream out;
1088
+ out << "[" << method_name << "] Inputs must include at least one positional or keyword array";
1089
+ throw std::invalid_argument(out.str());
1090
+ }
1091
+
1092
+ GraphIrExportInvocation invocation;
1093
+ invocation.fun = fun;
1094
+ invocation.args = std::move(args);
1095
+ invocation.kwargs = std::move(kwargs);
1096
+ invocation.shapeless = RTEST(shapeless);
1097
+ return invocation;
1098
+ }
1099
+
1100
+ struct ParsedExportPayload {
1101
+ GraphIrExportInvocation invocation;
1102
+ OrderedJson payload;
1103
+ };
1104
+
1105
+ static ParsedExportPayload parse_and_export_payload_from_structured_args(
1106
+ VALUE fun,
1107
+ VALUE args_array,
1108
+ VALUE kwargs_hash,
1109
+ VALUE shapeless,
1110
+ const char* method_name,
1111
+ GraphIrExportTimingStats* timing_stats = nullptr) {
1112
+ ParsedExportPayload out;
1113
+ out.invocation = parse_ir_export_invocation_from_structured_args(
1114
+ fun,
1115
+ args_array,
1116
+ kwargs_hash,
1117
+ shapeless,
1118
+ method_name);
1119
+ out.payload = export_ir_payload(out.invocation, timing_stats);
1120
+ return out;
1121
+ }
1122
+
1123
+ static OrderedJson build_onnx_stub_payload(
1124
+ const OrderedJson& payload,
1125
+ int64_t opset,
1126
+ const std::string& model_name) {
1127
+ return mlx::onnx::ir_to_onnx_json_payload(
1128
+ payload, opset, model_name);
1129
+ }
1130
+
1131
+ static std::string write_onnx_binary_from_payload(
1132
+ const std::string& target,
1133
+ const OrderedJson& payload,
1134
+ int64_t opset,
1135
+ const std::string& model_name,
1136
+ const mlx::onnx::OnnxBinaryWriteOptions& options) {
1137
+ const auto onnx_payload = build_onnx_stub_payload(payload, opset, model_name);
1138
+ const auto artifact =
1139
+ mlx::onnx::build_onnx_binary_artifact_from_stub(
1140
+ onnx_payload, options);
1141
+ return mlx::onnx::write_onnx_binary_artifact_to_path(
1142
+ target, artifact, options);
1143
+ }
1144
+
1145
+ // ============================================================================
1146
+ // Section: Ruby Singleton Method Entry Points
1147
+ // ============================================================================
1148
+
1149
+ static VALUE onnx_native_export_graph_ir(
1150
+ VALUE,
1151
+ VALUE fun,
1152
+ VALUE args_array,
1153
+ VALUE kwargs_hash,
1154
+ VALUE shapeless) {
1155
+ // Entry points are thin wrappers so exception translation remains uniform.
1156
+ try {
1157
+ auto exported = parse_and_export_payload_from_structured_args(
1158
+ fun,
1159
+ args_array,
1160
+ kwargs_hash,
1161
+ shapeless,
1162
+ "native_export_ir");
1163
+ return ruby_value_from_ordered_json(exported.payload);
1164
+ } catch (const std::exception& error) {
1165
+ raise_onnx_native_exception(error);
1166
+ return Qnil;
1167
+ }
1168
+ }
1169
+
1170
+ static VALUE onnx_native_export_graph_ir_json(
1171
+ VALUE,
1172
+ VALUE fun,
1173
+ VALUE args_array,
1174
+ VALUE kwargs_hash,
1175
+ VALUE shapeless) {
1176
+ try {
1177
+ auto exported = parse_and_export_payload_from_structured_args(
1178
+ fun,
1179
+ args_array,
1180
+ kwargs_hash,
1181
+ shapeless,
1182
+ "native_export_ir_json");
1183
+ return ruby_string_from_std(exported.payload.dump());
1184
+ } catch (const std::exception& error) {
1185
+ raise_onnx_native_exception(error);
1186
+ return Qnil;
1187
+ }
1188
+ }
1189
+
1190
+ static VALUE onnx_native_graph_ir_to_onnx_json(VALUE, VALUE ir_source, VALUE opset, VALUE model_name) {
1191
+ try {
1192
+ return graph_ir_to_onnx_json_from_source(ir_source, opset, model_name);
1193
+ } catch (const std::exception& error) {
1194
+ raise_onnx_native_exception(error);
1195
+ return Qnil;
1196
+ }
1197
+ }
1198
+
1199
+ static VALUE onnx_native_export_onnx_json(
1200
+ VALUE,
1201
+ VALUE fun,
1202
+ VALUE args_array,
1203
+ VALUE kwargs_hash,
1204
+ VALUE shapeless,
1205
+ VALUE opset,
1206
+ VALUE model_name) {
1207
+ try {
1208
+ const bool timing_enabled = onnx_native_timing_enabled();
1209
+ const auto started_at = std::chrono::steady_clock::now();
1210
+ const auto decode_started_at = std::chrono::steady_clock::now();
1211
+ auto invocation = parse_ir_export_invocation_from_structured_args(
1212
+ fun,
1213
+ args_array,
1214
+ kwargs_hash,
1215
+ shapeless,
1216
+ "native_export_onnx_json");
1217
+ const double args_decode_ms = elapsed_millis(decode_started_at);
1218
+
1219
+ const auto opset_int = normalize_positive_integer(opset, "opset");
1220
+ const auto model_name_str = non_empty_model_name(model_name);
1221
+ GraphIrExportTimingStats export_stats;
1222
+ const auto export_started_at = std::chrono::steady_clock::now();
1223
+ const auto payload = export_ir_payload(
1224
+ invocation, timing_enabled ? &export_stats : nullptr);
1225
+ const double export_ir_ms = elapsed_millis(export_started_at);
1226
+ const auto lower_started_at = std::chrono::steady_clock::now();
1227
+ const auto onnx_payload = build_onnx_stub_payload(
1228
+ payload, opset_int, model_name_str);
1229
+ const double lower_onnx_ms = elapsed_millis(lower_started_at);
1230
+ const auto dump_started_at = std::chrono::steady_clock::now();
1231
+ const auto content = onnx_payload.dump();
1232
+ const double dump_json_ms = elapsed_millis(dump_started_at);
1233
+
1234
+ if (timing_enabled) {
1235
+ emit_export_onnx_json_timing_line(
1236
+ invocation,
1237
+ opset_int,
1238
+ model_name_str,
1239
+ export_stats,
1240
+ args_decode_ms,
1241
+ export_ir_ms,
1242
+ lower_onnx_ms,
1243
+ dump_json_ms,
1244
+ elapsed_millis(started_at),
1245
+ content.size());
1246
+ }
1247
+
1248
+ return ruby_string_from_std(content);
1249
+ } catch (const std::exception& error) {
1250
+ raise_onnx_native_exception(error);
1251
+ return Qnil;
1252
+ }
1253
+ }
1254
+
1255
+ static VALUE onnx_native_export_onnx_compatibility_report(
1256
+ VALUE,
1257
+ VALUE fun,
1258
+ VALUE args_array,
1259
+ VALUE kwargs_hash,
1260
+ VALUE shapeless) {
1261
+ try {
1262
+ auto exported = parse_and_export_payload_from_structured_args(
1263
+ fun,
1264
+ args_array,
1265
+ kwargs_hash,
1266
+ shapeless,
1267
+ "native_export_onnx_compatibility_report");
1268
+ const auto report =
1269
+ mlx::onnx::ir_compatibility_report_payload(
1270
+ exported.payload);
1271
+ return ruby_value_from_ordered_json(report);
1272
+ } catch (const std::exception& error) {
1273
+ raise_onnx_native_exception(error);
1274
+ return Qnil;
1275
+ }
1276
+ }
1277
+
1278
+ static VALUE onnx_native_export_onnx(
1279
+ VALUE,
1280
+ VALUE target_path,
1281
+ VALUE fun,
1282
+ VALUE args_array,
1283
+ VALUE kwargs_hash,
1284
+ VALUE shapeless,
1285
+ VALUE opset,
1286
+ VALUE model_name,
1287
+ VALUE external_data,
1288
+ VALUE external_data_file,
1289
+ VALUE external_data_size_threshold) {
1290
+ try {
1291
+ const auto target = ruby_path_string(target_path, "export_onnx");
1292
+ const auto options = normalize_onnx_binary_write_options(
1293
+ target,
1294
+ external_data,
1295
+ external_data_file,
1296
+ external_data_size_threshold);
1297
+ auto exported = parse_and_export_payload_from_structured_args(
1298
+ fun,
1299
+ args_array,
1300
+ kwargs_hash,
1301
+ shapeless,
1302
+ "native_export_onnx");
1303
+ const auto opset_int = normalize_positive_integer(opset, "opset");
1304
+ const auto model_name_str = non_empty_model_name(model_name);
1305
+
1306
+ return ruby_string_from_std(write_onnx_binary_from_payload(
1307
+ target, exported.payload, opset_int, model_name_str, options));
1308
+ } catch (const std::exception& error) {
1309
+ raise_onnx_native_exception(error);
1310
+ return Qnil;
1311
+ }
1312
+ }
1313
+
1314
+ static VALUE onnx_native_graph_ir_to_onnx(
1315
+ VALUE,
1316
+ VALUE target_path,
1317
+ VALUE ir_source,
1318
+ VALUE opset,
1319
+ VALUE model_name,
1320
+ VALUE external_data,
1321
+ VALUE external_data_file,
1322
+ VALUE external_data_size_threshold) {
1323
+ try {
1324
+ const auto target = ruby_path_string(target_path, "graph_ir_to_onnx");
1325
+ const auto options = normalize_onnx_binary_write_options(
1326
+ target,
1327
+ external_data,
1328
+ external_data_file,
1329
+ external_data_size_threshold);
1330
+ const auto opset_int = normalize_positive_integer(opset, "opset");
1331
+ const auto model_name_str = non_empty_model_name(model_name);
1332
+ const auto payload = parse_ir_source_payload(ir_source);
1333
+ return ruby_string_from_std(write_onnx_binary_from_payload(
1334
+ target, payload, opset_int, model_name_str, options));
1335
+ } catch (const std::exception& error) {
1336
+ raise_onnx_native_exception(error);
1337
+ return Qnil;
1338
+ }
1339
+ }
1340
+
1341
+ static VALUE onnx_native_graph_ir_compatibility_report_json(VALUE, VALUE ir_source) {
1342
+ try {
1343
+ return graph_ir_compatibility_report_json_from_source(ir_source);
1344
+ } catch (const std::exception& error) {
1345
+ raise_onnx_native_exception(error);
1346
+ return Qnil;
1347
+ }
1348
+ }
1349
+
1350
+ // ============================================================================
1351
+ // Section: Ruby Method Binding Registration
1352
+ // ============================================================================
1353
+
1354
+ } // namespace
1355
+
1356
+ extern "C" void init_onnx_native_bindings(VALUE mMLX) {
1357
+ mONNX = rb_define_module_under(mMLX, "ONNX");
1358
+ mONNXNative = rb_define_module_under(mONNX, "Native");
1359
+ eOnnxNativeUnsupportedError =
1360
+ rb_define_class_under(mONNXNative, "UnsupportedError", rb_eRuntimeError);
1361
+
1362
+ rb_define_singleton_method(
1363
+ mONNXNative,
1364
+ "export_graph_ir",
1365
+ RUBY_METHOD_FUNC(onnx_native_export_graph_ir),
1366
+ 4);
1367
+ rb_define_singleton_method(
1368
+ mONNXNative,
1369
+ "export_graph_ir_json",
1370
+ RUBY_METHOD_FUNC(onnx_native_export_graph_ir_json),
1371
+ 4);
1372
+ rb_define_singleton_method(
1373
+ mONNXNative,
1374
+ "graph_ir_to_onnx_json",
1375
+ RUBY_METHOD_FUNC(onnx_native_graph_ir_to_onnx_json),
1376
+ 3);
1377
+ rb_define_singleton_method(
1378
+ mONNXNative,
1379
+ "graph_ir_to_onnx",
1380
+ RUBY_METHOD_FUNC(onnx_native_graph_ir_to_onnx),
1381
+ 7);
1382
+ rb_define_singleton_method(
1383
+ mONNXNative,
1384
+ "export_onnx_json",
1385
+ RUBY_METHOD_FUNC(onnx_native_export_onnx_json),
1386
+ 6);
1387
+ rb_define_singleton_method(
1388
+ mONNXNative,
1389
+ "export_onnx_compatibility_report",
1390
+ RUBY_METHOD_FUNC(onnx_native_export_onnx_compatibility_report),
1391
+ 4);
1392
+ rb_define_singleton_method(
1393
+ mONNXNative,
1394
+ "export_onnx",
1395
+ RUBY_METHOD_FUNC(onnx_native_export_onnx),
1396
+ 10);
1397
+ rb_define_singleton_method(
1398
+ mONNXNative,
1399
+ "ir_compatibility_report_json",
1400
+ RUBY_METHOD_FUNC(onnx_native_graph_ir_compatibility_report_json),
1401
+ 1);
1402
+ }