mlx 0.30.7.2 → 0.30.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/extconf.rb +267 -8
- data/ext/mlx/native.cpp +112 -58
- data/ext/mlx-onnx/native.cpp +1402 -0
- data/ext/mlx-onnx/native.hpp +19 -0
- data/lib/mlx/core.rb +342 -117
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +2 -0
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +8 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +10 -0
- data/lib/mlx/nn/base.rb +4 -0
- data/lib/mlx/nn/layers/linear.rb +2 -3
- data/lib/mlx/onnx.rb +250 -0
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx-onnx/webgpu_harness.rb +289 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
- data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
- data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
- data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
- data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
- data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
- data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
- data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
- data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
- data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
- data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
- data/submodules/mlx-onnx/CMakeLists.txt +159 -0
- data/submodules/mlx-onnx/LICENSE +21 -0
- data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
- data/submodules/mlx-onnx/src/api.cpp +81 -0
- data/submodules/mlx-onnx/src/compat.cpp +111 -0
- data/submodules/mlx-onnx/src/detail.hpp +69 -0
- data/submodules/mlx-onnx/src/export.cpp +653 -0
- data/submodules/mlx-onnx/src/io.cpp +61 -0
- data/submodules/mlx-onnx/src/json.hpp +25 -0
- data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
- data/submodules/mlx-onnx/src/mappings.cpp +201 -0
- data/submodules/mlx-onnx/src/mappings.hpp +16 -0
- data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
- data/submodules/mlx-onnx/src/shared.cpp +206 -0
- metadata +665 -567
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
- /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
- /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
- /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
- /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
- /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
#include "mlx/backend/cuda/device.h"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
|
|
10
|
+
// Compute padded dimensions for tiled layout
|
|
11
|
+
// Tiles are 128 rows × 4 columns, must allocate full tiles
|
|
12
|
+
inline std::pair<int, int> get_padded_scale_dims(int num_rows, int num_cols) {
|
|
13
|
+
constexpr int rows_per_tile = 128;
|
|
14
|
+
constexpr int cols_per_tile = 4;
|
|
15
|
+
|
|
16
|
+
int padded_rows =
|
|
17
|
+
((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile;
|
|
18
|
+
int padded_cols =
|
|
19
|
+
((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile;
|
|
20
|
+
|
|
21
|
+
return {padded_rows, padded_cols};
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
void swizzle_scales(
|
|
25
|
+
const array& scales,
|
|
26
|
+
array& scales_tiled,
|
|
27
|
+
cu::CommandEncoder& enc,
|
|
28
|
+
const Stream& s);
|
|
29
|
+
|
|
30
|
+
inline array pad_and_swizzle_scales(
|
|
31
|
+
const array& scale,
|
|
32
|
+
cu::CommandEncoder& encoder,
|
|
33
|
+
const Stream& s) {
|
|
34
|
+
// Compute padded dimensions for full tiles (128 rows × 4 cols)
|
|
35
|
+
auto [pad_outer, pad_inner] =
|
|
36
|
+
get_padded_scale_dims(scale.shape(-2), scale.shape(-1));
|
|
37
|
+
// cuBLAS requirements for scale factor layout:
|
|
38
|
+
// 1. Dimensions must be padded to full tiles (128 rows × 4 cols)
|
|
39
|
+
// 2. Out-of-bounds values must be filled with zeros
|
|
40
|
+
// 3. Starting addresses must be 16-byte aligned
|
|
41
|
+
// https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
42
|
+
// Note: cu::malloc_async already provides 256-byte alignment
|
|
43
|
+
array scale_tiled(
|
|
44
|
+
cu::malloc_async(pad_outer * pad_inner, encoder),
|
|
45
|
+
Shape{pad_outer, pad_inner},
|
|
46
|
+
scale.dtype());
|
|
47
|
+
swizzle_scales(scale, scale_tiled, encoder, s);
|
|
48
|
+
|
|
49
|
+
encoder.add_temporary(scale_tiled);
|
|
50
|
+
return scale_tiled;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2
|
|
54
|
+
// Allocate beta zero on device as well
|
|
55
|
+
void compute_qqmm_pointers(
|
|
56
|
+
array& alpha_out,
|
|
57
|
+
array& beta_out,
|
|
58
|
+
const array& tensor_amax_x,
|
|
59
|
+
const array& tensor_amax_w,
|
|
60
|
+
cu::CommandEncoder& enc);
|
|
61
|
+
|
|
62
|
+
} // namespace mlx::core
|
|
@@ -51,7 +51,6 @@ void fast::Quantize::eval_gpu(
|
|
|
51
51
|
auto& s = stream();
|
|
52
52
|
auto& d = cu::device(s.device);
|
|
53
53
|
auto& enc = d.get_command_encoder(s);
|
|
54
|
-
|
|
55
54
|
if (dequantize_) {
|
|
56
55
|
auto wq = ensure_row_contiguous(inputs[0], enc, s);
|
|
57
56
|
auto scales = ensure_row_contiguous(inputs[1], enc, s);
|
|
@@ -63,7 +62,12 @@ void fast::Quantize::eval_gpu(
|
|
|
63
62
|
auto biases = ensure_row_contiguous(inputs[2], enc, s);
|
|
64
63
|
affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
|
|
65
64
|
} else {
|
|
66
|
-
|
|
65
|
+
// 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4
|
|
66
|
+
bool use_global_scale =
|
|
67
|
+
mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2;
|
|
68
|
+
std::optional<array> global_scale =
|
|
69
|
+
use_global_scale ? std::make_optional(inputs[2]) : std::nullopt;
|
|
70
|
+
fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s);
|
|
67
71
|
}
|
|
68
72
|
} else {
|
|
69
73
|
auto w = ensure_contiguous(inputs[0], enc, s);
|
|
@@ -72,12 +76,17 @@ void fast::Quantize::eval_gpu(
|
|
|
72
76
|
|
|
73
77
|
wq.set_data(cu::malloc_async(wq.nbytes(), enc));
|
|
74
78
|
scales.set_data(cu::malloc_async(scales.nbytes(), enc));
|
|
79
|
+
|
|
75
80
|
if (mode_ == QuantizationMode::Affine) {
|
|
76
81
|
auto& biases = outputs[2];
|
|
77
82
|
biases.set_data(cu::malloc_async(biases.nbytes(), enc));
|
|
78
83
|
affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
|
|
79
84
|
} else {
|
|
80
|
-
|
|
85
|
+
bool use_global_scale =
|
|
86
|
+
mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1;
|
|
87
|
+
std::optional<array> global_scale =
|
|
88
|
+
use_global_scale ? std::make_optional(inputs[1]) : std::nullopt;
|
|
89
|
+
fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s);
|
|
81
90
|
}
|
|
82
91
|
}
|
|
83
92
|
}
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
// Copyright © 2025 Apple Inc.
|
|
2
2
|
|
|
3
|
+
#include <optional>
|
|
3
4
|
#include "mlx/backend/cuda/device.h"
|
|
4
5
|
|
|
5
6
|
namespace mlx::core {
|
|
@@ -30,6 +31,7 @@ void fp_quantize(
|
|
|
30
31
|
array& scales,
|
|
31
32
|
int group_size,
|
|
32
33
|
int bits,
|
|
34
|
+
const std::optional<array>& global_scale,
|
|
33
35
|
cu::CommandEncoder& enc,
|
|
34
36
|
const Stream& s);
|
|
35
37
|
|
|
@@ -39,6 +41,7 @@ void fp_dequantize(
|
|
|
39
41
|
array& w,
|
|
40
42
|
int group_size,
|
|
41
43
|
int bits,
|
|
44
|
+
const std::optional<array>& global_scale,
|
|
42
45
|
cu::CommandEncoder& enc,
|
|
43
46
|
const Stream& s);
|
|
44
47
|
|
|
@@ -47,6 +50,7 @@ void fp_quantize_dequantize(
|
|
|
47
50
|
array& what,
|
|
48
51
|
int group_size,
|
|
49
52
|
int bits,
|
|
53
|
+
const std::optional<array>& global_scale,
|
|
50
54
|
cu::CommandEncoder& enc,
|
|
51
55
|
const Stream& s);
|
|
52
56
|
|
|
@@ -29,7 +29,7 @@ inline constexpr __device__ short get_bytes_per_pack() {
|
|
|
29
29
|
}
|
|
30
30
|
|
|
31
31
|
template <typename T>
|
|
32
|
-
__device__ __forceinline__ void
|
|
32
|
+
__device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) {
|
|
33
33
|
if constexpr (
|
|
34
34
|
(std::is_same<T, __nv_bfloat162>::value) ||
|
|
35
35
|
(std::is_same<T, __half2>::value)) {
|
|
@@ -247,6 +247,10 @@ void CommandEncoder::set_buffer(
|
|
|
247
247
|
const MTL::Buffer* buf,
|
|
248
248
|
int idx,
|
|
249
249
|
int64_t offset /* = 0 */) {
|
|
250
|
+
// Record as both input and output to ensure synchronization between command
|
|
251
|
+
// buffers
|
|
252
|
+
all_inputs_.insert((void*)buf);
|
|
253
|
+
all_outputs_.insert((void*)buf);
|
|
250
254
|
enc_->setBuffer(buf, offset, idx);
|
|
251
255
|
}
|
|
252
256
|
|
|
@@ -30,7 +30,7 @@ template <typename T, int N>
|
|
|
30
30
|
out_pixels *= params->oS[i];
|
|
31
31
|
|
|
32
32
|
// Set out
|
|
33
|
-
out += gid.z * filter_size + gid.y * (params->C);
|
|
33
|
+
out += (size_t)gid.z * filter_size + (size_t)gid.y * (params->C);
|
|
34
34
|
|
|
35
35
|
// Coordinates in input
|
|
36
36
|
int is[N] = {0};
|
|
@@ -93,7 +93,8 @@ template <typename T, int N>
|
|
|
93
93
|
out_pixels *= params->oS[i];
|
|
94
94
|
|
|
95
95
|
// Set out
|
|
96
|
-
out +=
|
|
96
|
+
out +=
|
|
97
|
+
(size_t)gid.z * filter_size + (size_t)gid.x * (filter_size / params->C);
|
|
97
98
|
|
|
98
99
|
// Coordinates in input
|
|
99
100
|
int is[N] = {0};
|
|
@@ -279,6 +279,8 @@ void extract_state(const T state, std::vector<StateT>& unpacked_state) {
|
|
|
279
279
|
unpacked_state.push_back(state);
|
|
280
280
|
} else if constexpr (std::is_enum_v<T>) {
|
|
281
281
|
unpacked_state.push_back(static_cast<int>(state));
|
|
282
|
+
} else if constexpr (std::is_same_v<T, Dtype>) {
|
|
283
|
+
unpacked_state.push_back(state);
|
|
282
284
|
} else if constexpr (is_iterable<T>) {
|
|
283
285
|
unpacked_state.push_back(state);
|
|
284
286
|
} else if constexpr (is_pair<T> || is_tuple<T>) {
|
|
@@ -446,6 +448,7 @@ struct PrimitiveFactory {
|
|
|
446
448
|
SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
|
|
447
449
|
SERIALIZE_PRIMITIVE(CustomKernel)};
|
|
448
450
|
std::unordered_map<std::string, std::string> name_remap;
|
|
451
|
+
std::unordered_map<int, Stream> stream_map;
|
|
449
452
|
|
|
450
453
|
PrimitiveFactory() {
|
|
451
454
|
for (auto& [n, f] : factory) {
|
|
@@ -471,13 +474,25 @@ struct PrimitiveFactory {
|
|
|
471
474
|
}
|
|
472
475
|
};
|
|
473
476
|
|
|
474
|
-
|
|
475
|
-
auto
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
477
|
+
Stream resolve_stream(const Stream& stream) {
|
|
478
|
+
if (auto it = stream_map.find(stream.index); it != stream_map.end()) {
|
|
479
|
+
return it->second;
|
|
480
|
+
}
|
|
481
|
+
// Try to find an existing stream on the same device
|
|
482
|
+
for (auto& s : get_streams()) {
|
|
483
|
+
if (s.device == stream.device) {
|
|
484
|
+
stream_map.emplace(stream.index, s);
|
|
485
|
+
return s;
|
|
486
|
+
}
|
|
480
487
|
}
|
|
488
|
+
// No stream on that device, make a new one
|
|
489
|
+
Stream s = new_stream(stream.device);
|
|
490
|
+
stream_map.emplace(stream.index, s);
|
|
491
|
+
return s;
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
std::shared_ptr<Primitive> load(Reader& is) {
|
|
495
|
+
auto stream = resolve_stream(deserialize<Stream>(is));
|
|
481
496
|
auto name = deserialize<std::string>(is);
|
|
482
497
|
if (auto it = factory.find(name); it != factory.end()) {
|
|
483
498
|
return it->second.deserialize(is, stream);
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
#include <sstream>
|
|
11
11
|
|
|
12
12
|
#include "mlx/backend/cuda/cuda.h"
|
|
13
|
+
#include "mlx/backend/metal/metal.h"
|
|
13
14
|
#include "mlx/fast_primitives.h"
|
|
14
15
|
#include "mlx/ops.h"
|
|
15
16
|
#include "mlx/primitives.h"
|
|
@@ -2311,6 +2312,40 @@ array argmax(
|
|
|
2311
2312
|
return out;
|
|
2312
2313
|
}
|
|
2313
2314
|
|
|
2315
|
+
array hanning(int M, StreamOrDevice s /* = {} */) {
|
|
2316
|
+
if (M < 1) {
|
|
2317
|
+
return array({});
|
|
2318
|
+
}
|
|
2319
|
+
if (M == 1) {
|
|
2320
|
+
return ones({1}, float32, s);
|
|
2321
|
+
}
|
|
2322
|
+
|
|
2323
|
+
auto n = arange(0, M, float32, s);
|
|
2324
|
+
array factor(M_PI / (M - 1), float32);
|
|
2325
|
+
return square(sin(multiply(factor, n, s), s), s);
|
|
2326
|
+
}
|
|
2327
|
+
|
|
2328
|
+
array hamming(int M, StreamOrDevice s /* = {} */) {
|
|
2329
|
+
if (M < 1) {
|
|
2330
|
+
return array({});
|
|
2331
|
+
}
|
|
2332
|
+
if (M == 1) {
|
|
2333
|
+
return ones({1}, float32, s);
|
|
2334
|
+
}
|
|
2335
|
+
|
|
2336
|
+
auto n = arange(0, M, float32, s);
|
|
2337
|
+
float factor_val = (2.0 * M_PI) / (M - 1);
|
|
2338
|
+
auto factor = array(factor_val, float32);
|
|
2339
|
+
|
|
2340
|
+
auto arg = multiply(factor, n, s);
|
|
2341
|
+
auto cos_vals = cos(arg, s);
|
|
2342
|
+
|
|
2343
|
+
auto left_coef = array(0.54f, float32);
|
|
2344
|
+
auto right_coef = array(0.46f, float32);
|
|
2345
|
+
|
|
2346
|
+
return subtract(left_coef, multiply(right_coef, cos_vals, s), s);
|
|
2347
|
+
}
|
|
2348
|
+
|
|
2314
2349
|
/** Returns a sorted copy of the flattened array. */
|
|
2315
2350
|
array sort(const array& a, StreamOrDevice s /* = {} */) {
|
|
2316
2351
|
int size = a.size();
|
|
@@ -4209,6 +4244,34 @@ std::pair<Dtype, QuantizationMode> validate_mode_with_type(
|
|
|
4209
4244
|
}
|
|
4210
4245
|
}
|
|
4211
4246
|
|
|
4247
|
+
void validate_global_scale(
|
|
4248
|
+
std::string_view tag,
|
|
4249
|
+
QuantizationMode qmode,
|
|
4250
|
+
const std::optional<array>& global_scale) {
|
|
4251
|
+
if (global_scale.has_value()) {
|
|
4252
|
+
if (qmode != QuantizationMode::Nvfp4) {
|
|
4253
|
+
std::ostringstream msg;
|
|
4254
|
+
msg << "[" << tag << "] Global scale is only supported for 'nvfp4' "
|
|
4255
|
+
<< "quantization mode.";
|
|
4256
|
+
throw std::invalid_argument(msg.str());
|
|
4257
|
+
} else {
|
|
4258
|
+
if (global_scale->size() != 1) {
|
|
4259
|
+
std::ostringstream msg;
|
|
4260
|
+
msg << "[" << tag << "] Global scale must be a scalar but got shape "
|
|
4261
|
+
<< global_scale->shape() << ".";
|
|
4262
|
+
throw std::invalid_argument(msg.str());
|
|
4263
|
+
}
|
|
4264
|
+
// TODO: not sure if type should be restricted to float32
|
|
4265
|
+
if (global_scale->dtype() != float32) {
|
|
4266
|
+
std::ostringstream msg;
|
|
4267
|
+
msg << "[" << tag << "] Global scale must have dtype float32 but got "
|
|
4268
|
+
<< global_scale->dtype() << ".";
|
|
4269
|
+
throw std::invalid_argument(msg.str());
|
|
4270
|
+
}
|
|
4271
|
+
}
|
|
4272
|
+
}
|
|
4273
|
+
}
|
|
4274
|
+
|
|
4212
4275
|
array quantized_matmul(
|
|
4213
4276
|
array x,
|
|
4214
4277
|
array w,
|
|
@@ -4251,7 +4314,6 @@ array quantized_matmul(
|
|
|
4251
4314
|
if (x.ndim() > 2 && w.ndim() > 2) {
|
|
4252
4315
|
inputs = broadcast_arrays(inputs, {-2, -1}, s);
|
|
4253
4316
|
}
|
|
4254
|
-
|
|
4255
4317
|
auto out_shape = inputs[0].shape();
|
|
4256
4318
|
out_shape.back() = w_outer_dims;
|
|
4257
4319
|
return array(
|
|
@@ -4267,7 +4329,10 @@ void validate_qqmm_inputs(
|
|
|
4267
4329
|
array w,
|
|
4268
4330
|
std::optional<array> scales_w,
|
|
4269
4331
|
int group_size,
|
|
4270
|
-
int bits
|
|
4332
|
+
int bits,
|
|
4333
|
+
std::optional<array> global_scale_x,
|
|
4334
|
+
std::optional<array> global_scale_w,
|
|
4335
|
+
QuantizationMode qmode) {
|
|
4271
4336
|
// check 2D (for now)
|
|
4272
4337
|
if (x.ndim() > 2 || w.ndim() > 2) {
|
|
4273
4338
|
std::ostringstream msg;
|
|
@@ -4304,6 +4369,19 @@ void validate_qqmm_inputs(
|
|
|
4304
4369
|
<< "first argument dtype == " << x.dtype() << ".";
|
|
4305
4370
|
throw std::invalid_argument(msg.str());
|
|
4306
4371
|
}
|
|
4372
|
+
// validate global scales
|
|
4373
|
+
validate_global_scale("qqmm", qmode, global_scale_x);
|
|
4374
|
+
validate_global_scale("qqmm", qmode, global_scale_w);
|
|
4375
|
+
// For nvfp4 mode, both global scales must be provided together or neither
|
|
4376
|
+
if (qmode == QuantizationMode::Nvfp4) {
|
|
4377
|
+
bool has_x = global_scale_x.has_value();
|
|
4378
|
+
bool has_w = global_scale_w.has_value();
|
|
4379
|
+
if (has_x != has_w) {
|
|
4380
|
+
throw std::invalid_argument(
|
|
4381
|
+
"[qqmm] For nvfp4 mode, either both global_scale_x and "
|
|
4382
|
+
"global_scale_w must be provided, or neither.");
|
|
4383
|
+
}
|
|
4384
|
+
}
|
|
4307
4385
|
}
|
|
4308
4386
|
|
|
4309
4387
|
std::pair<int, int> extract_qqmm_dims(
|
|
@@ -4343,6 +4421,8 @@ array qqmm(
|
|
|
4343
4421
|
std::optional<int> group_size_ /* = std::nullopt */,
|
|
4344
4422
|
std::optional<int> bits_ /* = std::nullopt */,
|
|
4345
4423
|
const std::string& mode /* = "nvfp4" */,
|
|
4424
|
+
const std::optional<array> global_scale_x /* = std::nullopt */,
|
|
4425
|
+
const std::optional<array> global_scale_w /* = std::nullopt */,
|
|
4346
4426
|
StreamOrDevice s /* = {} */) {
|
|
4347
4427
|
auto stream = to_stream(s);
|
|
4348
4428
|
auto qmode = string_to_quantization_mode(mode, "qqmm");
|
|
@@ -4369,7 +4449,8 @@ array qqmm(
|
|
|
4369
4449
|
}
|
|
4370
4450
|
|
|
4371
4451
|
// validate inputs
|
|
4372
|
-
validate_qqmm_inputs(
|
|
4452
|
+
validate_qqmm_inputs(
|
|
4453
|
+
x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);
|
|
4373
4454
|
// validate and extract shapes
|
|
4374
4455
|
auto [w_inner_dims, w_outer_dims] =
|
|
4375
4456
|
extract_qqmm_dims(x, w, scales_w, group_size, bits);
|
|
@@ -4380,6 +4461,11 @@ array qqmm(
|
|
|
4380
4461
|
if (scales_w.has_value()) {
|
|
4381
4462
|
inputs.push_back(*scales_w);
|
|
4382
4463
|
}
|
|
4464
|
+
if (global_scale_x.has_value() && global_scale_w.has_value()) {
|
|
4465
|
+
inputs.push_back(*global_scale_x);
|
|
4466
|
+
inputs.push_back(*global_scale_w);
|
|
4467
|
+
}
|
|
4468
|
+
|
|
4383
4469
|
auto out_shape = inputs[0].shape();
|
|
4384
4470
|
out_shape.back() = w_outer_dims;
|
|
4385
4471
|
auto out = array(
|
|
@@ -4515,6 +4601,7 @@ std::vector<array> fp_quantize(
|
|
|
4515
4601
|
int group_size,
|
|
4516
4602
|
int bits,
|
|
4517
4603
|
QuantizationMode mode,
|
|
4604
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4518
4605
|
Stream s) {
|
|
4519
4606
|
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
|
4520
4607
|
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
|
@@ -4532,6 +4619,12 @@ std::vector<array> fp_quantize(
|
|
|
4532
4619
|
<< bits << ".";
|
|
4533
4620
|
throw std::invalid_argument(msg.str());
|
|
4534
4621
|
}
|
|
4622
|
+
|
|
4623
|
+
auto inputs = std::vector<array>{w};
|
|
4624
|
+
if (global_scale.has_value()) {
|
|
4625
|
+
inputs.push_back(global_scale.value());
|
|
4626
|
+
}
|
|
4627
|
+
|
|
4535
4628
|
auto fallback = [bits = bits, group_size = group_size, s](
|
|
4536
4629
|
const std::vector<array>& inputs) -> std::vector<array> {
|
|
4537
4630
|
auto& w = inputs[0];
|
|
@@ -4543,8 +4636,13 @@ std::vector<array> fp_quantize(
|
|
|
4543
4636
|
divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
|
|
4544
4637
|
if (group_size == 16) {
|
|
4545
4638
|
// convert to e4m3
|
|
4639
|
+
auto scale_encode = inputs.size() > 1
|
|
4640
|
+
? divide(array(448.0f * 6.0f, float32), inputs[1], s)
|
|
4641
|
+
: array(1.0f, float32);
|
|
4642
|
+
scales = multiply(scales, scale_encode, s);
|
|
4546
4643
|
scales = to_fp8(scales, s);
|
|
4547
|
-
wq =
|
|
4644
|
+
wq = multiply(
|
|
4645
|
+
divide(wq, from_fp8(scales, w.dtype(), s), s), scale_encode, s);
|
|
4548
4646
|
} else {
|
|
4549
4647
|
// convert to e8m0
|
|
4550
4648
|
auto z = array(0, scales.dtype());
|
|
@@ -4600,9 +4698,9 @@ std::vector<array> fp_quantize(
|
|
|
4600
4698
|
{uint32, uint8},
|
|
4601
4699
|
std::make_shared<fast::Quantize>(
|
|
4602
4700
|
s, fallback, group_size, bits, mode, false),
|
|
4603
|
-
|
|
4701
|
+
inputs);
|
|
4604
4702
|
}
|
|
4605
|
-
return fallback(
|
|
4703
|
+
return fallback(inputs);
|
|
4606
4704
|
}
|
|
4607
4705
|
|
|
4608
4706
|
std::vector<array> quantize(
|
|
@@ -4610,6 +4708,7 @@ std::vector<array> quantize(
|
|
|
4610
4708
|
std::optional<int> group_size_ /* = std::nullopt */,
|
|
4611
4709
|
std::optional<int> bits_ /* = std::nullopt */,
|
|
4612
4710
|
const std::string& mode /* = "affine" */,
|
|
4711
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4613
4712
|
StreamOrDevice s /* = {} */) {
|
|
4614
4713
|
auto qmode = string_to_quantization_mode(mode, "quantize");
|
|
4615
4714
|
auto [group_size, bits] =
|
|
@@ -4636,11 +4735,17 @@ std::vector<array> quantize(
|
|
|
4636
4735
|
<< " matrix has shape " << w.shape();
|
|
4637
4736
|
throw std::invalid_argument(msg.str());
|
|
4638
4737
|
}
|
|
4639
|
-
|
|
4738
|
+
if (to_stream(s).device == Device::gpu && metal::is_available() &&
|
|
4739
|
+
global_scale.has_value()) {
|
|
4740
|
+
std::ostringstream msg;
|
|
4741
|
+
msg << "[quantize] Global scale is not supported on the Metal backend.";
|
|
4742
|
+
throw std::invalid_argument(msg.str());
|
|
4743
|
+
}
|
|
4744
|
+
validate_global_scale("quantize", qmode, global_scale);
|
|
4640
4745
|
if (qmode == QuantizationMode::Affine) {
|
|
4641
4746
|
return affine_quantize(w, group_size, bits, s);
|
|
4642
4747
|
} else {
|
|
4643
|
-
return fp_quantize(w, group_size, bits, qmode, to_stream(s));
|
|
4748
|
+
return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s));
|
|
4644
4749
|
}
|
|
4645
4750
|
}
|
|
4646
4751
|
|
|
@@ -4745,6 +4850,7 @@ array fp_dequantize(
|
|
|
4745
4850
|
int bits,
|
|
4746
4851
|
Dtype out_type,
|
|
4747
4852
|
QuantizationMode mode,
|
|
4853
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4748
4854
|
Stream s) {
|
|
4749
4855
|
int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
|
|
4750
4856
|
int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
|
|
@@ -4789,6 +4895,11 @@ array fp_dequantize(
|
|
|
4789
4895
|
throw std::invalid_argument(msg.str());
|
|
4790
4896
|
}
|
|
4791
4897
|
|
|
4898
|
+
auto inputs = std::vector<array>{w, scales};
|
|
4899
|
+
if (global_scale.has_value()) {
|
|
4900
|
+
inputs.push_back(global_scale.value());
|
|
4901
|
+
}
|
|
4902
|
+
|
|
4792
4903
|
auto fallback =
|
|
4793
4904
|
[wshape = std::move(wshape),
|
|
4794
4905
|
sshape = std::move(sshape),
|
|
@@ -4831,13 +4942,17 @@ array fp_dequantize(
|
|
|
4831
4942
|
out = reshape(out, {-1, group_size}, s);
|
|
4832
4943
|
scales = reshape(scales, {-1, 1}, s);
|
|
4833
4944
|
if (group_size == 16) {
|
|
4834
|
-
|
|
4945
|
+
array inv_scale_enc = inputs.size() > 2
|
|
4946
|
+
? divide(inputs[2], array(448.0f * 6.0f, out_type), s)
|
|
4947
|
+
: array(1.0f, out_type);
|
|
4948
|
+
scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s);
|
|
4835
4949
|
} else {
|
|
4836
4950
|
scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
|
|
4837
4951
|
scales = power(array(2.0f, out_type), scales, s);
|
|
4838
4952
|
}
|
|
4839
4953
|
return {reshape(multiply(out, scales, s), wshape, s)};
|
|
4840
4954
|
};
|
|
4955
|
+
|
|
4841
4956
|
if (s.device == Device::gpu) {
|
|
4842
4957
|
auto out_shape = w.shape();
|
|
4843
4958
|
out_shape.back() = out_size;
|
|
@@ -4846,9 +4961,9 @@ array fp_dequantize(
|
|
|
4846
4961
|
out_type,
|
|
4847
4962
|
std::make_shared<fast::Quantize>(
|
|
4848
4963
|
s, fallback, group_size, bits, mode, true),
|
|
4849
|
-
|
|
4964
|
+
inputs);
|
|
4850
4965
|
}
|
|
4851
|
-
return fallback(
|
|
4966
|
+
return fallback(inputs)[0];
|
|
4852
4967
|
}
|
|
4853
4968
|
|
|
4854
4969
|
array dequantize(
|
|
@@ -4858,6 +4973,7 @@ array dequantize(
|
|
|
4858
4973
|
std::optional<int> group_size_ /* = std::nullopt */,
|
|
4859
4974
|
std::optional<int> bits_ /* = std::nullopt */,
|
|
4860
4975
|
const std::string& mode /* = "affine" */,
|
|
4976
|
+
const std::optional<array>& global_scale /* = std::nullopt */,
|
|
4861
4977
|
std::optional<Dtype> dtype /* = std::nullopt */,
|
|
4862
4978
|
StreamOrDevice s /* = {} */) {
|
|
4863
4979
|
auto [out_type, qmode] =
|
|
@@ -4884,6 +5000,14 @@ array dequantize(
|
|
|
4884
5000
|
<< "but it has only " << w.ndim() << ".";
|
|
4885
5001
|
throw std::invalid_argument(msg.str());
|
|
4886
5002
|
}
|
|
5003
|
+
if (global_scale.has_value()) {
|
|
5004
|
+
if (to_stream(s).device == Device::gpu && metal::is_available()) {
|
|
5005
|
+
std::ostringstream msg;
|
|
5006
|
+
msg << "[dequantize] Global scale is not supported on the Metal backend.";
|
|
5007
|
+
throw std::invalid_argument(msg.str());
|
|
5008
|
+
}
|
|
5009
|
+
}
|
|
5010
|
+
validate_global_scale("dequantize", qmode, global_scale);
|
|
4887
5011
|
|
|
4888
5012
|
if (qmode == QuantizationMode::Affine) {
|
|
4889
5013
|
return astype(
|
|
@@ -4892,7 +5016,14 @@ array dequantize(
|
|
|
4892
5016
|
s);
|
|
4893
5017
|
} else {
|
|
4894
5018
|
return fp_dequantize(
|
|
4895
|
-
w,
|
|
5019
|
+
w,
|
|
5020
|
+
scales,
|
|
5021
|
+
group_size,
|
|
5022
|
+
bits,
|
|
5023
|
+
out_type,
|
|
5024
|
+
qmode,
|
|
5025
|
+
global_scale,
|
|
5026
|
+
to_stream(s));
|
|
4896
5027
|
}
|
|
4897
5028
|
}
|
|
4898
5029
|
|
|
@@ -6091,4 +6222,4 @@ array contiguous(
|
|
|
6091
6222
|
{a});
|
|
6092
6223
|
}
|
|
6093
6224
|
|
|
6094
|
-
} // namespace mlx::core
|
|
6225
|
+
} // namespace mlx::core
|
|
@@ -666,6 +666,12 @@ min(const array& a,
|
|
|
666
666
|
MLX_API array
|
|
667
667
|
min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});
|
|
668
668
|
|
|
669
|
+
/** Returns the Hanning window of size M. */
|
|
670
|
+
MLX_API array hanning(int M, StreamOrDevice s = {});
|
|
671
|
+
|
|
672
|
+
/** Returns the Hamming window of size M. */
|
|
673
|
+
MLX_API array hamming(int M, StreamOrDevice s = {});
|
|
674
|
+
|
|
669
675
|
/** Returns the index of the minimum value in the array. */
|
|
670
676
|
MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
671
677
|
inline array argmin(const array& a, StreamOrDevice s = {}) {
|
|
@@ -1391,6 +1397,7 @@ MLX_API std::vector<array> quantize(
|
|
|
1391
1397
|
std::optional<int> group_size = std::nullopt,
|
|
1392
1398
|
std::optional<int> bits = std::nullopt,
|
|
1393
1399
|
const std::string& mode = "affine",
|
|
1400
|
+
const std::optional<array>& global_scale = std::nullopt,
|
|
1394
1401
|
StreamOrDevice s = {});
|
|
1395
1402
|
|
|
1396
1403
|
/** Dequantize a matrix produced by quantize() */
|
|
@@ -1401,17 +1408,20 @@ MLX_API array dequantize(
|
|
|
1401
1408
|
std::optional<int> group_size = std::nullopt,
|
|
1402
1409
|
std::optional<int> bits = std::nullopt,
|
|
1403
1410
|
const std::string& mode = "affine",
|
|
1411
|
+
const std::optional<array>& global_scale = std::nullopt,
|
|
1404
1412
|
std::optional<Dtype> dtype = std::nullopt,
|
|
1405
1413
|
StreamOrDevice s = {});
|
|
1406
1414
|
|
|
1407
1415
|
MLX_API array qqmm(
|
|
1408
1416
|
array x, // input activations
|
|
1409
1417
|
array w, // maybe quantized weights
|
|
1410
|
-
std::optional<array> w_scales = std::nullopt, // optional scales if w
|
|
1411
|
-
|
|
1418
|
+
const std::optional<array> w_scales = std::nullopt, // optional scales if w
|
|
1419
|
+
// is quantized
|
|
1412
1420
|
std::optional<int> group_size = std::nullopt,
|
|
1413
1421
|
std::optional<int> bits = std::nullopt,
|
|
1414
1422
|
const std::string& mode = "nvfp4",
|
|
1423
|
+
const std::optional<array> global_scale_x = std::nullopt,
|
|
1424
|
+
const std::optional<array> global_scale_w = std::nullopt,
|
|
1415
1425
|
StreamOrDevice s = {});
|
|
1416
1426
|
|
|
1417
1427
|
/** Convert an E4M3 float8 to the given floating point dtype. */
|
|
@@ -3424,6 +3424,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
|
|
3424
3424
|
group_size_,
|
|
3425
3425
|
bits_,
|
|
3426
3426
|
quantization_mode_to_string(mode_),
|
|
3427
|
+
{}, // placeholder for amax
|
|
3427
3428
|
std::nullopt,
|
|
3428
3429
|
stream());
|
|
3429
3430
|
wq = unflatten(wq, -1, {-1, group_size_}, stream());
|
|
@@ -3484,14 +3485,14 @@ std::vector<Shape> QQMatmul::output_shapes(const std::vector<array>& inputs) {
|
|
|
3484
3485
|
}
|
|
3485
3486
|
|
|
3486
3487
|
std::vector<array> QQMatmul::vjp(
|
|
3487
|
-
const std::vector<array>& primals, // non quantized x, non quantized w
|
|
3488
|
+
const std::vector<array>& primals, // non quantized x, non quantized w, if
|
|
3489
|
+
// nvfp4 global_scale_x, global_scale_w
|
|
3488
3490
|
const std::vector<array>& cotangents, // non quantized upstream grads
|
|
3489
3491
|
const std::vector<int>& argnums,
|
|
3490
3492
|
const std::vector<array>&) {
|
|
3491
|
-
|
|
3492
|
-
|
|
3493
|
-
|
|
3494
|
-
}
|
|
3493
|
+
bool is_nvfp4 = mode_ == QuantizationMode::Nvfp4;
|
|
3494
|
+
assert(primals.size() == 2 || (is_nvfp4 && primals.size() == 4));
|
|
3495
|
+
|
|
3495
3496
|
std::vector<array> vjps;
|
|
3496
3497
|
auto& cotan = cotangents[0];
|
|
3497
3498
|
auto& s = stream();
|
|
@@ -3499,6 +3500,15 @@ std::vector<array> QQMatmul::vjp(
|
|
|
3499
3500
|
// primal[0] -- non quantized activations (M, K)
|
|
3500
3501
|
// cotan -- non quantized grads (M, N)
|
|
3501
3502
|
auto qmode = quantization_mode_to_string(mode_);
|
|
3503
|
+
std::optional<array> cotan_amax = (primals.size() == 4)
|
|
3504
|
+
? std::make_optional(astype(max(abs(cotan, s), s), float32, s))
|
|
3505
|
+
: std::nullopt;
|
|
3506
|
+
|
|
3507
|
+
auto get_primal_scale = [&](int idx) {
|
|
3508
|
+
return (primals.size() == 4) ? std::make_optional(primals[idx])
|
|
3509
|
+
: std::nullopt;
|
|
3510
|
+
};
|
|
3511
|
+
|
|
3502
3512
|
for (auto arg : argnums) {
|
|
3503
3513
|
if (arg == 0) { // gradient wrt to x
|
|
3504
3514
|
// We transpose weights -> quantize along N
|
|
@@ -3509,6 +3519,8 @@ std::vector<array> QQMatmul::vjp(
|
|
|
3509
3519
|
group_size_,
|
|
3510
3520
|
bits_,
|
|
3511
3521
|
qmode,
|
|
3522
|
+
cotan_amax,
|
|
3523
|
+
get_primal_scale(3), // global_scale_w (for w.T)
|
|
3512
3524
|
s));
|
|
3513
3525
|
} else if (arg == 1) { // gradient wrt to weights
|
|
3514
3526
|
vjps.push_back(qqmm(
|
|
@@ -3518,7 +3530,11 @@ std::vector<array> QQMatmul::vjp(
|
|
|
3518
3530
|
group_size_,
|
|
3519
3531
|
bits_,
|
|
3520
3532
|
qmode,
|
|
3533
|
+
cotan_amax,
|
|
3534
|
+
get_primal_scale(2), // global_scale_x (for x.T)
|
|
3521
3535
|
s));
|
|
3536
|
+
} else {
|
|
3537
|
+
vjps.push_back(zeros_like(primals[arg], s));
|
|
3522
3538
|
}
|
|
3523
3539
|
}
|
|
3524
3540
|
return vjps;
|
|
@@ -3643,6 +3659,7 @@ std::vector<array> GatherQMM::vjp(
|
|
|
3643
3659
|
bits_,
|
|
3644
3660
|
quantization_mode_to_string(mode_),
|
|
3645
3661
|
std::nullopt,
|
|
3662
|
+
std::nullopt, // amax placeholder
|
|
3646
3663
|
stream()),
|
|
3647
3664
|
-1,
|
|
3648
3665
|
{-1, group_size_},
|
|
@@ -26,6 +26,10 @@ Stream get_stream(int index) {
|
|
|
26
26
|
return scheduler::scheduler().get_stream(index);
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
+
std::vector<Stream> get_streams() {
|
|
30
|
+
return scheduler::scheduler().get_streams();
|
|
31
|
+
}
|
|
32
|
+
|
|
29
33
|
Stream new_stream(Device d) {
|
|
30
34
|
if (!gpu::is_available() && d == Device::gpu) {
|
|
31
35
|
throw std::invalid_argument(
|
|
@@ -99,6 +99,9 @@ class Scheduler {
|
|
|
99
99
|
Stream get_stream(int index) const {
|
|
100
100
|
return streams_.at(index);
|
|
101
101
|
}
|
|
102
|
+
std::vector<Stream> get_streams() const {
|
|
103
|
+
return streams_;
|
|
104
|
+
}
|
|
102
105
|
|
|
103
106
|
void set_default_stream(const Stream& s) {
|
|
104
107
|
default_streams_.at(s.device.type) = s;
|