mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cuda.h>
|
|
4
|
+
#include <cuda_fp4.h>
|
|
5
|
+
#include <cuda_runtime.h>
|
|
6
|
+
#include "mlx/backend/cuda/vector_types.cuh"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core::cu {
|
|
9
|
+
|
|
10
|
+
using bf16x4 = Vector4_t<__nv_bfloat16>;
|
|
11
|
+
using fp16x4 = Vector4_t<__half>;
|
|
12
|
+
using f32x4 = Vector4_t<float>;
|
|
13
|
+
|
|
14
|
+
template <typename T>
|
|
15
|
+
__device__ __forceinline__ uint16_t
|
|
16
|
+
scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
|
|
17
|
+
// Fallback implementation for architectures that do not support cvt
|
|
18
|
+
// instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
|
|
19
|
+
uint16_t out_fp4x4 = 0;
|
|
20
|
+
fp32x4 scaled;
|
|
21
|
+
scaled.x = static_cast<float>(input.x) * scale;
|
|
22
|
+
scaled.y = static_cast<float>(input.y) * scale;
|
|
23
|
+
scaled.z = static_cast<float>(input.z) * scale;
|
|
24
|
+
scaled.w = static_cast<float>(input.w) * scale;
|
|
25
|
+
uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
|
|
26
|
+
uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
|
|
27
|
+
uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
|
|
28
|
+
uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
|
|
29
|
+
out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
|
|
30
|
+
(static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
|
|
31
|
+
static_cast<uint16_t>(q0);
|
|
32
|
+
return out_fp4x4;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
|
36
|
+
defined(__CUDA_ARCH_SPECIFIC__)
|
|
37
|
+
|
|
38
|
+
__device__ __forceinline__ uint16_t
|
|
39
|
+
scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
|
|
40
|
+
uint16_t out_fp4x4 = 0;
|
|
41
|
+
asm volatile(
|
|
42
|
+
"{\n"
|
|
43
|
+
".reg.b16 x0_bf16; \n\t" // first bf16
|
|
44
|
+
".reg.b16 x1_bf16; \n\t" // second bf16
|
|
45
|
+
".reg.b16 x2_bf16; \n\t" // third bf16
|
|
46
|
+
".reg.b16 x3_bf16; \n\t" // fourth bf16
|
|
47
|
+
".reg.b32 x0; \n\t" // to hold scaled first
|
|
48
|
+
".reg.b32 x1; \n\t" // to hold scaled second
|
|
49
|
+
".reg.b32 x2; \n\t" // to hold scaled third
|
|
50
|
+
".reg.b32 x3; \n\t" // to hold scaled fourth
|
|
51
|
+
".reg.b64 x01; \n\t" // to hold vector mul
|
|
52
|
+
".reg.b64 x23; \n\t"
|
|
53
|
+
".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
|
|
54
|
+
".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
|
|
55
|
+
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
|
|
56
|
+
"cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
|
|
57
|
+
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
|
58
|
+
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
|
59
|
+
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
|
60
|
+
"mov.b64 x01, {x0, x1}; \n\t"
|
|
61
|
+
"mul.f32x2 x01, x01, %2; \n\t" // scale first pair
|
|
62
|
+
"mov.b64 x23, {x2, x3}; \n\t"
|
|
63
|
+
"mul.f32x2 x23, x23, %2; \n\t" // scale second pair
|
|
64
|
+
"mov.b64 {x0, x1}, x01; \n\t"
|
|
65
|
+
"mov.b64 {x2, x3}, x23; \n\t"
|
|
66
|
+
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
|
|
67
|
+
// pair
|
|
68
|
+
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
|
|
69
|
+
// pair
|
|
70
|
+
"mov.b16 %0, {q0, q1}; \n\t" // pack to output
|
|
71
|
+
"}"
|
|
72
|
+
: "=h"(out_fp4x4)
|
|
73
|
+
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
|
74
|
+
"l"(reinterpret_cast<const uint64_t&>(
|
|
75
|
+
scale))); // here cast is needed becuase an asm operand must have
|
|
76
|
+
// scalar type
|
|
77
|
+
return out_fp4x4;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
|
|
81
|
+
const bf16x4 input_bf16x4,
|
|
82
|
+
const float2 scale,
|
|
83
|
+
uint32_t rbits) {
|
|
84
|
+
uint16_t out_fp4x4 = 0;
|
|
85
|
+
asm volatile(
|
|
86
|
+
"{\n"
|
|
87
|
+
".reg.b16 x0_bf16; \n\t"
|
|
88
|
+
".reg.b16 x1_bf16; \n\t"
|
|
89
|
+
".reg.b16 x2_bf16; \n\t"
|
|
90
|
+
".reg.b16 x3_bf16; \n\t"
|
|
91
|
+
".reg.b32 x0; \n\t"
|
|
92
|
+
".reg.b32 x1; \n\t"
|
|
93
|
+
".reg.b32 x2; \n\t"
|
|
94
|
+
".reg.b32 x3; \n\t"
|
|
95
|
+
".reg.b64 x01; \n\t"
|
|
96
|
+
".reg.b64 x23; \n\t"
|
|
97
|
+
".reg.b16 q0; \n\t"
|
|
98
|
+
"mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
|
|
99
|
+
"cvt.f32.bf16 x0, x0_bf16; \n\t"
|
|
100
|
+
"cvt.f32.bf16 x1, x1_bf16; \n\t"
|
|
101
|
+
"cvt.f32.bf16 x2, x2_bf16; \n\t"
|
|
102
|
+
"cvt.f32.bf16 x3, x3_bf16; \n\t"
|
|
103
|
+
"mov.b64 x01, {x0, x1}; \n\t"
|
|
104
|
+
"mul.f32x2 x01, x01, %2; \n\t"
|
|
105
|
+
"mov.b64 x23, {x2, x3}; \n\t"
|
|
106
|
+
"mul.f32x2 x23, x23, %2; \n\t"
|
|
107
|
+
"mov.b64 {x0, x1}, x01; \n\t"
|
|
108
|
+
"mov.b64 {x2, x3}, x23; \n\t"
|
|
109
|
+
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
|
110
|
+
"}"
|
|
111
|
+
: "=h"(out_fp4x4)
|
|
112
|
+
: "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
|
|
113
|
+
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
|
114
|
+
"r"(rbits));
|
|
115
|
+
return out_fp4x4;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
|
|
119
|
+
const float2 input_fp32x2_0,
|
|
120
|
+
const float2 input_fp32x2_1,
|
|
121
|
+
const float2 scale) {
|
|
122
|
+
uint16_t out_fp4x4 = 0;
|
|
123
|
+
asm volatile(
|
|
124
|
+
"{\n"
|
|
125
|
+
".reg.b32 x0; \n\t"
|
|
126
|
+
".reg.b32 x1; \n\t"
|
|
127
|
+
".reg.b32 x2; \n\t"
|
|
128
|
+
".reg.b32 x3; \n\t"
|
|
129
|
+
".reg.b64 x01; \n\t"
|
|
130
|
+
".reg.b64 x23; \n\t"
|
|
131
|
+
".reg.b8 q0; \n\t"
|
|
132
|
+
".reg.b8 q1; \n\t"
|
|
133
|
+
"mov.b64 x01, {%1, %2}; \n\t"
|
|
134
|
+
"mul.f32x2 x01, x01, %5; \n\t"
|
|
135
|
+
"mov.b64 x23, {%3, %4}; \n\t"
|
|
136
|
+
"mul.f32x2 x23, x23, %5; \n\t"
|
|
137
|
+
"mov.b64 {x0, x1}, x01; \n\t"
|
|
138
|
+
"mov.b64 {x2, x3}, x23; \n\t"
|
|
139
|
+
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
|
140
|
+
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
|
141
|
+
"mov.b16 %0, {q0, q1}; \n\t"
|
|
142
|
+
"}"
|
|
143
|
+
: "=h"(out_fp4x4)
|
|
144
|
+
: "f"(input_fp32x2_0.x),
|
|
145
|
+
"f"(input_fp32x2_0.y),
|
|
146
|
+
"f"(input_fp32x2_1.x),
|
|
147
|
+
"f"(input_fp32x2_1.y),
|
|
148
|
+
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
|
149
|
+
return out_fp4x4;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
__device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
|
|
153
|
+
const float2 input_fp32x2_0,
|
|
154
|
+
const float2 input_fp32x2_1,
|
|
155
|
+
const float2 scale,
|
|
156
|
+
uint32_t rbits) {
|
|
157
|
+
uint16_t out_fp4x4 = 0;
|
|
158
|
+
asm volatile(
|
|
159
|
+
"{\n"
|
|
160
|
+
".reg.b32 x0; \n\t"
|
|
161
|
+
".reg.b32 x1; \n\t"
|
|
162
|
+
".reg.b32 x2; \n\t"
|
|
163
|
+
".reg.b32 x3; \n\t"
|
|
164
|
+
".reg.b64 x01; \n\t"
|
|
165
|
+
".reg.b64 x23; \n\t"
|
|
166
|
+
".reg.b16 q0; \n\t"
|
|
167
|
+
"mov.b64 x01, {%1, %2}; \n\t"
|
|
168
|
+
"mul.f32x2 x01, x01, %5; \n\t"
|
|
169
|
+
"mov.b64 x23, {%3, %4}; \n\t"
|
|
170
|
+
"mul.f32x2 x23, x23, %5; \n\t"
|
|
171
|
+
"mov.b64 {x0, x1}, x01; \n\t"
|
|
172
|
+
"mov.b64 {x2, x3}, x23; \n\t"
|
|
173
|
+
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
|
|
174
|
+
"}"
|
|
175
|
+
: "=h"(out_fp4x4)
|
|
176
|
+
: "f"(input_fp32x2_0.x),
|
|
177
|
+
"f"(input_fp32x2_0.y),
|
|
178
|
+
"f"(input_fp32x2_1.x),
|
|
179
|
+
"f"(input_fp32x2_1.y),
|
|
180
|
+
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
|
181
|
+
"r"(rbits));
|
|
182
|
+
return out_fp4x4;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
__device__ __forceinline__ uint16_t
|
|
186
|
+
scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
|
|
187
|
+
uint16_t out_fp4x4 = 0;
|
|
188
|
+
asm volatile(
|
|
189
|
+
"{\n"
|
|
190
|
+
".reg.b16 x0_fp16; \n\t"
|
|
191
|
+
".reg.b16 x1_fp16; \n\t"
|
|
192
|
+
".reg.b16 x2_fp16; \n\t"
|
|
193
|
+
".reg.b16 x3_fp16; \n\t"
|
|
194
|
+
".reg.b32 x0; \n\t"
|
|
195
|
+
".reg.b32 x1; \n\t"
|
|
196
|
+
".reg.b32 x2; \n\t"
|
|
197
|
+
".reg.b32 x3; \n\t"
|
|
198
|
+
".reg.b64 x01; \n\t"
|
|
199
|
+
".reg.b64 x23; \n\t"
|
|
200
|
+
".reg.b8 q0; \n\t"
|
|
201
|
+
".reg.b8 q1; \n\t"
|
|
202
|
+
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
|
203
|
+
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
|
204
|
+
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
|
205
|
+
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
|
206
|
+
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
|
207
|
+
"mov.b64 x01, {x0, x1}; \n\t"
|
|
208
|
+
"mul.f32x2 x01, x01, %2; \n\t"
|
|
209
|
+
"mov.b64 x23, {x2, x3}; \n\t"
|
|
210
|
+
"mul.f32x2 x23, x23, %2; \n\t"
|
|
211
|
+
"mov.b64 {x0, x1}, x01; \n\t"
|
|
212
|
+
"mov.b64 {x2, x3}, x23; \n\t"
|
|
213
|
+
"cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
|
|
214
|
+
"cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
|
|
215
|
+
"mov.b16 %0, {q0, q1}; \n\t"
|
|
216
|
+
"}"
|
|
217
|
+
: "=h"(out_fp4x4)
|
|
218
|
+
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
|
219
|
+
"l"(reinterpret_cast<const uint64_t&>(scale)));
|
|
220
|
+
return out_fp4x4;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
|
|
224
|
+
const fp16x4 input_fp16x4,
|
|
225
|
+
const float2 scale,
|
|
226
|
+
uint32_t rbits) {
|
|
227
|
+
uint16_t out_fp4x4 = 0;
|
|
228
|
+
asm volatile(
|
|
229
|
+
"{\n"
|
|
230
|
+
".reg.b16 x0_fp16; \n\t"
|
|
231
|
+
".reg.b16 x1_fp16; \n\t"
|
|
232
|
+
".reg.b16 x2_fp16; \n\t"
|
|
233
|
+
".reg.b16 x3_fp16; \n\t"
|
|
234
|
+
".reg.b32 x0; \n\t"
|
|
235
|
+
".reg.b32 x1; \n\t"
|
|
236
|
+
".reg.b32 x2; \n\t"
|
|
237
|
+
".reg.b32 x3; \n\t"
|
|
238
|
+
".reg.b64 x01; \n\t"
|
|
239
|
+
".reg.b64 x23; \n\t"
|
|
240
|
+
".reg.b16 q0; \n\t"
|
|
241
|
+
"mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
|
|
242
|
+
"cvt.f32.f16 x0, x0_fp16; \n\t"
|
|
243
|
+
"cvt.f32.f16 x1, x1_fp16; \n\t"
|
|
244
|
+
"cvt.f32.f16 x2, x2_fp16; \n\t"
|
|
245
|
+
"cvt.f32.f16 x3, x3_fp16; \n\t"
|
|
246
|
+
"mov.b64 x01, {x0, x1}; \n\t"
|
|
247
|
+
"mul.f32x2 x01, x01, %2; \n\t"
|
|
248
|
+
"mov.b64 x23, {x2, x3}; \n\t"
|
|
249
|
+
"mul.f32x2 x23, x23, %2; \n\t"
|
|
250
|
+
"mov.b64 {x0, x1}, x01; \n\t"
|
|
251
|
+
"mov.b64 {x2, x3}, x23; \n\t"
|
|
252
|
+
"cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
|
|
253
|
+
"}"
|
|
254
|
+
: "=h"(out_fp4x4)
|
|
255
|
+
: "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
|
|
256
|
+
"l"(reinterpret_cast<const uint64_t&>(scale)),
|
|
257
|
+
"r"(rbits));
|
|
258
|
+
return out_fp4x4;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
template <bool USE_SR>
|
|
262
|
+
__device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
|
|
263
|
+
const bf16x4 input,
|
|
264
|
+
const float scale,
|
|
265
|
+
uint32_t rbits) {
|
|
266
|
+
float2 scale_fp32x2 = make_float2(scale, scale);
|
|
267
|
+
if constexpr (USE_SR) {
|
|
268
|
+
return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
|
269
|
+
} else {
|
|
270
|
+
return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
template <bool USE_SR>
|
|
275
|
+
__device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
|
|
276
|
+
const fp16x4 input,
|
|
277
|
+
const float scale,
|
|
278
|
+
uint32_t rbits) {
|
|
279
|
+
float2 scale_fp32x2 = make_float2(scale, scale);
|
|
280
|
+
if constexpr (USE_SR) {
|
|
281
|
+
return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
|
|
282
|
+
} else {
|
|
283
|
+
return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
template <bool USE_SR>
|
|
288
|
+
__device__ __forceinline__ uint16_t
|
|
289
|
+
scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
|
|
290
|
+
float2 scale_fp32x2 = make_float2(scale, scale);
|
|
291
|
+
float2 input_fp32x2_0 = make_float2(input.x, input.y);
|
|
292
|
+
float2 input_fp32x2_1 = make_float2(input.z, input.w);
|
|
293
|
+
|
|
294
|
+
if constexpr (USE_SR) {
|
|
295
|
+
return scale_cvt_fp32x4_to_fp4x4_rs(
|
|
296
|
+
input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
|
|
297
|
+
} else {
|
|
298
|
+
return scale_cvt_fp32x4_to_fp4x4_rn(
|
|
299
|
+
input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
template <typename T, bool USE_SR>
|
|
304
|
+
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
|
|
305
|
+
const Vector4_t<T> input,
|
|
306
|
+
const float scale,
|
|
307
|
+
uint32_t rbits) {
|
|
308
|
+
if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
|
309
|
+
return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
|
310
|
+
} else if constexpr (std::is_same<T, __half>::value) {
|
|
311
|
+
return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
|
312
|
+
} else {
|
|
313
|
+
return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
|
|
317
|
+
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
|
318
|
+
|
|
319
|
+
template <typename T, bool USE_SR>
|
|
320
|
+
__device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
|
|
321
|
+
const Vector4_t<T> input,
|
|
322
|
+
const float scale,
|
|
323
|
+
uint32_t rbits) {
|
|
324
|
+
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
|
|
325
|
+
(__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
|
|
326
|
+
return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
|
|
327
|
+
#else
|
|
328
|
+
static_assert(
|
|
329
|
+
!USE_SR,
|
|
330
|
+
"Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
|
|
331
|
+
return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
|
|
332
|
+
#endif
|
|
333
|
+
}
|
|
334
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/cuda/device/utils.cuh"
|
|
4
|
+
#include "mlx/backend/cuda/kernel_utils.cuh"
|
|
5
|
+
#include "mlx/backend/cuda/quantized/qmv.h"
|
|
6
|
+
#include "mlx/backend/cuda/quantized/quantized_utils.cuh"
|
|
7
|
+
#include "mlx/dtype_utils.h"
|
|
8
|
+
|
|
9
|
+
#include <cooperative_groups.h>
|
|
10
|
+
#include <cooperative_groups/reduce.h>
|
|
11
|
+
|
|
12
|
+
namespace mlx::core::cu {
|
|
13
|
+
|
|
14
|
+
namespace cg = cooperative_groups;
|
|
15
|
+
|
|
16
|
+
static constexpr int rows_per_block = 8;
|
|
17
|
+
|
|
18
|
+
template <typename T>
|
|
19
|
+
__device__ void adjust_matrix_offsets(
|
|
20
|
+
const T*& x,
|
|
21
|
+
const uint32_t*& w,
|
|
22
|
+
const uint8_t*& scales,
|
|
23
|
+
T*& y,
|
|
24
|
+
int output_stride,
|
|
25
|
+
const int& x_batch_ndims,
|
|
26
|
+
const Shape x_shape,
|
|
27
|
+
const Strides x_strides,
|
|
28
|
+
const int& w_batch_ndims,
|
|
29
|
+
const Shape w_shape,
|
|
30
|
+
const Strides w_strides,
|
|
31
|
+
const Strides s_strides) {
|
|
32
|
+
uint32_t idx = cg::this_grid().block_index().z;
|
|
33
|
+
if (x_batch_ndims == 1) {
|
|
34
|
+
x += idx * x_strides[0];
|
|
35
|
+
} else {
|
|
36
|
+
x += elem_to_loc(idx, x_shape.data(), x_strides.data(), x_batch_ndims);
|
|
37
|
+
}
|
|
38
|
+
if (w_batch_ndims == 1) {
|
|
39
|
+
w += idx * w_strides[0];
|
|
40
|
+
scales += idx * s_strides[0];
|
|
41
|
+
} else {
|
|
42
|
+
auto [w_idx, s_idx] = elem_to_loc(
|
|
43
|
+
idx, w_shape.data(), w_strides.data(), s_strides.data(), w_batch_ndims);
|
|
44
|
+
w += w_idx;
|
|
45
|
+
scales += s_idx;
|
|
46
|
+
}
|
|
47
|
+
y += idx * output_stride;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
template <
|
|
51
|
+
typename T,
|
|
52
|
+
int rows_per_block,
|
|
53
|
+
int n_per_thread,
|
|
54
|
+
int bits,
|
|
55
|
+
int group_size,
|
|
56
|
+
bool use_mx_scale>
|
|
57
|
+
__device__ void fp_qmv_impl(
|
|
58
|
+
const uint32_t* mat,
|
|
59
|
+
const uint8_t* scales_,
|
|
60
|
+
const T* vec,
|
|
61
|
+
T* out,
|
|
62
|
+
int rows,
|
|
63
|
+
int cols) {
|
|
64
|
+
auto block = cg::this_thread_block();
|
|
65
|
+
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
|
66
|
+
|
|
67
|
+
constexpr int vals_per_item = bits == 8 ? 4 : 8;
|
|
68
|
+
constexpr int nv_per_thread = vals_per_item * n_per_thread;
|
|
69
|
+
auto g_idx = block.group_index();
|
|
70
|
+
auto t_idx = block.thread_index();
|
|
71
|
+
int row = g_idx.y * rows_per_block + t_idx.y;
|
|
72
|
+
|
|
73
|
+
vec += g_idx.x * cols;
|
|
74
|
+
out += g_idx.x * rows;
|
|
75
|
+
|
|
76
|
+
using ScaleType =
|
|
77
|
+
std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
|
|
78
|
+
auto scales = (ScaleType*)(scales_);
|
|
79
|
+
auto packed_cols = cols / vals_per_item;
|
|
80
|
+
|
|
81
|
+
if (row < rows) {
|
|
82
|
+
constexpr int scales_per_step = std::max(nv_per_thread / group_size, 1);
|
|
83
|
+
constexpr int scale_step = (WARP_SIZE * nv_per_thread) / group_size;
|
|
84
|
+
constexpr int n_per_step = n_per_thread / scales_per_step;
|
|
85
|
+
// Offset scales to correct row
|
|
86
|
+
scales += row * (cols / group_size) +
|
|
87
|
+
(warp.thread_rank() * nv_per_thread) / group_size;
|
|
88
|
+
float sum = 0.0f;
|
|
89
|
+
for (int col = n_per_thread * warp.thread_rank(); col < packed_cols;
|
|
90
|
+
col += (WARP_SIZE * n_per_thread)) {
|
|
91
|
+
auto local_vec =
|
|
92
|
+
unsafe_load_vector<nv_per_thread>(vec + vals_per_item * col, 0);
|
|
93
|
+
auto local_mat =
|
|
94
|
+
unsafe_load_vector<n_per_thread>(mat + row * packed_cols + col, 0);
|
|
95
|
+
#pragma unroll
|
|
96
|
+
for (int i = 0; i < scales_per_step; ++i) {
|
|
97
|
+
float2 local_sum = {0.0f, 0.0f};
|
|
98
|
+
#pragma unroll
|
|
99
|
+
for (int j = 0; j < n_per_step; ++j) {
|
|
100
|
+
int k = n_per_step * i + j;
|
|
101
|
+
if constexpr (bits == 8) {
|
|
102
|
+
auto v = dequant_fp8(local_mat[k]);
|
|
103
|
+
local_sum.x +=
|
|
104
|
+
v.x * static_cast<float>(local_vec[vals_per_item * k]);
|
|
105
|
+
local_sum.x +=
|
|
106
|
+
v.y * static_cast<float>(local_vec[vals_per_item * k + 1]);
|
|
107
|
+
local_sum.y +=
|
|
108
|
+
v.z * static_cast<float>(local_vec[vals_per_item * k + 2]);
|
|
109
|
+
local_sum.y +=
|
|
110
|
+
v.w * static_cast<float>(local_vec[vals_per_item * k + 3]);
|
|
111
|
+
} else {
|
|
112
|
+
auto v = dequant_fp4(local_mat[k]);
|
|
113
|
+
local_sum.x +=
|
|
114
|
+
v.x * static_cast<float>(local_vec[vals_per_item * k]);
|
|
115
|
+
local_sum.y +=
|
|
116
|
+
v.y * static_cast<float>(local_vec[vals_per_item * k + 1]);
|
|
117
|
+
local_sum.x +=
|
|
118
|
+
v.z * static_cast<float>(local_vec[vals_per_item * k + 2]);
|
|
119
|
+
local_sum.y +=
|
|
120
|
+
v.w * static_cast<float>(local_vec[vals_per_item * k + 3]);
|
|
121
|
+
|
|
122
|
+
v = dequant_fp4(local_mat[k] >> 16);
|
|
123
|
+
local_sum.x +=
|
|
124
|
+
v.x * static_cast<float>(local_vec[vals_per_item * k + 4]);
|
|
125
|
+
local_sum.y +=
|
|
126
|
+
v.y * static_cast<float>(local_vec[vals_per_item * k + 5]);
|
|
127
|
+
local_sum.x +=
|
|
128
|
+
v.z * static_cast<float>(local_vec[vals_per_item * k + 6]);
|
|
129
|
+
local_sum.y +=
|
|
130
|
+
v.w * static_cast<float>(local_vec[vals_per_item * k + 7]);
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
sum += (local_sum.x + local_sum.y) * float(scales[i]);
|
|
134
|
+
}
|
|
135
|
+
scales += scale_step;
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
sum = cg::reduce(warp, sum, cg::plus<float>{});
|
|
139
|
+
if (warp.thread_rank() == 0) {
|
|
140
|
+
out[row] = static_cast<T>(sum);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
template <
|
|
146
|
+
typename T,
|
|
147
|
+
int rows_per_block,
|
|
148
|
+
int n_per_thread,
|
|
149
|
+
int bits,
|
|
150
|
+
int group_size,
|
|
151
|
+
bool use_mx_scale>
|
|
152
|
+
__global__ void fp_qmv_single(
|
|
153
|
+
const uint32_t* mat,
|
|
154
|
+
const uint8_t* scales,
|
|
155
|
+
const T* vec,
|
|
156
|
+
T* out,
|
|
157
|
+
int rows,
|
|
158
|
+
int cols) {
|
|
159
|
+
fp_qmv_impl<T, rows_per_block, n_per_thread, bits, group_size, use_mx_scale>(
|
|
160
|
+
mat, scales, vec, out, rows, cols);
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
template <
|
|
164
|
+
typename T,
|
|
165
|
+
int rows_per_block,
|
|
166
|
+
int n_per_thread,
|
|
167
|
+
int bits,
|
|
168
|
+
int group_size,
|
|
169
|
+
bool use_mx_scale>
|
|
170
|
+
__global__ void fp_qmv_batched(
|
|
171
|
+
const uint32_t* mat,
|
|
172
|
+
const uint8_t* scales,
|
|
173
|
+
const T* vec,
|
|
174
|
+
T* out,
|
|
175
|
+
int rows,
|
|
176
|
+
int cols,
|
|
177
|
+
int vec_batch_ndims,
|
|
178
|
+
const __grid_constant__ Shape vec_shape,
|
|
179
|
+
const __grid_constant__ Strides vec_strides,
|
|
180
|
+
int mat_batch_ndims,
|
|
181
|
+
const __grid_constant__ Shape mat_shape,
|
|
182
|
+
const __grid_constant__ Strides mat_strides,
|
|
183
|
+
const __grid_constant__ Strides scales_strides) {
|
|
184
|
+
adjust_matrix_offsets<T>(
|
|
185
|
+
vec,
|
|
186
|
+
mat,
|
|
187
|
+
scales,
|
|
188
|
+
out,
|
|
189
|
+
rows * vec_shape[vec_batch_ndims],
|
|
190
|
+
vec_batch_ndims,
|
|
191
|
+
vec_shape,
|
|
192
|
+
vec_strides,
|
|
193
|
+
mat_batch_ndims,
|
|
194
|
+
mat_shape,
|
|
195
|
+
mat_strides,
|
|
196
|
+
scales_strides);
|
|
197
|
+
fp_qmv_impl<T, rows_per_block, n_per_thread, bits, group_size, use_mx_scale>(
|
|
198
|
+
mat, scales, vec, out, rows, cols);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
template <typename F>
|
|
202
|
+
void dispatch_1_2_4(int n, F&& f) {
|
|
203
|
+
switch (n) {
|
|
204
|
+
case 1:
|
|
205
|
+
f(std::integral_constant<int, 1>{});
|
|
206
|
+
break;
|
|
207
|
+
case 2:
|
|
208
|
+
f(std::integral_constant<int, 2>{});
|
|
209
|
+
break;
|
|
210
|
+
case 4:
|
|
211
|
+
f(std::integral_constant<int, 4>{});
|
|
212
|
+
break;
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
void fp_qmv(
|
|
217
|
+
const array& mat,
|
|
218
|
+
const array& scales,
|
|
219
|
+
const array& vec,
|
|
220
|
+
array& out,
|
|
221
|
+
int bits,
|
|
222
|
+
int group_size,
|
|
223
|
+
int M,
|
|
224
|
+
int N,
|
|
225
|
+
int K,
|
|
226
|
+
CommandEncoder& encoder) {
|
|
227
|
+
encoder.set_input_array(mat);
|
|
228
|
+
encoder.set_input_array(scales);
|
|
229
|
+
encoder.set_input_array(vec);
|
|
230
|
+
encoder.set_output_array(out);
|
|
231
|
+
dispatch_float_types(out.dtype(), "qmv", [&](auto type_tag) {
|
|
232
|
+
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
|
233
|
+
if constexpr (!std::is_same_v<T, double>) {
|
|
234
|
+
dim3 block_dims{WARP_SIZE, rows_per_block};
|
|
235
|
+
uint32_t B = out.size() / (M * N);
|
|
236
|
+
uint32_t blocks_y = (N + rows_per_block - 1) / rows_per_block;
|
|
237
|
+
const uint32_t* mat_ptr = gpu_ptr<uint32_t>(mat);
|
|
238
|
+
const T* vec_ptr = gpu_ptr<T>(vec);
|
|
239
|
+
int n = 1;
|
|
240
|
+
if (K % 32 == 0 && cu::is_aligned<4>(mat_ptr) &&
|
|
241
|
+
((bits == 4 && cu::is_aligned<8>(vec_ptr)) ||
|
|
242
|
+
cu::is_aligned<4>(vec_ptr))) {
|
|
243
|
+
n = 4;
|
|
244
|
+
} else if (
|
|
245
|
+
cu::is_aligned<2>(mat_ptr) &&
|
|
246
|
+
((bits == 4 && cu::is_aligned<4>(vec_ptr)) ||
|
|
247
|
+
cu::is_aligned<2>(vec_ptr))) {
|
|
248
|
+
n = 2;
|
|
249
|
+
}
|
|
250
|
+
dispatch_1_2_4(n, [&](auto n) {
|
|
251
|
+
dispatch_bool(B > 1, [&](auto batched) {
|
|
252
|
+
if (!batched.value) {
|
|
253
|
+
auto kernel =
|
|
254
|
+
fp_qmv_single<T, rows_per_block, n.value, 4, 32, true>;
|
|
255
|
+
if (bits == 8) {
|
|
256
|
+
kernel = fp_qmv_single<T, rows_per_block, n.value, 8, 32, true>;
|
|
257
|
+
} else if (group_size == 16) {
|
|
258
|
+
kernel = fp_qmv_single<T, rows_per_block, n.value, 4, 16, false>;
|
|
259
|
+
}
|
|
260
|
+
encoder.add_kernel_node(
|
|
261
|
+
kernel,
|
|
262
|
+
{static_cast<uint32_t>(M), blocks_y},
|
|
263
|
+
block_dims,
|
|
264
|
+
0,
|
|
265
|
+
mat_ptr,
|
|
266
|
+
gpu_ptr<uint8_t>(scales),
|
|
267
|
+
vec_ptr,
|
|
268
|
+
gpu_ptr<T>(out),
|
|
269
|
+
N,
|
|
270
|
+
K);
|
|
271
|
+
} else {
|
|
272
|
+
auto kernel =
|
|
273
|
+
fp_qmv_batched<T, rows_per_block, n.value, 4, 32, true>;
|
|
274
|
+
if (bits == 8) {
|
|
275
|
+
kernel = fp_qmv_batched<T, rows_per_block, n.value, 8, 32, true>;
|
|
276
|
+
} else if (group_size == 16) {
|
|
277
|
+
kernel = fp_qmv_batched<T, rows_per_block, n.value, 4, 16, false>;
|
|
278
|
+
}
|
|
279
|
+
encoder.add_kernel_node(
|
|
280
|
+
kernel,
|
|
281
|
+
{static_cast<uint32_t>(M), blocks_y, B},
|
|
282
|
+
block_dims,
|
|
283
|
+
0,
|
|
284
|
+
mat_ptr,
|
|
285
|
+
gpu_ptr<uint8_t>(scales),
|
|
286
|
+
vec_ptr,
|
|
287
|
+
gpu_ptr<T>(out),
|
|
288
|
+
N,
|
|
289
|
+
K,
|
|
290
|
+
vec.ndim() - 2,
|
|
291
|
+
const_param(vec.shape()),
|
|
292
|
+
const_param(vec.strides()),
|
|
293
|
+
mat.ndim() - 2,
|
|
294
|
+
const_param(mat.shape()),
|
|
295
|
+
const_param(mat.strides()),
|
|
296
|
+
const_param(scales.strides()));
|
|
297
|
+
}
|
|
298
|
+
});
|
|
299
|
+
});
|
|
300
|
+
}
|
|
301
|
+
});
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cuda/device.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core::cu {
|
|
8
|
+
|
|
9
|
+
void fp_qmv(
|
|
10
|
+
const array& w,
|
|
11
|
+
const array& scales,
|
|
12
|
+
const array& vec,
|
|
13
|
+
array& out,
|
|
14
|
+
int bits,
|
|
15
|
+
int group_size,
|
|
16
|
+
int M,
|
|
17
|
+
int N,
|
|
18
|
+
int K,
|
|
19
|
+
CommandEncoder& encoder);
|
|
20
|
+
|
|
21
|
+
} // namespace mlx::core::cu
|