@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
#define GELU_COEF_A 0.044715f
|
|
7
7
|
#define GELU_QUICK_COEF -1.702f
|
|
8
8
|
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
|
9
|
+
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
|
9
10
|
|
|
10
11
|
kernel void kernel_gelu(
|
|
11
12
|
global float * src0,
|
|
@@ -35,6 +36,32 @@ kernel void kernel_gelu_4(
|
|
|
35
36
|
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
36
37
|
}
|
|
37
38
|
|
|
39
|
+
kernel void kernel_gelu_erf(
|
|
40
|
+
global float * src0,
|
|
41
|
+
ulong offset0,
|
|
42
|
+
global float * dst,
|
|
43
|
+
ulong offsetd
|
|
44
|
+
) {
|
|
45
|
+
src0 = (global float*)((global char*)src0 + offset0);
|
|
46
|
+
dst = (global float*)((global char*)dst + offsetd);
|
|
47
|
+
|
|
48
|
+
float x = src0[get_global_id(0)];
|
|
49
|
+
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
kernel void kernel_gelu_erf_4(
|
|
53
|
+
global float4 * src0,
|
|
54
|
+
ulong offset0,
|
|
55
|
+
global float4 * dst,
|
|
56
|
+
ulong offsetd
|
|
57
|
+
) {
|
|
58
|
+
src0 = (global float4*)((global char*)src0 + offset0);
|
|
59
|
+
dst = (global float4*)((global char*)dst + offsetd);
|
|
60
|
+
|
|
61
|
+
float4 x = src0[get_global_id(0)];
|
|
62
|
+
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
|
|
63
|
+
}
|
|
64
|
+
|
|
38
65
|
kernel void kernel_gelu_quick(
|
|
39
66
|
global float * src0,
|
|
40
67
|
ulong offset0,
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
#define GELU_COEF_A 0.044715f
|
|
4
|
+
#define GELU_QUICK_COEF -1.702f
|
|
5
|
+
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
|
6
|
+
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
|
7
|
+
|
|
8
|
+
//------------------------------------------------------------------------------
|
|
9
|
+
// geglu
|
|
10
|
+
//------------------------------------------------------------------------------
|
|
11
|
+
kernel void kernel_geglu(
|
|
12
|
+
global char * src0,
|
|
13
|
+
ulong offset0,
|
|
14
|
+
global char * src1,
|
|
15
|
+
ulong offset1,
|
|
16
|
+
global char * dst,
|
|
17
|
+
ulong offsetd,
|
|
18
|
+
ulong nb01,
|
|
19
|
+
ulong nb11,
|
|
20
|
+
int ne0,
|
|
21
|
+
ulong nb1,
|
|
22
|
+
int ne00_off,
|
|
23
|
+
int ne10_off
|
|
24
|
+
) {
|
|
25
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
26
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
27
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
28
|
+
|
|
29
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
30
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
31
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
|
32
|
+
|
|
33
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
34
|
+
const float x0 = src0_row[i0];
|
|
35
|
+
const float x1 = src1_row[i0];
|
|
36
|
+
|
|
37
|
+
const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
38
|
+
|
|
39
|
+
dst_row[i0] = gelu*x1;
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
kernel void kernel_geglu_f16(
|
|
44
|
+
global char * src0,
|
|
45
|
+
ulong offset0,
|
|
46
|
+
global char * src1,
|
|
47
|
+
ulong offset1,
|
|
48
|
+
global char * dst,
|
|
49
|
+
ulong offsetd,
|
|
50
|
+
ulong nb01,
|
|
51
|
+
ulong nb11,
|
|
52
|
+
int ne0,
|
|
53
|
+
ulong nb1,
|
|
54
|
+
int ne00_off,
|
|
55
|
+
int ne10_off
|
|
56
|
+
) {
|
|
57
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
58
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
59
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
60
|
+
|
|
61
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
62
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
63
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
|
64
|
+
|
|
65
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
66
|
+
const half x0 = src0_row[i0];
|
|
67
|
+
const half x1 = src1_row[i0];
|
|
68
|
+
|
|
69
|
+
const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
70
|
+
|
|
71
|
+
dst_row[i0] = gelu*x1;
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
//------------------------------------------------------------------------------
|
|
76
|
+
// reglu
|
|
77
|
+
//------------------------------------------------------------------------------
|
|
78
|
+
kernel void kernel_reglu(
|
|
79
|
+
global char * src0,
|
|
80
|
+
ulong offset0,
|
|
81
|
+
global char * src1,
|
|
82
|
+
ulong offset1,
|
|
83
|
+
global char * dst,
|
|
84
|
+
ulong offsetd,
|
|
85
|
+
ulong nb01,
|
|
86
|
+
ulong nb11,
|
|
87
|
+
int ne0,
|
|
88
|
+
ulong nb1,
|
|
89
|
+
int ne00_off,
|
|
90
|
+
int ne10_off
|
|
91
|
+
) {
|
|
92
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
93
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
94
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
95
|
+
|
|
96
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
97
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
98
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
|
99
|
+
|
|
100
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
101
|
+
const float x0 = src0_row[i0];
|
|
102
|
+
const float x1 = src1_row[i0];
|
|
103
|
+
|
|
104
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
kernel void kernel_reglu_f16(
|
|
109
|
+
global char * src0,
|
|
110
|
+
ulong offset0,
|
|
111
|
+
global char * src1,
|
|
112
|
+
ulong offset1,
|
|
113
|
+
global char * dst,
|
|
114
|
+
ulong offsetd,
|
|
115
|
+
ulong nb01,
|
|
116
|
+
ulong nb11,
|
|
117
|
+
int ne0,
|
|
118
|
+
ulong nb1,
|
|
119
|
+
int ne00_off,
|
|
120
|
+
int ne10_off
|
|
121
|
+
) {
|
|
122
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
123
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
124
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
125
|
+
|
|
126
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
127
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
128
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
|
129
|
+
|
|
130
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
131
|
+
const half x0 = src0_row[i0];
|
|
132
|
+
const half x1 = src1_row[i0];
|
|
133
|
+
|
|
134
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
//------------------------------------------------------------------------------
|
|
139
|
+
// swiglu
|
|
140
|
+
//------------------------------------------------------------------------------
|
|
141
|
+
kernel void kernel_swiglu(
|
|
142
|
+
global char * src0,
|
|
143
|
+
ulong offset0,
|
|
144
|
+
global char * src1,
|
|
145
|
+
ulong offset1,
|
|
146
|
+
global char * dst,
|
|
147
|
+
ulong offsetd,
|
|
148
|
+
ulong nb01,
|
|
149
|
+
ulong nb11,
|
|
150
|
+
int ne0,
|
|
151
|
+
ulong nb1,
|
|
152
|
+
int ne00_off,
|
|
153
|
+
int ne10_off
|
|
154
|
+
) {
|
|
155
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
156
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
157
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
158
|
+
|
|
159
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
160
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
161
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
|
162
|
+
|
|
163
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
164
|
+
const float x0 = src0_row[i0];
|
|
165
|
+
const float x1 = src1_row[i0];
|
|
166
|
+
|
|
167
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
|
168
|
+
|
|
169
|
+
dst_row[i0] = silu*x1;
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
kernel void kernel_swiglu_f16(
|
|
174
|
+
global char * src0,
|
|
175
|
+
ulong offset0,
|
|
176
|
+
global char * src1,
|
|
177
|
+
ulong offset1,
|
|
178
|
+
global char * dst,
|
|
179
|
+
ulong offsetd,
|
|
180
|
+
ulong nb01,
|
|
181
|
+
ulong nb11,
|
|
182
|
+
int ne0,
|
|
183
|
+
ulong nb1,
|
|
184
|
+
int ne00_off,
|
|
185
|
+
int ne10_off
|
|
186
|
+
) {
|
|
187
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
188
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
189
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
190
|
+
|
|
191
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
192
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
193
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
|
194
|
+
|
|
195
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
196
|
+
const half x0 = src0_row[i0];
|
|
197
|
+
const half x1 = src1_row[i0];
|
|
198
|
+
|
|
199
|
+
const half silu = x0 / (1.0f + exp(-x0));
|
|
200
|
+
|
|
201
|
+
dst_row[i0] = silu*x1;
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
//------------------------------------------------------------------------------
|
|
206
|
+
// geglu_erf
|
|
207
|
+
//------------------------------------------------------------------------------
|
|
208
|
+
kernel void kernel_geglu_erf(
|
|
209
|
+
global char * src0,
|
|
210
|
+
ulong offset0,
|
|
211
|
+
global char * src1,
|
|
212
|
+
ulong offset1,
|
|
213
|
+
global char * dst,
|
|
214
|
+
ulong offsetd,
|
|
215
|
+
ulong nb01,
|
|
216
|
+
ulong nb11,
|
|
217
|
+
int ne0,
|
|
218
|
+
ulong nb1,
|
|
219
|
+
int ne00_off,
|
|
220
|
+
int ne10_off
|
|
221
|
+
) {
|
|
222
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
223
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
224
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
225
|
+
|
|
226
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
227
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
228
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
|
229
|
+
|
|
230
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
231
|
+
const float x0 = src0_row[i0];
|
|
232
|
+
const float x1 = src1_row[i0];
|
|
233
|
+
|
|
234
|
+
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
|
235
|
+
|
|
236
|
+
dst_row[i0] = gelu_erf*x1;
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
kernel void kernel_geglu_erf_f16(
|
|
241
|
+
global char * src0,
|
|
242
|
+
ulong offset0,
|
|
243
|
+
global char * src1,
|
|
244
|
+
ulong offset1,
|
|
245
|
+
global char * dst,
|
|
246
|
+
ulong offsetd,
|
|
247
|
+
ulong nb01,
|
|
248
|
+
ulong nb11,
|
|
249
|
+
int ne0,
|
|
250
|
+
ulong nb1,
|
|
251
|
+
int ne00_off,
|
|
252
|
+
int ne10_off
|
|
253
|
+
) {
|
|
254
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
255
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
256
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
257
|
+
|
|
258
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
259
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
260
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
|
261
|
+
|
|
262
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
263
|
+
const half x0 = src0_row[i0];
|
|
264
|
+
const half x1 = src1_row[i0];
|
|
265
|
+
|
|
266
|
+
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
|
|
267
|
+
|
|
268
|
+
dst_row[i0] = gelu_erf*x1;
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
//------------------------------------------------------------------------------
|
|
273
|
+
// geglu_quick
|
|
274
|
+
//------------------------------------------------------------------------------
|
|
275
|
+
kernel void kernel_geglu_quick(
|
|
276
|
+
global char * src0,
|
|
277
|
+
ulong offset0,
|
|
278
|
+
global char * src1,
|
|
279
|
+
ulong offset1,
|
|
280
|
+
global char * dst,
|
|
281
|
+
ulong offsetd,
|
|
282
|
+
ulong nb01,
|
|
283
|
+
ulong nb11,
|
|
284
|
+
int ne0,
|
|
285
|
+
ulong nb1,
|
|
286
|
+
int ne00_off,
|
|
287
|
+
int ne10_off
|
|
288
|
+
) {
|
|
289
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
290
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
291
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
292
|
+
|
|
293
|
+
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
294
|
+
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
295
|
+
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
|
|
296
|
+
|
|
297
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
298
|
+
const float x0 = src0_row[i0];
|
|
299
|
+
const float x1 = src1_row[i0];
|
|
300
|
+
|
|
301
|
+
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
|
302
|
+
|
|
303
|
+
dst_row[i0] = gelu_quick*x1;
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
kernel void kernel_geglu_quick_f16(
|
|
308
|
+
global char * src0,
|
|
309
|
+
ulong offset0,
|
|
310
|
+
global char * src1,
|
|
311
|
+
ulong offset1,
|
|
312
|
+
global char * dst,
|
|
313
|
+
ulong offsetd,
|
|
314
|
+
ulong nb01,
|
|
315
|
+
ulong nb11,
|
|
316
|
+
int ne0,
|
|
317
|
+
ulong nb1,
|
|
318
|
+
int ne00_off,
|
|
319
|
+
int ne10_off
|
|
320
|
+
) {
|
|
321
|
+
src0 = (global char*)((global char*)src0 + offset0);
|
|
322
|
+
src1 = (global char*)((global char*)src1 + offset1);
|
|
323
|
+
dst = (global char*)((global char*)dst + offsetd);
|
|
324
|
+
|
|
325
|
+
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
|
|
326
|
+
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
|
|
327
|
+
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
|
|
328
|
+
|
|
329
|
+
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
|
|
330
|
+
const half x0 = src0_row[i0];
|
|
331
|
+
const half x1 = src1_row[i0];
|
|
332
|
+
|
|
333
|
+
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
|
|
334
|
+
|
|
335
|
+
dst_row[i0] = gelu_quick*x1;
|
|
336
|
+
}
|
|
337
|
+
}
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
#if defined(cl_qcom_reqd_sub_group_size)
|
|
4
|
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
|
5
|
+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
|
6
|
+
#else
|
|
7
|
+
#define REQD_SUBGROUP_SIZE_128
|
|
8
|
+
#endif
|
|
9
|
+
|
|
10
|
+
#define OPWM 64
|
|
11
|
+
#define OPWN 64
|
|
12
|
+
#define CPWK 8
|
|
13
|
+
#define OPTM 4
|
|
14
|
+
#define OPTN 8
|
|
15
|
+
|
|
16
|
+
#define WG_M (OPWM / OPTM)
|
|
17
|
+
#define WG_N (OPWN / OPTN)
|
|
18
|
+
#define VEC_K (CPWK / 4)
|
|
19
|
+
|
|
20
|
+
REQD_SUBGROUP_SIZE_128
|
|
21
|
+
__kernel void mul_mat_f16_f32(
|
|
22
|
+
const int M, const int N, const int K,
|
|
23
|
+
__global const void* A_void, ulong A_offset,
|
|
24
|
+
__global const void* B_void, ulong B_offset,
|
|
25
|
+
__global void* C_void, ulong C_offset) {
|
|
26
|
+
|
|
27
|
+
__global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
|
|
28
|
+
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
|
|
29
|
+
__global float* C = (__global float*)((__global char*)C_void + C_offset);
|
|
30
|
+
|
|
31
|
+
const int lidm = get_local_id(0);
|
|
32
|
+
const int lidn = get_local_id(1);
|
|
33
|
+
const int lid = lidn * WG_M + lidm;
|
|
34
|
+
|
|
35
|
+
const int offsetM = get_group_id(0) * OPWM;
|
|
36
|
+
const int offsetN = get_group_id(1) * OPWN;
|
|
37
|
+
|
|
38
|
+
__local half4 Alocal[OPWM][VEC_K];
|
|
39
|
+
__local float4 Blocal[OPWN][VEC_K];
|
|
40
|
+
|
|
41
|
+
float sum[OPTM][OPTN];
|
|
42
|
+
|
|
43
|
+
for (int wm = 0; wm < OPTM; wm++) {
|
|
44
|
+
for (int wn = 0; wn < OPTN; wn++) {
|
|
45
|
+
sum[wm][wn] = 0.0f;
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
const int numTiles = (K + CPWK - 1) / CPWK;
|
|
50
|
+
|
|
51
|
+
const int load_row_a = lid % OPWM;
|
|
52
|
+
const int load_vec_k_a = lid / OPWM;
|
|
53
|
+
const int global_row_a = offsetM + load_row_a;
|
|
54
|
+
|
|
55
|
+
const int load_row_b = lid % OPWN;
|
|
56
|
+
const int load_vec_k_b = lid / OPWN;
|
|
57
|
+
const int global_row_b = offsetN + load_row_b;
|
|
58
|
+
|
|
59
|
+
for (int t = 0; t < numTiles; t++) {
|
|
60
|
+
const int k_start = t * CPWK;
|
|
61
|
+
const int k_vec_start_a = k_start + load_vec_k_a * 4;
|
|
62
|
+
const int k_vec_start_b = k_start + load_vec_k_b * 4;
|
|
63
|
+
|
|
64
|
+
if (global_row_a < M && k_vec_start_a < K) {
|
|
65
|
+
if (k_vec_start_a + 3 < K) {
|
|
66
|
+
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
|
|
67
|
+
} else {
|
|
68
|
+
half4 tempA = (half4)(0.0h);
|
|
69
|
+
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
|
|
70
|
+
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
|
|
71
|
+
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
|
|
72
|
+
Alocal[load_row_a][load_vec_k_a] = tempA;
|
|
73
|
+
}
|
|
74
|
+
} else {
|
|
75
|
+
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
if (global_row_b < N && k_vec_start_b < K) {
|
|
79
|
+
if (k_vec_start_b + 3 < K) {
|
|
80
|
+
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
|
|
81
|
+
} else {
|
|
82
|
+
float4 tempB = (float4)(0.0f);
|
|
83
|
+
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
|
|
84
|
+
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
|
|
85
|
+
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
|
|
86
|
+
Blocal[load_row_b][load_vec_k_b] = tempB;
|
|
87
|
+
}
|
|
88
|
+
} else {
|
|
89
|
+
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
93
|
+
|
|
94
|
+
#pragma unroll
|
|
95
|
+
for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
|
|
96
|
+
float4 a_fvecs[OPTM];
|
|
97
|
+
int current_row_a = lidm;
|
|
98
|
+
for (int wm = 0; wm < OPTM; wm++) {
|
|
99
|
+
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
|
|
100
|
+
current_row_a += WG_M;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
float4 b_fvecs[OPTN];
|
|
104
|
+
int current_row_b = lidn;
|
|
105
|
+
for (int wn = 0; wn < OPTN; wn++) {
|
|
106
|
+
b_fvecs[wn] = Blocal[current_row_b][k_vec];
|
|
107
|
+
current_row_b += WG_N;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
for (int wm = 0; wm < OPTM; wm++) {
|
|
111
|
+
for (int wn = 0; wn < OPTN; wn++) {
|
|
112
|
+
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
for (int wm = 0; wm < OPTM; wm++) {
|
|
120
|
+
int globalRow = offsetM + lidm + wm * WG_M;
|
|
121
|
+
if (globalRow < M) {
|
|
122
|
+
for (int wn = 0; wn < OPTN; wn++) {
|
|
123
|
+
int globalCol = offsetN + lidn + wn * WG_N;
|
|
124
|
+
if (globalCol < N) {
|
|
125
|
+
C[globalCol * M + globalRow] = sum[wm][wn];
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
@@ -8,9 +8,10 @@ kernel void kernel_scale(
|
|
|
8
8
|
ulong offset0,
|
|
9
9
|
global float4 * dst,
|
|
10
10
|
ulong offsetd,
|
|
11
|
-
float scale
|
|
11
|
+
float scale,
|
|
12
|
+
float bias
|
|
12
13
|
) {
|
|
13
14
|
src0 = (global float4*)((global char*)src0 + offset0);
|
|
14
15
|
dst = (global float4*)((global char*)dst + offsetd);
|
|
15
|
-
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
|
|
16
|
+
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
|
|
16
17
|
}
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
kernel void kernel_set_rows_f32(
|
|
4
|
+
global char * src0,
|
|
5
|
+
ulong offset0,
|
|
6
|
+
global char * src1,
|
|
7
|
+
ulong offset1,
|
|
8
|
+
global char * dst,
|
|
9
|
+
ulong offsetd,
|
|
10
|
+
int ne01,
|
|
11
|
+
ulong nb01,
|
|
12
|
+
ulong nb02,
|
|
13
|
+
ulong nb03,
|
|
14
|
+
int ne11,
|
|
15
|
+
int ne12,
|
|
16
|
+
ulong nb10,
|
|
17
|
+
ulong nb11,
|
|
18
|
+
ulong nb12,
|
|
19
|
+
int nblk0,
|
|
20
|
+
ulong nb1,
|
|
21
|
+
ulong nb2,
|
|
22
|
+
ulong nb3
|
|
23
|
+
) {
|
|
24
|
+
src0 = src0 + offset0;
|
|
25
|
+
src1 = src1 + offset1;
|
|
26
|
+
dst = dst + offsetd;
|
|
27
|
+
|
|
28
|
+
int i03 = get_group_id(2);
|
|
29
|
+
int i02 = get_group_id(1);
|
|
30
|
+
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
|
31
|
+
|
|
32
|
+
if (i01 >= ne01) {
|
|
33
|
+
return;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
int i12 = i03%ne12;
|
|
37
|
+
int i11 = i02%ne11;
|
|
38
|
+
|
|
39
|
+
int i10 = i01;
|
|
40
|
+
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
|
41
|
+
|
|
42
|
+
global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
|
43
|
+
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
44
|
+
|
|
45
|
+
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
|
46
|
+
dst_row[ind] = (float)src_row[ind];
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
kernel void kernel_set_rows_f16(
|
|
51
|
+
global char * src0,
|
|
52
|
+
ulong offset0,
|
|
53
|
+
global char * src1,
|
|
54
|
+
ulong offset1,
|
|
55
|
+
global char * dst,
|
|
56
|
+
ulong offsetd,
|
|
57
|
+
int ne01,
|
|
58
|
+
ulong nb01,
|
|
59
|
+
ulong nb02,
|
|
60
|
+
ulong nb03,
|
|
61
|
+
int ne11,
|
|
62
|
+
int ne12,
|
|
63
|
+
ulong nb10,
|
|
64
|
+
ulong nb11,
|
|
65
|
+
ulong nb12,
|
|
66
|
+
int nblk0,
|
|
67
|
+
ulong nb1,
|
|
68
|
+
ulong nb2,
|
|
69
|
+
ulong nb3
|
|
70
|
+
) {
|
|
71
|
+
src0 = src0 + offset0;
|
|
72
|
+
src1 = src1 + offset1;
|
|
73
|
+
dst = dst + offsetd;
|
|
74
|
+
|
|
75
|
+
int i03 = get_group_id(2);
|
|
76
|
+
int i02 = get_group_id(1);
|
|
77
|
+
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
|
78
|
+
|
|
79
|
+
if (i01 >= ne01) {
|
|
80
|
+
return;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
int i12 = i03%ne12;
|
|
84
|
+
int i11 = i02%ne11;
|
|
85
|
+
|
|
86
|
+
int i10 = i01;
|
|
87
|
+
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
|
88
|
+
|
|
89
|
+
global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
|
90
|
+
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
91
|
+
|
|
92
|
+
for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
|
|
93
|
+
dst_row[ind] = src_row[ind];
|
|
94
|
+
}
|
|
95
|
+
}
|
|
@@ -22,32 +22,45 @@
|
|
|
22
22
|
REQD_SUBGROUP_SIZE_64
|
|
23
23
|
#endif
|
|
24
24
|
kernel void kernel_soft_max_4_f16(
|
|
25
|
-
global
|
|
25
|
+
global char * src0,
|
|
26
26
|
ulong offset0,
|
|
27
|
-
global
|
|
27
|
+
global char * src1,
|
|
28
28
|
ulong offset1,
|
|
29
|
-
global
|
|
29
|
+
global char * dst,
|
|
30
30
|
ulong offsetd,
|
|
31
31
|
int ne00,
|
|
32
|
-
|
|
33
|
-
|
|
32
|
+
ulong nb01,
|
|
33
|
+
ulong nb02,
|
|
34
|
+
ulong nb03,
|
|
35
|
+
int ne12,
|
|
36
|
+
int ne13,
|
|
37
|
+
ulong nb11,
|
|
38
|
+
ulong nb12,
|
|
39
|
+
ulong nb13,
|
|
40
|
+
ulong nb1,
|
|
41
|
+
ulong nb2,
|
|
42
|
+
ulong nb3,
|
|
34
43
|
float scale,
|
|
35
44
|
float max_bias,
|
|
36
45
|
float m0,
|
|
37
46
|
float m1,
|
|
38
47
|
int n_head_log2
|
|
39
48
|
) {
|
|
40
|
-
src0 =
|
|
41
|
-
src1 =
|
|
42
|
-
dst
|
|
49
|
+
src0 = src0 + offset0;
|
|
50
|
+
src1 = src1 + offset1;
|
|
51
|
+
dst = dst + offsetd;
|
|
43
52
|
|
|
44
53
|
int i03 = get_group_id(2);
|
|
45
54
|
int i02 = get_group_id(1);
|
|
46
55
|
int i01 = get_group_id(0);
|
|
47
56
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
57
|
+
int i13 = i03%ne13;
|
|
58
|
+
int i12 = i02%ne12;
|
|
59
|
+
int i11 = i01;
|
|
60
|
+
|
|
61
|
+
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
|
62
|
+
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
|
|
63
|
+
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
|
|
51
64
|
|
|
52
65
|
float slope = 1.0f;
|
|
53
66
|
|