@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -90,7 +90,7 @@ struct tile_x_sizes {
|
|
|
90
90
|
};
|
|
91
91
|
|
|
92
92
|
static int get_mmq_x_max_host(const int cc) {
|
|
93
|
-
return new_mma_available(cc) ? 128 :
|
|
93
|
+
return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 :
|
|
94
94
|
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
|
|
95
95
|
#ifdef GGML_CUDA_FORCE_MMQ
|
|
96
96
|
128 : 64;
|
|
@@ -100,12 +100,12 @@ static int get_mmq_x_max_host(const int cc) {
|
|
|
100
100
|
}
|
|
101
101
|
|
|
102
102
|
static constexpr __device__ int get_mmq_x_max_device() {
|
|
103
|
-
#
|
|
103
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
104
104
|
return 128;
|
|
105
|
-
#else // NEW_MMA_AVAILABLE
|
|
105
|
+
#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
106
106
|
|
|
107
107
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
108
|
-
return
|
|
108
|
+
return 64;
|
|
109
109
|
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
110
110
|
|
|
111
111
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
@@ -115,12 +115,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
|
115
115
|
return MMQ_DP4A_MAX_BATCH_SIZE;
|
|
116
116
|
#endif // GGML_CUDA_FORCE_MMQ
|
|
117
117
|
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
118
|
-
|
|
119
118
|
return 64;
|
|
120
119
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
121
120
|
|
|
122
121
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
123
|
-
#endif // NEW_MMA_AVAILABLE
|
|
122
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
124
123
|
}
|
|
125
124
|
|
|
126
125
|
static int get_mmq_y_host(const int cc) {
|
|
@@ -144,16 +143,25 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
|
144
143
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
145
144
|
}
|
|
146
145
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
#define
|
|
155
|
-
|
|
156
|
-
#define
|
|
146
|
+
// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
|
|
147
|
+
// The K dimension of the tiles has either,
|
|
148
|
+
// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
|
|
149
|
+
// 32 bit elements for the quantized data (does not include scales).
|
|
150
|
+
// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
|
|
151
|
+
// The final tile size in K direction is padded to avoid shared memory bank conflicts,
|
|
152
|
+
// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
|
|
153
|
+
#define MMQ_TILE_NE_K 32
|
|
154
|
+
|
|
155
|
+
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
|
|
156
|
+
#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
|
|
157
|
+
#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
|
|
158
|
+
#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
|
|
159
|
+
#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
|
|
160
|
+
#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
|
|
161
|
+
#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
162
|
+
#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
163
|
+
#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
164
|
+
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
|
|
157
165
|
|
|
158
166
|
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
|
159
167
|
switch (type) {
|
|
@@ -179,11 +187,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
179
187
|
}
|
|
180
188
|
}
|
|
181
189
|
|
|
182
|
-
#define MMQ_MMA_TILE_X_K_Q8_0 (2*
|
|
183
|
-
#define MMQ_MMA_TILE_X_K_Q8_1 (2*
|
|
184
|
-
#define MMQ_MMA_TILE_X_K_Q2_K (2*
|
|
185
|
-
#define MMQ_MMA_TILE_X_K_Q3_K (2*
|
|
186
|
-
#define MMQ_MMA_TILE_X_K_Q6_K (2*
|
|
190
|
+
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
191
|
+
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
192
|
+
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
|
193
|
+
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
|
194
|
+
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
|
187
195
|
|
|
188
196
|
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
|
189
197
|
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|
@@ -215,42 +223,80 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
215
223
|
}
|
|
216
224
|
}
|
|
217
225
|
|
|
218
|
-
|
|
226
|
+
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
|
227
|
+
#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
|
|
219
228
|
|
|
220
229
|
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
|
221
|
-
|
|
230
|
+
if (amd_mfma_available(cc)) {
|
|
231
|
+
return mmq_x >= 128 ? 32 : 16;
|
|
232
|
+
} else if (new_mma_available(cc) && mmq_x >= 48) {
|
|
233
|
+
return 16;
|
|
234
|
+
} else {
|
|
235
|
+
return 8;
|
|
236
|
+
}
|
|
222
237
|
}
|
|
223
238
|
|
|
224
|
-
#
|
|
239
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
240
|
+
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
|
241
|
+
return mmq_x >= 128 ? 32 : 16;
|
|
242
|
+
}
|
|
243
|
+
#elif defined(NEW_MMA_AVAILABLE)
|
|
225
244
|
static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
|
|
226
245
|
return mmq_x >= 48 ? 16 : 8;
|
|
227
246
|
}
|
|
228
247
|
#else
|
|
229
|
-
static constexpr __device__ int mmq_get_granularity_device(const int /*
|
|
248
|
+
static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
|
|
230
249
|
return 8;
|
|
231
250
|
}
|
|
232
|
-
#endif //
|
|
251
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
252
|
+
|
|
253
|
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
254
|
+
static int mmq_get_nwarps_host(const int cc) {
|
|
255
|
+
return amd_mfma_available(cc) ? 8 : 4;
|
|
256
|
+
}
|
|
257
|
+
#else
|
|
258
|
+
static int mmq_get_nwarps_host(const int /*cc*/) {
|
|
259
|
+
return 8;
|
|
260
|
+
}
|
|
261
|
+
#endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
262
|
+
|
|
263
|
+
static constexpr __device__ int mmq_get_nwarps_device() {
|
|
264
|
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
265
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
266
|
+
return 8;
|
|
267
|
+
#else
|
|
268
|
+
return 4;
|
|
269
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
270
|
+
#else
|
|
271
|
+
return 8;
|
|
272
|
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
273
|
+
}
|
|
233
274
|
|
|
234
275
|
// ------------------------------------------------------------
|
|
235
276
|
|
|
236
|
-
template <int mmq_y,
|
|
277
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
|
|
237
278
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
279
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
280
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
238
281
|
|
|
239
|
-
#
|
|
282
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
240
283
|
int * x_qs = (int *) x_tile;
|
|
241
|
-
float * x_df = (float *) (x_qs + 2*
|
|
284
|
+
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
242
285
|
#else
|
|
243
286
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
|
244
287
|
int * x_qs = (int *) x_tile;
|
|
245
288
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
246
|
-
#endif // NEW_MMA_AVAILABLE
|
|
289
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
247
290
|
|
|
248
|
-
|
|
249
|
-
|
|
291
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
|
|
292
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
293
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
294
|
+
const int kbx = txi / QI4_0;
|
|
295
|
+
const int kqsx = txi % QI4_0;
|
|
250
296
|
|
|
251
297
|
#pragma unroll
|
|
252
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
253
|
-
int i = i0 + threadIdx.y;
|
|
298
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
299
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
254
300
|
|
|
255
301
|
if (need_check) {
|
|
256
302
|
i = min(i, i_max);
|
|
@@ -259,20 +305,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
259
305
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
|
|
260
306
|
const int qs0 = get_int_b2(bxi->qs, kqsx);
|
|
261
307
|
|
|
262
|
-
#
|
|
308
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
263
309
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
|
|
264
310
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
|
|
265
311
|
#else
|
|
266
|
-
x_qs[i*(
|
|
267
|
-
#endif // NEW_MMA_AVAILABLE
|
|
312
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
313
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
268
314
|
}
|
|
269
315
|
|
|
270
|
-
|
|
316
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
|
|
317
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
271
318
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
272
319
|
|
|
273
320
|
#pragma unroll
|
|
274
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
275
|
-
int i = i0 + threadIdx.y *
|
|
321
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
322
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
276
323
|
|
|
277
324
|
if (need_check) {
|
|
278
325
|
i = min(i, i_max);
|
|
@@ -280,17 +327,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
280
327
|
|
|
281
328
|
const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
|
|
282
329
|
|
|
283
|
-
#
|
|
284
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
330
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
331
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
285
332
|
#else
|
|
286
|
-
x_df[i*(
|
|
287
|
-
#endif // NEW_MMA_AVAILABLE
|
|
333
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
|
|
334
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
288
335
|
}
|
|
289
336
|
}
|
|
290
337
|
|
|
291
|
-
template <int mmq_x, int mmq_y
|
|
338
|
+
template <int mmq_x, int mmq_y>
|
|
292
339
|
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
293
340
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
341
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
342
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
294
343
|
|
|
295
344
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
|
|
296
345
|
const int * x_qs = (const int *) x;
|
|
@@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
299
348
|
const half2 * y_ds = (const half2 *) y;
|
|
300
349
|
|
|
301
350
|
// #pragma unroll
|
|
302
|
-
for (int k01 = 0; k01 <
|
|
351
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
|
|
303
352
|
const int k0 = k00 + k01;
|
|
304
353
|
|
|
305
354
|
#pragma unroll
|
|
@@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
307
356
|
const int j = j0 + threadIdx.y;
|
|
308
357
|
|
|
309
358
|
#pragma unroll
|
|
310
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
359
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
311
360
|
const int i = i0 + threadIdx.x;
|
|
312
361
|
|
|
313
362
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
@@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
|
|
320
369
|
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
|
|
321
370
|
}
|
|
322
371
|
|
|
323
|
-
sum[j0/nwarps*mmq_y/
|
|
324
|
-
(&x_qs[i*(
|
|
325
|
-
x_df[i*(
|
|
372
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
|
373
|
+
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
|
|
374
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
326
375
|
}
|
|
327
376
|
}
|
|
328
377
|
}
|
|
329
378
|
}
|
|
330
379
|
|
|
331
|
-
template <int mmq_y,
|
|
380
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
|
332
381
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
382
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
383
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
333
384
|
|
|
334
|
-
#
|
|
385
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
335
386
|
int * x_qs = (int *) x_tile;
|
|
336
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
387
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
337
388
|
#else
|
|
338
389
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
|
339
390
|
int * x_qs = (int *) x_tile;
|
|
340
391
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
341
|
-
#endif // NEW_MMA_AVAILABLE
|
|
392
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
342
393
|
|
|
343
|
-
|
|
344
|
-
|
|
394
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
|
|
395
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
396
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
397
|
+
const int kbx = txi / QI4_1;
|
|
398
|
+
const int kqsx = txi % QI4_1;
|
|
345
399
|
|
|
346
400
|
#pragma unroll
|
|
347
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
348
|
-
int i = i0 + threadIdx.y;
|
|
401
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
402
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
349
403
|
|
|
350
404
|
if (need_check) {
|
|
351
405
|
i = min(i, i_max);
|
|
@@ -354,20 +408,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
354
408
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
|
|
355
409
|
const int qs0 = get_int_b4(bxi->qs, kqsx);
|
|
356
410
|
|
|
357
|
-
#
|
|
411
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
358
412
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
|
359
413
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
|
|
360
414
|
#else
|
|
361
|
-
x_qs[i*(
|
|
362
|
-
#endif // NEW_MMA_AVAILABLE
|
|
415
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
416
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
363
417
|
}
|
|
364
418
|
|
|
365
|
-
|
|
419
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
|
|
420
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
366
421
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
367
422
|
|
|
368
423
|
#pragma unroll
|
|
369
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
370
|
-
int i = i0 + threadIdx.y *
|
|
424
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
425
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
371
426
|
|
|
372
427
|
if (need_check) {
|
|
373
428
|
i = min(i, i_max);
|
|
@@ -375,17 +430,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
375
430
|
|
|
376
431
|
const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
|
|
377
432
|
|
|
378
|
-
#
|
|
379
|
-
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1
|
|
433
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
434
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
|
380
435
|
#else
|
|
381
|
-
x_dm[i*(
|
|
382
|
-
#endif // NEW_MMA_AVAILABLE
|
|
436
|
+
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
|
|
437
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
383
438
|
}
|
|
384
439
|
}
|
|
385
440
|
|
|
386
|
-
template <int mmq_x, int mmq_y
|
|
441
|
+
template <int mmq_x, int mmq_y>
|
|
387
442
|
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
388
443
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
444
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
445
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
389
446
|
|
|
390
447
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
|
|
391
448
|
const int * x_qs = (const int *) x;
|
|
@@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
394
451
|
const half2 * y_ds = (const half2 *) y;
|
|
395
452
|
|
|
396
453
|
// #pragma unroll
|
|
397
|
-
for (int k01 = 0; k01 <
|
|
454
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
|
|
398
455
|
const int k0 = k00 + k01;
|
|
399
456
|
|
|
400
457
|
#pragma unroll
|
|
@@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
402
459
|
const int j = j0 + threadIdx.y;
|
|
403
460
|
|
|
404
461
|
#pragma unroll
|
|
405
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
462
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
406
463
|
const int i = i0 + threadIdx.x;
|
|
407
464
|
|
|
408
465
|
const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
|
|
@@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
|
|
415
472
|
u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
|
|
416
473
|
}
|
|
417
474
|
|
|
418
|
-
sum[j0/nwarps*mmq_y/
|
|
419
|
-
(&x_qs[i*(
|
|
420
|
-
x_dm[i*(
|
|
475
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
|
|
476
|
+
(&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
|
|
477
|
+
x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
421
478
|
}
|
|
422
479
|
}
|
|
423
480
|
}
|
|
424
481
|
}
|
|
425
482
|
|
|
426
|
-
template <int mmq_y,
|
|
483
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
|
427
484
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
485
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
486
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
428
487
|
|
|
429
|
-
#
|
|
488
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
430
489
|
int * x_qs = (int *) x_tile;
|
|
431
|
-
float * x_df = (float *) (x_qs +
|
|
490
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
432
491
|
#else
|
|
433
492
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
|
|
434
493
|
int * x_qs = (int *) x_tile;
|
|
435
494
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
436
|
-
#endif // NEW_MMA_AVAILABLE
|
|
495
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
437
496
|
|
|
438
|
-
|
|
439
|
-
|
|
497
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
|
|
498
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
499
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
500
|
+
const int kbx = txi / QI5_0;
|
|
501
|
+
const int kqsx = txi % QI5_0;
|
|
440
502
|
|
|
441
503
|
#pragma unroll
|
|
442
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
443
|
-
int i = i0 + threadIdx.y;
|
|
504
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
505
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
444
506
|
|
|
445
507
|
if (need_check) {
|
|
446
508
|
i = min(i, i_max);
|
|
@@ -449,7 +511,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
449
511
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
|
|
450
512
|
|
|
451
513
|
const int ql = get_int_b2(bxi->qs, kqsx);
|
|
452
|
-
const int qh = get_int_b2(bxi->qh, 0) >> (4 *
|
|
514
|
+
const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
|
|
453
515
|
|
|
454
516
|
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
|
455
517
|
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
@@ -465,21 +527,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
465
527
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
|
466
528
|
qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
|
|
467
529
|
|
|
468
|
-
#
|
|
530
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
469
531
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
|
470
532
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
471
533
|
#else
|
|
472
|
-
x_qs[i*(2*
|
|
473
|
-
x_qs[i*(2*
|
|
474
|
-
#endif // NEW_MMA_AVAILABLE
|
|
534
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
|
|
535
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
|
|
536
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
475
537
|
}
|
|
476
538
|
|
|
477
|
-
|
|
539
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
|
|
540
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
478
541
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
479
542
|
|
|
480
543
|
#pragma unroll
|
|
481
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
482
|
-
int i = i0 + threadIdx.y *
|
|
544
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
545
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
483
546
|
|
|
484
547
|
if (need_check) {
|
|
485
548
|
i = min(i, i_max);
|
|
@@ -487,32 +550,37 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
487
550
|
|
|
488
551
|
const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
|
|
489
552
|
|
|
490
|
-
#
|
|
491
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
553
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
554
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
492
555
|
#else
|
|
493
|
-
x_df[i*(
|
|
494
|
-
#endif // NEW_MMA_AVAILABLE
|
|
556
|
+
x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
|
|
557
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
495
558
|
}
|
|
496
559
|
}
|
|
497
560
|
|
|
498
|
-
template <int mmq_y,
|
|
561
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
|
499
562
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
563
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
564
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
500
565
|
|
|
501
|
-
#
|
|
566
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
502
567
|
int * x_qs = (int *) x_tile;
|
|
503
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
568
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
504
569
|
#else
|
|
505
570
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
|
506
571
|
int * x_qs = (int *) x_tile;
|
|
507
572
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
508
|
-
#endif // NEW_MMA_AVAILABLE
|
|
573
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
509
574
|
|
|
510
|
-
|
|
511
|
-
|
|
575
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
|
|
576
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
577
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
578
|
+
const int kbx = txi / QI5_1;
|
|
579
|
+
const int kqsx = txi % QI5_1;
|
|
512
580
|
|
|
513
581
|
#pragma unroll
|
|
514
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
515
|
-
int i = i0 + threadIdx.y;
|
|
582
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
583
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
516
584
|
|
|
517
585
|
if (need_check) {
|
|
518
586
|
i = min(i, i_max);
|
|
@@ -521,7 +589,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
521
589
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
|
|
522
590
|
|
|
523
591
|
const int ql = get_int_b4(bxi->qs, kqsx);
|
|
524
|
-
const int qh = get_int_b4(bxi->qh, 0) >> (4 *
|
|
592
|
+
const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
|
|
525
593
|
|
|
526
594
|
int qs0 = (ql >> 0) & 0x0F0F0F0F;
|
|
527
595
|
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
|
|
@@ -535,21 +603,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
535
603
|
qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
|
|
536
604
|
qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
|
|
537
605
|
|
|
538
|
-
#
|
|
606
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
539
607
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
|
540
608
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
541
609
|
#else
|
|
542
|
-
x_qs[i*(2*
|
|
543
|
-
x_qs[i*(2*
|
|
544
|
-
#endif // NEW_MMA_AVAILABLE
|
|
610
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
|
|
611
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
|
|
612
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
545
613
|
}
|
|
546
614
|
|
|
547
|
-
|
|
615
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
|
|
616
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
548
617
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
549
618
|
|
|
550
619
|
#pragma unroll
|
|
551
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
552
|
-
int i = i0 + threadIdx.y *
|
|
620
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
621
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
553
622
|
|
|
554
623
|
if (need_check) {
|
|
555
624
|
i = min(i, i_max);
|
|
@@ -557,32 +626,38 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
557
626
|
|
|
558
627
|
const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
|
|
559
628
|
|
|
560
|
-
#
|
|
561
|
-
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1
|
|
629
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
630
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
|
|
562
631
|
#else
|
|
563
|
-
x_dm[i*(
|
|
564
|
-
#endif // NEW_MMA_AVAILABLE
|
|
632
|
+
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
|
|
633
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
565
634
|
}
|
|
566
635
|
}
|
|
567
636
|
|
|
568
|
-
template <int mmq_y,
|
|
637
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
|
569
638
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
639
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
640
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
570
641
|
|
|
571
|
-
#
|
|
642
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
572
643
|
int * x_qs = (int *) x_tile;
|
|
573
|
-
float * x_df = (float *) (x_tile + 2*
|
|
644
|
+
float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
|
|
574
645
|
#else
|
|
575
646
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
576
647
|
int * x_qs = (int *) x_tile;
|
|
577
648
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
578
|
-
#endif // NEW_MMA_AVAILABLE
|
|
649
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
579
650
|
|
|
580
|
-
|
|
581
|
-
|
|
651
|
+
// MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
|
|
652
|
+
constexpr int threads_per_row = 32;
|
|
653
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
654
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
655
|
+
const int kbx = txi / QI8_0;
|
|
656
|
+
const int kqsx = txi % QI8_0;
|
|
582
657
|
|
|
583
658
|
#pragma unroll
|
|
584
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
585
|
-
int i = i0 + threadIdx.y;
|
|
659
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
660
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
586
661
|
|
|
587
662
|
if (need_check) {
|
|
588
663
|
i = min(i, i_max);
|
|
@@ -590,21 +665,22 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
590
665
|
|
|
591
666
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
|
|
592
667
|
|
|
593
|
-
#
|
|
594
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0
|
|
595
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 +
|
|
668
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
669
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
|
670
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
|
596
671
|
#else
|
|
597
|
-
x_qs[i*(2*
|
|
598
|
-
x_qs[i*(2*
|
|
599
|
-
#endif // NEW_MMA_AVAILABLE
|
|
672
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
|
|
673
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
|
|
674
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
600
675
|
}
|
|
601
676
|
|
|
602
|
-
|
|
677
|
+
constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
|
|
678
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
603
679
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
604
680
|
|
|
605
681
|
#pragma unroll
|
|
606
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
607
|
-
int i = i0 + threadIdx.y *
|
|
682
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
683
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
608
684
|
|
|
609
685
|
if (need_check) {
|
|
610
686
|
i = min(i, i_max);
|
|
@@ -612,17 +688,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
612
688
|
|
|
613
689
|
const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
|
|
614
690
|
|
|
615
|
-
#
|
|
616
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
691
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
692
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
|
|
617
693
|
#else
|
|
618
|
-
x_df[i*(2*
|
|
619
|
-
#endif // NEW_MMA_AVAILABLE
|
|
694
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
|
|
695
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
620
696
|
}
|
|
621
697
|
}
|
|
622
698
|
|
|
623
|
-
template <int mmq_x, int mmq_y
|
|
699
|
+
template <int mmq_x, int mmq_y>
|
|
624
700
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
625
701
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
702
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
703
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
626
704
|
|
|
627
705
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
|
|
628
706
|
const int * x_qs = (const int *) x;
|
|
@@ -631,7 +709,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
|
631
709
|
const float * y_df = (const float *) y;
|
|
632
710
|
|
|
633
711
|
// #pragma unroll
|
|
634
|
-
for (int k01 = 0; k01 <
|
|
712
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
|
635
713
|
const int k0 = k00 + k01;
|
|
636
714
|
|
|
637
715
|
#pragma unroll
|
|
@@ -639,21 +717,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
|
639
717
|
const int j = j0 + threadIdx.y;
|
|
640
718
|
|
|
641
719
|
#pragma unroll
|
|
642
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
720
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
643
721
|
const int i = i0 + threadIdx.x;
|
|
644
722
|
|
|
645
|
-
sum[j0/nwarps*mmq_y/
|
|
646
|
-
(&x_qs[i*(2*
|
|
647
|
-
x_df[i*(2*
|
|
723
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
|
|
724
|
+
(&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
|
|
725
|
+
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
|
|
648
726
|
}
|
|
649
727
|
}
|
|
650
728
|
}
|
|
651
729
|
}
|
|
652
730
|
|
|
653
|
-
template <int mmq_x, int mmq_y,
|
|
731
|
+
template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
|
|
654
732
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
655
733
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
734
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
735
|
+
typedef tile<16, 8, int> tile_A;
|
|
736
|
+
typedef tile<16, 8, int> tile_B;
|
|
737
|
+
typedef tile<16, 16, int> tile_C;
|
|
738
|
+
|
|
739
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
740
|
+
constexpr int rows_per_warp = granularity;
|
|
741
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
656
742
|
|
|
743
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
744
|
+
|
|
745
|
+
const int * x_qs = (const int *) x;
|
|
746
|
+
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
|
|
747
|
+
const int * y_qs = (const int *) y + 4;
|
|
748
|
+
const float * y_df = (const float *) y;
|
|
749
|
+
const half2 * y_ds = (const half2 *) y;
|
|
750
|
+
|
|
751
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
752
|
+
|
|
753
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
754
|
+
const int k0 = k00 + k01;
|
|
755
|
+
|
|
756
|
+
tile_A A[ntx];
|
|
757
|
+
#pragma unroll
|
|
758
|
+
for (int n = 0; n < ntx; ++n) {
|
|
759
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
#pragma unroll
|
|
763
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
764
|
+
tile_B B;
|
|
765
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
766
|
+
|
|
767
|
+
float dB;
|
|
768
|
+
const int j = j0 + tile_C::get_j(0);
|
|
769
|
+
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
|
770
|
+
dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
|
771
|
+
} else {
|
|
772
|
+
dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
#pragma unroll
|
|
776
|
+
for (int n = 0; n < ntx; ++n) {
|
|
777
|
+
tile_C C;
|
|
778
|
+
mma(C, A[n], B);
|
|
779
|
+
|
|
780
|
+
#pragma unroll
|
|
781
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
782
|
+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
|
|
783
|
+
const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
|
784
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
|
|
785
|
+
}
|
|
786
|
+
}
|
|
787
|
+
}
|
|
788
|
+
}
|
|
789
|
+
#else
|
|
657
790
|
typedef tile<16, 8, int> tile_A;
|
|
658
791
|
typedef tile< 8, 8, int> tile_B;
|
|
659
792
|
typedef tile<16, 8, int> tile_C;
|
|
@@ -662,23 +795,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
662
795
|
constexpr int rows_per_warp = 2 * granularity;
|
|
663
796
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
664
797
|
|
|
665
|
-
y += (threadIdx.y % ntx) * (
|
|
798
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
666
799
|
|
|
667
800
|
const int * x_qs = (const int *) x;
|
|
668
|
-
const float * x_df = (const float *) x_qs + 2*
|
|
801
|
+
const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
|
|
669
802
|
const int * y_qs = (const int *) y + 4;
|
|
670
803
|
const float * y_df = (const float *) y;
|
|
671
804
|
const half2 * y_ds = (const half2 *) y;
|
|
672
805
|
|
|
673
|
-
tile_A A[ntx][
|
|
674
|
-
float dA[ntx][tile_C::ne/2][
|
|
806
|
+
tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
|
|
807
|
+
float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
|
|
675
808
|
|
|
676
809
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
|
677
810
|
|
|
678
811
|
#pragma unroll
|
|
679
812
|
for (int n = 0; n < ntx; ++n) {
|
|
680
813
|
#pragma unroll
|
|
681
|
-
for (int k01 = 0; k01 <
|
|
814
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
682
815
|
const int k0 = k00 + k01;
|
|
683
816
|
|
|
684
817
|
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
|
@@ -689,7 +822,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
689
822
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
|
690
823
|
|
|
691
824
|
#pragma unroll
|
|
692
|
-
for (int k01 = 0; k01 <
|
|
825
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
693
826
|
const int k0 = k00 + k01;
|
|
694
827
|
|
|
695
828
|
dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
|
|
@@ -700,7 +833,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
700
833
|
#pragma unroll
|
|
701
834
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
702
835
|
#pragma unroll
|
|
703
|
-
for (int k01 = 0; k01 <
|
|
836
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
704
837
|
tile_B B;
|
|
705
838
|
float dB[tile_C::ne/2];
|
|
706
839
|
|
|
@@ -729,11 +862,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
729
862
|
}
|
|
730
863
|
}
|
|
731
864
|
}
|
|
865
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
732
866
|
}
|
|
733
867
|
|
|
734
|
-
template <int mmq_x, int mmq_y
|
|
868
|
+
template <int mmq_x, int mmq_y>
|
|
735
869
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
736
870
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
871
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
872
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
737
873
|
|
|
738
874
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
|
|
739
875
|
const int * x_qs = (const int *) x;
|
|
@@ -742,7 +878,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
|
742
878
|
const half2 * y_ds = (const half2 *) y;
|
|
743
879
|
|
|
744
880
|
// #pragma unroll
|
|
745
|
-
for (int k01 = 0; k01 <
|
|
881
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
|
|
746
882
|
const int k0 = k00 + k01;
|
|
747
883
|
|
|
748
884
|
#pragma unroll
|
|
@@ -750,45 +886,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
|
750
886
|
const int j = j0 + threadIdx.y;
|
|
751
887
|
|
|
752
888
|
#pragma unroll
|
|
753
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
889
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
754
890
|
const int i = i0 + threadIdx.x;
|
|
755
891
|
|
|
756
|
-
sum[j0/nwarps*mmq_y/
|
|
757
|
-
(&x_qs[i*(2*
|
|
758
|
-
x_dm[i*(
|
|
892
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
|
|
893
|
+
(&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
|
894
|
+
x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
759
895
|
}
|
|
760
896
|
}
|
|
761
897
|
}
|
|
762
898
|
}
|
|
763
899
|
|
|
764
|
-
template <int mmq_x, int mmq_y
|
|
900
|
+
template <int mmq_x, int mmq_y>
|
|
765
901
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
766
902
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
903
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
904
|
+
typedef tile<16, 8, int> tile_A;
|
|
905
|
+
typedef tile<16, 8, int> tile_B;
|
|
906
|
+
typedef tile<16, 16, int> tile_C;
|
|
767
907
|
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
908
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
909
|
+
constexpr int rows_per_warp = granularity;
|
|
910
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
911
|
+
|
|
912
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
913
|
+
|
|
914
|
+
const int * x_qs = (const int *) x;
|
|
915
|
+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
|
|
916
|
+
const int * y_qs = (const int *) y + 4;
|
|
917
|
+
const half2 * y_dm = (const half2 *) y;
|
|
918
|
+
|
|
919
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
920
|
+
|
|
921
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
922
|
+
const int k0 = k00 + k01;
|
|
923
|
+
|
|
924
|
+
tile_A A[ntx];
|
|
925
|
+
#pragma unroll
|
|
926
|
+
for (int n = 0; n < ntx; ++n) {
|
|
927
|
+
load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
|
928
|
+
}
|
|
929
|
+
|
|
930
|
+
#pragma unroll
|
|
931
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
932
|
+
tile_B B;
|
|
933
|
+
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
934
|
+
|
|
935
|
+
const int j = j0 + tile_C::get_j(0);
|
|
936
|
+
const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
937
|
+
|
|
938
|
+
#pragma unroll
|
|
939
|
+
for (int n = 0; n < ntx; ++n) {
|
|
940
|
+
tile_C C;
|
|
941
|
+
mma(C, A[n], B);
|
|
942
|
+
|
|
943
|
+
#pragma unroll
|
|
944
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
945
|
+
const int i = i0 + n*tile_A::I + tile_C::get_i(l);
|
|
946
|
+
float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
|
|
947
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
|
|
948
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
|
|
949
|
+
}
|
|
950
|
+
}
|
|
951
|
+
}
|
|
952
|
+
}
|
|
953
|
+
#else
|
|
954
|
+
typedef tile<16, 8, int> tile_A;
|
|
955
|
+
typedef tile< 8, 8, int> tile_B;
|
|
956
|
+
typedef tile<16, 8, int> tile_C;
|
|
771
957
|
|
|
772
958
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
773
959
|
constexpr int rows_per_warp = 2 * granularity;
|
|
774
960
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
775
961
|
|
|
776
|
-
y += (threadIdx.y % ntx) * (
|
|
962
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
777
963
|
|
|
778
964
|
const int * x_qs = (const int *) x;
|
|
779
|
-
const half2 * x_dm = (const half2 *) x_qs + 2*
|
|
965
|
+
const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
|
|
780
966
|
const int * y_qs = (const int *) y + 4;
|
|
781
967
|
const half2 * y_dm = (const half2 *) y;
|
|
782
968
|
|
|
783
|
-
tile_A A[ntx][
|
|
784
|
-
float2 dmA[ntx][tile_C::ne/2][
|
|
969
|
+
tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
|
|
970
|
+
float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
|
|
785
971
|
|
|
786
972
|
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
|
787
973
|
|
|
788
974
|
#pragma unroll
|
|
789
975
|
for (int n = 0; n < ntx; ++n) {
|
|
790
976
|
#pragma unroll
|
|
791
|
-
for (int k01 = 0; k01 <
|
|
977
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
792
978
|
const int k0 = k00 + k01;
|
|
793
979
|
|
|
794
980
|
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
|
@@ -799,7 +985,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
799
985
|
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
|
800
986
|
|
|
801
987
|
#pragma unroll
|
|
802
|
-
for (int k01 = 0; k01 <
|
|
988
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
803
989
|
const int k0 = k00 + k01;
|
|
804
990
|
|
|
805
991
|
dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
|
|
@@ -810,7 +996,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
810
996
|
#pragma unroll
|
|
811
997
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
812
998
|
#pragma unroll
|
|
813
|
-
for (int k01 = 0; k01 <
|
|
999
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
814
1000
|
tile_B B;
|
|
815
1001
|
float2 dsB[tile_C::ne/2];
|
|
816
1002
|
|
|
@@ -836,11 +1022,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|
|
836
1022
|
}
|
|
837
1023
|
}
|
|
838
1024
|
}
|
|
1025
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
839
1026
|
}
|
|
840
1027
|
|
|
841
|
-
|
|
1028
|
+
// Used for Q3_K, IQ2_S, and IQ2_XS
|
|
1029
|
+
template <int mmq_x, int mmq_y>
|
|
842
1030
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
843
1031
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1032
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1033
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
844
1034
|
|
|
845
1035
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
|
846
1036
|
const int * x_qs = (const int *) x;
|
|
@@ -849,7 +1039,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
|
849
1039
|
const float * y_df = (const float *) y;
|
|
850
1040
|
|
|
851
1041
|
// #pragma unroll
|
|
852
|
-
for (int k01 = 0; k01 <
|
|
1042
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
|
|
853
1043
|
const int k0 = k00 + k01;
|
|
854
1044
|
|
|
855
1045
|
#pragma unroll
|
|
@@ -857,23 +1047,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
|
|
857
1047
|
const int j = j0 + threadIdx.y;
|
|
858
1048
|
|
|
859
1049
|
#pragma unroll
|
|
860
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1050
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
861
1051
|
const int i = i0 + threadIdx.x;
|
|
862
1052
|
|
|
863
|
-
sum[j0/nwarps*mmq_y/
|
|
864
|
-
&x_qs[i*(2*
|
|
1053
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
|
|
1054
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
|
|
865
1055
|
&y_qs[j*MMQ_TILE_Y_K + k01],
|
|
866
|
-
&x_df[i*(2*
|
|
1056
|
+
&x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
|
|
867
1057
|
y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
868
1058
|
}
|
|
869
1059
|
}
|
|
870
1060
|
}
|
|
871
1061
|
}
|
|
872
1062
|
|
|
873
|
-
|
|
1063
|
+
// Used for Q3_K, IQ2_S, and IQ2_XS:
|
|
1064
|
+
template <int mmq_x, int mmq_y>
|
|
874
1065
|
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
875
1066
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
876
|
-
#
|
|
1067
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1068
|
+
typedef tile<16, 8, int> tile_A;
|
|
1069
|
+
typedef tile<16, 8, int> tile_B;
|
|
1070
|
+
typedef tile<16, 16, int> tile_C;
|
|
1071
|
+
typedef tile<64, 2, int> tile_load;
|
|
1072
|
+
|
|
1073
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1074
|
+
constexpr int rows_per_warp = granularity;
|
|
1075
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1076
|
+
|
|
1077
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1078
|
+
|
|
1079
|
+
const int * x_qs = (const int *) x;
|
|
1080
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
1081
|
+
const int * y_qs = (const int *) y + 4;
|
|
1082
|
+
const float * y_df = (const float *) y;
|
|
1083
|
+
|
|
1084
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1085
|
+
|
|
1086
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1087
|
+
const int k0 = k00 + k01;
|
|
1088
|
+
|
|
1089
|
+
tile_A A[ntx];
|
|
1090
|
+
#pragma unroll
|
|
1091
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1092
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
1093
|
+
}
|
|
1094
|
+
|
|
1095
|
+
#pragma unroll
|
|
1096
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1097
|
+
tile_B B[1];
|
|
1098
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1099
|
+
|
|
1100
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1101
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
1102
|
+
|
|
1103
|
+
#pragma unroll
|
|
1104
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1105
|
+
tile_C C;
|
|
1106
|
+
mma(C, A[n], B[0]);
|
|
1107
|
+
|
|
1108
|
+
#pragma unroll
|
|
1109
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1110
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1111
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
}
|
|
1115
|
+
}
|
|
1116
|
+
#elif defined(NEW_MMA_AVAILABLE)
|
|
877
1117
|
|
|
878
1118
|
typedef tile<16, 4, int> tile_A;
|
|
879
1119
|
typedef tile<16, 8, int> tile_A_8;
|
|
@@ -884,10 +1124,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
884
1124
|
constexpr int rows_per_warp = 2 * granularity;
|
|
885
1125
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
886
1126
|
|
|
887
|
-
y += (threadIdx.y % ntx) * (
|
|
1127
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
888
1128
|
|
|
889
1129
|
const int * x_qs = (const int *) x;
|
|
890
|
-
const float * x_df = (const float *) x_qs +
|
|
1130
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
891
1131
|
const int * y_qs = (const int *) y + 4;
|
|
892
1132
|
const float * y_df = (const float *) y;
|
|
893
1133
|
|
|
@@ -899,7 +1139,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
899
1139
|
#pragma unroll
|
|
900
1140
|
for (int n = 0; n < ntx; ++n) {
|
|
901
1141
|
#pragma unroll
|
|
902
|
-
for (int k01 = 0; k01 <
|
|
1142
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
|
903
1143
|
const int k0 = k00 + k01;
|
|
904
1144
|
|
|
905
1145
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
|
@@ -910,7 +1150,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
910
1150
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
|
911
1151
|
|
|
912
1152
|
#pragma unroll
|
|
913
|
-
for (int k01 = 0; k01 <
|
|
1153
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
914
1154
|
const int k0 = k00 + k01;
|
|
915
1155
|
|
|
916
1156
|
dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
|
|
@@ -921,7 +1161,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
921
1161
|
#pragma unroll
|
|
922
1162
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
923
1163
|
#pragma unroll
|
|
924
|
-
for (int k01 = 0; k01 <
|
|
1164
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
|
925
1165
|
tile_B B[2];
|
|
926
1166
|
float dB[tile_C::ne/2];
|
|
927
1167
|
|
|
@@ -952,26 +1192,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
|
|
952
1192
|
#else
|
|
953
1193
|
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
|
|
954
1194
|
NO_DEVICE_CODE;
|
|
955
|
-
#endif //
|
|
1195
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
956
1196
|
}
|
|
957
1197
|
|
|
958
|
-
template <int mmq_y,
|
|
1198
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
|
959
1199
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1200
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
960
1201
|
|
|
961
|
-
#
|
|
1202
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
962
1203
|
int * x_qs = (int *) x_tile;
|
|
963
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
1204
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
964
1205
|
#else
|
|
965
1206
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
|
966
1207
|
int * x_qs = (int *) x_tile;
|
|
967
1208
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
968
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1209
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
969
1210
|
|
|
970
|
-
|
|
1211
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
|
|
1212
|
+
constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
|
|
1213
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
971
1214
|
|
|
972
1215
|
#pragma unroll
|
|
973
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
974
|
-
int i = i0 + threadIdx.y*
|
|
1216
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1217
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
975
1218
|
|
|
976
1219
|
if (need_check) {
|
|
977
1220
|
i = min(i, i_max);
|
|
@@ -987,11 +1230,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
987
1230
|
|
|
988
1231
|
const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
|
|
989
1232
|
|
|
990
|
-
#
|
|
1233
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
991
1234
|
x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
|
|
992
1235
|
#else
|
|
993
|
-
x_qs[i*(2*
|
|
994
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1236
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
|
1237
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
995
1238
|
}
|
|
996
1239
|
|
|
997
1240
|
const int sc_m = bxi->scales[kqsx];
|
|
@@ -1002,17 +1245,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1002
1245
|
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
|
1003
1246
|
#endif // FAST_FP16_AVAILABLE
|
|
1004
1247
|
|
|
1005
|
-
#
|
|
1248
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1006
1249
|
x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
|
|
1007
1250
|
#else
|
|
1008
|
-
x_dm[i*(
|
|
1009
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1251
|
+
x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
|
|
1252
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1010
1253
|
}
|
|
1011
1254
|
}
|
|
1012
1255
|
|
|
1013
|
-
template <int mmq_x, int mmq_y
|
|
1256
|
+
template <int mmq_x, int mmq_y>
|
|
1014
1257
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
1015
1258
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1259
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1260
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1016
1261
|
|
|
1017
1262
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
|
|
1018
1263
|
const int * x_qs = (const int *) x;
|
|
@@ -1029,7 +1274,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1029
1274
|
}
|
|
1030
1275
|
|
|
1031
1276
|
#pragma unroll
|
|
1032
|
-
for (int k01 = 0; k01 <
|
|
1277
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
|
1033
1278
|
const int k0 = k00 + k01;
|
|
1034
1279
|
|
|
1035
1280
|
#pragma unroll
|
|
@@ -1037,13 +1282,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1037
1282
|
const int j = j0 + threadIdx.y;
|
|
1038
1283
|
|
|
1039
1284
|
#pragma unroll
|
|
1040
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1285
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1041
1286
|
const int i = i0 + threadIdx.x;
|
|
1042
1287
|
|
|
1043
1288
|
constexpr int ns = 2;
|
|
1044
|
-
sum[j0/nwarps*mmq_y/
|
|
1045
|
-
&x_qs[i*(2*
|
|
1046
|
-
&x_dm[i*(
|
|
1289
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
|
1290
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
|
1291
|
+
&x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
|
1047
1292
|
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
|
1048
1293
|
}
|
|
1049
1294
|
}
|
|
@@ -1052,7 +1297,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1052
1297
|
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
|
|
1053
1298
|
// As a workaround 2 separate loops are used instead.
|
|
1054
1299
|
#pragma unroll
|
|
1055
|
-
for (int k01 =
|
|
1300
|
+
for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
|
1056
1301
|
const int k0 = k00 + k01;
|
|
1057
1302
|
|
|
1058
1303
|
#pragma unroll
|
|
@@ -1060,23 +1305,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
|
|
1060
1305
|
const int j = j0 + threadIdx.y;
|
|
1061
1306
|
|
|
1062
1307
|
#pragma unroll
|
|
1063
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1308
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1064
1309
|
const int i = i0 + threadIdx.x;
|
|
1065
1310
|
|
|
1066
1311
|
constexpr int ns = 1;
|
|
1067
|
-
sum[j0/nwarps*mmq_y/
|
|
1068
|
-
&x_qs[i*(2*
|
|
1069
|
-
&x_dm[i*(
|
|
1312
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
|
1313
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
|
1314
|
+
&x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
|
1070
1315
|
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
|
1071
1316
|
}
|
|
1072
1317
|
}
|
|
1073
1318
|
}
|
|
1074
1319
|
}
|
|
1075
1320
|
|
|
1076
|
-
template <int mmq_x, int mmq_y
|
|
1321
|
+
template <int mmq_x, int mmq_y>
|
|
1077
1322
|
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
1078
1323
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1079
|
-
#
|
|
1324
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1325
|
+
typedef tile<16, 8, int> tile_A;
|
|
1326
|
+
typedef tile<16, 8, int> tile_B;
|
|
1327
|
+
typedef tile<16, 16, int> tile_C;
|
|
1328
|
+
typedef tile<64, 2, int> tile_load;
|
|
1329
|
+
|
|
1330
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
1331
|
+
constexpr int rows_per_warp = granularity;
|
|
1332
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1333
|
+
|
|
1334
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1335
|
+
|
|
1336
|
+
const int * x_qs = (const int *) x;
|
|
1337
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1338
|
+
const int * y_qs = (const int *) y + 4;
|
|
1339
|
+
const half2 * y_ds = (const half2 *) y;
|
|
1340
|
+
|
|
1341
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
1342
|
+
|
|
1343
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
1344
|
+
const int k0 = k00 + k01;
|
|
1345
|
+
|
|
1346
|
+
tile_A A[ntx];
|
|
1347
|
+
#pragma unroll
|
|
1348
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1349
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
#pragma unroll
|
|
1353
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
1354
|
+
tile_B B[1];
|
|
1355
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
1356
|
+
|
|
1357
|
+
const int j = j0 + tile_C::get_j(0);
|
|
1358
|
+
const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
|
|
1359
|
+
const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
|
|
1360
|
+
: (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
|
|
1361
|
+
: __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
|
|
1362
|
+
|
|
1363
|
+
tile_C Cm;
|
|
1364
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1365
|
+
tile_A A1;
|
|
1366
|
+
A1.x[0] = 0x01010101;
|
|
1367
|
+
A1.x[1] = 0x01010101;
|
|
1368
|
+
mma(Cm, A1, B[0]);
|
|
1369
|
+
}
|
|
1370
|
+
|
|
1371
|
+
#pragma unroll
|
|
1372
|
+
for (int n = 0; n < ntx; ++n) {
|
|
1373
|
+
tile_C Cd;
|
|
1374
|
+
mma(Cd, A[n], B[0]);
|
|
1375
|
+
|
|
1376
|
+
#pragma unroll
|
|
1377
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1378
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
1379
|
+
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
|
|
1380
|
+
float tmp = Cd.x[l]*dm.x;
|
|
1381
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1382
|
+
tmp -= Cm.x[l]*dm.y;
|
|
1383
|
+
}
|
|
1384
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
|
|
1385
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
|
|
1386
|
+
}
|
|
1387
|
+
}
|
|
1388
|
+
}
|
|
1389
|
+
}
|
|
1390
|
+
#elif defined(NEW_MMA_AVAILABLE)
|
|
1080
1391
|
|
|
1081
1392
|
typedef tile<16, 4, int> tile_A;
|
|
1082
1393
|
typedef tile<16, 8, int> tile_A_8;
|
|
@@ -1087,10 +1398,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1087
1398
|
constexpr int rows_per_warp = 2 * granularity;
|
|
1088
1399
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1089
1400
|
|
|
1090
|
-
y += (threadIdx.y % ntx) * (
|
|
1401
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1091
1402
|
|
|
1092
1403
|
const int * x_qs = (const int *) x;
|
|
1093
|
-
const half2 * x_dm = (const half2 *) x_qs +
|
|
1404
|
+
const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
|
|
1094
1405
|
const int * y_qs = (const int *) y + 4;
|
|
1095
1406
|
const half2 * y_ds = (const half2 *) y;
|
|
1096
1407
|
|
|
@@ -1103,7 +1414,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1103
1414
|
#pragma unroll
|
|
1104
1415
|
for (int n = 0; n < ntx; ++n) {
|
|
1105
1416
|
#pragma unroll
|
|
1106
|
-
for (int k01 = 0; k01 <
|
|
1417
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
1107
1418
|
const int k0 = k00 + k01;
|
|
1108
1419
|
|
|
1109
1420
|
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
|
@@ -1117,7 +1428,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1117
1428
|
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
|
1118
1429
|
|
|
1119
1430
|
#pragma unroll
|
|
1120
|
-
for (int k01 = 0; k01 <
|
|
1431
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
|
|
1121
1432
|
const int k0 = k00 + k01;
|
|
1122
1433
|
|
|
1123
1434
|
const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
|
|
@@ -1140,7 +1451,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1140
1451
|
}
|
|
1141
1452
|
|
|
1142
1453
|
#pragma unroll
|
|
1143
|
-
for (int k01 = 0; k01 <
|
|
1454
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
|
|
1144
1455
|
tile_B B[2];
|
|
1145
1456
|
|
|
1146
1457
|
// Here load_generic is faster than load_ldmatrix.
|
|
@@ -1148,7 +1459,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1148
1459
|
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
|
1149
1460
|
|
|
1150
1461
|
tile_C Cm[2];
|
|
1151
|
-
if (k01 >=
|
|
1462
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1152
1463
|
tile_A A1;
|
|
1153
1464
|
A1.x[0] = 0x01010101;
|
|
1154
1465
|
A1.x[1] = 0x01010101;
|
|
@@ -1166,16 +1477,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1166
1477
|
#pragma unroll
|
|
1167
1478
|
for (int l = 0; l < tile_C::ne; ++l) {
|
|
1168
1479
|
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
|
1169
|
-
if (k01 >=
|
|
1480
|
+
if (k01 >= MMQ_TILE_NE_K * 3/4) {
|
|
1170
1481
|
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
|
1171
1482
|
}
|
|
1172
|
-
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 <
|
|
1483
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
|
|
1173
1484
|
}
|
|
1174
1485
|
}
|
|
1175
1486
|
}
|
|
1176
1487
|
|
|
1177
1488
|
#pragma unroll
|
|
1178
|
-
for (int k01 = 0; k01 <
|
|
1489
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
|
|
1179
1490
|
float2 sB[tile_C::ne/2];
|
|
1180
1491
|
|
|
1181
1492
|
#pragma unroll
|
|
@@ -1198,27 +1509,31 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
|
|
1198
1509
|
#else
|
|
1199
1510
|
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
|
|
1200
1511
|
NO_DEVICE_CODE;
|
|
1201
|
-
#endif //
|
|
1512
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
1202
1513
|
}
|
|
1203
1514
|
|
|
1204
|
-
template <int mmq_y,
|
|
1515
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
|
|
1205
1516
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1517
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1518
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1206
1519
|
|
|
1207
|
-
#
|
|
1520
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1208
1521
|
int * x_qs = (int *) x_tile;
|
|
1209
|
-
float * x_df = (float *) (x_qs +
|
|
1522
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1210
1523
|
#else
|
|
1211
1524
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
|
1212
1525
|
int * x_qs = (int *) x_tile;
|
|
1213
1526
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1214
1527
|
int * x_sc = (int *) (x_df + txs.dm);
|
|
1215
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1528
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1216
1529
|
|
|
1217
|
-
|
|
1530
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
|
|
1531
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
1532
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
1218
1533
|
|
|
1219
1534
|
#pragma unroll
|
|
1220
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1221
|
-
int i = i0 + threadIdx.y
|
|
1535
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1536
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
1222
1537
|
|
|
1223
1538
|
if (need_check) {
|
|
1224
1539
|
i = min(i, i_max);
|
|
@@ -1238,17 +1553,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1238
1553
|
|
|
1239
1554
|
const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
|
|
1240
1555
|
|
|
1241
|
-
#
|
|
1556
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1242
1557
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
|
|
1243
1558
|
#else
|
|
1244
|
-
x_qs[i*(2*
|
|
1245
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1559
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
|
|
1560
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1246
1561
|
}
|
|
1247
1562
|
}
|
|
1248
1563
|
|
|
1564
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1249
1565
|
#pragma unroll
|
|
1250
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1251
|
-
int i = i0 + threadIdx.y*
|
|
1566
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1567
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
|
|
1252
1568
|
|
|
1253
1569
|
if (need_check) {
|
|
1254
1570
|
i = min(i, i_max);
|
|
@@ -1256,7 +1572,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1256
1572
|
|
|
1257
1573
|
const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
|
|
1258
1574
|
|
|
1259
|
-
const int ksc = threadIdx.x %
|
|
1575
|
+
const int ksc = threadIdx.x % 4;
|
|
1260
1576
|
|
|
1261
1577
|
const int ksc_low = ksc % (QI3_K/8);
|
|
1262
1578
|
const int shift_low = 4 * (ksc / (QI3_K/8));
|
|
@@ -1268,23 +1584,23 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1268
1584
|
|
|
1269
1585
|
const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
|
|
1270
1586
|
|
|
1271
|
-
#
|
|
1587
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1272
1588
|
const int8_t * sc8 = (const int8_t *) ≻
|
|
1273
1589
|
const float d = bxi->d;
|
|
1274
1590
|
|
|
1275
1591
|
#pragma unroll
|
|
1276
1592
|
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
1277
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*
|
|
1593
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
|
|
1278
1594
|
}
|
|
1279
1595
|
#else
|
|
1280
|
-
x_sc[i*(
|
|
1281
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1596
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
|
|
1597
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1282
1598
|
}
|
|
1283
1599
|
|
|
1284
|
-
#
|
|
1600
|
+
#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
|
|
1285
1601
|
#pragma unroll
|
|
1286
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1287
|
-
int i = (i0 + threadIdx.y*
|
|
1602
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
1603
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1288
1604
|
|
|
1289
1605
|
if (need_check) {
|
|
1290
1606
|
i = min(i, i_max);
|
|
@@ -1294,12 +1610,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1294
1610
|
|
|
1295
1611
|
x_df[i] = bxi->d;
|
|
1296
1612
|
}
|
|
1297
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1613
|
+
#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE))
|
|
1298
1614
|
}
|
|
1299
1615
|
|
|
1300
|
-
template <int mmq_x, int mmq_y
|
|
1616
|
+
template <int mmq_x, int mmq_y>
|
|
1301
1617
|
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
1302
1618
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1619
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1620
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1303
1621
|
|
|
1304
1622
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
|
|
1305
1623
|
const int * x_qs = (const int *) x;
|
|
@@ -1309,7 +1627,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
|
1309
1627
|
const float * y_df = (const float *) y;
|
|
1310
1628
|
|
|
1311
1629
|
// #pragma unroll
|
|
1312
|
-
for (int k01 = 0; k01 <
|
|
1630
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
|
1313
1631
|
const int k0 = k00 + k01;
|
|
1314
1632
|
|
|
1315
1633
|
#pragma unroll
|
|
@@ -1317,13 +1635,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
|
|
|
1317
1635
|
const int j = j0 + threadIdx.y;
|
|
1318
1636
|
|
|
1319
1637
|
#pragma unroll
|
|
1320
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1638
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1321
1639
|
const int i = i0 + threadIdx.x;
|
|
1322
1640
|
|
|
1323
|
-
const int8_t * scales = ((const int8_t *) (x_sc + i*(
|
|
1641
|
+
const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
|
|
1324
1642
|
|
|
1325
|
-
sum[j0/nwarps*mmq_y/
|
|
1326
|
-
&x_qs[i*(2*
|
|
1643
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
|
|
1644
|
+
&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
|
|
1327
1645
|
x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1328
1646
|
}
|
|
1329
1647
|
}
|
|
@@ -1340,72 +1658,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
|
|
|
1340
1658
|
((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
|
|
1341
1659
|
}
|
|
1342
1660
|
|
|
1343
|
-
template <int mmq_y,
|
|
1661
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
|
|
1344
1662
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1663
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1664
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1345
1665
|
|
|
1346
|
-
#
|
|
1666
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1347
1667
|
int * x_qs = (int *) x_tile;
|
|
1348
|
-
half2 * x_dm = (half2 *) (x_qs + 2*
|
|
1668
|
+
half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
|
|
1349
1669
|
#else
|
|
1350
1670
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
|
1351
1671
|
int * x_qs = (int *) x_tile;
|
|
1352
1672
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
1353
1673
|
int * x_sc = (int *) (x_dm + txs.dm);
|
|
1354
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1674
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1675
|
+
|
|
1676
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
|
|
1677
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
1678
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1355
1679
|
|
|
1356
1680
|
#pragma unroll
|
|
1357
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1358
|
-
int i = i0 + threadIdx.y;
|
|
1681
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1682
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1359
1683
|
|
|
1360
1684
|
if (need_check) {
|
|
1361
1685
|
i = min(i, i_max);
|
|
1362
1686
|
}
|
|
1363
1687
|
|
|
1364
1688
|
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
|
1365
|
-
const int qs0 = get_int_b4(bxi->qs,
|
|
1689
|
+
const int qs0 = get_int_b4(bxi->qs, txi);
|
|
1366
1690
|
|
|
1367
|
-
#
|
|
1368
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(
|
|
1369
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(
|
|
1691
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1692
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
|
|
1693
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
|
|
1370
1694
|
#else
|
|
1371
|
-
x_qs[i*(
|
|
1372
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1695
|
+
x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
|
|
1696
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1373
1697
|
}
|
|
1374
1698
|
|
|
1375
|
-
#
|
|
1376
|
-
|
|
1699
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1700
|
+
constexpr int rows_per_warp = warp_size / 2;
|
|
1377
1701
|
#pragma unroll
|
|
1378
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1702
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1703
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1704
|
+
// Need if on AMD instead of % because warp_size == 64
|
|
1705
|
+
// This causes double work and throughput loss (MI300X)
|
|
1706
|
+
// H100 loses about 100 t/s with 'if' condition over '%'
|
|
1707
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
|
|
1708
|
+
if (i < mmq_y) {
|
|
1709
|
+
#else
|
|
1710
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
|
1711
|
+
{
|
|
1712
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
1713
|
+
if (need_check) {
|
|
1714
|
+
i = min(i, i_max);
|
|
1715
|
+
}
|
|
1384
1716
|
|
|
1385
|
-
|
|
1717
|
+
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
|
|
1386
1718
|
|
|
1387
|
-
|
|
1388
|
-
|
|
1719
|
+
const int * scales = (const int *) bxi->scales;
|
|
1720
|
+
const int ksc = threadIdx.x % 2;
|
|
1389
1721
|
|
|
1390
|
-
|
|
1391
|
-
|
|
1722
|
+
const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
|
|
1723
|
+
const int m32 = unpack_scales_q45_K(scales, ksc + 2);
|
|
1392
1724
|
|
|
1393
|
-
|
|
1394
|
-
|
|
1725
|
+
const uint8_t * sc8 = (const uint8_t *) &sc32;
|
|
1726
|
+
const uint8_t * m8 = (const uint8_t *) &m32;
|
|
1395
1727
|
|
|
1396
|
-
|
|
1728
|
+
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
|
1397
1729
|
|
|
1398
|
-
#pragma unroll
|
|
1399
|
-
|
|
1400
|
-
|
|
1730
|
+
#pragma unroll
|
|
1731
|
+
for (int l = 0; l < sizeof(int); ++l) {
|
|
1732
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
|
1733
|
+
}
|
|
1401
1734
|
}
|
|
1402
1735
|
}
|
|
1403
|
-
|
|
1404
1736
|
#else
|
|
1405
|
-
|
|
1406
1737
|
#pragma unroll
|
|
1407
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1408
|
-
int i = (i0 + threadIdx.y*
|
|
1738
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
1739
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1409
1740
|
|
|
1410
1741
|
if (need_check) {
|
|
1411
1742
|
i = min(i, i_max);
|
|
@@ -1415,30 +1746,32 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1415
1746
|
|
|
1416
1747
|
x_dm[i] = bxi->dm;
|
|
1417
1748
|
}
|
|
1418
|
-
|
|
1749
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1419
1750
|
#pragma unroll
|
|
1420
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1421
|
-
int i = (i0 + threadIdx.y
|
|
1751
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1752
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
|
1422
1753
|
|
|
1423
1754
|
if (need_check) {
|
|
1424
1755
|
i = min(i, i_max);
|
|
1425
1756
|
}
|
|
1426
1757
|
|
|
1427
|
-
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (
|
|
1758
|
+
const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
|
|
1428
1759
|
|
|
1429
1760
|
const int * scales = (const int *) bxi->scales;
|
|
1430
1761
|
|
|
1431
|
-
const int ksc = threadIdx.x % (
|
|
1762
|
+
const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
|
|
1432
1763
|
const int scales8 = unpack_scales_q45_K(scales, ksc);
|
|
1433
1764
|
|
|
1434
|
-
x_sc[i*(
|
|
1765
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
|
1435
1766
|
}
|
|
1436
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1767
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1437
1768
|
}
|
|
1438
1769
|
|
|
1439
|
-
template <int mmq_x, int mmq_y
|
|
1770
|
+
template <int mmq_x, int mmq_y>
|
|
1440
1771
|
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
1441
1772
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1773
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1774
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1442
1775
|
|
|
1443
1776
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
|
|
1444
1777
|
const int * x_qs = (const int *) x;
|
|
@@ -1448,7 +1781,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
|
1448
1781
|
const half2 * y_ds = (const half2 *) y;
|
|
1449
1782
|
|
|
1450
1783
|
// #pragma unroll
|
|
1451
|
-
for (int k01 = 0; k01 <
|
|
1784
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
|
|
1452
1785
|
const int k0 = k00 + k01;
|
|
1453
1786
|
|
|
1454
1787
|
#pragma unroll
|
|
@@ -1456,97 +1789,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
|
|
|
1456
1789
|
const int j = j0 + threadIdx.y;
|
|
1457
1790
|
|
|
1458
1791
|
#pragma unroll
|
|
1459
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1792
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1460
1793
|
const int i = i0 + threadIdx.x;
|
|
1461
1794
|
|
|
1462
|
-
const uint8_t * sc = (const uint8_t *) &x_sc[i * (
|
|
1795
|
+
const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
|
|
1463
1796
|
|
|
1464
|
-
sum[j0/nwarps*mmq_y/
|
|
1465
|
-
&x_qs[i*(
|
|
1797
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
|
|
1798
|
+
&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
|
1466
1799
|
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1467
1800
|
}
|
|
1468
1801
|
}
|
|
1469
1802
|
}
|
|
1470
1803
|
}
|
|
1471
1804
|
|
|
1472
|
-
template <int mmq_y,
|
|
1805
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
|
|
1473
1806
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1807
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1808
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1474
1809
|
|
|
1475
|
-
#
|
|
1810
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1476
1811
|
int * x_qs = (int *) x_tile;
|
|
1477
|
-
half2 * x_dm = (half2 *) (x_qs +
|
|
1812
|
+
half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1478
1813
|
#else
|
|
1479
1814
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
|
1480
1815
|
int * x_qs = (int *) x_tile;
|
|
1481
1816
|
half2 * x_dm = (half2 *) (x_qs + txs.qs);
|
|
1482
1817
|
int * x_sc = (int *) (x_dm + txs.dm);
|
|
1483
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1818
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1819
|
+
|
|
1820
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
|
|
1821
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
1822
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1484
1823
|
|
|
1485
1824
|
#pragma unroll
|
|
1486
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1487
|
-
int i = i0 + threadIdx.y;
|
|
1825
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1826
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1488
1827
|
|
|
1489
1828
|
if (need_check) {
|
|
1490
1829
|
i = min(i, i_max);
|
|
1491
1830
|
}
|
|
1492
1831
|
|
|
1493
1832
|
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
|
|
1494
|
-
const int ky = QR5_K*
|
|
1833
|
+
const int ky = QR5_K*txi;
|
|
1495
1834
|
|
|
1496
|
-
const int ql = get_int_b4(bxi->qs,
|
|
1835
|
+
const int ql = get_int_b4(bxi->qs, txi);
|
|
1497
1836
|
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
|
1498
1837
|
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
|
1499
1838
|
|
|
1500
|
-
const int qh = get_int_b4(bxi->qh,
|
|
1501
|
-
const int qh0 = ((qh >> (2 * (
|
|
1502
|
-
const int qh1 = ((qh >> (2 * (
|
|
1839
|
+
const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
|
|
1840
|
+
const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
|
|
1841
|
+
const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
|
|
1503
1842
|
|
|
1504
|
-
const int kq0 = ky - ky % (QI5_K/2) +
|
|
1505
|
-
const int kq1 = ky - ky % (QI5_K/2) +
|
|
1843
|
+
const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
|
|
1844
|
+
const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
|
|
1506
1845
|
|
|
1507
|
-
#
|
|
1846
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1508
1847
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
|
|
1509
1848
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
|
|
1510
1849
|
#else
|
|
1511
|
-
x_qs[i*(2*
|
|
1512
|
-
x_qs[i*(2*
|
|
1513
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1850
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
|
|
1851
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
|
|
1852
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1514
1853
|
}
|
|
1515
1854
|
|
|
1516
|
-
#
|
|
1517
|
-
|
|
1855
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1856
|
+
constexpr int rows_per_warp = warp_size / 2;
|
|
1518
1857
|
#pragma unroll
|
|
1519
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1858
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1859
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
1860
|
+
// Need if on AMD instead of % because warp_size == 64
|
|
1861
|
+
// This causes double work and throughput loss (MI300X)
|
|
1862
|
+
// H100 loses about 100 t/s with 'if' condition over '%'
|
|
1863
|
+
int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
|
|
1864
|
+
if (i < mmq_y) {
|
|
1865
|
+
#else
|
|
1866
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
|
|
1867
|
+
{
|
|
1868
|
+
#endif // defined(AMD_MFMA_AVAILABLE)
|
|
1869
|
+
if (need_check) {
|
|
1870
|
+
i = min(i, i_max);
|
|
1871
|
+
}
|
|
1525
1872
|
|
|
1526
|
-
|
|
1873
|
+
const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
|
|
1527
1874
|
|
|
1528
|
-
|
|
1529
|
-
|
|
1875
|
+
const int * scales = (const int *) bxi->scales;
|
|
1876
|
+
const int ksc = threadIdx.x % 2;
|
|
1530
1877
|
|
|
1531
|
-
|
|
1532
|
-
|
|
1878
|
+
const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
|
|
1879
|
+
const int m32 = unpack_scales_q45_K(scales, ksc + 2);
|
|
1533
1880
|
|
|
1534
|
-
|
|
1535
|
-
|
|
1881
|
+
const uint8_t * sc8 = (const uint8_t *) &sc32;
|
|
1882
|
+
const uint8_t * m8 = (const uint8_t *) &m32;
|
|
1536
1883
|
|
|
1537
|
-
|
|
1884
|
+
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
|
1538
1885
|
|
|
1539
1886
|
#pragma unroll
|
|
1540
|
-
|
|
1541
|
-
|
|
1887
|
+
for (int l = 0; l < int(sizeof(int)); ++l) {
|
|
1888
|
+
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
|
1889
|
+
}
|
|
1542
1890
|
}
|
|
1543
1891
|
}
|
|
1544
|
-
|
|
1545
1892
|
#else
|
|
1546
|
-
|
|
1547
1893
|
#pragma unroll
|
|
1548
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1549
|
-
int i = (i0 + threadIdx.y*
|
|
1894
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
1895
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1550
1896
|
|
|
1551
1897
|
if (need_check) {
|
|
1552
1898
|
i = min(i, i_max);
|
|
@@ -1557,9 +1903,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1557
1903
|
x_dm[i] = bxi->dm;
|
|
1558
1904
|
}
|
|
1559
1905
|
|
|
1906
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1560
1907
|
#pragma unroll
|
|
1561
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*
|
|
1562
|
-
int i = (i0 + threadIdx.y*
|
|
1908
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
1909
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
|
1563
1910
|
|
|
1564
1911
|
if (need_check) {
|
|
1565
1912
|
i = min(i, i_max);
|
|
@@ -1569,17 +1916,19 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1569
1916
|
|
|
1570
1917
|
const int * scales = (const int *) bxi->scales;
|
|
1571
1918
|
|
|
1572
|
-
const int ksc = threadIdx.x % (
|
|
1919
|
+
const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
|
|
1573
1920
|
const int scales8 = unpack_scales_q45_K(scales, ksc);
|
|
1574
1921
|
|
|
1575
|
-
x_sc[i*(
|
|
1922
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
|
|
1576
1923
|
}
|
|
1577
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1924
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1578
1925
|
}
|
|
1579
1926
|
|
|
1580
|
-
template <int mmq_x, int mmq_y
|
|
1927
|
+
template <int mmq_x, int mmq_y>
|
|
1581
1928
|
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
1582
1929
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1930
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1931
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1583
1932
|
|
|
1584
1933
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
|
|
1585
1934
|
const int * x_qs = (const int *) x;
|
|
@@ -1589,7 +1938,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
|
1589
1938
|
const half2 * y_ds = (const half2 *) y;
|
|
1590
1939
|
|
|
1591
1940
|
// #pragma unroll
|
|
1592
|
-
for (int k01 = 0; k01 <
|
|
1941
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
|
|
1593
1942
|
const int k0 = k00 + k01;
|
|
1594
1943
|
|
|
1595
1944
|
#pragma unroll
|
|
@@ -1597,36 +1946,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
|
|
|
1597
1946
|
const int j = j0 + threadIdx.y;
|
|
1598
1947
|
|
|
1599
1948
|
#pragma unroll
|
|
1600
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
1949
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1601
1950
|
const int i = i0 + threadIdx.x;
|
|
1602
1951
|
|
|
1603
|
-
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (
|
|
1952
|
+
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
|
|
1604
1953
|
|
|
1605
|
-
sum[j0/nwarps*mmq_y/
|
|
1606
|
-
&x_qs[i*(QR5_K*
|
|
1954
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
|
|
1955
|
+
&x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
|
|
1607
1956
|
x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1608
1957
|
}
|
|
1609
1958
|
}
|
|
1610
1959
|
}
|
|
1611
1960
|
}
|
|
1612
1961
|
|
|
1613
|
-
template <int mmq_y,
|
|
1962
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
|
|
1614
1963
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
1964
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
1965
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1615
1966
|
|
|
1616
|
-
#
|
|
1967
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1617
1968
|
int * x_qs = (int *) x_tile;
|
|
1618
|
-
float * x_df = (float *) (x_qs +
|
|
1619
|
-
int * x_sc = (int *) (x_df +
|
|
1969
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1970
|
+
int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
|
|
1620
1971
|
#else
|
|
1621
1972
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
|
1622
1973
|
int * x_qs = (int *) x_tile;
|
|
1623
1974
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1624
1975
|
int * x_sc = (int *) (x_df + txs.dm);
|
|
1625
|
-
#endif // NEW_MMA_AVAILABLE
|
|
1976
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1977
|
+
|
|
1978
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
|
|
1979
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
1980
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1626
1981
|
|
|
1627
1982
|
#pragma unroll
|
|
1628
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1629
|
-
int i = i0 + threadIdx.y;
|
|
1983
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
1984
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1630
1985
|
|
|
1631
1986
|
if (need_check) {
|
|
1632
1987
|
i = min(i, i_max);
|
|
@@ -1634,67 +1989,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1634
1989
|
|
|
1635
1990
|
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
|
1636
1991
|
|
|
1637
|
-
const int ql = get_int_b2(bxi->ql,
|
|
1992
|
+
const int ql = get_int_b2(bxi->ql, txi);
|
|
1638
1993
|
const int ql0 = (ql >> 0) & 0x0F0F0F0F;
|
|
1639
1994
|
const int ql1 = (ql >> 4) & 0x0F0F0F0F;
|
|
1640
1995
|
|
|
1641
|
-
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (
|
|
1642
|
-
const int qh0 = ((qh >> ((
|
|
1643
|
-
const int qh1 = (qh >> ((
|
|
1996
|
+
const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
|
|
1997
|
+
const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
|
|
1998
|
+
const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
|
|
1644
1999
|
|
|
1645
|
-
const int kq0 = 2*
|
|
1646
|
-
const int kq1 = 2*
|
|
2000
|
+
const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
|
|
2001
|
+
const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
|
|
1647
2002
|
|
|
1648
|
-
#
|
|
2003
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1649
2004
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
|
1650
2005
|
x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
|
1651
2006
|
#else
|
|
1652
|
-
x_qs[i*(2*
|
|
1653
|
-
x_qs[i*(2*
|
|
1654
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2007
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
|
|
2008
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
|
|
2009
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1655
2010
|
}
|
|
1656
2011
|
|
|
1657
|
-
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
|
1658
|
-
const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
|
1659
|
-
|
|
1660
2012
|
#pragma unroll
|
|
1661
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1662
|
-
int i = (i0 + threadIdx.y
|
|
2013
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
|
|
2014
|
+
int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
|
|
1663
2015
|
|
|
1664
2016
|
if (need_check) {
|
|
1665
2017
|
i = min(i, i_max);
|
|
1666
2018
|
}
|
|
1667
2019
|
|
|
1668
|
-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride
|
|
2020
|
+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
|
|
1669
2021
|
|
|
1670
|
-
#
|
|
1671
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q6_K
|
|
2022
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2023
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
|
|
1672
2024
|
#else
|
|
1673
|
-
x_df[i*(
|
|
1674
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2025
|
+
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
|
|
2026
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1675
2027
|
}
|
|
1676
2028
|
|
|
2029
|
+
constexpr int rows_per_warp = warp_size / 4;
|
|
1677
2030
|
#pragma unroll
|
|
1678
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps
|
|
1679
|
-
int i = (i0 + threadIdx.y
|
|
2031
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
|
|
2032
|
+
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
|
|
1680
2033
|
|
|
1681
2034
|
if (need_check) {
|
|
1682
2035
|
i = min(i, i_max);
|
|
1683
2036
|
}
|
|
1684
2037
|
|
|
1685
|
-
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (
|
|
2038
|
+
const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
|
|
1686
2039
|
|
|
1687
|
-
#
|
|
1688
|
-
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x
|
|
2040
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2041
|
+
x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
|
|
1689
2042
|
#else
|
|
1690
|
-
x_sc[i*(
|
|
1691
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2043
|
+
x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
|
|
2044
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1692
2045
|
}
|
|
1693
2046
|
}
|
|
1694
2047
|
|
|
1695
|
-
template <int mmq_x, int mmq_y
|
|
2048
|
+
template <int mmq_x, int mmq_y>
|
|
1696
2049
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
1697
2050
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
2051
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2052
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1698
2053
|
|
|
1699
2054
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
|
|
1700
2055
|
const int * x_qs = (const int *) x;
|
|
@@ -1704,7 +2059,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
|
1704
2059
|
const float * y_df = (const float *) y;
|
|
1705
2060
|
|
|
1706
2061
|
// #pragma unroll
|
|
1707
|
-
for (int k01 = 0; k01 <
|
|
2062
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
|
|
1708
2063
|
const int k0 = k00 + k01;
|
|
1709
2064
|
|
|
1710
2065
|
#pragma unroll
|
|
@@ -1712,23 +2067,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
|
|
|
1712
2067
|
const int j = j0 + threadIdx.y;
|
|
1713
2068
|
|
|
1714
2069
|
#pragma unroll
|
|
1715
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
2070
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
1716
2071
|
const int i = i0 + threadIdx.x;
|
|
1717
2072
|
|
|
1718
|
-
const int8_t * sc = ((const int8_t *) &x_sc[i * (
|
|
2073
|
+
const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
|
|
1719
2074
|
|
|
1720
|
-
sum[j0/nwarps*mmq_y/
|
|
1721
|
-
&x_qs[i*(QR6_K*
|
|
1722
|
-
x_df[i*(
|
|
2075
|
+
sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
|
|
2076
|
+
&x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
|
|
2077
|
+
x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
|
1723
2078
|
}
|
|
1724
2079
|
}
|
|
1725
2080
|
}
|
|
1726
2081
|
}
|
|
1727
2082
|
|
|
1728
|
-
template <int mmq_x, int mmq_y
|
|
2083
|
+
template <int mmq_x, int mmq_y>
|
|
1729
2084
|
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
1730
2085
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
1731
|
-
#
|
|
2086
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
2087
|
+
typedef tile<16, 8, int> tile_A;
|
|
2088
|
+
typedef tile<16, 8, int> tile_B;
|
|
2089
|
+
typedef tile<16, 16, int> tile_C;
|
|
2090
|
+
typedef tile<64, 2, int> tile_load;
|
|
2091
|
+
|
|
2092
|
+
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2093
|
+
constexpr int rows_per_warp = granularity;
|
|
2094
|
+
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2095
|
+
|
|
2096
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
2097
|
+
|
|
2098
|
+
const int * x_qs = (const int *) x;
|
|
2099
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2100
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
2101
|
+
const int * y_qs = (const int *) y + 4;
|
|
2102
|
+
const float * y_df = (const float *) y;
|
|
2103
|
+
|
|
2104
|
+
const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
2105
|
+
|
|
2106
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
|
|
2107
|
+
const int k0 = k00 + k01;
|
|
2108
|
+
|
|
2109
|
+
tile_A A[ntx];
|
|
2110
|
+
#pragma unroll
|
|
2111
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2112
|
+
load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
|
|
2113
|
+
}
|
|
2114
|
+
|
|
2115
|
+
#pragma unroll
|
|
2116
|
+
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
2117
|
+
tile_B B[1];
|
|
2118
|
+
load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
|
|
2119
|
+
|
|
2120
|
+
const int j = j0 + tile_C::get_j(0);
|
|
2121
|
+
const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
|
|
2122
|
+
|
|
2123
|
+
#pragma unroll
|
|
2124
|
+
for (int n = 0; n < ntx; ++n) {
|
|
2125
|
+
tile_C C;
|
|
2126
|
+
mma(C, A[n], B[0]);
|
|
2127
|
+
|
|
2128
|
+
#pragma unroll
|
|
2129
|
+
for (int l = 0; l < tile_C::ne; ++l) {
|
|
2130
|
+
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
|
2131
|
+
const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
|
|
2132
|
+
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
|
|
2133
|
+
}
|
|
2134
|
+
}
|
|
2135
|
+
}
|
|
2136
|
+
}
|
|
2137
|
+
#elif defined(NEW_MMA_AVAILABLE)
|
|
1732
2138
|
|
|
1733
2139
|
typedef tile<16, 4, int> tile_A;
|
|
1734
2140
|
typedef tile< 8, 4, int> tile_B;
|
|
@@ -1738,11 +2144,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1738
2144
|
constexpr int rows_per_warp = 2 * granularity;
|
|
1739
2145
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
1740
2146
|
|
|
1741
|
-
y += (threadIdx.y % ntx) * (
|
|
2147
|
+
y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
|
|
1742
2148
|
|
|
1743
2149
|
const int * x_qs = (const int *) x;
|
|
1744
|
-
const float * x_df = (const float *) x_qs +
|
|
1745
|
-
const int * x_sc = (const int *) x_df +
|
|
2150
|
+
const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
|
|
2151
|
+
const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
|
|
1746
2152
|
const int * y_qs = (const int *) y + 4;
|
|
1747
2153
|
const float * y_df = (const float *) y;
|
|
1748
2154
|
|
|
@@ -1755,7 +2161,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1755
2161
|
#pragma unroll
|
|
1756
2162
|
for (int n = 0; n < ntx; ++n) {
|
|
1757
2163
|
#pragma unroll
|
|
1758
|
-
for (int k01 = 0; k01 <
|
|
2164
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
|
1759
2165
|
const int k0 = k00 + k01;
|
|
1760
2166
|
|
|
1761
2167
|
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
|
@@ -1763,7 +2169,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1763
2169
|
}
|
|
1764
2170
|
|
|
1765
2171
|
#pragma unroll
|
|
1766
|
-
for (int k01 = 0; k01 <
|
|
2172
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
|
|
1767
2173
|
const int k0 = k00 + k01;
|
|
1768
2174
|
|
|
1769
2175
|
#pragma unroll
|
|
@@ -1793,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1793
2199
|
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
|
1794
2200
|
|
|
1795
2201
|
#pragma unroll
|
|
1796
|
-
for (int k01 = 0; k01 <
|
|
2202
|
+
for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
|
|
1797
2203
|
tile_B B[2];
|
|
1798
2204
|
float dB[tile_C::ne/2];
|
|
1799
2205
|
|
|
@@ -1832,27 +2238,32 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
|
|
1832
2238
|
#else
|
|
1833
2239
|
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
|
|
1834
2240
|
NO_DEVICE_CODE;
|
|
1835
|
-
#endif //
|
|
2241
|
+
#endif // AMD_MFMA_AVAILABLE
|
|
1836
2242
|
}
|
|
1837
2243
|
|
|
1838
|
-
template <int mmq_y,
|
|
2244
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
|
|
1839
2245
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2246
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2247
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1840
2248
|
|
|
1841
|
-
#
|
|
2249
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1842
2250
|
int * x_qs = (int *) x_tile;
|
|
1843
|
-
float * x_df = (float *) (x_qs +
|
|
2251
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1844
2252
|
#else
|
|
1845
2253
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
|
|
1846
2254
|
int * x_qs = (int *) x_tile;
|
|
1847
2255
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1848
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2256
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1849
2257
|
|
|
1850
|
-
|
|
1851
|
-
|
|
2258
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
|
|
2259
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2260
|
+
const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
2261
|
+
const int kbx = txi / QI4_NL;
|
|
2262
|
+
const int kqsx = txi % QI4_NL;
|
|
1852
2263
|
|
|
1853
2264
|
#pragma unroll
|
|
1854
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
1855
|
-
int i = i0 + threadIdx.y;
|
|
2265
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
2266
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
1856
2267
|
|
|
1857
2268
|
if (need_check) {
|
|
1858
2269
|
i = min(i, i_max);
|
|
@@ -1862,22 +2273,24 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1862
2273
|
|
|
1863
2274
|
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
|
|
1864
2275
|
const int2 v = get_int_from_table_16(aux_q4);
|
|
1865
|
-
const int k0 =
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 +
|
|
2276
|
+
const int k0 = kbx * (2 * QI4_NL) + kqsx;
|
|
2277
|
+
|
|
2278
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2279
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
|
2280
|
+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
|
|
1869
2281
|
#else
|
|
1870
|
-
x_qs[i*(2*
|
|
1871
|
-
x_qs[i*(2*
|
|
1872
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2282
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
2283
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
|
|
2284
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1873
2285
|
}
|
|
1874
2286
|
|
|
1875
|
-
|
|
2287
|
+
constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
|
|
2288
|
+
constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
|
|
1876
2289
|
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
|
|
1877
2290
|
|
|
1878
2291
|
#pragma unroll
|
|
1879
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
1880
|
-
int i = i0 + threadIdx.y *
|
|
2292
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
2293
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
|
|
1881
2294
|
|
|
1882
2295
|
if (need_check) {
|
|
1883
2296
|
i = min(i, i_max);
|
|
@@ -1885,31 +2298,35 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1885
2298
|
|
|
1886
2299
|
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
|
|
1887
2300
|
|
|
1888
|
-
#
|
|
1889
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2301
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2302
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
|
|
1890
2303
|
#else
|
|
1891
|
-
x_df[i*(
|
|
1892
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2304
|
+
x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
|
|
2305
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1893
2306
|
}
|
|
1894
2307
|
}
|
|
1895
2308
|
|
|
1896
|
-
template <int mmq_y,
|
|
2309
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
|
|
1897
2310
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2311
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2312
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1898
2313
|
|
|
1899
|
-
#
|
|
2314
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1900
2315
|
int * x_qs = (int *) x_tile;
|
|
1901
|
-
float * x_df = (float *) (x_qs +
|
|
2316
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1902
2317
|
#else
|
|
1903
2318
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
|
|
1904
2319
|
int * x_qs = (int *) x_tile;
|
|
1905
2320
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1906
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2321
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1907
2322
|
|
|
1908
|
-
|
|
2323
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
|
|
2324
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2325
|
+
const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
|
|
1909
2326
|
|
|
1910
2327
|
#pragma unroll
|
|
1911
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
1912
|
-
int i = i0 + threadIdx.y*
|
|
2328
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2329
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
1913
2330
|
|
|
1914
2331
|
if (need_check) {
|
|
1915
2332
|
i = min(i, i_max);
|
|
@@ -1932,42 +2349,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1932
2349
|
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
|
1933
2350
|
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
|
1934
2351
|
|
|
1935
|
-
#
|
|
2352
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1936
2353
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
|
1937
2354
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
|
|
1938
2355
|
#else
|
|
1939
|
-
x_qs[i*(2*
|
|
1940
|
-
x_qs[i*(2*
|
|
1941
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2356
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
|
|
2357
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
|
|
2358
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1942
2359
|
}
|
|
1943
2360
|
|
|
1944
2361
|
const int ls = aux32 >> 28;
|
|
1945
2362
|
const float d = bxi->d;
|
|
1946
|
-
#
|
|
1947
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2363
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2364
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
|
1948
2365
|
#else
|
|
1949
|
-
x_df[i*(
|
|
1950
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2366
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
|
2367
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1951
2368
|
}
|
|
1952
2369
|
}
|
|
1953
2370
|
|
|
1954
|
-
template <int mmq_y,
|
|
2371
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
|
|
1955
2372
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2373
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2374
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
1956
2375
|
|
|
1957
|
-
#
|
|
2376
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1958
2377
|
int * x_qs = (int *) x_tile;
|
|
1959
|
-
float * x_df = (float *) (x_qs +
|
|
2378
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
1960
2379
|
#else
|
|
1961
2380
|
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
|
|
1962
2381
|
int * x_qs = (int *) x_tile;
|
|
1963
2382
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
1964
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2383
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1965
2384
|
|
|
1966
|
-
|
|
2385
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
|
|
2386
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2387
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
1967
2388
|
|
|
1968
2389
|
#pragma unroll
|
|
1969
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
1970
|
-
int i = i0 + threadIdx.y*
|
|
2390
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2391
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
1971
2392
|
|
|
1972
2393
|
if (need_check) {
|
|
1973
2394
|
i = min(i, i_max);
|
|
@@ -1986,44 +2407,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
1986
2407
|
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
|
1987
2408
|
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
|
1988
2409
|
|
|
1989
|
-
#
|
|
2410
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1990
2411
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
1991
2412
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
|
1992
2413
|
#else
|
|
1993
|
-
x_qs[i*(2*
|
|
1994
|
-
x_qs[i*(2*
|
|
1995
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2414
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2415
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2416
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
1996
2417
|
}
|
|
1997
2418
|
|
|
1998
2419
|
const int ls = bxi->scales[kqsx];
|
|
1999
2420
|
const float d = bxi->d;
|
|
2000
|
-
#
|
|
2001
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2002
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2421
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2422
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2423
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2003
2424
|
#else
|
|
2004
|
-
x_df[i*(2*
|
|
2005
|
-
x_df[i*(2*
|
|
2006
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2425
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2426
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2427
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2007
2428
|
}
|
|
2008
2429
|
}
|
|
2009
2430
|
|
|
2010
|
-
template <int mmq_y,
|
|
2431
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
|
|
2011
2432
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2433
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2434
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2012
2435
|
|
|
2013
|
-
#
|
|
2436
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2014
2437
|
int * x_qs = (int *) x_tile;
|
|
2015
|
-
float * x_df = (float *) (x_qs +
|
|
2438
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2016
2439
|
#else
|
|
2017
2440
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
|
|
2018
2441
|
int * x_qs = (int *) x_tile;
|
|
2019
2442
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2020
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2443
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2021
2444
|
|
|
2022
|
-
|
|
2445
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
|
|
2446
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2447
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2023
2448
|
|
|
2024
2449
|
#pragma unroll
|
|
2025
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2026
|
-
int i = i0 + threadIdx.y*
|
|
2450
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2451
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2027
2452
|
|
|
2028
2453
|
if (need_check) {
|
|
2029
2454
|
i = min(i, i_max);
|
|
@@ -2049,44 +2474,48 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2049
2474
|
const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
|
|
2050
2475
|
const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
|
|
2051
2476
|
|
|
2052
|
-
#
|
|
2477
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2053
2478
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2054
2479
|
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2055
2480
|
#else
|
|
2056
|
-
x_qs[i*(2*
|
|
2057
|
-
x_qs[i*(2*
|
|
2058
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2481
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2482
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2483
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2059
2484
|
}
|
|
2060
2485
|
|
|
2061
2486
|
const int ls = bxi->scales[kqsx];
|
|
2062
2487
|
const float d = bxi->d;
|
|
2063
|
-
#
|
|
2064
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2065
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q3_K
|
|
2488
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2489
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2490
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2066
2491
|
#else
|
|
2067
|
-
x_df[i*(2*
|
|
2068
|
-
x_df[i*(2*
|
|
2069
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2492
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
|
|
2493
|
+
x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
|
|
2494
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2070
2495
|
}
|
|
2071
2496
|
}
|
|
2072
2497
|
|
|
2073
|
-
template <int mmq_y,
|
|
2498
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
|
|
2074
2499
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2500
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2501
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2075
2502
|
|
|
2076
|
-
#
|
|
2503
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2077
2504
|
int * x_qs = (int *) x_tile;
|
|
2078
|
-
float * x_df = (float *) (x_qs +
|
|
2505
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2079
2506
|
#else
|
|
2080
2507
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
|
|
2081
2508
|
int * x_qs = (int *) x_tile;
|
|
2082
2509
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2083
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2510
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2084
2511
|
|
|
2085
|
-
|
|
2512
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
|
|
2513
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2514
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2086
2515
|
|
|
2087
2516
|
#pragma unroll
|
|
2088
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2089
|
-
int i = i0 + threadIdx.y*
|
|
2517
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2518
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2090
2519
|
|
|
2091
2520
|
if (need_check) {
|
|
2092
2521
|
i = min(i, i_max);
|
|
@@ -2107,42 +2536,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2107
2536
|
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
|
2108
2537
|
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
|
2109
2538
|
|
|
2110
|
-
#
|
|
2539
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2111
2540
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2112
2541
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2113
2542
|
#else
|
|
2114
|
-
x_qs[i*(2*
|
|
2115
|
-
x_qs[i*(2*
|
|
2116
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2543
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
|
|
2544
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
|
|
2545
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2117
2546
|
}
|
|
2118
2547
|
|
|
2119
2548
|
const int ls = aux32 >> 28;
|
|
2120
2549
|
const float d = bxi->d;
|
|
2121
|
-
#
|
|
2122
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2550
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2551
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
|
|
2123
2552
|
#else
|
|
2124
|
-
x_df[i*(
|
|
2125
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2553
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
|
|
2554
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2126
2555
|
}
|
|
2127
2556
|
}
|
|
2128
2557
|
|
|
2129
|
-
template <int mmq_y,
|
|
2558
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
|
|
2130
2559
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2560
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2561
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2131
2562
|
|
|
2132
|
-
#
|
|
2563
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2133
2564
|
int * x_qs = (int *) x_tile;
|
|
2134
|
-
float * x_df = (float *) (x_qs +
|
|
2565
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2135
2566
|
#else
|
|
2136
2567
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
|
2137
2568
|
int * x_qs = (int *) x_tile;
|
|
2138
2569
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2139
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2570
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2140
2571
|
|
|
2141
|
-
|
|
2572
|
+
constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
|
|
2573
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2574
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2142
2575
|
|
|
2143
2576
|
#pragma unroll
|
|
2144
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2145
|
-
int i = i0 + threadIdx.y*
|
|
2577
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2578
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2146
2579
|
|
|
2147
2580
|
if (need_check) {
|
|
2148
2581
|
i = min(i, i_max);
|
|
@@ -2170,42 +2603,46 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2170
2603
|
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
|
2171
2604
|
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
|
2172
2605
|
|
|
2173
|
-
#
|
|
2606
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2174
2607
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
|
|
2175
2608
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
|
|
2176
2609
|
#else
|
|
2177
|
-
x_qs[i*(2*
|
|
2178
|
-
x_qs[i*(2*
|
|
2179
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2610
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
|
|
2611
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
|
|
2612
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2180
2613
|
}
|
|
2181
2614
|
|
|
2182
2615
|
const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
|
|
2183
2616
|
const float d = bxi->d;
|
|
2184
|
-
#
|
|
2185
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2617
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2618
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
|
|
2186
2619
|
#else
|
|
2187
|
-
x_df[i*(
|
|
2188
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2620
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
|
|
2621
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2189
2622
|
}
|
|
2190
2623
|
}
|
|
2191
2624
|
|
|
2192
|
-
template <int mmq_y,
|
|
2625
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
|
|
2193
2626
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2627
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2628
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2194
2629
|
|
|
2195
|
-
#
|
|
2630
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2196
2631
|
int * x_qs = (int *) x_tile;
|
|
2197
|
-
half2 * x_ds = (half2 *) (x_qs +
|
|
2632
|
+
half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2198
2633
|
#else
|
|
2199
2634
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
|
|
2200
2635
|
int * x_qs = (int *) x_tile;
|
|
2201
2636
|
half2 * x_ds = (half2 *) (x_qs + txs.qs);
|
|
2202
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2637
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2203
2638
|
|
|
2204
|
-
|
|
2639
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
|
|
2640
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2641
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2205
2642
|
|
|
2206
2643
|
#pragma unroll
|
|
2207
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2208
|
-
int i = i0 + threadIdx.y*
|
|
2644
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
|
|
2645
|
+
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
|
|
2209
2646
|
|
|
2210
2647
|
if (need_check) {
|
|
2211
2648
|
i = min(i, i_max);
|
|
@@ -2225,66 +2662,71 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2225
2662
|
const int grid0 = (grid >> 0) & 0x0F0F0F0F;
|
|
2226
2663
|
const int grid1 = (grid >> 4) & 0x0F0F0F0F;
|
|
2227
2664
|
|
|
2228
|
-
#
|
|
2665
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2229
2666
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
|
|
2230
2667
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
|
|
2231
2668
|
#else
|
|
2232
|
-
x_qs[i*(2*
|
|
2233
|
-
x_qs[i*(2*
|
|
2234
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2669
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
|
|
2670
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
|
|
2671
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2235
2672
|
}
|
|
2236
2673
|
|
|
2237
2674
|
const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
|
|
2238
2675
|
const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
|
|
2239
2676
|
|
|
2240
|
-
#
|
|
2241
|
-
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1
|
|
2677
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2678
|
+
x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
|
|
2242
2679
|
#else
|
|
2243
|
-
x_ds[i*(
|
|
2244
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2680
|
+
x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
|
|
2681
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2245
2682
|
}
|
|
2246
2683
|
}
|
|
2247
2684
|
|
|
2248
|
-
template <int mmq_y,
|
|
2685
|
+
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
|
|
2249
2686
|
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
|
|
2687
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2688
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2250
2689
|
|
|
2251
|
-
#
|
|
2690
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2252
2691
|
int * x_qs = (int *) x_tile;
|
|
2253
|
-
float * x_df = (float *) (x_qs +
|
|
2692
|
+
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
|
2254
2693
|
#else
|
|
2255
2694
|
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
|
|
2256
2695
|
int * x_qs = (int *) x_tile;
|
|
2257
2696
|
float * x_df = (float *) (x_qs + txs.qs);
|
|
2258
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2697
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2259
2698
|
|
|
2260
|
-
|
|
2261
|
-
|
|
2699
|
+
constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
|
|
2700
|
+
constexpr int nrows = warp_size / threads_per_row;
|
|
2701
|
+
const int kqsx = threadIdx.x % threads_per_row;
|
|
2262
2702
|
|
|
2263
2703
|
#pragma unroll
|
|
2264
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
|
|
2265
|
-
int i = i0 + threadIdx.y;
|
|
2704
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
|
|
2705
|
+
int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
|
|
2266
2706
|
|
|
2267
2707
|
if (need_check) {
|
|
2268
2708
|
i = min(i, i_max);
|
|
2269
2709
|
}
|
|
2270
2710
|
|
|
2271
|
-
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride
|
|
2711
|
+
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
|
|
2272
2712
|
|
|
2273
2713
|
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
|
|
2274
2714
|
const int2 v = get_int_from_table_16(aux_q4);
|
|
2275
|
-
const int k0 = 8 * (
|
|
2276
|
-
|
|
2715
|
+
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
|
|
2716
|
+
|
|
2717
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2277
2718
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
|
|
2278
2719
|
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
|
|
2279
2720
|
#else
|
|
2280
|
-
x_qs[i*(2*
|
|
2281
|
-
x_qs[i*(2*
|
|
2282
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2721
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
|
|
2722
|
+
x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
|
|
2723
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2283
2724
|
}
|
|
2284
2725
|
|
|
2726
|
+
constexpr int rows_per_warp = warp_size / 8;
|
|
2285
2727
|
#pragma unroll
|
|
2286
|
-
for (int i0 = 0; i0 < mmq_y; i0 += nwarps *
|
|
2287
|
-
int i = i0 + threadIdx.y *
|
|
2728
|
+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
|
|
2729
|
+
int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
|
|
2288
2730
|
|
|
2289
2731
|
if (need_check) {
|
|
2290
2732
|
i = min(i, i_max);
|
|
@@ -2297,18 +2739,21 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
|
|
2297
2739
|
const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
|
|
2298
2740
|
| (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
|
|
2299
2741
|
|
|
2300
|
-
#
|
|
2301
|
-
x_df[i*MMQ_MMA_TILE_X_K_Q8_0
|
|
2742
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2743
|
+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
|
|
2302
2744
|
#else
|
|
2303
|
-
x_df[i*(
|
|
2304
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2745
|
+
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
|
|
2746
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2305
2747
|
}
|
|
2306
2748
|
}
|
|
2307
2749
|
|
|
2308
|
-
template<int mmq_x, int mmq_y,
|
|
2750
|
+
template<int mmq_x, int mmq_y, bool need_check>
|
|
2309
2751
|
static __device__ __forceinline__ void mmq_write_back_dp4a(
|
|
2310
2752
|
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
|
|
2311
2753
|
const int stride, const int i_max, const int j_max) {
|
|
2754
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2755
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2756
|
+
|
|
2312
2757
|
#pragma unroll
|
|
2313
2758
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
|
2314
2759
|
const int j = j0 + threadIdx.y;
|
|
@@ -2318,32 +2763,40 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
|
|
|
2318
2763
|
}
|
|
2319
2764
|
|
|
2320
2765
|
#pragma unroll
|
|
2321
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
2766
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2322
2767
|
const int i = i0 + threadIdx.x;
|
|
2323
2768
|
|
|
2324
2769
|
if (need_check && i > i_max) {
|
|
2325
2770
|
continue;
|
|
2326
2771
|
}
|
|
2327
2772
|
|
|
2328
|
-
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/
|
|
2773
|
+
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
2329
2774
|
}
|
|
2330
2775
|
}
|
|
2331
2776
|
}
|
|
2332
2777
|
|
|
2333
|
-
template<
|
|
2778
|
+
template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
|
|
2334
2779
|
static __device__ __forceinline__ void mmq_write_back_mma(
|
|
2335
2780
|
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
|
|
2336
2781
|
const int stride, const int i_max, const int j_max) {
|
|
2337
|
-
typedef tile<16, 8, int> tile_C;
|
|
2338
2782
|
|
|
2339
2783
|
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
2784
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2785
|
+
|
|
2786
|
+
#if defined(AMD_MFMA_AVAILABLE)
|
|
2787
|
+
constexpr int tileC_IJ = mmq_get_granularity_device(0);
|
|
2788
|
+
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
|
|
2789
|
+
constexpr int rows_per_warp = granularity;
|
|
2790
|
+
#else
|
|
2791
|
+
typedef tile<16, 8, int> tile_C;
|
|
2340
2792
|
constexpr int rows_per_warp = 2 * granularity;
|
|
2793
|
+
#endif
|
|
2341
2794
|
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
|
2342
2795
|
|
|
2343
2796
|
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
|
2344
|
-
#
|
|
2797
|
+
#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
|
2345
2798
|
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
|
2346
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2799
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2347
2800
|
|
|
2348
2801
|
#pragma unroll
|
|
2349
2802
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
|
@@ -2371,179 +2824,181 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|
|
2371
2824
|
|
|
2372
2825
|
// -------------------------------------------------------------------------------------------------------------------------------------
|
|
2373
2826
|
|
|
2374
|
-
template <int mmq_x, int mmq_y,
|
|
2827
|
+
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
|
|
2375
2828
|
struct mmq_type_traits;
|
|
2376
2829
|
|
|
2377
|
-
template <int mmq_x, int mmq_y,
|
|
2378
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2830
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2831
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
|
|
2379
2832
|
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
|
2380
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y,
|
|
2381
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2382
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2833
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
|
|
2834
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
|
|
2835
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2383
2836
|
};
|
|
2384
2837
|
|
|
2385
|
-
template <int mmq_x, int mmq_y,
|
|
2386
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2838
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2839
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
|
|
2387
2840
|
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
|
2388
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y,
|
|
2389
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2390
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y
|
|
2841
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
|
|
2842
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
2843
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2391
2844
|
};
|
|
2392
2845
|
|
|
2393
|
-
template <int mmq_x, int mmq_y,
|
|
2394
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2846
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2847
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
|
|
2395
2848
|
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
|
2396
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y,
|
|
2397
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2398
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2849
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
|
|
2850
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2851
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2399
2852
|
};
|
|
2400
2853
|
|
|
2401
|
-
template <int mmq_x, int mmq_y,
|
|
2402
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2854
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2855
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
|
|
2403
2856
|
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
|
2404
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y,
|
|
2405
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2406
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y
|
|
2857
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
|
|
2858
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
2859
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2407
2860
|
};
|
|
2408
2861
|
|
|
2409
|
-
template <int mmq_x, int mmq_y,
|
|
2410
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2862
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2863
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|
2411
2864
|
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
|
2412
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y,
|
|
2413
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2414
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2865
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
|
|
2866
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2867
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2415
2868
|
};
|
|
2416
2869
|
|
|
2417
|
-
template <int mmq_x, int mmq_y,
|
|
2418
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2870
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2871
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
|
2419
2872
|
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
|
2420
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y,
|
|
2421
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y
|
|
2422
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y
|
|
2873
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
|
|
2874
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
|
|
2875
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2423
2876
|
};
|
|
2424
2877
|
|
|
2425
|
-
template <int mmq_x, int mmq_y,
|
|
2426
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2878
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2879
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
|
|
2427
2880
|
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
|
2428
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y,
|
|
2429
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
|
2430
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y
|
|
2881
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
|
|
2882
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
2883
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2431
2884
|
};
|
|
2432
2885
|
|
|
2433
|
-
template <int mmq_x, int mmq_y,
|
|
2434
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2886
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2887
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
|
|
2435
2888
|
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
|
2436
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y,
|
|
2437
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2438
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y
|
|
2889
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
|
|
2890
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
2891
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2439
2892
|
};
|
|
2440
2893
|
|
|
2441
|
-
template <int mmq_x, int mmq_y,
|
|
2442
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2894
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2895
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
|
|
2443
2896
|
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
|
2444
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y,
|
|
2445
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2446
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y
|
|
2897
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
|
|
2898
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
2899
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2447
2900
|
};
|
|
2448
2901
|
|
|
2449
|
-
template <int mmq_x, int mmq_y,
|
|
2450
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2902
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2903
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
|
|
2451
2904
|
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
|
2452
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y,
|
|
2453
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y
|
|
2454
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y
|
|
2905
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
|
|
2906
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
|
|
2907
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2455
2908
|
};
|
|
2456
2909
|
|
|
2457
|
-
template <int mmq_x, int mmq_y,
|
|
2458
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2910
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2911
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
|
|
2459
2912
|
static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
|
|
2460
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y,
|
|
2461
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2462
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2913
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
|
|
2914
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2915
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2463
2916
|
};
|
|
2464
2917
|
|
|
2465
|
-
template <int mmq_x, int mmq_y,
|
|
2466
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2918
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2919
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
|
|
2467
2920
|
static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
|
|
2468
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y,
|
|
2469
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
|
2470
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y
|
|
2921
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
|
|
2922
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
2923
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2471
2924
|
};
|
|
2472
2925
|
|
|
2473
|
-
template <int mmq_x, int mmq_y,
|
|
2474
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2926
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2927
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
|
|
2475
2928
|
static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
|
|
2476
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y,
|
|
2477
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y
|
|
2478
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y
|
|
2929
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
|
|
2930
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
|
2931
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2479
2932
|
};
|
|
2480
2933
|
|
|
2481
|
-
template <int mmq_x, int mmq_y,
|
|
2482
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2934
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2935
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
|
|
2483
2936
|
static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
|
|
2484
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y,
|
|
2485
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2486
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2937
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
|
|
2938
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2939
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2487
2940
|
};
|
|
2488
2941
|
|
|
2489
|
-
template <int mmq_x, int mmq_y,
|
|
2490
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2942
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2943
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
|
|
2491
2944
|
static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
|
|
2492
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y,
|
|
2493
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2494
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2945
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
|
|
2946
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2947
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2495
2948
|
};
|
|
2496
2949
|
|
|
2497
|
-
template <int mmq_x, int mmq_y,
|
|
2498
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2950
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2951
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
|
|
2499
2952
|
static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
|
|
2500
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y,
|
|
2501
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y
|
|
2502
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y
|
|
2953
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
|
|
2954
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
|
|
2955
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2503
2956
|
};
|
|
2504
2957
|
|
|
2505
|
-
template <int mmq_x, int mmq_y,
|
|
2506
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2958
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2959
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
|
|
2507
2960
|
static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
|
|
2508
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y,
|
|
2509
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2510
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2961
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
|
|
2962
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2963
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2511
2964
|
};
|
|
2512
2965
|
|
|
2513
|
-
template <int mmq_x, int mmq_y,
|
|
2514
|
-
struct mmq_type_traits<mmq_x, mmq_y,
|
|
2966
|
+
template <int mmq_x, int mmq_y, bool need_check>
|
|
2967
|
+
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
|
|
2515
2968
|
static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
|
|
2516
|
-
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y,
|
|
2517
|
-
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y,
|
|
2518
|
-
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y
|
|
2969
|
+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
|
|
2970
|
+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
2971
|
+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
2519
2972
|
};
|
|
2520
2973
|
|
|
2521
|
-
template <ggml_type type, int mmq_x,
|
|
2974
|
+
template <ggml_type type, int mmq_x, bool need_check, bool fixup>
|
|
2522
2975
|
static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
2523
2976
|
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
|
|
2524
2977
|
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
|
2525
2978
|
const int stride_row_x, const int ncols_y, const int stride_col_dst,
|
|
2526
2979
|
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
|
|
2527
2980
|
|
|
2981
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
2982
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
2528
2983
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
2529
2984
|
constexpr int mmq_y = get_mmq_y_device();
|
|
2530
|
-
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y,
|
|
2985
|
+
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
|
|
2531
2986
|
|
|
2532
2987
|
extern __shared__ int data_mul_mat_q[];
|
|
2533
2988
|
int * tile_y = data_mul_mat_q + mmq_x;
|
|
2534
|
-
int * tile_x = tile_y + GGML_PAD(mmq_x*
|
|
2989
|
+
int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
|
|
2535
2990
|
|
|
2536
|
-
#
|
|
2537
|
-
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y,
|
|
2538
|
-
constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y,
|
|
2991
|
+
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2992
|
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
|
|
2993
|
+
constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
|
|
2539
2994
|
#else
|
|
2540
|
-
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y,
|
|
2541
|
-
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y,
|
|
2542
|
-
#endif // NEW_MMA_AVAILABLE
|
|
2995
|
+
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
|
|
2996
|
+
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
|
2997
|
+
#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
|
|
2543
2998
|
|
|
2544
2999
|
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
2545
3000
|
|
|
2546
|
-
float sum[mmq_x*mmq_y / (nwarps*
|
|
3001
|
+
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
|
2547
3002
|
|
|
2548
3003
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
|
2549
3004
|
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
|
@@ -2551,8 +3006,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2551
3006
|
{
|
|
2552
3007
|
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
2553
3008
|
#pragma unroll
|
|
2554
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*
|
|
2555
|
-
int l = l0 + threadIdx.y*
|
|
3009
|
+
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
|
3010
|
+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2556
3011
|
|
|
2557
3012
|
tile_y[l] = by0[l];
|
|
2558
3013
|
}
|
|
@@ -2567,8 +3022,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2567
3022
|
{
|
|
2568
3023
|
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
2569
3024
|
#pragma unroll
|
|
2570
|
-
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*
|
|
2571
|
-
int l = l0 + threadIdx.y*
|
|
3025
|
+
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
|
3026
|
+
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2572
3027
|
|
|
2573
3028
|
tile_y[l] = by0[l];
|
|
2574
3029
|
}
|
|
@@ -2576,7 +3031,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2576
3031
|
|
|
2577
3032
|
__syncthreads();
|
|
2578
3033
|
|
|
2579
|
-
vec_dot(tile_x, tile_y, sum,
|
|
3034
|
+
vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
|
|
2580
3035
|
|
|
2581
3036
|
__syncthreads();
|
|
2582
3037
|
}
|
|
@@ -2591,16 +3046,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
2591
3046
|
|
|
2592
3047
|
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
|
2593
3048
|
|
|
2594
|
-
template <ggml_type type, int mmq_x,
|
|
3049
|
+
template <ggml_type type, int mmq_x, bool need_check>
|
|
2595
3050
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
2596
3051
|
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
|
2597
|
-
__launch_bounds__(
|
|
3052
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
|
|
2598
3053
|
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
|
|
2599
3054
|
#else
|
|
2600
3055
|
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
2601
|
-
__launch_bounds__(
|
|
3056
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
|
|
2602
3057
|
#else
|
|
2603
|
-
__launch_bounds__(
|
|
3058
|
+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
|
|
2604
3059
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
2605
3060
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
2606
3061
|
static __global__ void mul_mat_q(
|
|
@@ -2616,6 +3071,9 @@ static __global__ void mul_mat_q(
|
|
|
2616
3071
|
return;
|
|
2617
3072
|
}
|
|
2618
3073
|
|
|
3074
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3075
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3076
|
+
|
|
2619
3077
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
2620
3078
|
constexpr int mmq_y = get_mmq_y_device();
|
|
2621
3079
|
|
|
@@ -2627,10 +3085,10 @@ static __global__ void mul_mat_q(
|
|
|
2627
3085
|
// For MoE the correct indices are loaded from ids_dst.
|
|
2628
3086
|
extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
|
|
2629
3087
|
#pragma unroll
|
|
2630
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2631
|
-
const int j = j0 + threadIdx.y*
|
|
3088
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3089
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2632
3090
|
|
|
2633
|
-
if (j0 + nwarps*
|
|
3091
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2634
3092
|
break;
|
|
2635
3093
|
}
|
|
2636
3094
|
|
|
@@ -2639,7 +3097,7 @@ static __global__ void mul_mat_q(
|
|
|
2639
3097
|
__syncthreads();
|
|
2640
3098
|
|
|
2641
3099
|
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
|
2642
|
-
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3100
|
+
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
2643
3101
|
{
|
|
2644
3102
|
const int wt = blockIdx.z / nchannels_y;
|
|
2645
3103
|
const int zt = blockIdx.z - wt*nchannels_y;
|
|
@@ -2667,10 +3125,10 @@ static __global__ void mul_mat_q(
|
|
|
2667
3125
|
|
|
2668
3126
|
// __syncthreads(); // There is no previous tile that could cause a race condition.
|
|
2669
3127
|
#pragma unroll
|
|
2670
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2671
|
-
const int j = j0 + threadIdx.y*
|
|
3128
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3129
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2672
3130
|
|
|
2673
|
-
if (j0 + nwarps*
|
|
3131
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2674
3132
|
break;
|
|
2675
3133
|
}
|
|
2676
3134
|
|
|
@@ -2688,12 +3146,12 @@ static __global__ void mul_mat_q(
|
|
|
2688
3146
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
2689
3147
|
|
|
2690
3148
|
constexpr bool fixup = false;
|
|
2691
|
-
mul_mat_q_process_tile<type, mmq_x,
|
|
3149
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
2692
3150
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
2693
3151
|
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
|
|
2694
3152
|
return;
|
|
2695
3153
|
}
|
|
2696
|
-
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
3154
|
+
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
2697
3155
|
|
|
2698
3156
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
2699
3157
|
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
@@ -2745,10 +3203,10 @@ static __global__ void mul_mat_q(
|
|
|
2745
3203
|
|
|
2746
3204
|
__syncthreads();
|
|
2747
3205
|
#pragma unroll
|
|
2748
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2749
|
-
const int j = j0 + threadIdx.y*
|
|
3206
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3207
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2750
3208
|
|
|
2751
|
-
if (j0 + nwarps*
|
|
3209
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2752
3210
|
break;
|
|
2753
3211
|
}
|
|
2754
3212
|
|
|
@@ -2766,7 +3224,7 @@ static __global__ void mul_mat_q(
|
|
|
2766
3224
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
2767
3225
|
|
|
2768
3226
|
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
|
2769
|
-
mul_mat_q_process_tile<type, mmq_x,
|
|
3227
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
2770
3228
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
2771
3229
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
2772
3230
|
|
|
@@ -2812,10 +3270,10 @@ static __global__ void mul_mat_q(
|
|
|
2812
3270
|
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
|
|
2813
3271
|
__syncthreads();
|
|
2814
3272
|
#pragma unroll
|
|
2815
|
-
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*
|
|
2816
|
-
const int j = j0 + threadIdx.y*
|
|
3273
|
+
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
|
|
3274
|
+
const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
|
|
2817
3275
|
|
|
2818
|
-
if (j0 + nwarps*
|
|
3276
|
+
if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
|
|
2819
3277
|
break;
|
|
2820
3278
|
}
|
|
2821
3279
|
|
|
@@ -2833,13 +3291,13 @@ static __global__ void mul_mat_q(
|
|
|
2833
3291
|
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
|
|
2834
3292
|
|
|
2835
3293
|
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
|
2836
|
-
mul_mat_q_process_tile<type, mmq_x,
|
|
3294
|
+
mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
|
|
2837
3295
|
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
|
|
2838
3296
|
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
|
2839
3297
|
}
|
|
2840
3298
|
|
|
2841
3299
|
|
|
2842
|
-
template <ggml_type type, int mmq_x,
|
|
3300
|
+
template <ggml_type type, int mmq_x, bool need_check>
|
|
2843
3301
|
static __global__ void mul_mat_q_stream_k_fixup(
|
|
2844
3302
|
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
|
2845
3303
|
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
|
@@ -2849,7 +3307,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2849
3307
|
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
2850
3308
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
2851
3309
|
|
|
2852
|
-
|
|
3310
|
+
constexpr int nwarps = mmq_get_nwarps_device();
|
|
3311
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
3312
|
+
|
|
3313
|
+
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
|
2853
3314
|
|
|
2854
3315
|
const int ntx = (ncols_dst + mmq_x - 1) / mmq_x;
|
|
2855
3316
|
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
|
|
@@ -2893,10 +3354,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2893
3354
|
const int j = j0 + threadIdx.y;
|
|
2894
3355
|
|
|
2895
3356
|
#pragma unroll
|
|
2896
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3357
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2897
3358
|
const int i = i0 + threadIdx.x;
|
|
2898
3359
|
|
|
2899
|
-
sum[(j0/nwarps) * (mmq_y/
|
|
3360
|
+
sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
|
|
2900
3361
|
}
|
|
2901
3362
|
}
|
|
2902
3363
|
|
|
@@ -2937,14 +3398,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2937
3398
|
}
|
|
2938
3399
|
|
|
2939
3400
|
#pragma unroll
|
|
2940
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3401
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2941
3402
|
const int i = i0 + threadIdx.x;
|
|
2942
3403
|
|
|
2943
3404
|
if (need_check && i > i_max) {
|
|
2944
3405
|
continue;
|
|
2945
3406
|
}
|
|
2946
3407
|
|
|
2947
|
-
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/
|
|
3408
|
+
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
2948
3409
|
}
|
|
2949
3410
|
}
|
|
2950
3411
|
return;
|
|
@@ -2955,7 +3416,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2955
3416
|
const int col_high = expert_bounds[zt + 1];
|
|
2956
3417
|
const int col_diff = col_high - col_low;
|
|
2957
3418
|
|
|
2958
|
-
for (int j = threadIdx.y*
|
|
3419
|
+
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
|
2959
3420
|
ids_dst_shared[j] = ids_dst[col_low + j];
|
|
2960
3421
|
}
|
|
2961
3422
|
__syncthreads();
|
|
@@ -2975,14 +3436,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
2975
3436
|
}
|
|
2976
3437
|
|
|
2977
3438
|
#pragma unroll
|
|
2978
|
-
for (int i0 = 0; i0 < mmq_y; i0 +=
|
|
3439
|
+
for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
|
|
2979
3440
|
const int i = i0 + threadIdx.x;
|
|
2980
3441
|
|
|
2981
3442
|
if (need_check && i > i_max) {
|
|
2982
3443
|
continue;
|
|
2983
3444
|
}
|
|
2984
3445
|
|
|
2985
|
-
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/
|
|
3446
|
+
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
|
|
2986
3447
|
}
|
|
2987
3448
|
}
|
|
2988
3449
|
}
|
|
@@ -2996,13 +3457,13 @@ struct mmq_args {
|
|
|
2996
3457
|
};
|
|
2997
3458
|
|
|
2998
3459
|
template<ggml_type type>
|
|
2999
|
-
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
|
|
3460
|
+
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
|
|
3000
3461
|
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
|
|
3001
3462
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
|
3002
3463
|
const size_t nbs_ids = mmq_x*sizeof(int);
|
|
3003
|
-
const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
|
3464
|
+
const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
|
3004
3465
|
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
|
3005
|
-
return nbs_ids + nbs_x + GGML_PAD(nbs_y,
|
|
3466
|
+
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
|
3006
3467
|
}
|
|
3007
3468
|
|
|
3008
3469
|
template <ggml_type type, int mmq_x>
|
|
@@ -3010,20 +3471,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3010
3471
|
const int id = ggml_cuda_get_device();
|
|
3011
3472
|
const int cc = ggml_cuda_info().devices[id].cc;
|
|
3012
3473
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
|
3474
|
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
3475
|
+
const int nwarps = mmq_get_nwarps_host(cc);
|
|
3013
3476
|
const int mmq_y = get_mmq_y_host(cc);
|
|
3014
3477
|
|
|
3015
|
-
const dim3 block_dims(
|
|
3478
|
+
const dim3 block_dims(warp_size, nwarps, 1);
|
|
3016
3479
|
|
|
3017
|
-
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
|
|
3480
|
+
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
|
|
3018
3481
|
|
|
3019
|
-
|
|
3020
|
-
|
|
3021
|
-
if (!shared_memory_limit_raised[id]) {
|
|
3022
|
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
3023
|
-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
|
|
3024
|
-
shared_memory_limit_raised[id] = true;
|
|
3025
|
-
}
|
|
3026
|
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
|
|
3482
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
|
|
3483
|
+
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
|
|
3027
3484
|
|
|
3028
3485
|
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
|
|
3029
3486
|
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
|
|
@@ -3038,14 +3495,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3038
3495
|
if (!args.use_stream_k) {
|
|
3039
3496
|
if (args.nrows_x % mmq_y == 0) {
|
|
3040
3497
|
constexpr bool need_check = false;
|
|
3041
|
-
mul_mat_q<type, mmq_x,
|
|
3498
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3042
3499
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3043
3500
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3044
3501
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
3045
3502
|
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
|
|
3046
3503
|
} else {
|
|
3047
3504
|
constexpr bool need_check = true;
|
|
3048
|
-
mul_mat_q<type, mmq_x,
|
|
3505
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
|
|
3049
3506
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
|
|
3050
3507
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3051
3508
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
@@ -3065,8 +3522,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3065
3522
|
|
|
3066
3523
|
if (args.nrows_x % mmq_y == 0) {
|
|
3067
3524
|
constexpr bool need_check = false;
|
|
3068
|
-
|
|
3069
|
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3525
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3070
3526
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3071
3527
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3072
3528
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
@@ -3076,13 +3532,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3076
3532
|
return;
|
|
3077
3533
|
}
|
|
3078
3534
|
|
|
3079
|
-
mul_mat_q_stream_k_fixup<type, mmq_x,
|
|
3535
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
|
3080
3536
|
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
|
3081
3537
|
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
|
3082
3538
|
} else {
|
|
3083
3539
|
constexpr bool need_check = true;
|
|
3084
|
-
|
|
3085
|
-
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3540
|
+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
|
|
3086
3541
|
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
|
|
3087
3542
|
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
|
|
3088
3543
|
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
|
|
@@ -3092,7 +3547,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3092
3547
|
return;
|
|
3093
3548
|
}
|
|
3094
3549
|
|
|
3095
|
-
mul_mat_q_stream_k_fixup<type, mmq_x,
|
|
3550
|
+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
|
|
3096
3551
|
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
|
|
3097
3552
|
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
|
|
3098
3553
|
}
|
|
@@ -3100,9 +3555,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
|
3100
3555
|
|
|
3101
3556
|
template <ggml_type type>
|
|
3102
3557
|
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
|
3103
|
-
const int id
|
|
3104
|
-
const int cc
|
|
3105
|
-
const size_t smpbo
|
|
3558
|
+
const int id = ggml_cuda_get_device();
|
|
3559
|
+
const int cc = ggml_cuda_info().devices[id].cc;
|
|
3560
|
+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
|
3561
|
+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
|
3562
|
+
const int nwarps = mmq_get_nwarps_host(cc);
|
|
3106
3563
|
|
|
3107
3564
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
|
3108
3565
|
const int mmq_y = get_mmq_y_host(cc);
|
|
@@ -3113,7 +3570,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
|
|
3113
3570
|
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
|
|
3114
3571
|
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
|
3115
3572
|
|
|
3116
|
-
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
|
|
3573
|
+
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
|
|
3117
3574
|
continue;
|
|
3118
3575
|
}
|
|
3119
3576
|
|