mlx 0.30.7.3 → 0.30.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/extconf.rb +267 -8
- data/ext/mlx/native.cpp +104 -56
- data/ext/mlx-onnx/native.cpp +1402 -0
- data/ext/mlx-onnx/native.hpp +19 -0
- data/lib/mlx/core.rb +342 -117
- data/lib/mlx/nn/base.rb +4 -0
- data/lib/mlx/nn/layers/linear.rb +2 -3
- data/lib/mlx/onnx.rb +250 -0
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx-onnx/webgpu_harness.rb +289 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
- data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
- data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
- data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
- data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
- data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
- data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
- data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
- data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
- data/submodules/mlx-onnx/CMakeLists.txt +159 -0
- data/submodules/mlx-onnx/LICENSE +21 -0
- data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
- data/submodules/mlx-onnx/src/api.cpp +81 -0
- data/submodules/mlx-onnx/src/compat.cpp +111 -0
- data/submodules/mlx-onnx/src/detail.hpp +69 -0
- data/submodules/mlx-onnx/src/export.cpp +653 -0
- data/submodules/mlx-onnx/src/io.cpp +61 -0
- data/submodules/mlx-onnx/src/json.hpp +25 -0
- data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
- data/submodules/mlx-onnx/src/mappings.cpp +201 -0
- data/submodules/mlx-onnx/src/mappings.hpp +16 -0
- data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
- data/submodules/mlx-onnx/src/shared.cpp +206 -0
- metadata +609 -563
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
- /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
- /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
|
@@ -13,39 +13,26 @@ namespace mlx::core {
|
|
|
13
13
|
|
|
14
14
|
namespace {
|
|
15
15
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
cudaDataType_t
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
} else if (mode == "nvfp4") {
|
|
22
|
-
return CUDA_R_8F_UE4M3;
|
|
23
|
-
} else {
|
|
24
|
-
throw std::runtime_error(
|
|
25
|
-
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
|
|
26
|
-
}
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
cudaDataType_t qmode_to_cublas_dtype(std::string mode) {
|
|
30
|
-
if (mode == "mxfp8") {
|
|
31
|
-
return CUDA_R_8F_E4M3;
|
|
32
|
-
} else if (mode == "nvfp4") {
|
|
33
|
-
return CUDA_R_4F_E2M1;
|
|
34
|
-
} else {
|
|
35
|
-
throw std::runtime_error(
|
|
36
|
-
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
|
|
37
|
-
}
|
|
38
|
-
}
|
|
16
|
+
struct QuantModeConfig {
|
|
17
|
+
cudaDataType_t data_type;
|
|
18
|
+
cudaDataType_t scale_dtype;
|
|
19
|
+
cublasLtMatmulMatrixScale_t scale_mode;
|
|
20
|
+
};
|
|
39
21
|
|
|
40
|
-
|
|
22
|
+
QuantModeConfig get_quant_mode_config(const std::string& mode) {
|
|
41
23
|
if (mode == "mxfp8") {
|
|
42
|
-
return
|
|
24
|
+
return {
|
|
25
|
+
CUDA_R_8F_E4M3,
|
|
26
|
+
CUDA_R_8F_UE8M0,
|
|
27
|
+
CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0};
|
|
43
28
|
} else if (mode == "nvfp4") {
|
|
44
|
-
return
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
29
|
+
return {
|
|
30
|
+
CUDA_R_4F_E2M1,
|
|
31
|
+
CUDA_R_8F_UE4M3,
|
|
32
|
+
CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3};
|
|
48
33
|
}
|
|
34
|
+
throw std::runtime_error(
|
|
35
|
+
fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
|
|
49
36
|
}
|
|
50
37
|
|
|
51
38
|
} // namespace
|
|
@@ -64,21 +51,21 @@ CublasQQMM::CublasQQMM(
|
|
|
64
51
|
int64_t a_batch_stride,
|
|
65
52
|
int64_t b_batch_stride,
|
|
66
53
|
Dtype out_dtype,
|
|
67
|
-
std::string qmode) {
|
|
54
|
+
const std::string& qmode) {
|
|
55
|
+
auto config = get_quant_mode_config(qmode);
|
|
56
|
+
|
|
68
57
|
// The compute type must be CUBLAS_COMPUTE_32F.
|
|
69
58
|
// The scale type must be CUDA_R_32F.
|
|
70
59
|
cudaDataType_t scale_type = CUDA_R_32F;
|
|
71
60
|
cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
|
|
72
61
|
cudaDataType_t output_type =
|
|
73
62
|
cublas_utils::dtype_to_cublas_type(out_dtype, "CublasQQMM");
|
|
74
|
-
cudaDataType_t data_type = qmode_to_cublas_dtype(qmode);
|
|
75
|
-
quantization_mode_ = std::string(qmode);
|
|
76
63
|
|
|
77
64
|
init_base(
|
|
78
65
|
device,
|
|
79
66
|
scale_type,
|
|
80
67
|
gemm_compute_type,
|
|
81
|
-
data_type,
|
|
68
|
+
config.data_type,
|
|
82
69
|
output_type,
|
|
83
70
|
a_transposed,
|
|
84
71
|
a_rows,
|
|
@@ -92,8 +79,8 @@ CublasQQMM::CublasQQMM(
|
|
|
92
79
|
a_batch_stride,
|
|
93
80
|
b_batch_stride);
|
|
94
81
|
|
|
95
|
-
a_scale_mode_ =
|
|
96
|
-
b_scale_mode_ =
|
|
82
|
+
a_scale_mode_ = config.scale_mode;
|
|
83
|
+
b_scale_mode_ = config.scale_mode;
|
|
97
84
|
|
|
98
85
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
99
86
|
matmul_desc_,
|
|
@@ -123,7 +110,7 @@ CublasQQMM::CublasQQMM(
|
|
|
123
110
|
int64_t b_batch_stride,
|
|
124
111
|
int64_t c_batch_stride,
|
|
125
112
|
Dtype out_dtype,
|
|
126
|
-
std::string qmode)
|
|
113
|
+
const std::string& qmode)
|
|
127
114
|
: CublasQQMM(
|
|
128
115
|
device,
|
|
129
116
|
a_transposed,
|
|
@@ -158,11 +145,14 @@ void CublasQQMM::run(
|
|
|
158
145
|
const array& b,
|
|
159
146
|
const array& a_scale,
|
|
160
147
|
const array& b_scale,
|
|
161
|
-
|
|
148
|
+
const array& alpha,
|
|
149
|
+
const array& beta) {
|
|
162
150
|
encoder.set_input_array(a);
|
|
163
151
|
encoder.set_input_array(b);
|
|
164
152
|
encoder.set_input_array(a_scale);
|
|
165
153
|
encoder.set_input_array(b_scale);
|
|
154
|
+
encoder.set_input_array(alpha);
|
|
155
|
+
encoder.set_input_array(beta);
|
|
166
156
|
encoder.set_output_array(out);
|
|
167
157
|
|
|
168
158
|
execute(
|
|
@@ -173,19 +163,37 @@ void CublasQQMM::run(
|
|
|
173
163
|
gpu_ptr<void>(a_scale),
|
|
174
164
|
gpu_ptr<void>(b_scale),
|
|
175
165
|
nullptr,
|
|
176
|
-
alpha)
|
|
166
|
+
gpu_ptr<void>(alpha),
|
|
167
|
+
gpu_ptr<void>(beta));
|
|
177
168
|
}
|
|
178
169
|
|
|
179
|
-
void CublasQQMM::
|
|
170
|
+
void CublasQQMM::run(
|
|
171
|
+
cu::CommandEncoder& encoder,
|
|
172
|
+
array& out,
|
|
173
|
+
const array& a,
|
|
174
|
+
const array& b,
|
|
175
|
+
const array& a_scale,
|
|
176
|
+
const array& b_scale) {
|
|
177
|
+
encoder.set_input_array(a);
|
|
178
|
+
encoder.set_input_array(b);
|
|
179
|
+
encoder.set_input_array(a_scale);
|
|
180
|
+
encoder.set_input_array(b_scale);
|
|
181
|
+
encoder.set_output_array(out);
|
|
182
|
+
|
|
183
|
+
execute(
|
|
184
|
+
encoder,
|
|
185
|
+
gpu_ptr<void>(out),
|
|
186
|
+
gpu_ptr<void>(a),
|
|
187
|
+
gpu_ptr<void>(b),
|
|
188
|
+
gpu_ptr<void>(a_scale),
|
|
189
|
+
gpu_ptr<void>(b_scale),
|
|
190
|
+
nullptr);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
void CublasQQMM::set_scales_ptrs(
|
|
180
194
|
cu::CommandEncoder& encoder,
|
|
181
|
-
void* out,
|
|
182
|
-
const void* a,
|
|
183
|
-
const void* b,
|
|
184
195
|
const void* a_scale,
|
|
185
|
-
const void* b_scale
|
|
186
|
-
const void* c,
|
|
187
|
-
float alpha /* = 1 */,
|
|
188
|
-
float beta /* = 0 */) {
|
|
196
|
+
const void* b_scale) {
|
|
189
197
|
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
190
198
|
matmul_desc_,
|
|
191
199
|
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
|
|
@@ -196,6 +204,49 @@ void CublasQQMM::execute(
|
|
|
196
204
|
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
|
|
197
205
|
&a_scale,
|
|
198
206
|
sizeof(a_scale)));
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
void CublasQQMM::execute(
|
|
210
|
+
cu::CommandEncoder& encoder,
|
|
211
|
+
void* out,
|
|
212
|
+
const void* a,
|
|
213
|
+
const void* b,
|
|
214
|
+
const void* a_scale,
|
|
215
|
+
const void* b_scale,
|
|
216
|
+
const void* c,
|
|
217
|
+
const void* alpha,
|
|
218
|
+
const void* beta) {
|
|
219
|
+
set_scales_ptrs(encoder, a_scale, b_scale);
|
|
220
|
+
// alpha and beta are both should be device pointers for nvfp4
|
|
221
|
+
// by default cublas uses host pointers
|
|
222
|
+
// https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
|
|
223
|
+
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
|
|
224
|
+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
225
|
+
matmul_desc_,
|
|
226
|
+
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
|
227
|
+
&pointer_mode,
|
|
228
|
+
sizeof(pointer_mode)));
|
|
229
|
+
execute_matmul(encoder, out, a, b, c, alpha, beta);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
void CublasQQMM::execute(
|
|
233
|
+
cu::CommandEncoder& encoder,
|
|
234
|
+
void* out,
|
|
235
|
+
const void* a,
|
|
236
|
+
const void* b,
|
|
237
|
+
const void* a_scale,
|
|
238
|
+
const void* b_scale,
|
|
239
|
+
const void* c,
|
|
240
|
+
const float alpha /* = 1 */,
|
|
241
|
+
const float beta /* = 0 */) {
|
|
242
|
+
set_scales_ptrs(encoder, a_scale, b_scale);
|
|
243
|
+
// alpha and beta are both should be host pointers
|
|
244
|
+
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
|
|
245
|
+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
|
|
246
|
+
matmul_desc_,
|
|
247
|
+
CUBLASLT_MATMUL_DESC_POINTER_MODE,
|
|
248
|
+
&pointer_mode,
|
|
249
|
+
sizeof(pointer_mode)));
|
|
199
250
|
|
|
200
251
|
const void* alpha_ptr = α
|
|
201
252
|
const void* beta_ptr = β
|
|
@@ -25,7 +25,7 @@ class CublasQQMM : public CublasMatmulBase {
|
|
|
25
25
|
int64_t a_batch_stride,
|
|
26
26
|
int64_t b_batch_stride,
|
|
27
27
|
Dtype out_dtype,
|
|
28
|
-
std::string quantization_mode);
|
|
28
|
+
const std::string& quantization_mode);
|
|
29
29
|
|
|
30
30
|
CublasQQMM(
|
|
31
31
|
cu::Device& device,
|
|
@@ -43,7 +43,7 @@ class CublasQQMM : public CublasMatmulBase {
|
|
|
43
43
|
int64_t b_batch_stride,
|
|
44
44
|
int64_t c_batch_stride,
|
|
45
45
|
Dtype out_dtype,
|
|
46
|
-
std::string quantization_mode);
|
|
46
|
+
const std::string& quantization_mode);
|
|
47
47
|
|
|
48
48
|
void run(
|
|
49
49
|
cu::CommandEncoder& encoder,
|
|
@@ -52,20 +52,33 @@ class CublasQQMM : public CublasMatmulBase {
|
|
|
52
52
|
const array& b,
|
|
53
53
|
const array& a_scale,
|
|
54
54
|
const array& b_scale,
|
|
55
|
-
|
|
55
|
+
const array& alpha,
|
|
56
|
+
const array& beta);
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
void run_batched(
|
|
58
|
+
void run(
|
|
59
59
|
cu::CommandEncoder& encoder,
|
|
60
60
|
array& out,
|
|
61
61
|
const array& a,
|
|
62
62
|
const array& b,
|
|
63
63
|
const array& a_scale,
|
|
64
|
-
const array& b_scale
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
64
|
+
const array& b_scale);
|
|
65
|
+
|
|
66
|
+
private:
|
|
67
|
+
void set_scales_ptrs(
|
|
68
|
+
cu::CommandEncoder& encoder,
|
|
69
|
+
const void* a_scale,
|
|
70
|
+
const void* b_scale);
|
|
71
|
+
|
|
72
|
+
void execute(
|
|
73
|
+
cu::CommandEncoder& encoder,
|
|
74
|
+
void* out,
|
|
75
|
+
const void* a,
|
|
76
|
+
const void* b,
|
|
77
|
+
const void* a_scale,
|
|
78
|
+
const void* b_scale,
|
|
79
|
+
const void* c,
|
|
80
|
+
const void* alpha,
|
|
81
|
+
const void* beta);
|
|
69
82
|
|
|
70
83
|
void execute(
|
|
71
84
|
cu::CommandEncoder& encoder,
|
|
@@ -75,10 +88,9 @@ class CublasQQMM : public CublasMatmulBase {
|
|
|
75
88
|
const void* a_scale,
|
|
76
89
|
const void* b_scale,
|
|
77
90
|
const void* c,
|
|
78
|
-
float alpha = 1,
|
|
79
|
-
float beta = 0);
|
|
91
|
+
const float alpha = 1.0f,
|
|
92
|
+
const float beta = 0.0f);
|
|
80
93
|
|
|
81
|
-
std::string quantization_mode_;
|
|
82
94
|
cublasLtMatmulMatrixScale_t a_scale_mode_;
|
|
83
95
|
cublasLtMatmulMatrixScale_t b_scale_mode_;
|
|
84
96
|
cublasLtMatmulMatrixScale_t c_scale_mode_;
|
|
@@ -11,6 +11,11 @@
|
|
|
11
11
|
|
|
12
12
|
#include <cooperative_groups.h>
|
|
13
13
|
#include <cooperative_groups/reduce.h>
|
|
14
|
+
#include <cuda_fp4.h>
|
|
15
|
+
#include <cuda_fp8.h>
|
|
16
|
+
|
|
17
|
+
constexpr float F8E4M3_MAX = 448.0f;
|
|
18
|
+
constexpr float F4E2M1_MAX = 6.0f;
|
|
14
19
|
|
|
15
20
|
namespace mlx::core {
|
|
16
21
|
namespace cu {
|
|
@@ -29,7 +34,16 @@ struct Dequantize {
|
|
|
29
34
|
namespace cg = cooperative_groups;
|
|
30
35
|
|
|
31
36
|
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
|
|
32
|
-
__global__ void fp_quantize_dequantize(
|
|
37
|
+
__global__ void fp_quantize_dequantize(
|
|
38
|
+
T* w,
|
|
39
|
+
T* out,
|
|
40
|
+
size_t size,
|
|
41
|
+
float* global_scale = nullptr) {
|
|
42
|
+
const bool use_global_scale = global_scale != nullptr;
|
|
43
|
+
const float scale_enc =
|
|
44
|
+
use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
|
|
45
|
+
const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f;
|
|
46
|
+
|
|
33
47
|
using Tx2 = Vector2_t<T>;
|
|
34
48
|
using Tx4 = Vector4_t<T>;
|
|
35
49
|
uint32_t rbits = 0; // reserved bits for future use
|
|
@@ -48,26 +62,28 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
|
|
|
48
62
|
}
|
|
49
63
|
|
|
50
64
|
auto w_tile = load_vector<group_size, T>(w, thread_idx);
|
|
51
|
-
float
|
|
65
|
+
float scale_dec_b = 0.0f;
|
|
52
66
|
|
|
53
67
|
Tx2 amax_2x = Tx2{0.0f, 0.0f};
|
|
54
68
|
|
|
55
69
|
#pragma unroll
|
|
56
70
|
for (int i = 0; i < group_size; i += 2) {
|
|
57
71
|
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
|
|
58
|
-
|
|
72
|
+
absmax_x2<Tx2>(amax_2x, amax_2x, pair);
|
|
59
73
|
}
|
|
60
74
|
|
|
61
|
-
|
|
75
|
+
scale_dec_b = static_cast<float>(
|
|
62
76
|
max(fabsf(static_cast<float>(amax_2x.x)),
|
|
63
77
|
fabsf(static_cast<float>(amax_2x.y))));
|
|
64
78
|
|
|
65
|
-
|
|
79
|
+
scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
|
|
80
|
+
scale_dec_b *= scale_enc;
|
|
66
81
|
// Convert to mx scale or nv scale
|
|
67
82
|
using ScaleType =
|
|
68
83
|
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
|
69
|
-
auto s = ScaleType(
|
|
70
|
-
|
|
84
|
+
auto s = ScaleType(scale_dec_b);
|
|
85
|
+
float scale_enc_b = scale_enc / float(s);
|
|
86
|
+
float scale_dec = float(s) * inv_scale_enc;
|
|
71
87
|
AlignedVector<T, group_size> w_hat;
|
|
72
88
|
|
|
73
89
|
#pragma unroll
|
|
@@ -76,24 +92,36 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
|
|
|
76
92
|
float4 dq;
|
|
77
93
|
if constexpr (bits == 8) {
|
|
78
94
|
uint32_t quantized_val =
|
|
79
|
-
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4,
|
|
95
|
+
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
|
|
80
96
|
dq = dequant_fp8(quantized_val);
|
|
81
97
|
} else {
|
|
82
98
|
uint16_t quantized_val =
|
|
83
|
-
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4,
|
|
99
|
+
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
|
|
84
100
|
dq = dequant_fp4(quantized_val);
|
|
85
101
|
}
|
|
86
|
-
w_hat[i * 4] = static_cast<T>(dq.x *
|
|
87
|
-
w_hat[i * 4 + 1] = static_cast<T>(dq.y *
|
|
88
|
-
w_hat[i * 4 + 2] = static_cast<T>(dq.z *
|
|
89
|
-
w_hat[i * 4 + 3] = static_cast<T>(dq.w *
|
|
102
|
+
w_hat[i * 4] = static_cast<T>(dq.x * scale_dec);
|
|
103
|
+
w_hat[i * 4 + 1] = static_cast<T>(dq.y * scale_dec);
|
|
104
|
+
w_hat[i * 4 + 2] = static_cast<T>(dq.z * scale_dec);
|
|
105
|
+
w_hat[i * 4 + 3] = static_cast<T>(dq.w * scale_dec);
|
|
90
106
|
}
|
|
91
107
|
store_vector<group_size>(out, thread_idx, w_hat);
|
|
92
108
|
}
|
|
93
109
|
|
|
94
110
|
template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
|
|
95
|
-
__global__ void
|
|
96
|
-
|
|
111
|
+
__global__ void fp_quantize_rowwise(
|
|
112
|
+
T* w,
|
|
113
|
+
uint8_t* out,
|
|
114
|
+
uint8_t* scales,
|
|
115
|
+
size_t size,
|
|
116
|
+
float* global_scale = nullptr) {
|
|
117
|
+
// NVFP4 conversion:
|
|
118
|
+
// Global encode scale: (448 × 6) / *global_scale
|
|
119
|
+
// Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8
|
|
120
|
+
// E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b
|
|
121
|
+
const bool use_global_scale = global_scale != nullptr;
|
|
122
|
+
const float scale_enc =
|
|
123
|
+
use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
|
|
124
|
+
|
|
97
125
|
using Tx2 = Vector2_t<T>;
|
|
98
126
|
using Tx4 = Vector4_t<T>;
|
|
99
127
|
uint32_t rbits = 0; // reserved bits for future use
|
|
@@ -112,27 +140,28 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
|
|
112
140
|
}
|
|
113
141
|
|
|
114
142
|
auto w_tile = load_vector<group_size, T>(w, thread_idx);
|
|
115
|
-
float
|
|
143
|
+
float scale_dec_b = 0.0f;
|
|
116
144
|
|
|
117
145
|
Tx2 amax_2x = Tx2{0.0f, 0.0f};
|
|
118
146
|
|
|
119
147
|
#pragma unroll
|
|
120
148
|
for (int i = 0; i < group_size; i += 2) {
|
|
121
149
|
auto pair = Tx2{w_tile[i], w_tile[i + 1]};
|
|
122
|
-
|
|
150
|
+
absmax_x2<Tx2>(amax_2x, amax_2x, pair);
|
|
123
151
|
}
|
|
124
152
|
|
|
125
|
-
|
|
153
|
+
scale_dec_b = static_cast<float>(
|
|
126
154
|
max(fabsf(static_cast<float>(amax_2x.x)),
|
|
127
155
|
fabsf(static_cast<float>(amax_2x.y))));
|
|
128
156
|
|
|
129
|
-
|
|
157
|
+
scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
|
|
158
|
+
scale_dec_b *= scale_enc;
|
|
130
159
|
// Convert to mx scale or nv scale
|
|
131
160
|
using ScaleType =
|
|
132
161
|
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
|
133
|
-
auto s = ScaleType(
|
|
162
|
+
auto s = ScaleType(scale_dec_b);
|
|
134
163
|
uint8_t q_scale = s.__x;
|
|
135
|
-
|
|
164
|
+
float scale_enc_b = scale_enc / float(s);
|
|
136
165
|
|
|
137
166
|
scales[thread_idx] = q_scale;
|
|
138
167
|
constexpr int elem_per_byte = bits == 8 ? 1 : 2;
|
|
@@ -143,11 +172,11 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
|
|
|
143
172
|
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
|
|
144
173
|
if constexpr (bits == 8) {
|
|
145
174
|
uint32_t quantized_val =
|
|
146
|
-
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4,
|
|
175
|
+
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
|
|
147
176
|
*reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
|
|
148
177
|
} else {
|
|
149
178
|
uint16_t quantized_val =
|
|
150
|
-
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4,
|
|
179
|
+
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
|
|
151
180
|
*reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
|
|
152
181
|
}
|
|
153
182
|
}
|
|
@@ -161,11 +190,15 @@ __global__ void fp_quantize_columnwise(
|
|
|
161
190
|
uint8_t* scales,
|
|
162
191
|
size_t size,
|
|
163
192
|
int M,
|
|
164
|
-
int K
|
|
193
|
+
int K,
|
|
194
|
+
float* global_scale = nullptr) {
|
|
165
195
|
// Input: [M, K] with strides [1, M] (M-major)
|
|
166
196
|
// Quantized output: [M, K/elem_per_byte] row-major (K-major)
|
|
167
197
|
// Scales: [M, K/group_size] row-major (K-major)
|
|
168
198
|
// Quantize along K (last dimension, groups of group_size elements)
|
|
199
|
+
const bool use_global_scale = global_scale != nullptr;
|
|
200
|
+
const float scale_enc =
|
|
201
|
+
use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
|
|
169
202
|
|
|
170
203
|
using Tx2 = Vector2_t<T>;
|
|
171
204
|
using Tx4 = Vector4_t<T>;
|
|
@@ -215,16 +248,18 @@ __global__ void fp_quantize_columnwise(
|
|
|
215
248
|
#pragma unroll
|
|
216
249
|
for (int r = 0; r < group_size; r += 2) {
|
|
217
250
|
auto pair = Tx2{thread_data[r], thread_data[r + 1]};
|
|
218
|
-
|
|
251
|
+
absmax_x2<Tx2>(amax_2x, amax_2x, pair);
|
|
219
252
|
}
|
|
220
|
-
float
|
|
253
|
+
float scale_dec_b =
|
|
221
254
|
max(fabsf(static_cast<float>(amax_2x.x)),
|
|
222
255
|
fabsf(static_cast<float>(amax_2x.y)));
|
|
223
|
-
|
|
256
|
+
scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
|
|
257
|
+
scale_dec_b *= scale_enc;
|
|
258
|
+
// Convert to mx scale or nv scale
|
|
224
259
|
using ScaleType =
|
|
225
260
|
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
|
226
|
-
auto s = ScaleType(
|
|
227
|
-
|
|
261
|
+
auto s = ScaleType(scale_dec_b);
|
|
262
|
+
float scale_enc_b = scale_enc / float(s);
|
|
228
263
|
scales_smem[tidx][tidy] = s.__x;
|
|
229
264
|
|
|
230
265
|
int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group;
|
|
@@ -234,12 +269,12 @@ __global__ void fp_quantize_columnwise(
|
|
|
234
269
|
Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&thread_data[j * 4]);
|
|
235
270
|
if constexpr (bits == 8) {
|
|
236
271
|
uint32_t quantized_val =
|
|
237
|
-
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4,
|
|
272
|
+
scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
|
|
238
273
|
*reinterpret_cast<uint32_t*>(&quantized_smem[shared_idx + j * 4]) =
|
|
239
274
|
quantized_val;
|
|
240
275
|
} else {
|
|
241
276
|
uint16_t quantized_val =
|
|
242
|
-
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4,
|
|
277
|
+
scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
|
|
243
278
|
*reinterpret_cast<uint16_t*>(&quantized_smem[shared_idx + j * 2]) =
|
|
244
279
|
quantized_val;
|
|
245
280
|
}
|
|
@@ -282,8 +317,12 @@ __global__ void fp_quantize_columnwise(
|
|
|
282
317
|
}
|
|
283
318
|
|
|
284
319
|
template <typename T, int group_size, int bits, bool use_mx_scale>
|
|
285
|
-
__global__ void
|
|
286
|
-
|
|
320
|
+
__global__ void fp_dequantize(
|
|
321
|
+
const uint8_t* w,
|
|
322
|
+
const uint8_t* scales,
|
|
323
|
+
T* out,
|
|
324
|
+
size_t size,
|
|
325
|
+
float* global_scale = nullptr) {
|
|
287
326
|
auto block_size = cg::this_thread_block().dim_threads();
|
|
288
327
|
auto block_idx = cg::this_thread_block().group_index();
|
|
289
328
|
auto idx_in_block = cg::this_thread_block().thread_index();
|
|
@@ -294,6 +333,10 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
|
|
|
294
333
|
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
|
|
295
334
|
|
|
296
335
|
constexpr int pack_factor = bits == 8 ? 1 : 2;
|
|
336
|
+
const bool use_global_scale = global_scale != nullptr;
|
|
337
|
+
const float inv_scale_enc = use_mx_scale
|
|
338
|
+
? 1.0f
|
|
339
|
+
: (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f);
|
|
297
340
|
size_t offset = tidx + grid_dim_x * size_t(tidy);
|
|
298
341
|
size_t oindex = offset * pack_factor;
|
|
299
342
|
|
|
@@ -304,7 +347,7 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
|
|
|
304
347
|
size_t gindex = oindex / group_size;
|
|
305
348
|
using ScaleType =
|
|
306
349
|
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
|
307
|
-
auto scale = float(((ScaleType*)(scales))[gindex]);
|
|
350
|
+
auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc;
|
|
308
351
|
|
|
309
352
|
out += oindex;
|
|
310
353
|
|
|
@@ -346,9 +389,13 @@ void fp_quantize_dequantize(
|
|
|
346
389
|
array& what,
|
|
347
390
|
int group_size,
|
|
348
391
|
int bits,
|
|
392
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
349
393
|
cu::CommandEncoder& enc,
|
|
350
394
|
const Stream& s) {
|
|
351
395
|
enc.set_input_array(w);
|
|
396
|
+
if (global_scale.has_value()) {
|
|
397
|
+
enc.set_input_array(global_scale.value());
|
|
398
|
+
}
|
|
352
399
|
enc.set_output_array(what);
|
|
353
400
|
dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) {
|
|
354
401
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
@@ -370,7 +417,9 @@ void fp_quantize_dequantize(
|
|
|
370
417
|
0,
|
|
371
418
|
gpu_ptr<T>(w),
|
|
372
419
|
gpu_ptr<T>(what),
|
|
373
|
-
w.size()
|
|
420
|
+
w.size(),
|
|
421
|
+
global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
|
|
422
|
+
: nullptr);
|
|
374
423
|
}
|
|
375
424
|
});
|
|
376
425
|
}
|
|
@@ -381,9 +430,13 @@ void fp_quantize(
|
|
|
381
430
|
array& scales,
|
|
382
431
|
int group_size,
|
|
383
432
|
int bits,
|
|
433
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
384
434
|
cu::CommandEncoder& enc,
|
|
385
435
|
const Stream& s) {
|
|
386
436
|
enc.set_input_array(w);
|
|
437
|
+
if (global_scale.has_value()) {
|
|
438
|
+
enc.set_input_array(global_scale.value());
|
|
439
|
+
}
|
|
387
440
|
enc.set_output_array(wq);
|
|
388
441
|
enc.set_output_array(scales);
|
|
389
442
|
if (w.strides().back() != 1) {
|
|
@@ -410,7 +463,9 @@ void fp_quantize(
|
|
|
410
463
|
gpu_ptr<uint8_t>(scales),
|
|
411
464
|
w.size(),
|
|
412
465
|
M,
|
|
413
|
-
K
|
|
466
|
+
K,
|
|
467
|
+
global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
|
|
468
|
+
: nullptr);
|
|
414
469
|
} else {
|
|
415
470
|
throw std::runtime_error(
|
|
416
471
|
"[Quantize::eval_gpu] Can not quantize input with type float64.");
|
|
@@ -438,7 +493,9 @@ void fp_quantize(
|
|
|
438
493
|
gpu_ptr<T>(w),
|
|
439
494
|
gpu_ptr<uint8_t>(wq),
|
|
440
495
|
gpu_ptr<uint8_t>(scales),
|
|
441
|
-
w.size()
|
|
496
|
+
w.size(),
|
|
497
|
+
global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
|
|
498
|
+
: nullptr);
|
|
442
499
|
} else {
|
|
443
500
|
throw std::runtime_error(
|
|
444
501
|
"[Quantize::eval_gpu] Can not quantize input with type float64.");
|
|
@@ -453,6 +510,7 @@ void fp_dequantize(
|
|
|
453
510
|
array& w,
|
|
454
511
|
int group_size,
|
|
455
512
|
int bits,
|
|
513
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
456
514
|
cu::CommandEncoder& enc,
|
|
457
515
|
const Stream& s) {
|
|
458
516
|
constexpr int uint8_per_uint32 = 4;
|
|
@@ -465,6 +523,9 @@ void fp_dequantize(
|
|
|
465
523
|
|
|
466
524
|
enc.set_input_array(wq);
|
|
467
525
|
enc.set_input_array(scales);
|
|
526
|
+
if (global_scale.has_value()) {
|
|
527
|
+
enc.set_input_array(global_scale.value());
|
|
528
|
+
}
|
|
468
529
|
enc.set_output_array(w);
|
|
469
530
|
dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
|
|
470
531
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
@@ -485,7 +546,9 @@ void fp_dequantize(
|
|
|
485
546
|
gpu_ptr<uint8_t>(wq),
|
|
486
547
|
gpu_ptr<uint8_t>(scales),
|
|
487
548
|
gpu_ptr<T>(w),
|
|
488
|
-
w.size()
|
|
549
|
+
w.size(),
|
|
550
|
+
global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
|
|
551
|
+
: nullptr);
|
|
489
552
|
} else {
|
|
490
553
|
throw std::runtime_error(
|
|
491
554
|
"[Quantize::eval_gpu] Can not dequantize to output with type float64.");
|