@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
#extension GL_KHR_cooperative_matrix : enable
|
|
19
19
|
#extension GL_KHR_memory_scope_semantics : enable
|
|
20
20
|
#extension GL_KHR_shader_subgroup_basic : enable
|
|
21
|
+
#extension GL_KHR_shader_subgroup_ballot : enable
|
|
21
22
|
#endif
|
|
22
23
|
|
|
23
24
|
#ifdef MUL_MAT_ID
|
|
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
|
|
104
105
|
|
|
105
106
|
#ifdef MUL_MAT_ID
|
|
106
107
|
shared u16vec2 row_ids[4096];
|
|
108
|
+
uint _ne1;
|
|
109
|
+
#ifdef COOPMAT
|
|
110
|
+
shared uint _ne1_sh;
|
|
111
|
+
#endif
|
|
107
112
|
#endif // MUL_MAT_ID
|
|
108
113
|
|
|
109
114
|
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
|
@@ -172,7 +177,47 @@ void main() {
|
|
|
172
177
|
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
|
173
178
|
|
|
174
179
|
#ifdef MUL_MAT_ID
|
|
175
|
-
|
|
180
|
+
#ifdef COOPMAT
|
|
181
|
+
// Spread the search across all elements in the first subgroup
|
|
182
|
+
if (gl_SubgroupID == 0) {
|
|
183
|
+
_ne1 = 0;
|
|
184
|
+
uint num_elements = p.nei1 * p.nei0;
|
|
185
|
+
|
|
186
|
+
uint ids[16];
|
|
187
|
+
uint iter = 0;
|
|
188
|
+
|
|
189
|
+
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
|
190
|
+
// prefetch up to 16 elements
|
|
191
|
+
if (iter == 0) {
|
|
192
|
+
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
193
|
+
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
|
194
|
+
bool in_range = i < num_elements;
|
|
195
|
+
uint ii1 = i / p.nei0;
|
|
196
|
+
uint ii0 = i % p.nei0;
|
|
197
|
+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
uint i = j + gl_SubgroupInvocationID;
|
|
201
|
+
bool in_range = i < num_elements;
|
|
202
|
+
uint ii1 = i / p.nei0;
|
|
203
|
+
uint ii0 = i % p.nei0;
|
|
204
|
+
uint id = ids[iter++];
|
|
205
|
+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
206
|
+
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
|
207
|
+
if (in_range && id == expert_idx) {
|
|
208
|
+
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
|
209
|
+
}
|
|
210
|
+
_ne1 += subgroupBallotBitCount(ballot);
|
|
211
|
+
iter &= 15;
|
|
212
|
+
}
|
|
213
|
+
_ne1_sh = _ne1;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
barrier();
|
|
217
|
+
|
|
218
|
+
_ne1 = _ne1_sh;
|
|
219
|
+
#else
|
|
220
|
+
_ne1 = 0;
|
|
176
221
|
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
|
177
222
|
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
|
178
223
|
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
|
@@ -183,6 +228,7 @@ void main() {
|
|
|
183
228
|
}
|
|
184
229
|
|
|
185
230
|
barrier();
|
|
231
|
+
#endif
|
|
186
232
|
|
|
187
233
|
// Workgroup has no work
|
|
188
234
|
if (ic * BN >= _ne1) return;
|
|
@@ -500,10 +546,9 @@ void main() {
|
|
|
500
546
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
501
547
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
502
548
|
|
|
503
|
-
const uint ib = idx /
|
|
504
|
-
const uint ib32 = (idx %
|
|
505
|
-
const uint ib8 =
|
|
506
|
-
const int i8 = 2 * int(idx % 4);
|
|
549
|
+
const uint ib = idx / 32; // 8 values per idx
|
|
550
|
+
const uint ib32 = (idx % 32) / 4; // 0..7
|
|
551
|
+
const uint ib8 = idx % 32;
|
|
507
552
|
|
|
508
553
|
const float d = float(data_a[ib].d);
|
|
509
554
|
const uint qh = data_a[ib].qh[ib32];
|
|
@@ -512,22 +557,16 @@ void main() {
|
|
|
512
557
|
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
|
513
558
|
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
|
514
559
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
);
|
|
519
|
-
const vec2 v = dl * (vec2(gvec) + delta);
|
|
520
|
-
|
|
521
|
-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
522
|
-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
|
560
|
+
[[unroll]] for (int k = 0; k < 8; ++k) {
|
|
561
|
+
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
|
562
|
+
}
|
|
523
563
|
#elif defined(DATA_A_IQ1_M)
|
|
524
564
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
525
565
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
526
566
|
|
|
527
|
-
const uint ib = idx /
|
|
528
|
-
const uint ib8 =
|
|
567
|
+
const uint ib = idx / 32; // 8 values per idx
|
|
568
|
+
const uint ib8 = idx % 32;
|
|
529
569
|
const uint ib16 = ib8 / 2;
|
|
530
|
-
const int i8 = 2 * int(idx % 4);
|
|
531
570
|
|
|
532
571
|
const uint16_t[4] scales = data_a[ib].scales;
|
|
533
572
|
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
|
@@ -538,21 +577,17 @@ void main() {
|
|
|
538
577
|
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
|
539
578
|
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
|
540
579
|
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
|
541
|
-
const ivec2 gvec = ivec2(
|
|
542
|
-
bitfieldExtract(grid, 2 * (i8), 2),
|
|
543
|
-
bitfieldExtract(grid, 2 * (i8 + 1), 2)
|
|
544
|
-
);
|
|
545
|
-
const vec2 v = dl * (vec2(gvec) + delta);
|
|
546
580
|
|
|
547
|
-
|
|
548
|
-
|
|
581
|
+
[[unroll]] for (int k = 0; k < 8; ++k) {
|
|
582
|
+
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
|
583
|
+
}
|
|
549
584
|
#elif defined(DATA_A_IQ2_XXS)
|
|
550
585
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
551
586
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
552
587
|
|
|
553
|
-
const uint ib = idx /
|
|
554
|
-
const uint ib32 = (idx %
|
|
555
|
-
const uint ib8 =
|
|
588
|
+
const uint ib = idx / 32; // 8 values per idx
|
|
589
|
+
const uint ib32 = (idx % 32) / 4; // 0..7
|
|
590
|
+
const uint ib8 = idx % 4;
|
|
556
591
|
|
|
557
592
|
const float d = float(data_a[ib].d);
|
|
558
593
|
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
|
@@ -562,63 +597,81 @@ void main() {
|
|
|
562
597
|
data_a[ib].qs[8*ib32 + 6],
|
|
563
598
|
data_a[ib].qs[8*ib32 + 7]
|
|
564
599
|
));
|
|
565
|
-
const
|
|
600
|
+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
|
566
601
|
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
|
567
|
-
const uint sign =
|
|
568
|
-
const
|
|
569
|
-
const
|
|
570
|
-
const
|
|
571
|
-
|
|
572
|
-
buf_a[buf_idx ] = FLOAT_TYPE(
|
|
573
|
-
buf_a[buf_idx + 1] = FLOAT_TYPE(
|
|
602
|
+
const uint sign = sign7 | (bitCount(sign7) << 7);
|
|
603
|
+
const uvec2 grid = iq2xxs_grid[qs];
|
|
604
|
+
const vec4 grid0 = vec4(unpack8(grid.x));
|
|
605
|
+
const vec4 grid1 = vec4(unpack8(grid.y));
|
|
606
|
+
|
|
607
|
+
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
|
608
|
+
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
|
609
|
+
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
|
610
|
+
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
|
611
|
+
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
|
612
|
+
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
|
613
|
+
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
|
614
|
+
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
|
574
615
|
#elif defined(DATA_A_IQ2_XS)
|
|
575
616
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
576
617
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
577
618
|
|
|
578
|
-
const uint ib = idx /
|
|
579
|
-
const uint ib32 = (idx %
|
|
580
|
-
const uint ib8 =
|
|
619
|
+
const uint ib = idx / 32; // 8 values per idx
|
|
620
|
+
const uint ib32 = (idx % 32) / 4; // 0..7
|
|
621
|
+
const uint ib8 = idx % 4; // 0..3
|
|
581
622
|
|
|
582
623
|
const float d = float(data_a[ib].d);
|
|
583
624
|
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
|
584
|
-
const
|
|
625
|
+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
|
585
626
|
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
|
586
627
|
const uint sign7 = qs >> 9;
|
|
587
|
-
const uint sign =
|
|
588
|
-
const
|
|
589
|
-
const
|
|
590
|
-
const
|
|
591
|
-
|
|
592
|
-
buf_a[buf_idx ] = FLOAT_TYPE(
|
|
593
|
-
buf_a[buf_idx + 1] = FLOAT_TYPE(
|
|
628
|
+
const uint sign = sign7 | (bitCount(sign7) << 7);
|
|
629
|
+
const uvec2 grid = iq2xs_grid[qs & 511];
|
|
630
|
+
const vec4 grid0 = vec4(unpack8(grid.x));
|
|
631
|
+
const vec4 grid1 = vec4(unpack8(grid.y));
|
|
632
|
+
|
|
633
|
+
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
|
634
|
+
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
|
635
|
+
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
|
636
|
+
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
|
637
|
+
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
|
638
|
+
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
|
639
|
+
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
|
640
|
+
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
|
594
641
|
#elif defined(DATA_A_IQ2_S)
|
|
595
642
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
596
643
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
597
644
|
|
|
598
|
-
const uint ib = idx /
|
|
599
|
-
const uint ib8 =
|
|
600
|
-
const uint ib32 = ib8 / 4;
|
|
645
|
+
const uint ib = idx / 32; // 8 values per idx
|
|
646
|
+
const uint ib8 = idx % 32; // 0..31
|
|
647
|
+
const uint ib32 = ib8 / 4; // 0..7
|
|
601
648
|
|
|
602
649
|
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
|
603
650
|
const uint qs = data_a[ib].qs[ib8];
|
|
604
651
|
const uint qh = data_a[ib].qh[ib32];
|
|
605
652
|
const uint qhshift = 2 * (ib8 % 4);
|
|
606
|
-
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]
|
|
653
|
+
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
|
607
654
|
|
|
608
655
|
const float d = float(data_a[ib].d);
|
|
609
|
-
const
|
|
610
|
-
const
|
|
611
|
-
const
|
|
612
|
-
const
|
|
613
|
-
|
|
614
|
-
buf_a[buf_idx ] = FLOAT_TYPE(
|
|
615
|
-
buf_a[buf_idx + 1] = FLOAT_TYPE(
|
|
656
|
+
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
|
657
|
+
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
|
658
|
+
const vec4 grid0 = vec4(unpack8(grid.x));
|
|
659
|
+
const vec4 grid1 = vec4(unpack8(grid.y));
|
|
660
|
+
|
|
661
|
+
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
|
662
|
+
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
|
663
|
+
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
|
664
|
+
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
|
665
|
+
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
|
666
|
+
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
|
667
|
+
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
|
668
|
+
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
|
616
669
|
#elif defined(DATA_A_IQ3_XXS)
|
|
617
670
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
618
671
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
619
672
|
|
|
620
|
-
const uint ib = idx /
|
|
621
|
-
const uint iqs =
|
|
673
|
+
const uint ib = idx / 64; // 4 values per idx
|
|
674
|
+
const uint iqs = idx % 64; // 0..63
|
|
622
675
|
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
|
623
676
|
|
|
624
677
|
const float d = float(data_a[ib].d);
|
|
@@ -631,33 +684,36 @@ void main() {
|
|
|
631
684
|
));
|
|
632
685
|
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
|
633
686
|
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
|
634
|
-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (
|
|
635
|
-
const
|
|
636
|
-
const
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
buf_a[buf_idx
|
|
640
|
-
buf_a[buf_idx +
|
|
687
|
+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
|
688
|
+
const uint grid = iq3xxs_grid[qs];
|
|
689
|
+
const vec4 v = db * vec4(unpack8(grid));
|
|
690
|
+
|
|
691
|
+
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
|
692
|
+
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
|
693
|
+
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
|
694
|
+
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
|
641
695
|
#elif defined(DATA_A_IQ3_S)
|
|
642
696
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
643
697
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
644
698
|
|
|
645
|
-
const uint ib = idx /
|
|
646
|
-
const uint iqs =
|
|
699
|
+
const uint ib = idx / 64; // 4 values per idx
|
|
700
|
+
const uint iqs = idx % 64; // 0..63
|
|
647
701
|
const uint iqh = iqs / 8;
|
|
648
702
|
|
|
649
703
|
const float d = float(data_a[ib].d);
|
|
650
704
|
const uint qs = data_a[ib].qs[iqs];
|
|
651
705
|
const uint qh = data_a[ib].qh[iqh];
|
|
652
|
-
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (
|
|
706
|
+
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
|
653
707
|
const uint scale = data_a[ib].scales[iqs / 16];
|
|
654
708
|
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
|
655
709
|
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
|
656
|
-
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]
|
|
657
|
-
const
|
|
710
|
+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
|
711
|
+
const vec4 v = db * vec4(unpack8(grid));
|
|
658
712
|
|
|
659
|
-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
660
|
-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
|
713
|
+
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
|
714
|
+
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
|
715
|
+
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
|
716
|
+
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
|
661
717
|
#elif defined(DATA_A_IQ4_XS)
|
|
662
718
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
663
719
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
@@ -162,17 +162,32 @@ void main() {
|
|
|
162
162
|
_ne1 = 0;
|
|
163
163
|
uint num_elements = p.nei1 * p.nei0;
|
|
164
164
|
|
|
165
|
-
|
|
165
|
+
uint ids[16];
|
|
166
|
+
uint iter = 0;
|
|
167
|
+
|
|
168
|
+
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
|
169
|
+
// prefetch up to 16 elements
|
|
170
|
+
if (iter == 0) {
|
|
171
|
+
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
|
172
|
+
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
|
173
|
+
bool in_range = i < num_elements;
|
|
174
|
+
uint ii1 = i / p.nei0;
|
|
175
|
+
uint ii0 = i % p.nei0;
|
|
176
|
+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
uint i = j + gl_SubgroupInvocationID;
|
|
166
180
|
bool in_range = i < num_elements;
|
|
167
|
-
uint ii0 = i % p.nei0;
|
|
168
181
|
uint ii1 = i / p.nei0;
|
|
169
|
-
uint
|
|
182
|
+
uint ii0 = i % p.nei0;
|
|
183
|
+
uint id = ids[iter++];
|
|
170
184
|
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
|
171
185
|
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
|
172
186
|
if (in_range && id == expert_idx) {
|
|
173
187
|
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
|
174
188
|
}
|
|
175
189
|
_ne1 += subgroupBallotBitCount(ballot);
|
|
190
|
+
iter &= 15;
|
|
176
191
|
}
|
|
177
192
|
_ne1_sh = _ne1;
|
|
178
193
|
}
|
|
@@ -414,17 +429,31 @@ void main() {
|
|
|
414
429
|
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
|
415
430
|
}
|
|
416
431
|
|
|
417
|
-
|
|
418
|
-
|
|
432
|
+
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
|
433
|
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
434
|
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
419
435
|
|
|
420
|
-
|
|
436
|
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
421
437
|
#ifdef MUL_MAT_ID
|
|
422
|
-
|
|
438
|
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
423
439
|
#else
|
|
424
|
-
|
|
440
|
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
425
441
|
#endif
|
|
426
442
|
|
|
427
|
-
|
|
443
|
+
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
444
|
+
} else {
|
|
445
|
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
446
|
+
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
447
|
+
|
|
448
|
+
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
449
|
+
#ifdef MUL_MAT_ID
|
|
450
|
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
451
|
+
#else
|
|
452
|
+
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
453
|
+
#endif
|
|
454
|
+
|
|
455
|
+
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
456
|
+
}
|
|
428
457
|
}
|
|
429
458
|
|
|
430
459
|
// Convert from ACC_TYPE to D_TYPE
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
#version 450
|
|
2
2
|
|
|
3
|
-
#include "
|
|
3
|
+
#include "generic_binary_head.comp"
|
|
4
4
|
#include "types.comp"
|
|
5
5
|
|
|
6
6
|
#extension GL_EXT_control_flow_attributes : enable
|
|
7
7
|
#define BLOCK_SIZE 512
|
|
8
8
|
|
|
9
|
+
layout (constant_id = 1) const bool do_multiply = false;
|
|
10
|
+
|
|
9
11
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
10
12
|
|
|
11
13
|
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
|
@@ -25,6 +27,7 @@ void main() {
|
|
|
25
27
|
const uint stride_sample = p.nb03;
|
|
26
28
|
|
|
27
29
|
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
|
30
|
+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
|
28
31
|
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
|
29
32
|
|
|
30
33
|
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
|
@@ -46,7 +49,13 @@ void main() {
|
|
|
46
49
|
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
|
47
50
|
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
|
48
51
|
|
|
49
|
-
|
|
50
|
-
|
|
52
|
+
if (do_multiply) {
|
|
53
|
+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
|
54
|
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
|
55
|
+
}
|
|
56
|
+
} else {
|
|
57
|
+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
|
58
|
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
|
59
|
+
}
|
|
51
60
|
}
|
|
52
61
|
}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
#version 450
|
|
2
|
+
|
|
3
|
+
#include "types.comp"
|
|
4
|
+
#include "generic_unary_head.comp"
|
|
5
|
+
|
|
6
|
+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
7
|
+
|
|
8
|
+
uint wrap_idx(int i, uint ne) {
|
|
9
|
+
if (i < 0) {
|
|
10
|
+
return i + ne;
|
|
11
|
+
} else if (i >= ne) {
|
|
12
|
+
return i - ne;
|
|
13
|
+
}
|
|
14
|
+
return i;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
void main() {
|
|
18
|
+
const uint idx = get_idx();
|
|
19
|
+
if (idx >= p.ne) {
|
|
20
|
+
return;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
|
24
|
+
const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
|
|
25
|
+
const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
|
|
26
|
+
const uint i2_offset = i2*p.ne11*p.ne10;
|
|
27
|
+
const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
|
|
28
|
+
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
|
|
29
|
+
|
|
30
|
+
const uint p1 = floatBitsToUint(p.param1);
|
|
31
|
+
const uint p2 = floatBitsToUint(p.param2);
|
|
32
|
+
const int s0 = int(p1 >> 16) - 0x8000;
|
|
33
|
+
const int s1 = int(p1 & 0xFFFF) - 0x8000;
|
|
34
|
+
const int s2 = int(p2 >> 16) - 0x8000;
|
|
35
|
+
const int s3 = int(p2 & 0xFFFF) - 0x8000;
|
|
36
|
+
|
|
37
|
+
const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
|
|
38
|
+
const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
|
|
39
|
+
const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
|
|
40
|
+
const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
|
|
41
|
+
|
|
42
|
+
const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
|
|
43
|
+
const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
|
|
44
|
+
|
|
45
|
+
data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
|
|
46
|
+
}
|
|
@@ -14,21 +14,19 @@ void main() {
|
|
|
14
14
|
|
|
15
15
|
const uint row_dst = gl_GlobalInvocationID.x;
|
|
16
16
|
|
|
17
|
-
if (i0 >= p.n_dims) {
|
|
18
|
-
const uint i = row_dst*ne0 + i0;
|
|
19
|
-
|
|
20
|
-
data_d[i + 0] = data_a[i + 0];
|
|
21
|
-
data_d[i + 1] = data_a[i + 1];
|
|
22
|
-
|
|
23
|
-
return;
|
|
24
|
-
}
|
|
25
|
-
|
|
26
17
|
const uint row_x = row_dst % ne1;
|
|
27
18
|
const uint channel_x = row_dst / ne1;
|
|
28
19
|
|
|
29
20
|
const uint idst = row_dst*ne0 + i0/2;
|
|
30
21
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
|
31
22
|
|
|
23
|
+
if (i0 >= p.n_dims) {
|
|
24
|
+
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
|
25
|
+
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
|
26
|
+
|
|
27
|
+
return;
|
|
28
|
+
}
|
|
29
|
+
|
|
32
30
|
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
|
|
33
31
|
const int sec_w = p.sections[1] + p.sections[0];
|
|
34
32
|
const uint sector = (i0 / 2) % sect_dims;
|
|
@@ -13,21 +13,19 @@ void main() {
|
|
|
13
13
|
|
|
14
14
|
const uint row_dst = gl_GlobalInvocationID.x;
|
|
15
15
|
|
|
16
|
-
if (i0 >= p.n_dims) {
|
|
17
|
-
const uint i = row_dst*ne0 + i0;
|
|
18
|
-
|
|
19
|
-
data_d[i + 0] = data_a[i + 0];
|
|
20
|
-
data_d[i + 1] = data_a[i + 1];
|
|
21
|
-
|
|
22
|
-
return;
|
|
23
|
-
}
|
|
24
|
-
|
|
25
16
|
const uint row_x = row_dst % ne1;
|
|
26
17
|
const uint channel_x = row_dst / ne1;
|
|
27
18
|
|
|
28
19
|
const uint idst = row_dst*ne0 + i0/2;
|
|
29
20
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
|
|
30
21
|
|
|
22
|
+
if (i0 >= p.n_dims) {
|
|
23
|
+
data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
|
|
24
|
+
data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
|
|
25
|
+
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
|
|
31
29
|
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
|
32
30
|
|
|
33
31
|
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
|
@@ -13,21 +13,19 @@ void main() {
|
|
|
13
13
|
|
|
14
14
|
const uint row_dst = gl_GlobalInvocationID.x;
|
|
15
15
|
|
|
16
|
-
if (i0 >= p.n_dims) {
|
|
17
|
-
const uint i = row_dst*ne0 + i0;
|
|
18
|
-
|
|
19
|
-
data_d[i + 0] = data_a[i + 0];
|
|
20
|
-
data_d[i + 1] = data_a[i + 1];
|
|
21
|
-
|
|
22
|
-
return;
|
|
23
|
-
}
|
|
24
|
-
|
|
25
16
|
const uint row_x = row_dst % ne1;
|
|
26
17
|
const uint channel_x = row_dst / ne1;
|
|
27
18
|
|
|
28
19
|
const uint idst = row_dst*ne0 + i0;
|
|
29
20
|
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
|
|
30
21
|
|
|
22
|
+
if (i0 >= p.n_dims) {
|
|
23
|
+
data_d[idst + 0] = data_a[ix + 0];
|
|
24
|
+
data_d[idst + 1] = data_a[ix + 1];
|
|
25
|
+
|
|
26
|
+
return;
|
|
27
|
+
}
|
|
28
|
+
|
|
31
29
|
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
|
|
32
30
|
|
|
33
31
|
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
|
|
@@ -18,7 +18,7 @@ void main() {
|
|
|
18
18
|
continue;
|
|
19
19
|
}
|
|
20
20
|
|
|
21
|
-
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
|
|
21
|
+
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
|
|
22
22
|
idx += num_threads;
|
|
23
23
|
}
|
|
24
24
|
}
|
|
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
|
|
|
6
6
|
{
|
|
7
7
|
uint KX;
|
|
8
8
|
uint KY;
|
|
9
|
+
uint ne00;
|
|
10
|
+
uint ne01;
|
|
11
|
+
uint ne02;
|
|
12
|
+
uint ne12;
|
|
13
|
+
uint ne13;
|
|
14
|
+
uint nb11;
|
|
15
|
+
uint nb12;
|
|
16
|
+
uint nb13;
|
|
9
17
|
float scale;
|
|
10
18
|
float max_bias;
|
|
11
19
|
float m0;
|
|
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
|
|
|
31
39
|
void soft_max(uint num_iters) {
|
|
32
40
|
const uint tid = gl_LocalInvocationID.x;
|
|
33
41
|
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
|
34
|
-
|
|
42
|
+
|
|
43
|
+
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
|
|
44
|
+
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
|
|
45
|
+
const uint32_t i01 = rowx % p.ne01;
|
|
46
|
+
|
|
47
|
+
uint rowy_start = 0;
|
|
48
|
+
if (p.KY > 0) {
|
|
49
|
+
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
|
|
50
|
+
}
|
|
35
51
|
|
|
36
52
|
if (rowx >= p.nrows_x) {
|
|
37
53
|
return;
|
|
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
|
|
|
41
57
|
|
|
42
58
|
// ALiBi
|
|
43
59
|
if (p.max_bias > 0.0f) {
|
|
44
|
-
const uint h = rowx/p.
|
|
60
|
+
const uint h = (rowx / p.ne01) % p.ne02; // head index
|
|
45
61
|
|
|
46
62
|
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
|
|
47
63
|
const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
|
|
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
|
|
|
67
83
|
|
|
68
84
|
FLOAT_TYPE b = FLOAT_TYPE(0);
|
|
69
85
|
if (p.KY > 0 && col < p.KX) {
|
|
70
|
-
b = data_b[
|
|
86
|
+
b = data_b[rowy_start + col];
|
|
71
87
|
}
|
|
72
88
|
|
|
73
89
|
FLOAT_TYPE v = a * p.scale + slope * b;
|
|
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
|
|
|
111
127
|
if (idx < DATA_CACHE_SIZE) {
|
|
112
128
|
val = exp(data_cache[idx] - max_val);
|
|
113
129
|
} else {
|
|
114
|
-
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[
|
|
130
|
+
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
|
115
131
|
}
|
|
116
132
|
sum += val;
|
|
117
133
|
if (idx < DATA_CACHE_SIZE) {
|