mlx 0.30.7.2 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (605) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +112 -58
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/distributed_utils/common.rb +1 -1
  8. data/lib/mlx/distributed_utils/config.rb +7 -4
  9. data/lib/mlx/distributed_utils/launch.rb +2 -0
  10. data/lib/mlx/dsl/attention.rb +132 -0
  11. data/lib/mlx/dsl/builder.rb +8 -0
  12. data/lib/mlx/dsl/config_schema.rb +133 -0
  13. data/lib/mlx/dsl/generate.rb +193 -0
  14. data/lib/mlx/dsl/kv_cache.rb +96 -0
  15. data/lib/mlx/dsl/masks.rb +32 -0
  16. data/lib/mlx/dsl/positions.rb +35 -0
  17. data/lib/mlx/dsl/run_stack.rb +68 -0
  18. data/lib/mlx/dsl/tensor.rb +126 -0
  19. data/lib/mlx/dsl/transformer_block.rb +113 -0
  20. data/lib/mlx/dsl/weight_map.rb +140 -0
  21. data/lib/mlx/dsl.rb +10 -0
  22. data/lib/mlx/nn/base.rb +4 -0
  23. data/lib/mlx/nn/layers/linear.rb +2 -3
  24. data/lib/mlx/onnx.rb +250 -0
  25. data/lib/mlx/version.rb +1 -1
  26. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  28. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  29. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  30. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  31. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  32. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  33. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  34. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  35. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  36. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  37. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  38. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  39. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  40. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  41. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  42. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  43. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  44. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  45. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  46. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  47. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  48. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  49. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  50. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  51. data/submodules/mlx-onnx/LICENSE +21 -0
  52. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  53. data/submodules/mlx-onnx/src/api.cpp +81 -0
  54. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  55. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  56. data/submodules/mlx-onnx/src/export.cpp +653 -0
  57. data/submodules/mlx-onnx/src/io.cpp +61 -0
  58. data/submodules/mlx-onnx/src/json.hpp +25 -0
  59. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  60. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  61. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  62. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  63. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  64. metadata +665 -567
  65. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  66. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  67. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  68. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  69. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  70. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  71. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  590. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  591. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  592. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  593. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  594. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  595. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  596. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  597. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  598. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  599. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  600. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  601. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  602. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  603. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  604. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  605. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -0,0 +1,62 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ // Compute padded dimensions for tiled layout
