local-llm-rn 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/CMakeLists.txt +285 -0
- package/cpp/common/CMakeLists.txt +149 -0
- package/cpp/common/arg.cpp +3799 -0
- package/cpp/common/arg.h +131 -0
- package/cpp/common/base64.hpp +392 -0
- package/cpp/common/build-info.cpp.in +4 -0
- package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
- package/cpp/common/chat-parser-xml-toolcall.h +45 -0
- package/cpp/common/chat-parser.cpp +1649 -0
- package/cpp/common/chat-parser.h +133 -0
- package/cpp/common/chat-peg-parser.cpp +124 -0
- package/cpp/common/chat-peg-parser.h +105 -0
- package/cpp/common/chat.cpp +3355 -0
- package/cpp/common/chat.h +252 -0
- package/cpp/common/common.cpp +1824 -0
- package/cpp/common/common.h +930 -0
- package/cpp/common/console.cpp +1137 -0
- package/cpp/common/console.h +41 -0
- package/cpp/common/debug.cpp +167 -0
- package/cpp/common/debug.h +43 -0
- package/cpp/common/download.cpp +792 -0
- package/cpp/common/download.h +84 -0
- package/cpp/common/http.h +84 -0
- package/cpp/common/jinja/README.md +88 -0
- package/cpp/common/jinja/caps.cpp +285 -0
- package/cpp/common/jinja/caps.h +30 -0
- package/cpp/common/jinja/lexer.cpp +341 -0
- package/cpp/common/jinja/lexer.h +157 -0
- package/cpp/common/jinja/parser.cpp +591 -0
- package/cpp/common/jinja/parser.h +21 -0
- package/cpp/common/jinja/runtime.cpp +867 -0
- package/cpp/common/jinja/runtime.h +638 -0
- package/cpp/common/jinja/string.cpp +213 -0
- package/cpp/common/jinja/string.h +61 -0
- package/cpp/common/jinja/utils.h +149 -0
- package/cpp/common/jinja/value.cpp +1393 -0
- package/cpp/common/jinja/value.h +756 -0
- package/cpp/common/json-partial.cpp +324 -0
- package/cpp/common/json-partial.h +39 -0
- package/cpp/common/json-schema-to-grammar.cpp +1153 -0
- package/cpp/common/json-schema-to-grammar.h +43 -0
- package/cpp/common/llguidance.cpp +258 -0
- package/cpp/common/log.cpp +446 -0
- package/cpp/common/log.h +119 -0
- package/cpp/common/ngram-cache.cpp +285 -0
- package/cpp/common/ngram-cache.h +101 -0
- package/cpp/common/ngram-map.cpp +530 -0
- package/cpp/common/ngram-map.h +115 -0
- package/cpp/common/ngram-mod.cpp +60 -0
- package/cpp/common/ngram-mod.h +38 -0
- package/cpp/common/peg-parser.cpp +1712 -0
- package/cpp/common/peg-parser.h +459 -0
- package/cpp/common/preset.cpp +483 -0
- package/cpp/common/preset.h +83 -0
- package/cpp/common/regex-partial.cpp +204 -0
- package/cpp/common/regex-partial.h +56 -0
- package/cpp/common/sampling.cpp +745 -0
- package/cpp/common/sampling.h +119 -0
- package/cpp/common/speculative.cpp +1074 -0
- package/cpp/common/speculative.h +41 -0
- package/cpp/common/unicode.cpp +64 -0
- package/cpp/common/unicode.h +22 -0
- package/cpp/ggml/CMakeLists.txt +494 -0
- package/cpp/ggml/cmake/GitVars.cmake +22 -0
- package/cpp/ggml/cmake/common.cmake +50 -0
- package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
- package/cpp/ggml/include/ggml-alloc.h +85 -0
- package/cpp/ggml/include/ggml-backend.h +373 -0
- package/cpp/ggml/include/ggml-blas.h +25 -0
- package/cpp/ggml/include/ggml-cann.h +123 -0
- package/cpp/ggml/include/ggml-cpp.h +39 -0
- package/cpp/ggml/include/ggml-cpu.h +151 -0
- package/cpp/ggml/include/ggml-cuda.h +47 -0
- package/cpp/ggml/include/ggml-hexagon.h +19 -0
- package/cpp/ggml/include/ggml-metal.h +61 -0
- package/cpp/ggml/include/ggml-opencl.h +26 -0
- package/cpp/ggml/include/ggml-opt.h +256 -0
- package/cpp/ggml/include/ggml-rpc.h +30 -0
- package/cpp/ggml/include/ggml-sycl.h +49 -0
- package/cpp/ggml/include/ggml-virtgpu.h +14 -0
- package/cpp/ggml/include/ggml-vulkan.h +29 -0
- package/cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/ggml/include/ggml-zdnn.h +17 -0
- package/cpp/ggml/include/ggml-zendnn.h +22 -0
- package/cpp/ggml/include/ggml.h +2753 -0
- package/cpp/ggml/include/gguf.h +204 -0
- package/cpp/ggml/src/CMakeLists.txt +492 -0
- package/cpp/ggml/src/ggml-alloc.c +1244 -0
- package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
- package/cpp/ggml/src/ggml-backend-dl.h +45 -0
- package/cpp/ggml/src/ggml-backend-impl.h +255 -0
- package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
- package/cpp/ggml/src/ggml-backend.cpp +2270 -0
- package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
- package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
- package/cpp/ggml/src/ggml-common.h +1878 -0
- package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
- package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
- package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
- package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
- package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
- package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
- package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
- package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- package/cpp/ggml/src/ggml-cpu/common.h +95 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
- package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
- package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
- package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
- package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
- package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
- package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
- package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
- package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
- package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
- package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
- package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
- package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
- package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
- package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
- package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
- package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
- package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
- package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
- package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
- package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
- package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
- package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
- package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
- package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
- package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
- package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
- package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
- package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
- package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
- package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
- package/cpp/ggml/src/ggml-impl.h +724 -0
- package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
- package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
- package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
- package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
- package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
- package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
- package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
- package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
- package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
- package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
- package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
- package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
- package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
- package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
- package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
- package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- package/cpp/ggml/src/ggml-opt.cpp +1093 -0
- package/cpp/ggml/src/ggml-quants.c +5325 -0
- package/cpp/ggml/src/ggml-quants.h +106 -0
- package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
- package/cpp/ggml/src/ggml-threading.cpp +12 -0
- package/cpp/ggml/src/ggml-threading.h +14 -0
- package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
- package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
- package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
- package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
- package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
- package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
- package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- package/cpp/ggml/src/ggml.c +7669 -0
- package/cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/ggml/src/gguf.cpp +1699 -0
- package/cpp/include/llama-cpp.h +32 -0
- package/cpp/include/llama.h +1568 -0
- package/cpp/mtmd/CMakeLists.txt +98 -0
- package/cpp/mtmd/README.md +63 -0
- package/cpp/mtmd/clip-graph.h +117 -0
- package/cpp/mtmd/clip-impl.h +586 -0
- package/cpp/mtmd/clip-model.h +390 -0
- package/cpp/mtmd/clip.cpp +4154 -0
- package/cpp/mtmd/clip.h +121 -0
- package/cpp/mtmd/deprecation-warning.cpp +22 -0
- package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
- package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
- package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
- package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
- package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
- package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
- package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
- package/cpp/mtmd/models/cogvlm.cpp +98 -0
- package/cpp/mtmd/models/conformer.cpp +216 -0
- package/cpp/mtmd/models/glm4v.cpp +122 -0
- package/cpp/mtmd/models/internvl.cpp +69 -0
- package/cpp/mtmd/models/kimik25.cpp +101 -0
- package/cpp/mtmd/models/kimivl.cpp +63 -0
- package/cpp/mtmd/models/llama4.cpp +96 -0
- package/cpp/mtmd/models/llava.cpp +374 -0
- package/cpp/mtmd/models/minicpmv.cpp +114 -0
- package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
- package/cpp/mtmd/models/models.h +128 -0
- package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
- package/cpp/mtmd/models/paddleocr.cpp +52 -0
- package/cpp/mtmd/models/pixtral.cpp +86 -0
- package/cpp/mtmd/models/qwen2vl.cpp +183 -0
- package/cpp/mtmd/models/qwen3vl.cpp +193 -0
- package/cpp/mtmd/models/siglip.cpp +86 -0
- package/cpp/mtmd/models/whisper-enc.cpp +115 -0
- package/cpp/mtmd/models/youtuvl.cpp +179 -0
- package/cpp/mtmd/mtmd-audio.cpp +730 -0
- package/cpp/mtmd/mtmd-audio.h +113 -0
- package/cpp/mtmd/mtmd-cli.cpp +437 -0
- package/cpp/mtmd/mtmd-helper.cpp +521 -0
- package/cpp/mtmd/mtmd-helper.h +96 -0
- package/cpp/mtmd/mtmd.cpp +1156 -0
- package/cpp/mtmd/mtmd.h +319 -0
- package/cpp/mtmd/requirements.txt +5 -0
- package/cpp/mtmd/test-1.jpeg +0 -0
- package/cpp/mtmd/test-2.mp3 +0 -0
- package/cpp/mtmd/tests.sh +192 -0
- package/cpp/src/CMakeLists.txt +169 -0
- package/cpp/src/llama-adapter.cpp +488 -0
- package/cpp/src/llama-adapter.h +89 -0
- package/cpp/src/llama-arch.cpp +2855 -0
- package/cpp/src/llama-arch.h +619 -0
- package/cpp/src/llama-batch.cpp +917 -0
- package/cpp/src/llama-batch.h +173 -0
- package/cpp/src/llama-chat.cpp +896 -0
- package/cpp/src/llama-chat.h +71 -0
- package/cpp/src/llama-context.cpp +3512 -0
- package/cpp/src/llama-context.h +359 -0
- package/cpp/src/llama-cparams.cpp +5 -0
- package/cpp/src/llama-cparams.h +44 -0
- package/cpp/src/llama-grammar.cpp +1464 -0
- package/cpp/src/llama-grammar.h +194 -0
- package/cpp/src/llama-graph.cpp +2685 -0
- package/cpp/src/llama-graph.h +1026 -0
- package/cpp/src/llama-hparams.cpp +234 -0
- package/cpp/src/llama-hparams.h +339 -0
- package/cpp/src/llama-impl.cpp +171 -0
- package/cpp/src/llama-impl.h +73 -0
- package/cpp/src/llama-io.cpp +15 -0
- package/cpp/src/llama-io.h +35 -0
- package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
- package/cpp/src/llama-kv-cache-iswa.h +137 -0
- package/cpp/src/llama-kv-cache.cpp +2271 -0
- package/cpp/src/llama-kv-cache.h +388 -0
- package/cpp/src/llama-kv-cells.h +533 -0
- package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
- package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
- package/cpp/src/llama-memory-hybrid.cpp +268 -0
- package/cpp/src/llama-memory-hybrid.h +139 -0
- package/cpp/src/llama-memory-recurrent.cpp +1165 -0
- package/cpp/src/llama-memory-recurrent.h +182 -0
- package/cpp/src/llama-memory.cpp +59 -0
- package/cpp/src/llama-memory.h +122 -0
- package/cpp/src/llama-mmap.cpp +785 -0
- package/cpp/src/llama-mmap.h +92 -0
- package/cpp/src/llama-model-loader.cpp +1414 -0
- package/cpp/src/llama-model-loader.h +203 -0
- package/cpp/src/llama-model-saver.cpp +286 -0
- package/cpp/src/llama-model-saver.h +37 -0
- package/cpp/src/llama-model.cpp +9253 -0
- package/cpp/src/llama-model.h +576 -0
- package/cpp/src/llama-quant.cpp +1119 -0
- package/cpp/src/llama-quant.h +1 -0
- package/cpp/src/llama-sampler.cpp +3885 -0
- package/cpp/src/llama-sampler.h +42 -0
- package/cpp/src/llama-vocab.cpp +3970 -0
- package/cpp/src/llama-vocab.h +187 -0
- package/cpp/src/llama.cpp +1313 -0
- package/cpp/src/models/afmoe.cpp +191 -0
- package/cpp/src/models/apertus.cpp +125 -0
- package/cpp/src/models/arcee.cpp +135 -0
- package/cpp/src/models/arctic.cpp +138 -0
- package/cpp/src/models/arwkv7.cpp +86 -0
- package/cpp/src/models/baichuan.cpp +122 -0
- package/cpp/src/models/bailingmoe.cpp +144 -0
- package/cpp/src/models/bailingmoe2.cpp +135 -0
- package/cpp/src/models/bert.cpp +178 -0
- package/cpp/src/models/bitnet.cpp +160 -0
- package/cpp/src/models/bloom.cpp +101 -0
- package/cpp/src/models/chameleon.cpp +178 -0
- package/cpp/src/models/chatglm.cpp +132 -0
- package/cpp/src/models/codeshell.cpp +111 -0
- package/cpp/src/models/cogvlm.cpp +102 -0
- package/cpp/src/models/cohere2-iswa.cpp +134 -0
- package/cpp/src/models/command-r.cpp +122 -0
- package/cpp/src/models/dbrx.cpp +123 -0
- package/cpp/src/models/deci.cpp +135 -0
- package/cpp/src/models/deepseek.cpp +144 -0
- package/cpp/src/models/deepseek2.cpp +262 -0
- package/cpp/src/models/delta-net-base.cpp +376 -0
- package/cpp/src/models/dots1.cpp +134 -0
- package/cpp/src/models/dream.cpp +105 -0
- package/cpp/src/models/ernie4-5-moe.cpp +150 -0
- package/cpp/src/models/ernie4-5.cpp +110 -0
- package/cpp/src/models/eurobert.cpp +97 -0
- package/cpp/src/models/exaone-moe.cpp +146 -0
- package/cpp/src/models/exaone.cpp +114 -0
- package/cpp/src/models/exaone4.cpp +123 -0
- package/cpp/src/models/falcon-h1.cpp +111 -0
- package/cpp/src/models/falcon.cpp +120 -0
- package/cpp/src/models/gemma-embedding.cpp +116 -0
- package/cpp/src/models/gemma.cpp +112 -0
- package/cpp/src/models/gemma2-iswa.cpp +128 -0
- package/cpp/src/models/gemma3.cpp +155 -0
- package/cpp/src/models/gemma3n-iswa.cpp +384 -0
- package/cpp/src/models/glm4-moe.cpp +170 -0
- package/cpp/src/models/glm4.cpp +157 -0
- package/cpp/src/models/gpt2.cpp +105 -0
- package/cpp/src/models/gptneox.cpp +144 -0
- package/cpp/src/models/granite-hybrid.cpp +196 -0
- package/cpp/src/models/granite.cpp +211 -0
- package/cpp/src/models/grok.cpp +159 -0
- package/cpp/src/models/grovemoe.cpp +141 -0
- package/cpp/src/models/hunyuan-dense.cpp +132 -0
- package/cpp/src/models/hunyuan-moe.cpp +154 -0
- package/cpp/src/models/internlm2.cpp +120 -0
- package/cpp/src/models/jais.cpp +86 -0
- package/cpp/src/models/jais2.cpp +123 -0
- package/cpp/src/models/jamba.cpp +106 -0
- package/cpp/src/models/kimi-linear.cpp +392 -0
- package/cpp/src/models/lfm2.cpp +190 -0
- package/cpp/src/models/llada-moe.cpp +122 -0
- package/cpp/src/models/llada.cpp +99 -0
- package/cpp/src/models/llama-iswa.cpp +178 -0
- package/cpp/src/models/llama.cpp +168 -0
- package/cpp/src/models/maincoder.cpp +117 -0
- package/cpp/src/models/mamba-base.cpp +285 -0
- package/cpp/src/models/mamba.cpp +54 -0
- package/cpp/src/models/mimo2-iswa.cpp +123 -0
- package/cpp/src/models/minicpm3.cpp +200 -0
- package/cpp/src/models/minimax-m2.cpp +124 -0
- package/cpp/src/models/mistral3.cpp +160 -0
- package/cpp/src/models/models.h +684 -0
- package/cpp/src/models/modern-bert.cpp +109 -0
- package/cpp/src/models/mpt.cpp +126 -0
- package/cpp/src/models/nemotron-h.cpp +148 -0
- package/cpp/src/models/nemotron.cpp +122 -0
- package/cpp/src/models/neo-bert.cpp +104 -0
- package/cpp/src/models/olmo.cpp +121 -0
- package/cpp/src/models/olmo2.cpp +150 -0
- package/cpp/src/models/olmoe.cpp +124 -0
- package/cpp/src/models/openai-moe-iswa.cpp +127 -0
- package/cpp/src/models/openelm.cpp +124 -0
- package/cpp/src/models/orion.cpp +123 -0
- package/cpp/src/models/paddleocr.cpp +122 -0
- package/cpp/src/models/pangu-embedded.cpp +121 -0
- package/cpp/src/models/phi2.cpp +121 -0
- package/cpp/src/models/phi3.cpp +152 -0
- package/cpp/src/models/plamo.cpp +110 -0
- package/cpp/src/models/plamo2.cpp +318 -0
- package/cpp/src/models/plamo3.cpp +128 -0
- package/cpp/src/models/plm.cpp +169 -0
- package/cpp/src/models/qwen.cpp +108 -0
- package/cpp/src/models/qwen2.cpp +126 -0
- package/cpp/src/models/qwen2moe.cpp +151 -0
- package/cpp/src/models/qwen2vl.cpp +117 -0
- package/cpp/src/models/qwen3.cpp +117 -0
- package/cpp/src/models/qwen35.cpp +386 -0
- package/cpp/src/models/qwen35moe.cpp +420 -0
- package/cpp/src/models/qwen3moe.cpp +124 -0
- package/cpp/src/models/qwen3next.cpp +525 -0
- package/cpp/src/models/qwen3vl-moe.cpp +140 -0
- package/cpp/src/models/qwen3vl.cpp +132 -0
- package/cpp/src/models/refact.cpp +94 -0
- package/cpp/src/models/rnd1.cpp +126 -0
- package/cpp/src/models/rwkv6-base.cpp +164 -0
- package/cpp/src/models/rwkv6.cpp +94 -0
- package/cpp/src/models/rwkv6qwen2.cpp +86 -0
- package/cpp/src/models/rwkv7-base.cpp +137 -0
- package/cpp/src/models/rwkv7.cpp +90 -0
- package/cpp/src/models/seed-oss.cpp +124 -0
- package/cpp/src/models/smallthinker.cpp +126 -0
- package/cpp/src/models/smollm3.cpp +128 -0
- package/cpp/src/models/stablelm.cpp +146 -0
- package/cpp/src/models/starcoder.cpp +100 -0
- package/cpp/src/models/starcoder2.cpp +121 -0
- package/cpp/src/models/step35-iswa.cpp +168 -0
- package/cpp/src/models/t5-dec.cpp +166 -0
- package/cpp/src/models/t5-enc.cpp +96 -0
- package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
- package/cpp/src/models/xverse.cpp +108 -0
- package/cpp/src/unicode-data.cpp +7034 -0
- package/cpp/src/unicode-data.h +20 -0
- package/cpp/src/unicode.cpp +1103 -0
- package/cpp/src/unicode.h +111 -0
- package/cpp/vendor/nlohmann/json.hpp +25526 -0
- package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/vendor/stb/stb_image.h +7988 -0
- package/ios/LocalLLM-Bridging-Header.h +2 -0
- package/ios/LocalLLM.h +5 -0
- package/ios/LocalLLM.mm +1267 -0
- package/local-llm-rn.podspec +60 -0
- package/package.json +35 -0
- package/src/NativeLocalLLM.ts +73 -0
- package/src/device.ts +50 -0
- package/src/download-adapter.ts +17 -0
- package/src/index.ts +21 -0
- package/src/native-bridge.ts +142 -0
- package/src/rn-downloader.ts +37 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
//------------------------------------------------------------------------------
|
|
4
|
+
// expm1
|
|
5
|
+
//------------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
kernel void kernel_expm1_f32(
|
|
8
|
+
global const float * src0,
|
|
9
|
+
ulong offset0,
|
|
10
|
+
global float * dst,
|
|
11
|
+
ulong offsetd
|
|
12
|
+
) {
|
|
13
|
+
src0 = (global float*)((global char*)src0 + offset0);
|
|
14
|
+
dst = (global float*)((global char*)dst + offsetd);
|
|
15
|
+
|
|
16
|
+
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
kernel void kernel_expm1_f32_4(
|
|
20
|
+
global const float4 * src0,
|
|
21
|
+
ulong offset0,
|
|
22
|
+
global float4 * dst,
|
|
23
|
+
ulong offsetd
|
|
24
|
+
) {
|
|
25
|
+
src0 = (global float4*)((global char*)src0 + offset0);
|
|
26
|
+
dst = (global float4*)((global char*)dst + offsetd);
|
|
27
|
+
|
|
28
|
+
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
kernel void kernel_expm1_f16(
|
|
32
|
+
global const half * src0,
|
|
33
|
+
ulong offset0,
|
|
34
|
+
global half * dst,
|
|
35
|
+
ulong offsetd
|
|
36
|
+
) {
|
|
37
|
+
src0 = (global half*)((global char*)src0 + offset0);
|
|
38
|
+
dst = (global half*)((global char*)dst + offsetd);
|
|
39
|
+
|
|
40
|
+
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
kernel void kernel_expm1_f16_4(
|
|
44
|
+
global const half4 * src0,
|
|
45
|
+
ulong offset0,
|
|
46
|
+
global half4 * dst,
|
|
47
|
+
ulong offsetd
|
|
48
|
+
) {
|
|
49
|
+
src0 = (global half4*)((global char*)src0 + offset0);
|
|
50
|
+
dst = (global half4*)((global char*)dst + offsetd);
|
|
51
|
+
|
|
52
|
+
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
kernel void kernel_expm1_f32_nc(
|
|
56
|
+
global const char * src0,
|
|
57
|
+
ulong offset0,
|
|
58
|
+
global char * dst,
|
|
59
|
+
ulong offsetd,
|
|
60
|
+
int ne00,
|
|
61
|
+
ulong nb00,
|
|
62
|
+
ulong nb01,
|
|
63
|
+
ulong nb02,
|
|
64
|
+
ulong nb03,
|
|
65
|
+
ulong nb0,
|
|
66
|
+
ulong nb1,
|
|
67
|
+
ulong nb2,
|
|
68
|
+
ulong nb3
|
|
69
|
+
) {
|
|
70
|
+
src0 = src0 + offset0;
|
|
71
|
+
dst = dst + offsetd;
|
|
72
|
+
|
|
73
|
+
const int i3 = get_group_id(2);
|
|
74
|
+
const int i2 = get_group_id(1);
|
|
75
|
+
const int i1 = get_group_id(0);
|
|
76
|
+
|
|
77
|
+
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
|
78
|
+
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
79
|
+
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
80
|
+
|
|
81
|
+
*y = exp(*x) - 1.0f;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
kernel void kernel_expm1_f16_nc(
|
|
86
|
+
global const char * src0,
|
|
87
|
+
ulong offset0,
|
|
88
|
+
global char * dst,
|
|
89
|
+
ulong offsetd,
|
|
90
|
+
int ne00,
|
|
91
|
+
ulong nb00,
|
|
92
|
+
ulong nb01,
|
|
93
|
+
ulong nb02,
|
|
94
|
+
ulong nb03,
|
|
95
|
+
ulong nb0,
|
|
96
|
+
ulong nb1,
|
|
97
|
+
ulong nb2,
|
|
98
|
+
ulong nb3
|
|
99
|
+
) {
|
|
100
|
+
src0 = src0 + offset0;
|
|
101
|
+
dst = dst + offsetd;
|
|
102
|
+
|
|
103
|
+
const int i3 = get_group_id(2);
|
|
104
|
+
const int i2 = get_group_id(1);
|
|
105
|
+
const int i1 = get_group_id(0);
|
|
106
|
+
|
|
107
|
+
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
|
108
|
+
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
109
|
+
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
110
|
+
|
|
111
|
+
*y = exp(*x) - 1.0f;
|
|
112
|
+
}
|
|
113
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
//------------------------------------------------------------------------------
|
|
4
|
+
// fill
|
|
5
|
+
//------------------------------------------------------------------------------
|
|
6
|
+
__kernel void kernel_fill_f32(
|
|
7
|
+
__global float *dst,
|
|
8
|
+
ulong offsetd,
|
|
9
|
+
float v,
|
|
10
|
+
int n
|
|
11
|
+
|
|
12
|
+
) {
|
|
13
|
+
dst = (global float*)((global char*)dst + offsetd);
|
|
14
|
+
if(get_global_id(0) < n){
|
|
15
|
+
dst[get_global_id(0)] = v;
|
|
16
|
+
}
|
|
17
|
+
}
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
#define ACC_TYPE float
|
|
4
|
+
#define ACC_TYPE4 float4
|
|
5
|
+
#define DATA_TYPE half
|
|
6
|
+
#define DATA_TYPE4 half4
|
|
7
|
+
#define CONVERT_ACC4(x) convert_float4(x)
|
|
8
|
+
#define CONVERT_DATA4(x) convert_half4(x)
|
|
9
|
+
|
|
10
|
+
#define DK_VEC (DK/4)
|
|
11
|
+
#define DV_VEC (DV/4)
|
|
12
|
+
#define WG_SIZE (BLOCK_M)
|
|
13
|
+
#define Q1_WG_SIZE 64
|
|
14
|
+
|
|
15
|
+
inline float get_alibi_slope(
|
|
16
|
+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
|
17
|
+
) {
|
|
18
|
+
if (max_bias <= 0.0f) {
|
|
19
|
+
return 1.0f;
|
|
20
|
+
}
|
|
21
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
22
|
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
23
|
+
|
|
24
|
+
return pow(base, exph);
|
|
25
|
+
}
|
|
26
|
+
__kernel void flash_attn_f16(
|
|
27
|
+
const global void * q_void, ulong q_offset,
|
|
28
|
+
const global void * k_void, ulong k_offset,
|
|
29
|
+
const global void * v_void, ulong v_offset,
|
|
30
|
+
global void * o_void, ulong o_offset,
|
|
31
|
+
const float scale,
|
|
32
|
+
const int n_q,
|
|
33
|
+
const int n_kv,
|
|
34
|
+
const int is_causal,
|
|
35
|
+
const int n_head,
|
|
36
|
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
|
37
|
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
|
38
|
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
|
39
|
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
|
40
|
+
const float max_bias,
|
|
41
|
+
const float m0,
|
|
42
|
+
const float m1,
|
|
43
|
+
const int n_head_log2,
|
|
44
|
+
const float logit_softcap,
|
|
45
|
+
const int n_head_kv,
|
|
46
|
+
const global void* mask_void,
|
|
47
|
+
const ulong mask_offset,
|
|
48
|
+
const ulong mask_nb1,
|
|
49
|
+
const ulong mask_nb2,
|
|
50
|
+
const ulong mask_nb3,
|
|
51
|
+
const int mask_ne2,
|
|
52
|
+
const int mask_ne3,
|
|
53
|
+
const global void* sinks_void,
|
|
54
|
+
const ulong sinks_offset
|
|
55
|
+
) {
|
|
56
|
+
const int tid = get_local_id(0);
|
|
57
|
+
const int block_q_idx = get_group_id(0);
|
|
58
|
+
const int head_batch_idx = get_global_id(1);
|
|
59
|
+
|
|
60
|
+
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
|
61
|
+
|
|
62
|
+
const int batch_idx = head_batch_idx / n_head;
|
|
63
|
+
const int head_idx = head_batch_idx % n_head;
|
|
64
|
+
|
|
65
|
+
const int gqa_ratio = n_head / n_head_kv;
|
|
66
|
+
const int head_kv_idx = head_idx / gqa_ratio;
|
|
67
|
+
|
|
68
|
+
const global char* q_base = (const global char*)q_void + q_offset;
|
|
69
|
+
const global char* k_base = (const global char*)k_void + k_offset;
|
|
70
|
+
const global char* v_base = (const global char*)v_void + v_offset;
|
|
71
|
+
global char* o_base = (global char*)o_void + o_offset;
|
|
72
|
+
|
|
73
|
+
const global char* mask_base = NULL;
|
|
74
|
+
if (mask_void != NULL) {
|
|
75
|
+
const int mask_head_idx = head_idx % mask_ne2;
|
|
76
|
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
|
77
|
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
ACC_TYPE4 q_priv[DK_VEC];
|
|
81
|
+
if (my_query_row < n_q) {
|
|
82
|
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
|
83
|
+
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
|
84
|
+
#pragma unroll
|
|
85
|
+
for (int i = 0; i < DK_VEC; ++i) {
|
|
86
|
+
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
ACC_TYPE4 o_acc[DV_VEC];
|
|
91
|
+
#pragma unroll
|
|
92
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
93
|
+
o_acc[i] = (ACC_TYPE4)(0.0f);
|
|
94
|
+
}
|
|
95
|
+
ACC_TYPE m_i = -INFINITY;
|
|
96
|
+
ACC_TYPE l_i = 0.0f;
|
|
97
|
+
|
|
98
|
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
|
99
|
+
|
|
100
|
+
__local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
|
101
|
+
__local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
|
102
|
+
|
|
103
|
+
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
|
104
|
+
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
|
105
|
+
const int row = i / DK_VEC;
|
|
106
|
+
const int col = i % DK_VEC;
|
|
107
|
+
const int k_row_idx = k_start + row;
|
|
108
|
+
if (k_row_idx < n_kv) {
|
|
109
|
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
|
110
|
+
l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
|
114
|
+
const int row = i / DV_VEC;
|
|
115
|
+
const int col = i % DV_VEC;
|
|
116
|
+
const int v_row_idx = k_start + row;
|
|
117
|
+
if (v_row_idx < n_kv) {
|
|
118
|
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
|
119
|
+
l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
123
|
+
|
|
124
|
+
if (my_query_row >= n_q) {
|
|
125
|
+
continue;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
for (int j = 0; j < BLOCK_N; j += 2) {
|
|
129
|
+
const int k_row0 = k_start + j;
|
|
130
|
+
const int k_row1 = k_start + j + 1;
|
|
131
|
+
|
|
132
|
+
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
|
133
|
+
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
|
134
|
+
#pragma unroll
|
|
135
|
+
for (int k = 0; k < DK_VEC; k++) {
|
|
136
|
+
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
|
137
|
+
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
|
138
|
+
}
|
|
139
|
+
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
|
140
|
+
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
|
141
|
+
|
|
142
|
+
if (is_causal) {
|
|
143
|
+
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
|
144
|
+
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
if (k_row0 >= n_kv) score0 = -INFINITY;
|
|
148
|
+
if (k_row1 >= n_kv) score1 = -INFINITY;
|
|
149
|
+
|
|
150
|
+
if (mask_base != NULL) {
|
|
151
|
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
|
152
|
+
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
|
153
|
+
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
if (logit_softcap > 0.0f) {
|
|
157
|
+
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
|
158
|
+
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
|
162
|
+
const ACC_TYPE p0 = exp(score0 - m_new);
|
|
163
|
+
const ACC_TYPE p1 = exp(score1 - m_new);
|
|
164
|
+
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
|
165
|
+
|
|
166
|
+
#pragma unroll
|
|
167
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
168
|
+
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
|
169
|
+
}
|
|
170
|
+
l_i = l_i * scale_prev + p0 + p1;
|
|
171
|
+
m_i = m_new;
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
if (my_query_row < n_q) {
|
|
176
|
+
if (sinks_void != NULL) {
|
|
177
|
+
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
|
178
|
+
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
|
179
|
+
const ACC_TYPE m_final = max(m_i, m_sink);
|
|
180
|
+
|
|
181
|
+
const ACC_TYPE scale_o = exp(m_i - m_final);
|
|
182
|
+
#pragma unroll
|
|
183
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
184
|
+
o_acc[i] *= scale_o;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
|
191
|
+
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
|
192
|
+
if (l_i > 0.0f) {
|
|
193
|
+
const ACC_TYPE l_inv = 1.0f / l_i;
|
|
194
|
+
#pragma unroll
|
|
195
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
196
|
+
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
|
197
|
+
}
|
|
198
|
+
} else {
|
|
199
|
+
#pragma unroll
|
|
200
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
201
|
+
o_row[i] = (DATA_TYPE4)(0.0f);
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
__kernel void flash_attn_f16_q1(
|
|
208
|
+
const global void * q_void, ulong q_offset,
|
|
209
|
+
const global void * k_void, ulong k_offset,
|
|
210
|
+
const global void * v_void, ulong v_offset,
|
|
211
|
+
global void * o_void, ulong o_offset,
|
|
212
|
+
const float scale,
|
|
213
|
+
const int n_q,
|
|
214
|
+
const int n_kv,
|
|
215
|
+
const int is_causal,
|
|
216
|
+
const int n_head,
|
|
217
|
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
|
218
|
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
|
219
|
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
|
220
|
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
|
221
|
+
const float max_bias,
|
|
222
|
+
const float m0,
|
|
223
|
+
const float m1,
|
|
224
|
+
const int n_head_log2,
|
|
225
|
+
const float logit_softcap,
|
|
226
|
+
const int n_head_kv,
|
|
227
|
+
const global void* mask_void,
|
|
228
|
+
const ulong mask_offset,
|
|
229
|
+
const ulong mask_nb1,
|
|
230
|
+
const ulong mask_nb2,
|
|
231
|
+
const ulong mask_nb3,
|
|
232
|
+
const int mask_ne2,
|
|
233
|
+
const int mask_ne3,
|
|
234
|
+
const global void* sinks_void,
|
|
235
|
+
const ulong sinks_offset
|
|
236
|
+
) {
|
|
237
|
+
const int tid = get_local_id(0);
|
|
238
|
+
const int head_batch_idx = get_global_id(1);
|
|
239
|
+
|
|
240
|
+
const int batch_idx = head_batch_idx / n_head;
|
|
241
|
+
const int head_idx = head_batch_idx % n_head;
|
|
242
|
+
|
|
243
|
+
const int gqa_ratio = n_head / n_head_kv;
|
|
244
|
+
const int head_kv_idx = head_idx / gqa_ratio;
|
|
245
|
+
|
|
246
|
+
const global char* q_base = (const global char*)q_void + q_offset;
|
|
247
|
+
const global char* k_base = (const global char*)k_void + k_offset;
|
|
248
|
+
const global char* v_base = (const global char*)v_void + v_offset;
|
|
249
|
+
global char* o_base = (global char*)o_void + o_offset;
|
|
250
|
+
|
|
251
|
+
const global char* mask_base = NULL;
|
|
252
|
+
if (mask_void != NULL) {
|
|
253
|
+
const int mask_head_idx = head_idx % mask_ne2;
|
|
254
|
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
|
255
|
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
ACC_TYPE4 q_priv[DK_VEC];
|
|
259
|
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
|
260
|
+
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
|
261
|
+
#pragma unroll
|
|
262
|
+
for (int i = 0; i < DK_VEC; ++i) {
|
|
263
|
+
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
|
267
|
+
|
|
268
|
+
const global ACC_TYPE* sinks_ptr = NULL;
|
|
269
|
+
if (sinks_void != NULL) {
|
|
270
|
+
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
|
274
|
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
|
275
|
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
|
276
|
+
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
|
277
|
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
|
278
|
+
#pragma unroll
|
|
279
|
+
for (int k = 0; k < DK_VEC; k++) {
|
|
280
|
+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
|
281
|
+
}
|
|
282
|
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
|
283
|
+
if (mask_base != NULL) {
|
|
284
|
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
|
285
|
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
|
286
|
+
}
|
|
287
|
+
if (logit_softcap > 0.0f) {
|
|
288
|
+
score = logit_softcap * tanh(score / logit_softcap);
|
|
289
|
+
}
|
|
290
|
+
m_i = max(m_i, score);
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
|
294
|
+
local_m[tid] = m_i;
|
|
295
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
296
|
+
#pragma unroll
|
|
297
|
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
|
298
|
+
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
|
299
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
300
|
+
}
|
|
301
|
+
const ACC_TYPE m_final = local_m[0];
|
|
302
|
+
|
|
303
|
+
ACC_TYPE4 o_acc[DV_VEC];
|
|
304
|
+
#pragma unroll
|
|
305
|
+
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
|
306
|
+
ACC_TYPE l_i = 0.0f;
|
|
307
|
+
|
|
308
|
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
|
309
|
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
|
310
|
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
|
311
|
+
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
|
312
|
+
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
|
313
|
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
|
314
|
+
#pragma unroll
|
|
315
|
+
for (int k = 0; k < DK_VEC; k++) {
|
|
316
|
+
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
|
317
|
+
}
|
|
318
|
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
|
319
|
+
if (mask_base != NULL) {
|
|
320
|
+
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
|
|
321
|
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
|
322
|
+
}
|
|
323
|
+
if (logit_softcap > 0.0f) {
|
|
324
|
+
score = logit_softcap * tanh(score / logit_softcap);
|
|
325
|
+
}
|
|
326
|
+
const ACC_TYPE p = exp(score - m_final);
|
|
327
|
+
l_i += p;
|
|
328
|
+
#pragma unroll
|
|
329
|
+
for (int i = 0; i < DV_VEC; i++) {
|
|
330
|
+
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
|
335
|
+
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
|
336
|
+
local_l[tid] = l_i;
|
|
337
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
338
|
+
#pragma unroll
|
|
339
|
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
|
340
|
+
if (tid < s) local_l[tid] += local_l[tid + s];
|
|
341
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
|
345
|
+
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
|
346
|
+
ACC_TYPE l_final = local_l[0];
|
|
347
|
+
|
|
348
|
+
if (sinks_ptr != NULL) {
|
|
349
|
+
l_final += exp(sinks_ptr[head_idx] - m_final);
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
if (l_final > 0.0f) {
|
|
353
|
+
const ACC_TYPE l_inv = 1.0f / l_final;
|
|
354
|
+
for (int i = 0; i < DV_VEC; i++) {
|
|
355
|
+
local_o_comp[tid] = o_acc[i];
|
|
356
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
357
|
+
#pragma unroll
|
|
358
|
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
|
359
|
+
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
|
360
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
361
|
+
}
|
|
362
|
+
if (tid == 0) {
|
|
363
|
+
o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
} else if (tid == 0) {
|
|
367
|
+
#pragma unroll
|
|
368
|
+
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
|
369
|
+
}
|
|
370
|
+
}
|