@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -13,7 +13,9 @@
|
|
|
13
13
|
#include "types.comp"
|
|
14
14
|
#include "flash_attn_base.comp"
|
|
15
15
|
|
|
16
|
-
const uint32_t
|
|
16
|
+
const uint32_t HSK_per_thread = HSK / D_split;
|
|
17
|
+
const uint32_t HSV_per_thread = HSV / D_split;
|
|
18
|
+
|
|
17
19
|
const uint32_t row_split = 4;
|
|
18
20
|
const uint32_t rows_per_thread = Br / row_split;
|
|
19
21
|
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
|
|
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
|
32
34
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
33
35
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
34
36
|
{
|
|
35
|
-
uint32_t offset = (iq2 + r) *
|
|
37
|
+
uint32_t offset = (iq2 + r) * HSV + c;
|
|
36
38
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
37
39
|
return elem;
|
|
38
40
|
}
|
|
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
|
|
|
44
46
|
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
|
45
47
|
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
|
46
48
|
|
|
47
|
-
const uint32_t qstride =
|
|
49
|
+
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
|
48
50
|
shared f16vec4 Qf[Br * qstride];
|
|
49
51
|
|
|
50
|
-
// Avoid padding for
|
|
51
|
-
const uint32_t sfshstride = (
|
|
52
|
+
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
|
53
|
+
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
|
52
54
|
shared ACC_TYPE sfsh[Bc * sfshstride];
|
|
53
55
|
|
|
54
|
-
const uint32_t kshstride =
|
|
56
|
+
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
|
55
57
|
shared f16vec4 ksh[Bc * kshstride];
|
|
56
58
|
|
|
57
59
|
shared float slope[Br];
|
|
@@ -74,18 +76,18 @@ void main() {
|
|
|
74
76
|
|
|
75
77
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
76
78
|
|
|
77
|
-
[[unroll]] for (uint32_t idx = 0; idx < Br *
|
|
78
|
-
uint32_t d = (idx + tid) % (
|
|
79
|
-
uint32_t r = (idx + tid) / (
|
|
80
|
-
if (r < Br && d <
|
|
79
|
+
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
80
|
+
uint32_t d = (idx + tid) % (HSK / 4);
|
|
81
|
+
uint32_t r = (idx + tid) / (HSK / 4);
|
|
82
|
+
if (r < Br && d < HSK / 4 &&
|
|
81
83
|
i * Br + r < N) {
|
|
82
84
|
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
|
83
85
|
}
|
|
84
86
|
}
|
|
85
87
|
barrier();
|
|
86
88
|
|
|
87
|
-
ACC_TYPEV4 Of[rows_per_thread][
|
|
88
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
89
|
+
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
|
90
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
89
91
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
90
92
|
Of[r][d] = ACC_TYPEV4(0.0);
|
|
91
93
|
}
|
|
@@ -123,14 +125,18 @@ void main() {
|
|
|
123
125
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
124
126
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
125
127
|
#endif
|
|
128
|
+
uint32_t m_offset = 0;
|
|
129
|
+
if (p.nem2 != 1 || p.nem3 != 1) {
|
|
130
|
+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
|
131
|
+
}
|
|
126
132
|
|
|
127
133
|
[[dont_unroll]]
|
|
128
134
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
129
135
|
|
|
130
|
-
[[unroll]] for (uint32_t idx = 0; idx < Bc *
|
|
131
|
-
uint32_t d = (idx + tid) % (
|
|
132
|
-
uint32_t c = (idx + tid) / (
|
|
133
|
-
if (c < Bc && d <
|
|
136
|
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
|
137
|
+
uint32_t d = (idx + tid) % (HSK / 4);
|
|
138
|
+
uint32_t c = (idx + tid) / (HSK / 4);
|
|
139
|
+
if (c < Bc && d < HSK / 4) {
|
|
134
140
|
#if BLOCK_SIZE > 1
|
|
135
141
|
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
|
136
142
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -145,14 +151,14 @@ void main() {
|
|
|
145
151
|
}
|
|
146
152
|
barrier();
|
|
147
153
|
|
|
148
|
-
// K * Q^T -> S^T: Bc x
|
|
149
|
-
// Bc split across workgroup (four subgroups), loop over
|
|
154
|
+
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
|
155
|
+
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
|
150
156
|
// This is written transposed in order to allow for N being 8 if implementations need it
|
|
151
157
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
|
152
158
|
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
|
153
159
|
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
|
154
160
|
|
|
155
|
-
for (uint32_t d = 0; d <
|
|
161
|
+
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
|
156
162
|
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
|
157
163
|
|
|
158
164
|
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
|
@@ -176,12 +182,12 @@ void main() {
|
|
|
176
182
|
barrier();
|
|
177
183
|
}
|
|
178
184
|
|
|
179
|
-
if (p.
|
|
185
|
+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
180
186
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
181
187
|
uint32_t c = (idx + tid) % Bc;
|
|
182
188
|
uint32_t r = (idx + tid) / Bc;
|
|
183
189
|
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
|
184
|
-
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
|
|
190
|
+
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
|
185
191
|
}
|
|
186
192
|
}
|
|
187
193
|
barrier();
|
|
@@ -202,7 +208,7 @@ void main() {
|
|
|
202
208
|
eMf[r] = exp(Moldf - Mf[r]);
|
|
203
209
|
}
|
|
204
210
|
|
|
205
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
211
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
206
212
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
207
213
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
|
208
214
|
}
|
|
@@ -217,7 +223,7 @@ void main() {
|
|
|
217
223
|
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
|
218
224
|
Lf[r] += Pf[r];
|
|
219
225
|
}
|
|
220
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
226
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
221
227
|
#if BLOCK_SIZE > 1
|
|
222
228
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
223
229
|
uint ib = coord / BLOCK_SIZE;
|
|
@@ -280,7 +286,7 @@ void main() {
|
|
|
280
286
|
}
|
|
281
287
|
|
|
282
288
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
283
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
289
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
284
290
|
|
|
285
291
|
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
|
286
292
|
tmpshv4[tid] = Of[r][d];
|
|
@@ -300,11 +306,11 @@ void main() {
|
|
|
300
306
|
// If there is split_k, then the split_k resolve shader does the final
|
|
301
307
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
302
308
|
if (p.k_num > 1) {
|
|
303
|
-
uint32_t o_offset =
|
|
309
|
+
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
|
304
310
|
|
|
305
311
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
306
312
|
if (tile_row(r) < N) {
|
|
307
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
313
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
308
314
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
309
315
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
|
310
316
|
}
|
|
@@ -312,7 +318,7 @@ void main() {
|
|
|
312
318
|
}
|
|
313
319
|
}
|
|
314
320
|
|
|
315
|
-
o_offset =
|
|
321
|
+
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
|
316
322
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
317
323
|
if (tile_row(r) < N) {
|
|
318
324
|
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
@@ -328,18 +334,18 @@ void main() {
|
|
|
328
334
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
329
335
|
}
|
|
330
336
|
|
|
331
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
337
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
332
338
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
333
339
|
Of[r][d] *= float16_t(Lfrcp[r]);
|
|
334
340
|
}
|
|
335
341
|
}
|
|
336
342
|
|
|
337
|
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
343
|
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
|
338
344
|
|
|
339
345
|
if (p.gqa_ratio > 1) {
|
|
340
346
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
341
347
|
if (tile_row(r) < N) {
|
|
342
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
348
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
343
349
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
344
350
|
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
|
|
345
351
|
}
|
|
@@ -349,9 +355,9 @@ void main() {
|
|
|
349
355
|
} else {
|
|
350
356
|
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
|
351
357
|
if (i * Br + tile_row(r) < N) {
|
|
352
|
-
[[unroll]] for (uint32_t d = 0; d <
|
|
358
|
+
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
|
353
359
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
354
|
-
data_o[o_offset + iq2 *
|
|
360
|
+
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
355
361
|
}
|
|
356
362
|
}
|
|
357
363
|
}
|
|
@@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
|
|
61
61
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
62
62
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
63
63
|
{
|
|
64
|
-
if (r < N && c <
|
|
65
|
-
uint32_t offset = (iq2 + r) *
|
|
64
|
+
if (r < N && c < HSV) {
|
|
65
|
+
uint32_t offset = (iq2 + r) * HSV + c;
|
|
66
66
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
67
67
|
}
|
|
68
68
|
return elem;
|
|
@@ -86,9 +86,9 @@ void main() {
|
|
|
86
86
|
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
|
87
87
|
#endif
|
|
88
88
|
|
|
89
|
-
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N,
|
|
90
|
-
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV,
|
|
91
|
-
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV,
|
|
89
|
+
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
|
|
90
|
+
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
|
|
91
|
+
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
|
|
92
92
|
|
|
93
93
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
94
94
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
|
@@ -104,16 +104,16 @@ void main() {
|
|
|
104
104
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
105
105
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
|
106
106
|
|
|
107
|
-
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br,
|
|
108
|
-
coopmat<float16_t, gl_ScopeWorkgroup, Br,
|
|
107
|
+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
|
108
|
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
|
109
109
|
|
|
110
110
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
|
111
|
-
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0,
|
|
111
|
+
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
|
112
112
|
|
|
113
|
-
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br,
|
|
113
|
+
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
|
114
114
|
Qf16 *= float16_t(p.scale);
|
|
115
115
|
|
|
116
|
-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
116
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
|
117
117
|
|
|
118
118
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
|
119
119
|
|
|
@@ -130,15 +130,20 @@ void main() {
|
|
|
130
130
|
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
|
131
131
|
}
|
|
132
132
|
|
|
133
|
+
uint32_t m_offset = 0;
|
|
134
|
+
if (p.nem2 != 1 || p.nem3 != 1) {
|
|
135
|
+
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
|
136
|
+
}
|
|
137
|
+
|
|
133
138
|
[[dont_unroll]]
|
|
134
139
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
135
140
|
|
|
136
141
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
137
142
|
|
|
138
|
-
coopmat<float16_t, gl_ScopeWorkgroup,
|
|
143
|
+
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
|
139
144
|
|
|
140
145
|
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
|
141
|
-
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0,
|
|
146
|
+
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
|
142
147
|
S = coopMatMulAdd(Qf16, K_T, S);
|
|
143
148
|
|
|
144
149
|
if (p.logit_softcap != 0.0f) {
|
|
@@ -148,14 +153,14 @@ void main() {
|
|
|
148
153
|
}
|
|
149
154
|
}
|
|
150
155
|
|
|
151
|
-
if (p.
|
|
156
|
+
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
152
157
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
153
158
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
154
159
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
155
160
|
|
|
156
161
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
157
162
|
|
|
158
|
-
coopMatLoadTensorNV(mv, data_m,
|
|
163
|
+
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
159
164
|
|
|
160
165
|
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
161
166
|
}
|
|
@@ -203,42 +208,42 @@ void main() {
|
|
|
203
208
|
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
|
204
209
|
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
|
205
210
|
|
|
206
|
-
coopmat<float16_t, gl_ScopeWorkgroup, Bc,
|
|
211
|
+
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
|
207
212
|
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
|
208
|
-
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0,
|
|
213
|
+
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
|
209
214
|
|
|
210
215
|
L = eM*L + rowsum;
|
|
211
216
|
|
|
212
217
|
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
|
213
218
|
// multiply rather than matrix multiply it has the diagonal element smeared
|
|
214
219
|
// across the row
|
|
215
|
-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
220
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
|
216
221
|
|
|
217
222
|
// resize eM by using smear/reduce
|
|
218
223
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
219
224
|
|
|
220
225
|
// multiply with fp16 accumulation, then add to O.
|
|
221
|
-
coopmat<float16_t, gl_ScopeWorkgroup, Br,
|
|
226
|
+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
|
222
227
|
PV = coopMatMulAdd(P_A, V, PV);
|
|
223
228
|
|
|
224
|
-
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
229
|
+
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
|
225
230
|
}
|
|
226
231
|
|
|
227
232
|
// If there is split_k, then the split_k resolve shader does the final
|
|
228
233
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
229
234
|
if (p.k_num > 1) {
|
|
230
|
-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br,
|
|
235
|
+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
|
231
236
|
|
|
232
|
-
uint32_t o_offset =
|
|
237
|
+
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
|
233
238
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
234
239
|
|
|
235
|
-
o_offset =
|
|
240
|
+
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
|
236
241
|
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
|
237
242
|
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
|
238
243
|
return;
|
|
239
244
|
}
|
|
240
245
|
|
|
241
|
-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br,
|
|
246
|
+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
|
242
247
|
|
|
243
248
|
// resize L by using smear/reduce
|
|
244
249
|
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
@@ -250,18 +255,18 @@ void main() {
|
|
|
250
255
|
|
|
251
256
|
O = Ldiag*O;
|
|
252
257
|
|
|
253
|
-
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
258
|
+
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
|
254
259
|
|
|
255
|
-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br,
|
|
260
|
+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
|
256
261
|
if (p.gqa_ratio > 1) {
|
|
257
262
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
258
263
|
} else {
|
|
259
264
|
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
|
260
|
-
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1,
|
|
265
|
+
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
|
|
261
266
|
|
|
262
267
|
// permute dimensions
|
|
263
268
|
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
|
264
269
|
|
|
265
|
-
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0,
|
|
270
|
+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
|
266
271
|
}
|
|
267
272
|
}
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
#extension GL_EXT_control_flow_attributes : enable
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
6
6
|
|
|
7
|
-
layout(
|
|
7
|
+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
8
8
|
|
|
9
9
|
layout (binding = 0) readonly buffer A {float data_a[];};
|
|
10
10
|
layout (binding = 1) writeonly buffer D {float data_d[];};
|
|
@@ -12,48 +12,80 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
|
|
|
12
12
|
layout (push_constant) uniform parameter {
|
|
13
13
|
uint D;
|
|
14
14
|
uint N;
|
|
15
|
+
uint ne3;
|
|
15
16
|
uint k_num;
|
|
16
17
|
} p;
|
|
17
18
|
|
|
19
|
+
shared float tmpsh[BLOCK_SIZE];
|
|
20
|
+
|
|
18
21
|
void main() {
|
|
19
22
|
// Each workgroup handles a row
|
|
20
23
|
const uint n = gl_WorkGroupID.x;
|
|
21
24
|
const uint tid = gl_LocalInvocationID.x;
|
|
25
|
+
const uint iq3 = gl_WorkGroupID.z;
|
|
22
26
|
|
|
23
27
|
uint D = p.D;
|
|
24
28
|
uint N = p.N;
|
|
25
29
|
uint k_num = p.k_num;
|
|
26
30
|
|
|
27
|
-
uint l_offset = D * N * k_num + n;
|
|
28
|
-
uint m_offset = D * N * k_num + N + n;
|
|
31
|
+
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
|
32
|
+
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
|
29
33
|
uint lm_stride = N * 2;
|
|
30
34
|
|
|
31
35
|
// Compute the max m value for the row
|
|
32
36
|
float m_max = -1.0/0.0;
|
|
33
|
-
|
|
34
|
-
float m = data_a[m_offset + k * lm_stride];
|
|
37
|
+
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
|
38
|
+
float m = data_a[m_offset + (k + tid) * lm_stride];
|
|
35
39
|
m_max = max(m_max, m);
|
|
36
40
|
}
|
|
37
41
|
|
|
42
|
+
// reduce across the workgroup
|
|
43
|
+
tmpsh[tid] = m_max;
|
|
44
|
+
barrier();
|
|
45
|
+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
46
|
+
if (tid < s) {
|
|
47
|
+
m_max = max(m_max, tmpsh[tid + s]);
|
|
48
|
+
tmpsh[tid] = m_max;
|
|
49
|
+
}
|
|
50
|
+
barrier();
|
|
51
|
+
}
|
|
52
|
+
m_max = tmpsh[0];
|
|
53
|
+
|
|
54
|
+
barrier();
|
|
55
|
+
|
|
38
56
|
// Compute L based on m_max
|
|
39
57
|
float L = 0;
|
|
40
|
-
|
|
41
|
-
float l = data_a[l_offset + k * lm_stride];
|
|
42
|
-
float m = data_a[m_offset + k * lm_stride];
|
|
58
|
+
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
|
|
59
|
+
float l = data_a[l_offset + (k + tid) * lm_stride];
|
|
60
|
+
float m = data_a[m_offset + (k + tid) * lm_stride];
|
|
43
61
|
L += exp(m - m_max) * l;
|
|
44
62
|
}
|
|
45
63
|
|
|
64
|
+
// reduce across the workgroup
|
|
65
|
+
tmpsh[tid] = L;
|
|
66
|
+
barrier();
|
|
67
|
+
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
68
|
+
if (tid < s) {
|
|
69
|
+
L += tmpsh[tid + s];
|
|
70
|
+
tmpsh[tid] = L;
|
|
71
|
+
}
|
|
72
|
+
barrier();
|
|
73
|
+
}
|
|
74
|
+
L = tmpsh[0];
|
|
75
|
+
|
|
46
76
|
L = 1.0 / L;
|
|
47
77
|
|
|
78
|
+
// D dimension is split across workgroups in the y dimension
|
|
79
|
+
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
|
48
80
|
// Scale and sum the O contributions based on m_max and store the result to memory
|
|
49
|
-
|
|
81
|
+
if (d < D) {
|
|
50
82
|
float O = 0.0;
|
|
51
83
|
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
|
52
|
-
uint o_offset = D * N * k + D * n + d;
|
|
84
|
+
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
|
53
85
|
float m = data_a[m_offset + k * lm_stride];
|
|
54
86
|
O += exp(m - m_max) * data_a[o_offset];
|
|
55
87
|
}
|
|
56
88
|
O *= L;
|
|
57
|
-
data_d[D * n + d] = O;
|
|
89
|
+
data_d[iq3 * D * N + D * n + d] = O;
|
|
58
90
|
}
|
|
59
91
|
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "glu_head.comp"
|
|
4
|
+
|
|
5
|
+
const float GELU_COEF_A = 0.044715f;
|
|
6
|
+
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
7
|
+
|
|
8
|
+
float op(float a, float b) {
|
|
9
|
+
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
|
|
10
|
+
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
#include "glu_main.comp"
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "glu_head.comp"
|
|
4
|
+
|
|
5
|
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
|
6
|
+
// ref: https://www.johndcook.com/blog/python_erf/
|
|
7
|
+
const float p_erf = 0.3275911f;
|
|
8
|
+
const float a1_erf = 0.254829592f;
|
|
9
|
+
const float a2_erf = -0.284496736f;
|
|
10
|
+
const float a3_erf = 1.421413741f;
|
|
11
|
+
const float a4_erf = -1.453152027f;
|
|
12
|
+
const float a5_erf = 1.061405429f;
|
|
13
|
+
|
|
14
|
+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
15
|
+
|
|
16
|
+
float op(float a, float b) {
|
|
17
|
+
const float a_div_sqr2 = a * SQRT_2_INV;
|
|
18
|
+
const float sign_x = sign(a_div_sqr2);
|
|
19
|
+
const float x = abs(a_div_sqr2);
|
|
20
|
+
const float t = 1.0f / (1.0f + p_erf * x);
|
|
21
|
+
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
22
|
+
const float erf_approx = sign_x * y;
|
|
23
|
+
|
|
24
|
+
return 0.5f * a * (1.0f + erf_approx) * b;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
#include "glu_main.comp"
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "generic_head.comp"
|
|
4
|
+
#include "types.comp"
|
|
5
|
+
|
|
6
|
+
#extension GL_EXT_control_flow_attributes : enable
|
|
7
|
+
|
|
8
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
9
|
+
|
|
10
|
+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
|
11
|
+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|
12
|
+
|
|
13
|
+
void main() {
|
|
14
|
+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
|
15
|
+
// ref: https://www.johndcook.com/blog/python_erf/
|
|
16
|
+
const float p_erf = 0.3275911f;
|
|
17
|
+
const float a1_erf = 0.254829592f;
|
|
18
|
+
const float a2_erf = -0.284496736f;
|
|
19
|
+
const float a3_erf = 1.421413741f;
|
|
20
|
+
const float a4_erf = -1.453152027f;
|
|
21
|
+
const float a5_erf = 1.061405429f;
|
|
22
|
+
|
|
23
|
+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
|
24
|
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
|
25
|
+
|
|
26
|
+
if (i >= p.KX) {
|
|
27
|
+
return;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const float a = float(data_a[i]);
|
|
31
|
+
const float a_div_sqr2 = a * SQRT_2_INV;
|
|
32
|
+
const float sign_x = sign(a_div_sqr2);
|
|
33
|
+
const float x = abs(a_div_sqr2);
|
|
34
|
+
const float t = 1.0f / (1.0f + p_erf * x);
|
|
35
|
+
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
|
36
|
+
const float erf_approx = sign_x * y;
|
|
37
|
+
|
|
38
|
+
data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
|
|
39
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#extension GL_EXT_shader_16bit_storage : require
|
|
2
|
+
|
|
3
|
+
#include "rte.comp"
|
|
4
|
+
|
|
5
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
6
|
+
|
|
7
|
+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
8
|
+
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
|
|
9
|
+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|
10
|
+
|
|
11
|
+
layout (push_constant) uniform parameter
|
|
12
|
+
{
|
|
13
|
+
uint N;
|
|
14
|
+
uint ne00;
|
|
15
|
+
uint ne20;
|
|
16
|
+
uint mode;
|
|
17
|
+
} p;
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
void main() {
|
|
2
|
+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
|
3
|
+
|
|
4
|
+
if (i >= p.N) {
|
|
5
|
+
return;
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
const uint row = i / p.ne20;
|
|
9
|
+
const uint col = i - row * p.ne20;
|
|
10
|
+
|
|
11
|
+
if (p.mode == 0) {
|
|
12
|
+
// Default
|
|
13
|
+
const uint offset = p.ne00 / 2;
|
|
14
|
+
const uint idx = row * p.ne00 + col;
|
|
15
|
+
|
|
16
|
+
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
|
17
|
+
} else if (p.mode == 1) {
|
|
18
|
+
// Swapped
|
|
19
|
+
const uint offset = p.ne00 / 2;
|
|
20
|
+
const uint idx = row * p.ne00 + col;
|
|
21
|
+
|
|
22
|
+
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
|
23
|
+
} else {
|
|
24
|
+
// Split
|
|
25
|
+
const uint idx = row * p.ne00 + col;
|
|
26
|
+
|
|
27
|
+
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
|
28
|
+
}
|
|
29
|
+
}
|
|
@@ -1,12 +1,9 @@
|
|
|
1
1
|
#version 450
|
|
2
2
|
|
|
3
3
|
#extension GL_EXT_shader_16bit_storage : require
|
|
4
|
-
#extension GL_EXT_spirv_intrinsics: enable
|
|
5
4
|
#extension GL_EXT_control_flow_attributes : require
|
|
6
5
|
|
|
7
|
-
#
|
|
8
|
-
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
|
9
|
-
#endif
|
|
6
|
+
#include "rte.comp"
|
|
10
7
|
|
|
11
8
|
layout (push_constant) uniform parameter
|
|
12
9
|
{
|
|
@@ -43,12 +40,10 @@ void main() {
|
|
|
43
40
|
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
|
44
41
|
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
|
|
45
42
|
const int oh_s1 = int(oh) * p.s1;
|
|
46
|
-
const uint ksize = p.OW *
|
|
43
|
+
const uint ksize = p.OW * p.KH;
|
|
47
44
|
|
|
48
45
|
const uint base_linear_idx = gidx * NUM_ITER;
|
|
49
46
|
|
|
50
|
-
const uint max_ky = ksize / p.OW;
|
|
51
|
-
|
|
52
47
|
uint current_kx = base_linear_idx / ksize;
|
|
53
48
|
const uint rem = base_linear_idx - (current_kx * ksize);
|
|
54
49
|
uint current_ky = rem / p.OW;
|
|
@@ -79,7 +74,7 @@ void main() {
|
|
|
79
74
|
|
|
80
75
|
if (++current_ix == p.OW) {
|
|
81
76
|
current_ix = 0;
|
|
82
|
-
if (++current_ky ==
|
|
77
|
+
if (++current_ky == p.KH) {
|
|
83
78
|
current_ky = 0;
|
|
84
79
|
current_kx++;
|
|
85
80
|
}
|