11
+ // Tiles are 128 rows × 4 columns, must allocate full tiles
12
+ inline std::pair<int, int> get_padded_scale_dims(int num_rows, int num_cols) {
13
+ constexpr int rows_per_tile = 128;
14
+ constexpr int cols_per_tile = 4;
15
+
16
+ int padded_rows =
17
+ ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile;
18
+ int padded_cols =
19
+ ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile;
20
+
21
+ return {padded_rows, padded_cols};
22
+ }
23
+
24
+ void swizzle_scales(
25
+ const array& scales,
26
+ array& scales_tiled,
27
+ cu::CommandEncoder& enc,
28
+ const Stream& s);
29
+
30
+ inline array pad_and_swizzle_scales(
31
+ const array& scale,
32
+ cu::CommandEncoder& encoder,
33
+ const Stream& s) {
34
+ // Compute padded dimensions for full tiles (128 rows × 4 cols)
35
+ auto [pad_outer, pad_inner] =
36
+ get_padded_scale_dims(scale.shape(-2), scale.shape(-1));
37
+ // cuBLAS requirements for scale factor layout:
38
+ // 1. Dimensions must be padded to full tiles (128 rows × 4 cols)
39
+ // 2. Out-of-bounds values must be filled with zeros
40
+ // 3. Starting addresses must be 16-byte aligned
41
+ // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
42
+ // Note: cu::malloc_async already provides 256-byte alignment
43
+ array scale_tiled(
44
+ cu::malloc_async(pad_outer * pad_inner, encoder),
45
+ Shape{pad_outer, pad_inner},
46
+ scale.dtype());
47
+ swizzle_scales(scale, scale_tiled, encoder, s);
48
+
49
+ encoder.add_temporary(scale_tiled);
50
+ return scale_tiled;
51
+ }
52
+
53
+ // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2
54
+ // Allocate beta zero on device as well
55
+ void compute_qqmm_pointers(
56
+ array& alpha_out,
57
+ array& beta_out,
58
+ const array& tensor_amax_x,
59
+ const array& tensor_amax_w,
60
+ cu::CommandEncoder& enc);
61
+
62
+ } // namespace mlx::core
@@ -51,7 +51,6 @@ void fast::Quantize::eval_gpu(
51
51
  auto& s = stream();
52
52
  auto& d = cu::device(s.device);
53
53
  auto& enc = d.get_command_encoder(s);
54
-
55
54
  if (dequantize_) {
56
55
  auto wq = ensure_row_contiguous(inputs[0], enc, s);
57
56
  auto scales = ensure_row_contiguous(inputs[1], enc, s);
@@ -63,7 +62,12 @@ void fast::Quantize::eval_gpu(
63
62
  auto biases = ensure_row_contiguous(inputs[2], enc, s);
64
63
  affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
65
64
  } else {
66
- fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
65
+ // 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4
66
+ bool use_global_scale =
67
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2;
68
+ std::optional<array> global_scale =
69
+ use_global_scale ? std::make_optional(inputs[2]) : std::nullopt;
70
+ fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s);
67
71
  }
68
72
  } else {
69
73
  auto w = ensure_contiguous(inputs[0], enc, s);
@@ -72,12 +76,17 @@ void fast::Quantize::eval_gpu(
72
76
 
73
77
  wq.set_data(cu::malloc_async(wq.nbytes(), enc));
74
78
  scales.set_data(cu::malloc_async(scales.nbytes(), enc));
79
+
75
80
  if (mode_ == QuantizationMode::Affine) {
76
81
  auto& biases = outputs[2];
77
82
  biases.set_data(cu::malloc_async(biases.nbytes(), enc));
78
83
  affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
79
84
  } else {
80
- fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
85
+ bool use_global_scale =
86
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1;
87
+ std::optional<array> global_scale =
88
+ use_global_scale ? std::make_optional(inputs[1]) : std::nullopt;
89
+ fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s);
81
90
  }
82
91
  }
83
92
  }
@@ -1,5 +1,6 @@
1
1
  // Copyright © 2025 Apple Inc.
2
2
 
3
+ #include <optional>
3
4
  #include "mlx/backend/cuda/device.h"
4
5
 
