@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#ifdef USE_COLLECTIVES
|
|
4
|
+
# extension GL_KHR_shader_subgroup_shuffle : enable
|
|
5
|
+
#endif
|
|
6
|
+
|
|
7
|
+
#include "types.comp"
|
|
8
|
+
|
|
9
|
+
// Make spec constant
|
|
10
|
+
#define SHMEM_PAD 0
|
|
11
|
+
|
|
12
|
+
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
|
13
|
+
layout(binding = 0) readonly buffer A {
|
|
14
|
+
A_TYPE knl_data[];
|
|
15
|
+
}; // src0 - kernel: [KW, KH, Cin, Cout]
|
|
16
|
+
|
|
17
|
+
layout(binding = 1) readonly buffer B {
|
|
18
|
+
B_TYPE src_data[];
|
|
19
|
+
}; // src1 - input: [W, H, Cin, N] -- channel_first format
|
|
20
|
+
|
|
21
|
+
layout(binding = 2) writeonly buffer D {
|
|
22
|
+
D_TYPE dst_data[];
|
|
23
|
+
}; // dst - result: [OW, OH, Cout, N]
|
|
24
|
+
|
|
25
|
+
layout(push_constant) uniform parameter {
|
|
26
|
+
// I/O channels, batch size
|
|
27
|
+
uint32_t Cout;
|
|
28
|
+
uint32_t Cin;
|
|
29
|
+
uint32_t N;
|
|
30
|
+
|
|
31
|
+
// Tensor spatial sizes: kernel, input, output
|
|
32
|
+
uint32_t KW;
|
|
33
|
+
uint32_t KH;
|
|
34
|
+
uint32_t W;
|
|
35
|
+
uint32_t H;
|
|
36
|
+
uint32_t OW;
|
|
37
|
+
uint32_t OH;
|
|
38
|
+
|
|
39
|
+
// Parameters: stride, padding, dilation - 0=y, 1=x
|
|
40
|
+
uint32_t s0;
|
|
41
|
+
uint32_t s1;
|
|
42
|
+
uint32_t p0;
|
|
43
|
+
uint32_t p1;
|
|
44
|
+
uint32_t d0;
|
|
45
|
+
uint32_t d1;
|
|
46
|
+
|
|
47
|
+
// Strides in elements
|
|
48
|
+
uint32_t nb01;
|
|
49
|
+
uint32_t nb02;
|
|
50
|
+
uint32_t nb03;
|
|
51
|
+
|
|
52
|
+
uint32_t nb11;
|
|
53
|
+
uint32_t nb12;
|
|
54
|
+
uint32_t nb13;
|
|
55
|
+
|
|
56
|
+
uint32_t nb1;
|
|
57
|
+
uint32_t nb2;
|
|
58
|
+
uint32_t nb3;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
p;
|
|
62
|
+
|
|
63
|
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
64
|
+
// Blocktile sizes
|
|
65
|
+
layout(constant_id = 1) const uint BS_K = 128;
|
|
66
|
+
layout(constant_id = 2) const uint BS_CRS = 16;
|
|
67
|
+
layout(constant_id = 3) const uint BS_NPQ = 128;
|
|
68
|
+
// Thread-tile sizes
|
|
69
|
+
layout(constant_id = 4) const uint TS_K = 8;
|
|
70
|
+
layout(constant_id = 5) const uint use_collectives = 1;
|
|
71
|
+
|
|
72
|
+
uint32_t tid = gl_LocalInvocationID.x;
|
|
73
|
+
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
|
74
|
+
|
|
75
|
+
uint splitWork(uint work_size, uint block_size) {
|
|
76
|
+
return (block_size + work_size - 1) / block_size;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
uint32_t K = p.Cout;
|
|
80
|
+
uint32_t CRS = p.Cin * p.KH * p.KW;
|
|
81
|
+
uint32_t NPQ = p.N * p.OH * p.OW;
|
|
82
|
+
|
|
83
|
+
uint32_t n_elems_out = K * NPQ;
|
|
84
|
+
|
|
85
|
+
// Number of blocktiles per input
|
|
86
|
+
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
|
87
|
+
|
|
88
|
+
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
|
89
|
+
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
|
90
|
+
|
|
91
|
+
const uint32_t Ash_numel = BS_K * BS_CRS;
|
|
92
|
+
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
|
|
93
|
+
|
|
94
|
+
const uint32_t Ash_len = BS_K * Ash_stride;
|
|
95
|
+
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
|
96
|
+
|
|
97
|
+
shared float Ash[Ash_len]; // K x CRS
|
|
98
|
+
shared float Bsh[Bsh_len]; // CRS x NPQ
|
|
99
|
+
|
|
100
|
+
// Threadtile sizes
|
|
101
|
+
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
|
102
|
+
|
|
103
|
+
// Number of threadtiles per blocktile
|
|
104
|
+
const uint32_t NT_K = BS_K / TS_K;
|
|
105
|
+
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
|
106
|
+
|
|
107
|
+
float regA[TS_K];
|
|
108
|
+
float regB[TS_NPQ];
|
|
109
|
+
float regC[TS_K][TS_NPQ];
|
|
110
|
+
|
|
111
|
+
/*
|
|
112
|
+
Compute
|
|
113
|
+
KxCRS @ CRSxNPQ = K x NPQ
|
|
114
|
+
K=Cout
|
|
115
|
+
C=Cin
|
|
116
|
+
R,S=KH,KW
|
|
117
|
+
P,Q=OH,OW
|
|
118
|
+
*/
|
|
119
|
+
|
|
120
|
+
uint32_t B_idx_K = gl_WorkGroupID.x;
|
|
121
|
+
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
|
|
122
|
+
|
|
123
|
+
uint32_t T_y = tid / NT_NPQ;
|
|
124
|
+
uint32_t T_x = tid % NT_NPQ;
|
|
125
|
+
|
|
126
|
+
uint32_t Ar = tid / BS_CRS;
|
|
127
|
+
uint32_t Ac = tid % BS_CRS;
|
|
128
|
+
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
|
129
|
+
|
|
130
|
+
uint32_t Br = tid / BS_NPQ;
|
|
131
|
+
uint32_t Bc = tid % BS_NPQ;
|
|
132
|
+
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
|
133
|
+
|
|
134
|
+
void main() {
|
|
135
|
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
|
136
|
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
|
137
|
+
regC[T_ly][T_lx] = 0.0;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
/* Advance block in CRS dim */
|
|
141
|
+
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
|
142
|
+
uint32_t CRS_idx_a;
|
|
143
|
+
uint32_t Cin_idx_a;
|
|
144
|
+
uint32_t KH_idx_a;
|
|
145
|
+
uint32_t KW_idx_a;
|
|
146
|
+
|
|
147
|
+
#ifdef USE_COLLECTIVES
|
|
148
|
+
uint32_t cached_CRS_idx;
|
|
149
|
+
uint32_t cached_Cin_idx;
|
|
150
|
+
uint32_t cached_KH_idx;
|
|
151
|
+
uint32_t cached_KW_idx;
|
|
152
|
+
if (use_collectives == 1) {
|
|
153
|
+
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
|
|
154
|
+
cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
|
|
155
|
+
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
|
|
156
|
+
cached_KH_idx = cached_CRS_remainder / p.KW;
|
|
157
|
+
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
|
|
158
|
+
|
|
159
|
+
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
|
|
160
|
+
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
|
|
161
|
+
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
|
|
162
|
+
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
|
|
163
|
+
} else {
|
|
164
|
+
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
|
165
|
+
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
|
|
166
|
+
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
|
167
|
+
KH_idx_a = CRS_remainder / p.KW;
|
|
168
|
+
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
|
169
|
+
}
|
|
170
|
+
#else
|
|
171
|
+
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
|
|
172
|
+
Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
|
|
173
|
+
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
|
|
174
|
+
KH_idx_a = CRS_remainder / p.KW;
|
|
175
|
+
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
|
|
176
|
+
#endif
|
|
177
|
+
|
|
178
|
+
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
|
179
|
+
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
|
180
|
+
uint32_t B_ly = r_offset + Ar;
|
|
181
|
+
uint32_t B_lx = Ac;
|
|
182
|
+
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
|
183
|
+
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
|
|
184
|
+
float val = knl_data[knl_idx];
|
|
185
|
+
if (K_idx >= K || CRS_idx_a >= CRS) {
|
|
186
|
+
val = 0.0;
|
|
187
|
+
}
|
|
188
|
+
Ash[B_ly * Ash_stride + B_lx] = val;
|
|
189
|
+
}
|
|
190
|
+
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
|
191
|
+
for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
|
192
|
+
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
|
193
|
+
uint32_t B_lx = Bc;
|
|
194
|
+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
|
195
|
+
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
|
|
196
|
+
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
|
|
197
|
+
uint32_t OH_idx = NPQ_remainder / p.OW;
|
|
198
|
+
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
|
|
199
|
+
|
|
200
|
+
uint32_t CRS_idx_b;
|
|
201
|
+
uint32_t Cin_idx_b;
|
|
202
|
+
uint32_t KH_idx_b;
|
|
203
|
+
uint32_t KW_idx_b;
|
|
204
|
+
#ifdef USE_COLLECTIVES
|
|
205
|
+
if (use_collectives == 1) {
|
|
206
|
+
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
|
|
207
|
+
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
|
|
208
|
+
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
|
|
209
|
+
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
|
|
210
|
+
} else {
|
|
211
|
+
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
|
212
|
+
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
|
|
213
|
+
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
|
214
|
+
KH_idx_b = CRS_remainder / p.KW;
|
|
215
|
+
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
|
216
|
+
}
|
|
217
|
+
#else
|
|
218
|
+
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
|
|
219
|
+
Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
|
|
220
|
+
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
|
|
221
|
+
KH_idx_b = CRS_remainder / p.KW;
|
|
222
|
+
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
|
223
|
+
#endif
|
|
224
|
+
|
|
225
|
+
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
|
|
226
|
+
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
|
|
227
|
+
uint32_t src_idx =
|
|
228
|
+
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
|
|
229
|
+
float val = src_data[src_idx];
|
|
230
|
+
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
|
|
231
|
+
val = 0.0;
|
|
232
|
+
}
|
|
233
|
+
Bsh[B_ly * Bsh_stride + B_lx] = val;
|
|
234
|
+
}
|
|
235
|
+
barrier();
|
|
236
|
+
for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
|
237
|
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
|
238
|
+
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
|
239
|
+
}
|
|
240
|
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
|
241
|
+
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
|
242
|
+
}
|
|
243
|
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
|
244
|
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
|
245
|
+
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
barrier();
|
|
250
|
+
}
|
|
251
|
+
/* Save C* */
|
|
252
|
+
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
|
253
|
+
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
|
254
|
+
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
|
255
|
+
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
|
256
|
+
uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
|
|
257
|
+
uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
|
|
258
|
+
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
|
|
259
|
+
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
|
|
260
|
+
if (K_idx < K && NPQ_idx < NPQ) {
|
|
261
|
+
dst_data[dst_idx] = regC[T_ly][T_lx];
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
}
|
|
@@ -1,22 +1,26 @@
|
|
|
1
1
|
#version 450
|
|
2
2
|
|
|
3
|
-
#
|
|
4
|
-
#extension GL_EXT_spirv_intrinsics : enable
|
|
5
|
-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
|
6
|
-
#endif // RTE16
|
|
7
|
-
|
|
3
|
+
#include "rte.comp"
|
|
8
4
|
#include "types.comp"
|
|
9
|
-
#include "generic_unary_head.comp"
|
|
10
5
|
|
|
11
|
-
#if defined(
|
|
12
|
-
|
|
13
|
-
|
|
6
|
+
#if defined(SET_ROWS) && QUANT_K == 1
|
|
7
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
8
|
+
const uint BLOCK_SIZE = 512;
|
|
14
9
|
#else
|
|
15
|
-
layout(local_size_x =
|
|
10
|
+
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
|
11
|
+
const uint BLOCK_SIZE = 32;
|
|
16
12
|
#endif
|
|
17
13
|
|
|
18
14
|
layout (binding = 0) readonly buffer S {float data_s[];};
|
|
15
|
+
|
|
16
|
+
#if defined(SET_ROWS)
|
|
17
|
+
#include "generic_binary_head.comp"
|
|
18
|
+
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
|
|
19
|
+
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
|
|
20
|
+
#else
|
|
21
|
+
#include "generic_unary_head.comp"
|
|
19
22
|
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
|
|
23
|
+
#endif
|
|
20
24
|
|
|
21
25
|
#if defined(DATA_A_Q4_0)
|
|
22
26
|
void quantize(uint dst_idx, uint src_idx)
|
|
@@ -221,15 +225,56 @@ void quantize(uint dst_idx, uint src_idx)
|
|
|
221
225
|
}
|
|
222
226
|
#endif
|
|
223
227
|
|
|
228
|
+
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
|
229
|
+
void quantize(uint dst_idx, uint src_idx)
|
|
230
|
+
{
|
|
231
|
+
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
|
|
232
|
+
}
|
|
233
|
+
#endif
|
|
234
|
+
|
|
235
|
+
#if defined(DATA_A_BF16)
|
|
236
|
+
void quantize(uint dst_idx, uint src_idx)
|
|
237
|
+
{
|
|
238
|
+
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
|
|
239
|
+
}
|
|
240
|
+
#endif
|
|
241
|
+
|
|
242
|
+
#if defined(SET_ROWS)
|
|
243
|
+
|
|
224
244
|
void main() {
|
|
225
245
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
226
246
|
init_iq_shmem(gl_WorkGroupSize);
|
|
227
|
-
|
|
247
|
+
#endif
|
|
248
|
+
|
|
249
|
+
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
|
|
250
|
+
|
|
251
|
+
if (idx >= p.ne) {
|
|
228
252
|
return;
|
|
229
253
|
}
|
|
254
|
+
|
|
255
|
+
uint i00, i01, i02, i03;
|
|
256
|
+
get_indices(idx, i00, i01, i02, i03);
|
|
257
|
+
|
|
258
|
+
uint i12 = fastmod(i03, p.ne12);
|
|
259
|
+
uint i11 = fastmod(i02, p.ne11);
|
|
260
|
+
uint i10 = i01;
|
|
261
|
+
|
|
262
|
+
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
|
|
263
|
+
|
|
264
|
+
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
|
|
265
|
+
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
|
|
266
|
+
|
|
267
|
+
quantize(dst_idx, src0_idx);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
#else
|
|
271
|
+
|
|
272
|
+
void main() {
|
|
273
|
+
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
274
|
+
init_iq_shmem(gl_WorkGroupSize);
|
|
230
275
|
#endif
|
|
231
276
|
|
|
232
|
-
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
|
|
277
|
+
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
|
|
233
278
|
|
|
234
279
|
if (idx >= p.ne) {
|
|
235
280
|
return;
|
|
@@ -240,3 +285,5 @@ void main() {
|
|
|
240
285
|
|
|
241
286
|
quantize(dst_idx, src_idx);
|
|
242
287
|
}
|
|
288
|
+
|
|
289
|
+
#endif
|
|
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
|
|
10
10
|
void main() {
|
|
11
11
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
|
12
12
|
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
|
13
|
-
if (i >= p.
|
|
13
|
+
if (i >= p.nel / QUANT_K) {
|
|
14
14
|
return;
|
|
15
15
|
}
|
|
16
16
|
|
|
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
|
|
10
10
|
void main() {
|
|
11
11
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
|
12
12
|
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
|
|
13
|
-
if (i >= p.
|
|
13
|
+
if (i >= p.nel / QUANT_K) {
|
|
14
14
|
return;
|
|
15
15
|
}
|
|
16
16
|
|
|
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
|
|
10
10
|
void main() {
|
|
11
11
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
|
12
12
|
const uint ib = gl_WorkGroupID.x * 256 + wgy;
|
|
13
|
-
if (ib >= p.
|
|
13
|
+
if (ib >= p.nel / QUANT_K) {
|
|
14
14
|
return;
|
|
15
15
|
}
|
|
16
16
|
|
|
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
|
|
10
10
|
void main() {
|
|
11
11
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
|
12
12
|
const uint ib = gl_WorkGroupID.x * 256 + wgy;
|
|
13
|
-
if (ib >= p.
|
|
13
|
+
if (ib >= p.nel / QUANT_K) {
|
|
14
14
|
return;
|
|
15
15
|
}
|
|
16
16
|
|
|
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
|
|
10
10
|
void main() {
|
|
11
11
|
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
|
|
12
12
|
const uint i = gl_WorkGroupID.x * 256 + wgy;
|
|
13
|
-
if (i >= p.
|
|
13
|
+
if (i >= p.nel / QUANT_K) {
|
|
14
14
|
return;
|
|
15
15
|
}
|
|
16
16
|
const uint tid = gl_LocalInvocationID.x;
|
|
@@ -11,7 +11,8 @@
|
|
|
11
11
|
#include "types.comp"
|
|
12
12
|
#include "flash_attn_base.comp"
|
|
13
13
|
|
|
14
|
-
const uint32_t
|
|
14
|
+
const uint32_t HSK_per_thread = HSK / D_split;
|
|
15
|
+
const uint32_t HSV_per_thread = HSV / D_split;
|
|
15
16
|
|
|
16
17
|
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
|
17
18
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
|
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
29
30
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
30
31
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
31
32
|
{
|
|
32
|
-
uint32_t offset = (iq2 + r) *
|
|
33
|
+
uint32_t offset = (iq2 + r) * HSV + c;
|
|
33
34
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
34
35
|
return elem;
|
|
35
36
|
}
|
|
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
|
|
38
39
|
shared vec4 tmpshv4[WorkGroupSize];
|
|
39
40
|
|
|
40
41
|
shared float masksh[Bc][Br];
|
|
41
|
-
shared vec4 Qf[Br][
|
|
42
|
+
shared vec4 Qf[Br][HSK / 4];
|
|
42
43
|
|
|
43
44
|
void main() {
|
|
44
45
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
@@ -53,18 +54,18 @@ void main() {
|
|
|
53
54
|
|
|
54
55
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
55
56
|
|
|
56
|
-
[[unroll]] for (uint32_t idx = 0; idx < Br *
|
|
57
|
-
uint32_t d = (idx + tid) % (
|
|
58
|
-
uint32_t r = (idx + tid) / (
|
|
59
|
-
if (r < Br && d <
|
|
57
|
+
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
58
|
+
uint32_t d = (idx + tid) % (HSK / 4);
|
|
59
|
+
uint32_t r = (idx + tid) / (HSK / 4);
|
|
60
|
+
if (r < Br && d < HSK / 4 &&
|
|
60
61
|
i * Br + r < N) {
|
|
61
62
|
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
|
62
63
|
}
|
|
63
64
|
}
|
|
64
65
|
barrier();
|
|
65
66
|
|
|
66
|
-
vec4 Of[Br][
|
|
67
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
67
|
+
vec4 Of[Br][HSV_per_thread / 4];
|
|
68
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
68
69
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
69
70
|
Of[r][d] = vec4(0.0);
|
|
70
71
|
}
|
|
@@ -99,6 +100,10 @@ void main() {
|
|
|
99
100
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
100
101
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
101
102
|
#endif
|
|
103
|
+
uint32_t m_offset = 0;
|
|
104
|
+
if (p.nem2 != 1 || p.nem3 != 1) {
|
|
105
|
+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
|
106
|
+
}
|
|
102
107
|
|
|
103
108
|
[[dont_unroll]]
|
|
104
109
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
@@ -112,7 +117,7 @@ void main() {
|
|
|
112
117
|
|
|
113
118
|
|
|
114
119
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
115
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
120
|
+
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
|
116
121
|
#if BLOCK_SIZE > 1
|
|
117
122
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
118
123
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -144,13 +149,13 @@ void main() {
|
|
|
144
149
|
}
|
|
145
150
|
}
|
|
146
151
|
|
|
147
|
-
if (p.
|
|
152
|
+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
148
153
|
|
|
149
154
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
150
155
|
uint32_t c = (idx + tid) % Bc;
|
|
151
156
|
uint32_t r = (idx + tid) / Bc;
|
|
152
157
|
if (idx + tid < Bc * Br) {
|
|
153
|
-
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
|
158
|
+
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
|
154
159
|
}
|
|
155
160
|
}
|
|
156
161
|
barrier();
|
|
@@ -191,14 +196,14 @@ void main() {
|
|
|
191
196
|
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
|
192
197
|
}
|
|
193
198
|
|
|
194
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
199
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
195
200
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
196
201
|
Of[r][d] = eMf[r] * Of[r][d];
|
|
197
202
|
}
|
|
198
203
|
}
|
|
199
204
|
|
|
200
205
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
201
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
206
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
202
207
|
#if BLOCK_SIZE > 1
|
|
203
208
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
204
209
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -255,7 +260,7 @@ void main() {
|
|
|
255
260
|
Lf[r] = tmpsh[d_tid];
|
|
256
261
|
barrier();
|
|
257
262
|
|
|
258
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
263
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
259
264
|
|
|
260
265
|
Of[r][d] = eMf * Of[r][d];
|
|
261
266
|
tmpshv4[tid] = Of[r][d];
|
|
@@ -277,11 +282,11 @@ void main() {
|
|
|
277
282
|
// If there is split_k, then the split_k resolve shader does the final
|
|
278
283
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
279
284
|
if (p.k_num > 1) {
|
|
280
|
-
uint32_t o_offset =
|
|
285
|
+
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
|
281
286
|
|
|
282
287
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
283
288
|
if (r < N) {
|
|
284
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
289
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
285
290
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
286
291
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
287
292
|
}
|
|
@@ -289,7 +294,7 @@ void main() {
|
|
|
289
294
|
}
|
|
290
295
|
}
|
|
291
296
|
|
|
292
|
-
o_offset =
|
|
297
|
+
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
|
293
298
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
294
299
|
if (r < N) {
|
|
295
300
|
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
@@ -305,18 +310,18 @@ void main() {
|
|
|
305
310
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
306
311
|
}
|
|
307
312
|
|
|
308
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
313
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
309
314
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
310
315
|
Of[r][d] *= Lfrcp[r];
|
|
311
316
|
}
|
|
312
317
|
}
|
|
313
318
|
|
|
314
|
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
319
|
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
|
315
320
|
|
|
316
321
|
if (p.gqa_ratio > 1) {
|
|
317
322
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
318
323
|
if (r < N) {
|
|
319
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
324
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
320
325
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
321
326
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
322
327
|
}
|
|
@@ -326,9 +331,9 @@ void main() {
|
|
|
326
331
|
} else {
|
|
327
332
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
328
333
|
if (i * Br + r < N) {
|
|
329
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
334
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
330
335
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
331
|
-
data_o[o_offset + iq2 *
|
|
336
|
+
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
332
337
|
}
|
|
333
338
|
}
|
|
334
339
|
}
|
|
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
4
4
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
|
5
5
|
layout (constant_id = 1) const uint32_t Br = 1;
|
|
6
6
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
7
|
-
layout (constant_id = 3) const uint32_t
|
|
8
|
-
layout (constant_id = 4) const uint32_t
|
|
9
|
-
layout (constant_id = 5) const uint32_t
|
|
10
|
-
|
|
7
|
+
layout (constant_id = 3) const uint32_t HSK = 32;
|
|
8
|
+
layout (constant_id = 4) const uint32_t HSV = 32;
|
|
9
|
+
layout (constant_id = 5) const uint32_t Clamp = 0;
|
|
10
|
+
layout (constant_id = 6) const uint32_t D_split = 16;
|
|
11
11
|
|
|
12
12
|
layout (push_constant) uniform parameter {
|
|
13
13
|
uint32_t N;
|
|
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
|
|
|
24
24
|
uint32_t nev2;
|
|
25
25
|
uint32_t nev3;
|
|
26
26
|
uint32_t nem1;
|
|
27
|
+
uint32_t nem2;
|
|
28
|
+
uint32_t nem3;
|
|
27
29
|
|
|
28
30
|
uint32_t nb01;
|
|
29
31
|
uint32_t nb02;
|
|
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
|
|
|
34
36
|
uint32_t nb21;
|
|
35
37
|
uint32_t nb22;
|
|
36
38
|
uint32_t nb23;
|
|
37
|
-
uint32_t nb31;
|
|
38
39
|
|
|
39
40
|
float scale;
|
|
40
41
|
float max_bias;
|
|
41
42
|
float logit_softcap;
|
|
42
43
|
|
|
43
|
-
uint32_t
|
|
44
|
-
uint32_t n_head_log2;
|
|
44
|
+
uint32_t mask_n_head_log2;
|
|
45
45
|
float m0;
|
|
46
46
|
float m1;
|
|
47
47
|
|
|
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
|
|
|
50
50
|
uint32_t k_num;
|
|
51
51
|
} p;
|
|
52
52
|
|
|
53
|
+
#define MASK_ENABLE_BIT (1<<16)
|
|
54
|
+
#define N_LOG2_MASK 0xFFFF
|
|
55
|
+
|
|
53
56
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
54
57
|
|
|
55
58
|
#if defined(A_TYPE_PACKED16)
|
|
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
|
|
|
100
103
|
{
|
|
101
104
|
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
102
105
|
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
|
|
107
|
+
|
|
108
|
+
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
|
|
109
|
+
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
|
|
105
110
|
|
|
106
111
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
107
112
|
}
|