5
6
  namespace mlx::core {
@@ -30,6 +31,7 @@ void fp_quantize(
30
31
  array& scales,
31
32
  int group_size,
32
33
  int bits,
34
+ const std::optional<array>& global_scale,
33
35
  cu::CommandEncoder& enc,
34
36
  const Stream& s);
35
37
 
@@ -39,6 +41,7 @@ void fp_dequantize(
39
41
  array& w,
40
42
  int group_size,
41
43
  int bits,
44
+ const std::optional<array>& global_scale,
42
45
  cu::CommandEncoder& enc,
43
46
  const Stream& s);
44
47
 
@@ -47,6 +50,7 @@ void fp_quantize_dequantize(
47
50
  array& what,
48
51
  int group_size,
49
52
  int bits,
53
+ const std::optional<array>& global_scale,
50
54
  cu::CommandEncoder& enc,
51
55
  const Stream& s);
52
56
 
@@ -29,7 +29,7 @@ inline constexpr __device__ short get_bytes_per_pack() {
29
29
  }
30
30
 
31
31
  template <typename T>
32
- __device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
32
+ __device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) {
33
33
  if constexpr (
34
34
  (std::is_same<T, __nv_bfloat162>::value) ||
35
35
  (std::is_same<T, __half2>::value)) {
@@ -247,6 +247,10 @@ void CommandEncoder::set_buffer(
247
247
  const MTL::Buffer* buf,
248
248
  int idx,
249
249
  int64_t offset /* = 0 */) {
250
+ // Record as both input and output to ensure synchronization between command
251
+ // buffers
252
+ all_inputs_.insert((void*)buf);
253
+ all_outputs_.insert((void*)buf);
250
254
  enc_->setBuffer(buf, offset, idx);
251
255
  }
252
256
 
@@ -30,7 +30,7 @@ template <typename T, int N>
30
30
  out_pixels *= params->oS[i];
31
31
 
32
32
  // Set out
33
- out += gid.z * filter_size + gid.y * (params->C);
33
+ out += (size_t)gid.z * filter_size + (size_t)gid.y * (params->C);
34
34
 
35
35
  // Coordinates in input
36
36
  int is[N] = {0};
@@ -93,7 +93,8 @@ template <typename T, int N>
93
93
  out_pixels *= params->oS[i];
94
94
 
95
95
  // Set out
96
- out += gid.z * filter_size + gid.x * (filter_size / params->C);
96
+ out +=
97
+ (size_t)gid.z * filter_size + (size_t)gid.x * (filter_size / params->C);
97
98
 
98
99
  // Coordinates in input
99
100
  int is[N] = {0};
@@ -279,6 +279,8 @@ void extract_state(const T state, std::vector<StateT>& unpacked_state) {
279
279
  unpacked_state.push_back(state);
280
280
  } else if constexpr (std::is_enum_v<T>) {
281
281
  unpacked_state.push_back(static_cast<int>(state));
282
+ } else if constexpr (std::is_same_v<T, Dtype>) {
283
+ unpacked_state.push_back(state);
282
284
  } else if constexpr (is_iterable<T>) {
283
285
  unpacked_state.push_back(state);
284
286
  } else if constexpr (is_pair<T> || is_tuple<T>) {
@@ -446,6 +448,7 @@ struct PrimitiveFactory {
446
448
  SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
447
449
  SERIALIZE_PRIMITIVE(CustomKernel)};
448
450
  std::unordered_map<std::string, std::string> name_remap;
451
+ std::unordered_map<int, Stream> stream_map;
449
452
 
450
453
  PrimitiveFactory() {
451
454
  for (auto& [n, f] : factory) {
@@ -471,13 +474,25 @@ struct PrimitiveFactory {
471
474
  }
472
475
  };
473
476
 
474
- std::shared_ptr<Primitive> load(Reader& is) {
475
- auto stream = deserialize<Stream>(is);
476
- if (get_stream(stream.index) != stream) {
477
- std::ostringstream msg;
478
- msg << "[import_function] Invalid stream encountered " << stream << ".";
479
- throw std::invalid_argument(msg.str());
477
+ Stream resolve_stream(const Stream& stream) {
478
+ if (auto it = stream_map.find(stream.index); it != stream_map.end()) {
479
+ return it->second;
480
+ }
481
+ // Try to find an existing stream on the same device
482
+ for (auto& s : get_streams()) {
483
+ if (s.device == stream.device) {
484
+ stream_map.emplace(stream.index, s);
485
+ return s;
486
+ }
480
487
  }
488
+ // No stream on that device, make a new one
489
+ Stream s = new_stream(stream.device);
490
+ stream_map.emplace(stream.index, s);
491
+ return s;
492
+ }
493
+
494
+ std::shared_ptr<Primitive> load(Reader& is) {
495
+ auto stream = resolve_stream(deserialize<Stream>(is));
481
496
  auto name = deserialize<std::string>(is);
482
497
  if (auto it = factory.find(name); it != factory.end()) {
483
498
  return it->second.deserialize(is, stream);
@@ -10,6 +10,7 @@
10
10
  #include <sstream>
11
11
 
12
12
  #include "mlx/backend/cuda/cuda.h"
13
+ #include "mlx/backend/metal/metal.h"
13
14
  #include "mlx/fast_primitives.h"
14
15
  #include "mlx/ops.h"
15
16
  #include "mlx/primitives.h"
@@ -2311,6 +2312,40 @@ array argmax(
2311
2312
  return out;
2312
2313
  }
2313
2314
 
2315
+ array hanning(int M, StreamOrDevice s /* = {} */) {
2316
+ if (M < 1) {
2317
+ return array({});
2318
+ }
2319
+ if (M == 1) {
2320
+ return ones({1}, float32, s);
2321
+ }
2322
+
2323
+ auto n = arange(0, M, float32, s);
2324
+ array factor(M_PI / (M - 1), float32);
2325
+ return square(sin(multiply(factor, n, s), s), s);
2326
+ }
2327
+
2328
+ array hamming(int M, StreamOrDevice s /* = {} */) {
2329
+ if (M < 1) {
2330
+ return array({});
2331
+ }
2332
+ if (M == 1) {
2333
+ return ones({1}, float32, s);
2334
+ }
2335
+
2336
+ auto n = arange(0, M, float32, s);
2337
+ float factor_val = (2.0 * M_PI) / (M - 1);
2338
+ auto factor = array(factor_val, float32);
2339
+
2340
+ auto arg = multiply(factor, n, s);
2341
+ auto cos_vals = cos(arg, s);
2342
+
2343
+ auto left_coef = array(0.54f, float32);
2344
+ auto right_coef = array(0.46f, float32);
2345
+
2346
+ return subtract(left_coef, multiply(right_coef, cos_vals, s), s);
2347
+ }
2348
+
2314
2349
  /** Returns a sorted copy of the flattened array. */
2315
2350
  array sort(const array& a, StreamOrDevice s /* = {} */) {
2316
2351
  int size = a.size();
@@ -4209,6 +4244,34 @@ std::pair<Dtype, QuantizationMode> validate_mode_with_type(
4209
4244
  }
4210
4245
  }
4211
4246
 
4247
+ void validate_global_scale(
4248
+ std::string_view tag,
4249
+ QuantizationMode qmode,
4250
+ const std::optional<array>& global_scale) {
4251
+ if (global_scale.has_value()) {
4252
+ if (qmode != QuantizationMode::Nvfp4) {
4253
+ std::ostringstream msg;
4254
+ msg << "[" << tag << "] Global scale is only supported for 'nvfp4' "
4255
+ << "quantization mode.";
4256
+ throw std::invalid_argument(msg.str());
4257
+ } else {
4258
+ if (global_scale->size() != 1) {
4259
+ std::ostringstream msg;
4260
+ msg << "[" << tag << "] Global scale must be a scalar but got shape "
4261
+ << global_scale->shape() << ".";
4262
+ throw std::invalid_argument(msg.str());
4263
+ }
4264
+ // TODO: not sure if type should be restricted to float32
4265
+ if (global_scale->dtype() != float32) {
4266
+ std::ostringstream msg;
4267
+ msg << "[" << tag << "] Global scale must have dtype float32 but got "
4268
+ << global_scale->dtype() << ".";
4269
+ throw std::invalid_argument(msg.str());
4270
+ }
4271
+ }
4272
+ }
4273
+ }
4274
+
4212
4275
  array quantized_matmul(
4213
4276
  array x,
4214
4277
  array w,
@@ -4251,7 +4314,6 @@ array quantized_matmul(
4251
4314
  if (x.ndim() > 2 && w.ndim() > 2) {
4252
4315
  inputs = broadcast_arrays(inputs, {-2, -1}, s);
4253
4316
  }
4254
-
4255
4317
  auto out_shape = inputs[0].shape();
4256
4318
  out_shape.back() = w_outer_dims;
4257
4319
  return array(
@@ -4267,7 +4329,10 @@ void validate_qqmm_inputs(
4267
4329
  array w,
4268
4330
  std::optional<array> scales_w,
4269
4331
  int group_size,
4270
- int bits) {
4332
+ int bits,
4333
+ std::optional<array> global_scale_x,
4334
+ std::optional<array> global_scale_w,
4335
+ QuantizationMode qmode) {
4271
4336
  // check 2D (for now)
4272
4337
  if (x.ndim() > 2 || w.ndim() > 2) {
4273
4338
  std::ostringstream msg;
@@ -4304,6 +4369,19 @@ void validate_qqmm_inputs(
4304
4369
  << "first argument dtype == " << x.dtype() << ".";
4305
4370
  throw std::invalid_argument(msg.str());
4306
4371
  }
4372
+ // validate global scales
4373
+ validate_global_scale("qqmm", qmode, global_scale_x);
4374
+ validate_global_scale("qqmm", qmode, global_scale_w);
4375
+ // For nvfp4 mode, both global scales must be provided together or neither
4376
+ if (qmode == QuantizationMode::Nvfp4) {
4377
+ bool has_x = global_scale_x.has_value();
4378
+ bool has_w = global_scale_w.has_value();
4379
+ if (has_x != has_w) {
4380
+ throw std::invalid_argument(
4381
+ "[qqmm] For nvfp4 mode, either both global_scale_x and "
4382
+ "global_scale_w must be provided, or neither.");
4383
+ }
4384
+ }
4307
4385
  }
4308
4386
 
4309
4387
  std::pair<int, int> extract_qqmm_dims(
@@ -4343,6 +4421,8 @@ array qqmm(
4343
4421
  std::optional<int> group_size_ /* = std::nullopt */,
4344
4422
  std::optional<int> bits_ /* = std::nullopt */,
4345
4423
  const std::string& mode /* = "nvfp4" */,
4424
+ const std::optional<array> global_scale_x /* = std::nullopt */,
4425
+ const std::optional<array> global_scale_w /* = std::nullopt */,
4346
4426
  StreamOrDevice s /* = {} */) {
4347
4427
  auto stream = to_stream(s);
4348
4428
  auto qmode = string_to_quantization_mode(mode, "qqmm");
@@ -4369,7 +4449,8 @@ array qqmm(
4369
4449
  }
4370
4450
 
4371
4451
  // validate inputs
4372
- validate_qqmm_inputs(x, w, scales_w, group_size, bits);
4452
+ validate_qqmm_inputs(
4453
+ x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);
4373
4454
  // validate and extract shapes
4374
4455
  auto [w_inner_dims, w_outer_dims] =
4375
4456
  extract_qqmm_dims(x, w, scales_w, group_size, bits);
@@ -4380,6 +4461,11 @@ array qqmm(
4380
4461
  if (scales_w.has_value()) {
4381
4462
  inputs.push_back(*scales_w);
4382
4463
  }
4464
+ if (global_scale_x.has_value() && global_scale_w.has_value()) {
4465
+ inputs.push_back(*global_scale_x);
4466
+ inputs.push_back(*global_scale_w);
4467
+ }
4468
+
4383
4469
  auto out_shape = inputs[0].shape();
4384
4470
  out_shape.back() = w_outer_dims;
4385
4471
  auto out = array(
@@ -4515,6 +4601,7 @@ std::vector<array> fp_quantize(
4515
4601
  int group_size,
4516
4602
  int bits,
4517
4603
  QuantizationMode mode,
4604
+ const std::optional<array>& global_scale /* = std::nullopt */,
4518
4605
  Stream s) {
4519
4606
  int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
4520
4607
  int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
@@ -4532,6 +4619,12 @@ std::vector<array> fp_quantize(
4532
4619
  << bits << ".";
4533
4620
  throw std::invalid_argument(msg.str());
4534
4621
  }
4622
+
4623
+ auto inputs = std::vector<array>{w};
4624
+ if (global_scale.has_value()) {
4625
+ inputs.push_back(global_scale.value());
4626
+ }
4627
+
4535
4628
  auto fallback = [bits = bits, group_size = group_size, s](
4536
4629
  const std::vector<array>& inputs) -> std::vector<array> {
4537
4630
  auto& w = inputs[0];
@@ -4543,8 +4636,13 @@ std::vector<array> fp_quantize(
4543
4636
  divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
4544
4637
  if (group_size == 16) {
4545
4638
  // convert to e4m3
4639
+ auto scale_encode = inputs.size() > 1
4640
+ ? divide(array(448.0f * 6.0f, float32), inputs[1], s)
4641
+ : array(1.0f, float32);
4642
+ scales = multiply(scales, scale_encode, s);
4546
4643
  scales = to_fp8(scales, s);
4547
- wq = divide(wq, from_fp8(scales, w.dtype(), s), s);
4644
+ wq = multiply(
4645
+ divide(wq, from_fp8(scales, w.dtype(), s), s), scale_encode, s);
4548
4646
  } else {
4549
4647
  // convert to e8m0
4550
4648
  auto z = array(0, scales.dtype());
@@ -4600,9 +4698,9 @@ std::vector<array> fp_quantize(
4600
4698
  {uint32, uint8},
4601
4699
  std::make_shared<fast::Quantize>(
4602
4700
  s, fallback, group_size, bits, mode, false),
4603
- {w});
4701
+ inputs);
4604
4702
  }
4605
- return fallback({w});
4703
+ return fallback(inputs);
4606
4704
  }
4607
4705
 
4608
4706
  std::vector<array> quantize(
@@ -4610,6 +4708,7 @@ std::vector<array> quantize(
4610
4708
  std::optional<int> group_size_ /* = std::nullopt */,
4611
4709
  std::optional<int> bits_ /* = std::nullopt */,
4612
4710
  const std::string& mode /* = "affine" */,
4711
+ const std::optional<array>& global_scale /* = std::nullopt */,
4613
4712
  StreamOrDevice s /* = {} */) {
4614
4713
  auto qmode = string_to_quantization_mode(mode, "quantize");
4615
4714
  auto [group_size, bits] =
@@ -4636,11 +4735,17 @@ std::vector<array> quantize(
4636
4735
  << " matrix has shape " << w.shape();
4637
4736
  throw std::invalid_argument(msg.str());
4638
4737
  }
4639
-
4738
+ if (to_stream(s).device == Device::gpu && metal::is_available() &&
4739
+ global_scale.has_value()) {
4740
+ std::ostringstream msg;
4741
+ msg << "[quantize] Global scale is not supported on the Metal backend.";
4742
+ throw std::invalid_argument(msg.str());
4743
+ }
4744
+ validate_global_scale("quantize", qmode, global_scale);
4640
4745
  if (qmode == QuantizationMode::Affine) {
4641
4746
  return affine_quantize(w, group_size, bits, s);
4642
4747
  } else {
4643
- return fp_quantize(w, group_size, bits, qmode, to_stream(s));
4748
+ return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s));
4644
4749
  }
4645
4750
  }
4646
4751
 
@@ -4745,6 +4850,7 @@ array fp_dequantize(
4745
4850
  int bits,
4746
4851
  Dtype out_type,
4747
4852
  QuantizationMode mode,
4853
+ const std::optional<array>& global_scale /* = std::nullopt */,
4748
4854
  Stream s) {
4749
4855
  int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
4750
4856
  int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
@@ -4789,6 +4895,11 @@ array fp_dequantize(
4789
4895
  throw std::invalid_argument(msg.str());
4790
4896
  }
4791
4897
 
4898
+ auto inputs = std::vector<array>{w, scales};
4899
+ if (global_scale.has_value()) {
4900
+ inputs.push_back(global_scale.value());
4901
+ }
4902
+
4792
4903
  auto fallback =
4793
4904
  [wshape = std::move(wshape),
4794
4905
  sshape = std::move(sshape),
@@ -4831,13 +4942,17 @@ array fp_dequantize(
4831
4942
  out = reshape(out, {-1, group_size}, s);
4832
4943
  scales = reshape(scales, {-1, 1}, s);
4833
4944
  if (group_size == 16) {
4834
- scales = from_fp8(scales, out_type, s);
4945
+ array inv_scale_enc = inputs.size() > 2
4946
+ ? divide(inputs[2], array(448.0f * 6.0f, out_type), s)
4947
+ : array(1.0f, out_type);
4948
+ scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s);
4835
4949
  } else {
4836
4950
  scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
4837
4951
  scales = power(array(2.0f, out_type), scales, s);
4838
4952
  }
4839
4953
  return {reshape(multiply(out, scales, s), wshape, s)};
4840
4954
  };
4955
+
4841
4956
  if (s.device == Device::gpu) {
4842
4957
  auto out_shape = w.shape();
4843
4958
  out_shape.back() = out_size;
@@ -4846,9 +4961,9 @@ array fp_dequantize(
4846
4961
  out_type,
4847
4962
  std::make_shared<fast::Quantize>(
4848
4963
  s, fallback, group_size, bits, mode, true),
4849
- {w, scales});
4964
+ inputs);
4850
4965
  }
4851
- return fallback({w, scales})[0];
4966
+ return fallback(inputs)[0];
4852
4967
  }
4853
4968
 
4854
4969
  array dequantize(
@@ -4858,6 +4973,7 @@ array dequantize(
4858
4973
  std::optional<int> group_size_ /* = std::nullopt */,
4859
4974
  std::optional<int> bits_ /* = std::nullopt */,
4860
4975
  const std::string& mode /* = "affine" */,
4976
+ const std::optional<array>& global_scale /* = std::nullopt */,
4861
4977
  std::optional<Dtype> dtype /* = std::nullopt */,
4862
4978
  StreamOrDevice s /* = {} */) {
4863
4979
  auto [out_type, qmode] =
@@ -4884,6 +5000,14 @@ array dequantize(
4884
5000
  << "but it has only " << w.ndim() << ".";
4885
5001
  throw std::invalid_argument(msg.str());
4886
5002
  }
5003
+ if (global_scale.has_value()) {
5004
+ if (to_stream(s).device == Device::gpu && metal::is_available()) {
5005
+ std::ostringstream msg;
5006
+ msg << "[dequantize] Global scale is not supported on the Metal backend.";
5007
+ throw std::invalid_argument(msg.str());
5008
+ }
5009
+ }
5010
+ validate_global_scale("dequantize", qmode, global_scale);
4887
5011
 
4888
5012
  if (qmode == QuantizationMode::Affine) {
4889
5013
  return astype(
@@ -4892,7 +5016,14 @@ array dequantize(
4892
5016
  s);
4893
5017
  } else {
4894
5018
  return fp_dequantize(
4895
- w, scales, group_size, bits, out_type, qmode, to_stream(s));
5019
+ w,
5020
+ scales,
5021
+ group_size,
5022
+ bits,
5023
+ out_type,
5024
+ qmode,
5025
+ global_scale,
5026
+ to_stream(s));
4896
5027
  }
4897
5028
  }
4898
5029
 
@@ -6091,4 +6222,4 @@ array contiguous(
6091
6222
  {a});
6092
6223
  }
6093
6224
 
6094
- } // namespace mlx::core
6225
+ } // namespace mlx::core
@@ -666,6 +666,12 @@ min(const array& a,
666
666
  MLX_API array
667
667
  min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});
668
668
 
669
+ /** Returns the Hanning window of size M. */
670
+ MLX_API array hanning(int M, StreamOrDevice s = {});
671
+
672
+ /** Returns the Hamming window of size M. */
673
+ MLX_API array hamming(int M, StreamOrDevice s = {});
674
+
669
675
  /** Returns the index of the minimum value in the array. */
670
676
  MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
671
677
  inline array argmin(const array& a, StreamOrDevice s = {}) {
@@ -1391,6 +1397,7 @@ MLX_API std::vector<array> quantize(
1391
1397
  std::optional<int> group_size = std::nullopt,
1392
1398
  std::optional<int> bits = std::nullopt,
1393
1399
  const std::string& mode = "affine",
1400
+ const std::optional<array>& global_scale = std::nullopt,
1394
1401
  StreamOrDevice s = {});
1395
1402
 
1396
1403
  /** Dequantize a matrix produced by quantize() */
@@ -1401,17 +1408,20 @@ MLX_API array dequantize(
1401
1408
  std::optional<int> group_size = std::nullopt,
1402
1409
  std::optional<int> bits = std::nullopt,
1403
1410
  const std::string& mode = "affine",
1411
+ const std::optional<array>& global_scale = std::nullopt,
1404
1412
  std::optional<Dtype> dtype = std::nullopt,
1405
1413
  StreamOrDevice s = {});
1406
1414
 
1407
1415
  MLX_API array qqmm(
1408
1416
  array x, // input activations
1409
1417
  array w, // maybe quantized weights
1410
- std::optional<array> w_scales = std::nullopt, // optional scales if w is
1411
- // quantized
1418
+ const std::optional<array> w_scales = std::nullopt, // optional scales if w
1419
+ // is quantized
1412
1420
  std::optional<int> group_size = std::nullopt,
1413
1421
  std::optional<int> bits = std::nullopt,
1414
1422
  const std::string& mode = "nvfp4",
1423
+ const std::optional<array> global_scale_x = std::nullopt,
1424
+ const std::optional<array> global_scale_w = std::nullopt,
1415
1425
  StreamOrDevice s = {});
1416
1426
 
1417
1427
  /** Convert an E4M3 float8 to the given floating point dtype. */
@@ -3424,6 +3424,7 @@ std::vector<array> QuantizedMatmul::vjp(
3424
3424
  group_size_,
3425
3425
  bits_,
3426
3426
  quantization_mode_to_string(mode_),
3427
+ {}, // placeholder for amax
3427
3428
  std::nullopt,
3428
3429
  stream());
3429
3430
  wq = unflatten(wq, -1, {-1, group_size_}, stream());
@@ -3484,14 +3485,14 @@ std::vector<Shape> QQMatmul::output_shapes(const std::vector<array>& inputs) {
3484
3485
  }
3485
3486
 
3486
3487
  std::vector<array> QQMatmul::vjp(
3487
- const std::vector<array>& primals, // non quantized x, non quantized w
3488
+ const std::vector<array>& primals, // non quantized x, non quantized w, if
3489
+ // nvfp4 global_scale_x, global_scale_w
3488
3490
  const std::vector<array>& cotangents, // non quantized upstream grads
3489
3491
  const std::vector<int>& argnums,
3490
3492
  const std::vector<array>&) {
3491
- if (primals.size() != 2) {
3492
- throw std::runtime_error(
3493
- "[QQMatmul::vjp] Expected exactly 2 non-quantized primal inputs (x, w).");
3494
- }
3493
+ bool is_nvfp4 = mode_ == QuantizationMode::Nvfp4;
3494
+ assert(primals.size() == 2 || (is_nvfp4 && primals.size() == 4));
3495
+
3495
3496
  std::vector<array> vjps;
3496
3497
  auto& cotan = cotangents[0];
3497
3498
  auto& s = stream();
@@ -3499,6 +3500,15 @@ std::vector<array> QQMatmul::vjp(
3499
3500
  // primal[0] -- non quantized activations (M, K)
3500
3501
  // cotan -- non quantized grads (M, N)
3501
3502
  auto qmode = quantization_mode_to_string(mode_);
3503
+ std::optional<array> cotan_amax = (primals.size() == 4)
3504
+ ? std::make_optional(astype(max(abs(cotan, s), s), float32, s))
3505
+ : std::nullopt;
3506
+
3507
+ auto get_primal_scale = [&](int idx) {
3508
+ return (primals.size() == 4) ? std::make_optional(primals[idx])
3509
+ : std::nullopt;
3510
+ };
3511
+
3502
3512
  for (auto arg : argnums) {
3503
3513
  if (arg == 0) { // gradient wrt to x
3504
3514
  // We transpose weights -> quantize along N
@@ -3509,6 +3519,8 @@ std::vector<array> QQMatmul::vjp(
3509
3519
  group_size_,
3510
3520
  bits_,
3511
3521
  qmode,
3522
+ cotan_amax,
3523
+ get_primal_scale(3), // global_scale_w (for w.T)
3512
3524
  s));
3513
3525
  } else if (arg == 1) { // gradient wrt to weights
3514
3526
  vjps.push_back(qqmm(
@@ -3518,7 +3530,11 @@ std::vector<array> QQMatmul::vjp(
3518
3530
  group_size_,
3519
3531
  bits_,
3520
3532
  qmode,
3533
+ cotan_amax,
3534
+ get_primal_scale(2), // global_scale_x (for x.T)
3521
3535
  s));
3536
+ } else {
3537
+ vjps.push_back(zeros_like(primals[arg], s));
3522
3538
  }
3523
3539
  }
3524
3540
  return vjps;
@@ -3643,6 +3659,7 @@ std::vector<array> GatherQMM::vjp(
3643
3659
  bits_,
3644
3660
  quantization_mode_to_string(mode_),
3645
3661
  std::nullopt,
3662
+ std::nullopt, // amax placeholder
3646
3663
  stream()),
3647
3664
  -1,
3648
3665
  {-1, group_size_},
@@ -26,6 +26,10 @@ Stream get_stream(int index) {
26
26
  return scheduler::scheduler().get_stream(index);
27
27
  }
28
28
 
29
+ std::vector<Stream> get_streams() {
30
+ return scheduler::scheduler().get_streams();
31
+ }
32
+
29
33
  Stream new_stream(Device d) {
30
34
  if (!gpu::is_available() && d == Device::gpu) {
31
35
  throw std::invalid_argument(
@@ -99,6 +99,9 @@ class Scheduler {
99
99
  Stream get_stream(int index) const {
100
100
  return streams_.at(index);
101
101
  }
102
+ std::vector<Stream> get_streams() const {
103
+ return streams_;
104
+ }
102
105
 
103
106
  void set_default_stream(const Stream& s) {
104
107
  default_streams_.at(s.device.type) = s;