local-llm-rn 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/CMakeLists.txt +285 -0
- package/cpp/common/CMakeLists.txt +149 -0
- package/cpp/common/arg.cpp +3799 -0
- package/cpp/common/arg.h +131 -0
- package/cpp/common/base64.hpp +392 -0
- package/cpp/common/build-info.cpp.in +4 -0
- package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
- package/cpp/common/chat-parser-xml-toolcall.h +45 -0
- package/cpp/common/chat-parser.cpp +1649 -0
- package/cpp/common/chat-parser.h +133 -0
- package/cpp/common/chat-peg-parser.cpp +124 -0
- package/cpp/common/chat-peg-parser.h +105 -0
- package/cpp/common/chat.cpp +3355 -0
- package/cpp/common/chat.h +252 -0
- package/cpp/common/common.cpp +1824 -0
- package/cpp/common/common.h +930 -0
- package/cpp/common/console.cpp +1137 -0
- package/cpp/common/console.h +41 -0
- package/cpp/common/debug.cpp +167 -0
- package/cpp/common/debug.h +43 -0
- package/cpp/common/download.cpp +792 -0
- package/cpp/common/download.h +84 -0
- package/cpp/common/http.h +84 -0
- package/cpp/common/jinja/README.md +88 -0
- package/cpp/common/jinja/caps.cpp +285 -0
- package/cpp/common/jinja/caps.h +30 -0
- package/cpp/common/jinja/lexer.cpp +341 -0
- package/cpp/common/jinja/lexer.h +157 -0
- package/cpp/common/jinja/parser.cpp +591 -0
- package/cpp/common/jinja/parser.h +21 -0
- package/cpp/common/jinja/runtime.cpp +867 -0
- package/cpp/common/jinja/runtime.h +638 -0
- package/cpp/common/jinja/string.cpp +213 -0
- package/cpp/common/jinja/string.h +61 -0
- package/cpp/common/jinja/utils.h +149 -0
- package/cpp/common/jinja/value.cpp +1393 -0
- package/cpp/common/jinja/value.h +756 -0
- package/cpp/common/json-partial.cpp +324 -0
- package/cpp/common/json-partial.h +39 -0
- package/cpp/common/json-schema-to-grammar.cpp +1153 -0
- package/cpp/common/json-schema-to-grammar.h +43 -0
- package/cpp/common/llguidance.cpp +258 -0
- package/cpp/common/log.cpp +446 -0
- package/cpp/common/log.h +119 -0
- package/cpp/common/ngram-cache.cpp +285 -0
- package/cpp/common/ngram-cache.h +101 -0
- package/cpp/common/ngram-map.cpp +530 -0
- package/cpp/common/ngram-map.h +115 -0
- package/cpp/common/ngram-mod.cpp +60 -0
- package/cpp/common/ngram-mod.h +38 -0
- package/cpp/common/peg-parser.cpp +1712 -0
- package/cpp/common/peg-parser.h +459 -0
- package/cpp/common/preset.cpp +483 -0
- package/cpp/common/preset.h +83 -0
- package/cpp/common/regex-partial.cpp +204 -0
- package/cpp/common/regex-partial.h +56 -0
- package/cpp/common/sampling.cpp +745 -0
- package/cpp/common/sampling.h +119 -0
- package/cpp/common/speculative.cpp +1074 -0
- package/cpp/common/speculative.h +41 -0
- package/cpp/common/unicode.cpp +64 -0
- package/cpp/common/unicode.h +22 -0
- package/cpp/ggml/CMakeLists.txt +494 -0
- package/cpp/ggml/cmake/GitVars.cmake +22 -0
- package/cpp/ggml/cmake/common.cmake +50 -0
- package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
- package/cpp/ggml/include/ggml-alloc.h +85 -0
- package/cpp/ggml/include/ggml-backend.h +373 -0
- package/cpp/ggml/include/ggml-blas.h +25 -0
- package/cpp/ggml/include/ggml-cann.h +123 -0
- package/cpp/ggml/include/ggml-cpp.h +39 -0
- package/cpp/ggml/include/ggml-cpu.h +151 -0
- package/cpp/ggml/include/ggml-cuda.h +47 -0
- package/cpp/ggml/include/ggml-hexagon.h +19 -0
- package/cpp/ggml/include/ggml-metal.h +61 -0
- package/cpp/ggml/include/ggml-opencl.h +26 -0
- package/cpp/ggml/include/ggml-opt.h +256 -0
- package/cpp/ggml/include/ggml-rpc.h +30 -0
- package/cpp/ggml/include/ggml-sycl.h +49 -0
- package/cpp/ggml/include/ggml-virtgpu.h +14 -0
- package/cpp/ggml/include/ggml-vulkan.h +29 -0
- package/cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/ggml/include/ggml-zdnn.h +17 -0
- package/cpp/ggml/include/ggml-zendnn.h +22 -0
- package/cpp/ggml/include/ggml.h +2753 -0
- package/cpp/ggml/include/gguf.h +204 -0
- package/cpp/ggml/src/CMakeLists.txt +492 -0
- package/cpp/ggml/src/ggml-alloc.c +1244 -0
- package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
- package/cpp/ggml/src/ggml-backend-dl.h +45 -0
- package/cpp/ggml/src/ggml-backend-impl.h +255 -0
- package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
- package/cpp/ggml/src/ggml-backend.cpp +2270 -0
- package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
- package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
- package/cpp/ggml/src/ggml-common.h +1878 -0
- package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
- package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
- package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
- package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
- package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
- package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
- package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
- package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- package/cpp/ggml/src/ggml-cpu/common.h +95 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
- package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
- package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
- package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
- package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
- package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
- package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
- package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
- package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
- package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
- package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
- package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
- package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
- package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
- package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
- package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
- package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
- package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
- package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
- package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
- package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
- package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
- package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
- package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
- package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
- package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
- package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
- package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
- package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
- package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
- package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
- package/cpp/ggml/src/ggml-impl.h +724 -0
- package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
- package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
- package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
- package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
- package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
- package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
- package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
- package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
- package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
- package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
- package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
- package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
- package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
- package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
- package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
- package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- package/cpp/ggml/src/ggml-opt.cpp +1093 -0
- package/cpp/ggml/src/ggml-quants.c +5325 -0
- package/cpp/ggml/src/ggml-quants.h +106 -0
- package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
- package/cpp/ggml/src/ggml-threading.cpp +12 -0
- package/cpp/ggml/src/ggml-threading.h +14 -0
- package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
- package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
- package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
- package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
- package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
- package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
- package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- package/cpp/ggml/src/ggml.c +7669 -0
- package/cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/ggml/src/gguf.cpp +1699 -0
- package/cpp/include/llama-cpp.h +32 -0
- package/cpp/include/llama.h +1568 -0
- package/cpp/mtmd/CMakeLists.txt +98 -0
- package/cpp/mtmd/README.md +63 -0
- package/cpp/mtmd/clip-graph.h +117 -0
- package/cpp/mtmd/clip-impl.h +586 -0
- package/cpp/mtmd/clip-model.h +390 -0
- package/cpp/mtmd/clip.cpp +4154 -0
- package/cpp/mtmd/clip.h +121 -0
- package/cpp/mtmd/deprecation-warning.cpp +22 -0
- package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
- package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
- package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
- package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
- package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
- package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
- package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
- package/cpp/mtmd/models/cogvlm.cpp +98 -0
- package/cpp/mtmd/models/conformer.cpp +216 -0
- package/cpp/mtmd/models/glm4v.cpp +122 -0
- package/cpp/mtmd/models/internvl.cpp +69 -0
- package/cpp/mtmd/models/kimik25.cpp +101 -0
- package/cpp/mtmd/models/kimivl.cpp +63 -0
- package/cpp/mtmd/models/llama4.cpp +96 -0
- package/cpp/mtmd/models/llava.cpp +374 -0
- package/cpp/mtmd/models/minicpmv.cpp +114 -0
- package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
- package/cpp/mtmd/models/models.h +128 -0
- package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
- package/cpp/mtmd/models/paddleocr.cpp +52 -0
- package/cpp/mtmd/models/pixtral.cpp +86 -0
- package/cpp/mtmd/models/qwen2vl.cpp +183 -0
- package/cpp/mtmd/models/qwen3vl.cpp +193 -0
- package/cpp/mtmd/models/siglip.cpp +86 -0
- package/cpp/mtmd/models/whisper-enc.cpp +115 -0
- package/cpp/mtmd/models/youtuvl.cpp +179 -0
- package/cpp/mtmd/mtmd-audio.cpp +730 -0
- package/cpp/mtmd/mtmd-audio.h +113 -0
- package/cpp/mtmd/mtmd-cli.cpp +437 -0
- package/cpp/mtmd/mtmd-helper.cpp +521 -0
- package/cpp/mtmd/mtmd-helper.h +96 -0
- package/cpp/mtmd/mtmd.cpp +1156 -0
- package/cpp/mtmd/mtmd.h +319 -0
- package/cpp/mtmd/requirements.txt +5 -0
- package/cpp/mtmd/test-1.jpeg +0 -0
- package/cpp/mtmd/test-2.mp3 +0 -0
- package/cpp/mtmd/tests.sh +192 -0
- package/cpp/src/CMakeLists.txt +169 -0
- package/cpp/src/llama-adapter.cpp +488 -0
- package/cpp/src/llama-adapter.h +89 -0
- package/cpp/src/llama-arch.cpp +2855 -0
- package/cpp/src/llama-arch.h +619 -0
- package/cpp/src/llama-batch.cpp +917 -0
- package/cpp/src/llama-batch.h +173 -0
- package/cpp/src/llama-chat.cpp +896 -0
- package/cpp/src/llama-chat.h +71 -0
- package/cpp/src/llama-context.cpp +3512 -0
- package/cpp/src/llama-context.h +359 -0
- package/cpp/src/llama-cparams.cpp +5 -0
- package/cpp/src/llama-cparams.h +44 -0
- package/cpp/src/llama-grammar.cpp +1464 -0
- package/cpp/src/llama-grammar.h +194 -0
- package/cpp/src/llama-graph.cpp +2685 -0
- package/cpp/src/llama-graph.h +1026 -0
- package/cpp/src/llama-hparams.cpp +234 -0
- package/cpp/src/llama-hparams.h +339 -0
- package/cpp/src/llama-impl.cpp +171 -0
- package/cpp/src/llama-impl.h +73 -0
- package/cpp/src/llama-io.cpp +15 -0
- package/cpp/src/llama-io.h +35 -0
- package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
- package/cpp/src/llama-kv-cache-iswa.h +137 -0
- package/cpp/src/llama-kv-cache.cpp +2271 -0
- package/cpp/src/llama-kv-cache.h +388 -0
- package/cpp/src/llama-kv-cells.h +533 -0
- package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
- package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
- package/cpp/src/llama-memory-hybrid.cpp +268 -0
- package/cpp/src/llama-memory-hybrid.h +139 -0
- package/cpp/src/llama-memory-recurrent.cpp +1165 -0
- package/cpp/src/llama-memory-recurrent.h +182 -0
- package/cpp/src/llama-memory.cpp +59 -0
- package/cpp/src/llama-memory.h +122 -0
- package/cpp/src/llama-mmap.cpp +785 -0
- package/cpp/src/llama-mmap.h +92 -0
- package/cpp/src/llama-model-loader.cpp +1414 -0
- package/cpp/src/llama-model-loader.h +203 -0
- package/cpp/src/llama-model-saver.cpp +286 -0
- package/cpp/src/llama-model-saver.h +37 -0
- package/cpp/src/llama-model.cpp +9253 -0
- package/cpp/src/llama-model.h +576 -0
- package/cpp/src/llama-quant.cpp +1119 -0
- package/cpp/src/llama-quant.h +1 -0
- package/cpp/src/llama-sampler.cpp +3885 -0
- package/cpp/src/llama-sampler.h +42 -0
- package/cpp/src/llama-vocab.cpp +3970 -0
- package/cpp/src/llama-vocab.h +187 -0
- package/cpp/src/llama.cpp +1313 -0
- package/cpp/src/models/afmoe.cpp +191 -0
- package/cpp/src/models/apertus.cpp +125 -0
- package/cpp/src/models/arcee.cpp +135 -0
- package/cpp/src/models/arctic.cpp +138 -0
- package/cpp/src/models/arwkv7.cpp +86 -0
- package/cpp/src/models/baichuan.cpp +122 -0
- package/cpp/src/models/bailingmoe.cpp +144 -0
- package/cpp/src/models/bailingmoe2.cpp +135 -0
- package/cpp/src/models/bert.cpp +178 -0
- package/cpp/src/models/bitnet.cpp +160 -0
- package/cpp/src/models/bloom.cpp +101 -0
- package/cpp/src/models/chameleon.cpp +178 -0
- package/cpp/src/models/chatglm.cpp +132 -0
- package/cpp/src/models/codeshell.cpp +111 -0
- package/cpp/src/models/cogvlm.cpp +102 -0
- package/cpp/src/models/cohere2-iswa.cpp +134 -0
- package/cpp/src/models/command-r.cpp +122 -0
- package/cpp/src/models/dbrx.cpp +123 -0
- package/cpp/src/models/deci.cpp +135 -0
- package/cpp/src/models/deepseek.cpp +144 -0
- package/cpp/src/models/deepseek2.cpp +262 -0
- package/cpp/src/models/delta-net-base.cpp +376 -0
- package/cpp/src/models/dots1.cpp +134 -0
- package/cpp/src/models/dream.cpp +105 -0
- package/cpp/src/models/ernie4-5-moe.cpp +150 -0
- package/cpp/src/models/ernie4-5.cpp +110 -0
- package/cpp/src/models/eurobert.cpp +97 -0
- package/cpp/src/models/exaone-moe.cpp +146 -0
- package/cpp/src/models/exaone.cpp +114 -0
- package/cpp/src/models/exaone4.cpp +123 -0
- package/cpp/src/models/falcon-h1.cpp +111 -0
- package/cpp/src/models/falcon.cpp +120 -0
- package/cpp/src/models/gemma-embedding.cpp +116 -0
- package/cpp/src/models/gemma.cpp +112 -0
- package/cpp/src/models/gemma2-iswa.cpp +128 -0
- package/cpp/src/models/gemma3.cpp +155 -0
- package/cpp/src/models/gemma3n-iswa.cpp +384 -0
- package/cpp/src/models/glm4-moe.cpp +170 -0
- package/cpp/src/models/glm4.cpp +157 -0
- package/cpp/src/models/gpt2.cpp +105 -0
- package/cpp/src/models/gptneox.cpp +144 -0
- package/cpp/src/models/granite-hybrid.cpp +196 -0
- package/cpp/src/models/granite.cpp +211 -0
- package/cpp/src/models/grok.cpp +159 -0
- package/cpp/src/models/grovemoe.cpp +141 -0
- package/cpp/src/models/hunyuan-dense.cpp +132 -0
- package/cpp/src/models/hunyuan-moe.cpp +154 -0
- package/cpp/src/models/internlm2.cpp +120 -0
- package/cpp/src/models/jais.cpp +86 -0
- package/cpp/src/models/jais2.cpp +123 -0
- package/cpp/src/models/jamba.cpp +106 -0
- package/cpp/src/models/kimi-linear.cpp +392 -0
- package/cpp/src/models/lfm2.cpp +190 -0
- package/cpp/src/models/llada-moe.cpp +122 -0
- package/cpp/src/models/llada.cpp +99 -0
- package/cpp/src/models/llama-iswa.cpp +178 -0
- package/cpp/src/models/llama.cpp +168 -0
- package/cpp/src/models/maincoder.cpp +117 -0
- package/cpp/src/models/mamba-base.cpp +285 -0
- package/cpp/src/models/mamba.cpp +54 -0
- package/cpp/src/models/mimo2-iswa.cpp +123 -0
- package/cpp/src/models/minicpm3.cpp +200 -0
- package/cpp/src/models/minimax-m2.cpp +124 -0
- package/cpp/src/models/mistral3.cpp +160 -0
- package/cpp/src/models/models.h +684 -0
- package/cpp/src/models/modern-bert.cpp +109 -0
- package/cpp/src/models/mpt.cpp +126 -0
- package/cpp/src/models/nemotron-h.cpp +148 -0
- package/cpp/src/models/nemotron.cpp +122 -0
- package/cpp/src/models/neo-bert.cpp +104 -0
- package/cpp/src/models/olmo.cpp +121 -0
- package/cpp/src/models/olmo2.cpp +150 -0
- package/cpp/src/models/olmoe.cpp +124 -0
- package/cpp/src/models/openai-moe-iswa.cpp +127 -0
- package/cpp/src/models/openelm.cpp +124 -0
- package/cpp/src/models/orion.cpp +123 -0
- package/cpp/src/models/paddleocr.cpp +122 -0
- package/cpp/src/models/pangu-embedded.cpp +121 -0
- package/cpp/src/models/phi2.cpp +121 -0
- package/cpp/src/models/phi3.cpp +152 -0
- package/cpp/src/models/plamo.cpp +110 -0
- package/cpp/src/models/plamo2.cpp +318 -0
- package/cpp/src/models/plamo3.cpp +128 -0
- package/cpp/src/models/plm.cpp +169 -0
- package/cpp/src/models/qwen.cpp +108 -0
- package/cpp/src/models/qwen2.cpp +126 -0
- package/cpp/src/models/qwen2moe.cpp +151 -0
- package/cpp/src/models/qwen2vl.cpp +117 -0
- package/cpp/src/models/qwen3.cpp +117 -0
- package/cpp/src/models/qwen35.cpp +386 -0
- package/cpp/src/models/qwen35moe.cpp +420 -0
- package/cpp/src/models/qwen3moe.cpp +124 -0
- package/cpp/src/models/qwen3next.cpp +525 -0
- package/cpp/src/models/qwen3vl-moe.cpp +140 -0
- package/cpp/src/models/qwen3vl.cpp +132 -0
- package/cpp/src/models/refact.cpp +94 -0
- package/cpp/src/models/rnd1.cpp +126 -0
- package/cpp/src/models/rwkv6-base.cpp +164 -0
- package/cpp/src/models/rwkv6.cpp +94 -0
- package/cpp/src/models/rwkv6qwen2.cpp +86 -0
- package/cpp/src/models/rwkv7-base.cpp +137 -0
- package/cpp/src/models/rwkv7.cpp +90 -0
- package/cpp/src/models/seed-oss.cpp +124 -0
- package/cpp/src/models/smallthinker.cpp +126 -0
- package/cpp/src/models/smollm3.cpp +128 -0
- package/cpp/src/models/stablelm.cpp +146 -0
- package/cpp/src/models/starcoder.cpp +100 -0
- package/cpp/src/models/starcoder2.cpp +121 -0
- package/cpp/src/models/step35-iswa.cpp +168 -0
- package/cpp/src/models/t5-dec.cpp +166 -0
- package/cpp/src/models/t5-enc.cpp +96 -0
- package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
- package/cpp/src/models/xverse.cpp +108 -0
- package/cpp/src/unicode-data.cpp +7034 -0
- package/cpp/src/unicode-data.h +20 -0
- package/cpp/src/unicode.cpp +1103 -0
- package/cpp/src/unicode.h +111 -0
- package/cpp/vendor/nlohmann/json.hpp +25526 -0
- package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/vendor/stb/stb_image.h +7988 -0
- package/ios/LocalLLM-Bridging-Header.h +2 -0
- package/ios/LocalLLM.h +5 -0
- package/ios/LocalLLM.mm +1267 -0
- package/local-llm-rn.podspec +60 -0
- package/package.json +35 -0
- package/src/NativeLocalLLM.ts +73 -0
- package/src/device.ts +50 -0
- package/src/download-adapter.ts +17 -0
- package/src/index.ts +21 -0
- package/src/native-bridge.ts +142 -0
- package/src/rn-downloader.ts +37 -0
|
@@ -0,0 +1,666 @@
|
|
|
1
|
+
#pragma clang diagnostic ignored "-Wunused-variable"
|
|
2
|
+
#pragma clang diagnostic ignored "-Wunused-function"
|
|
3
|
+
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
|
4
|
+
|
|
5
|
+
#include <assert.h>
|
|
6
|
+
#include <HAP_farf.h>
|
|
7
|
+
#include <HAP_perf.h>
|
|
8
|
+
#include <math.h>
|
|
9
|
+
#include <string.h>
|
|
10
|
+
|
|
11
|
+
#include "hex-dma.h"
|
|
12
|
+
#include "hvx-utils.h"
|
|
13
|
+
|
|
14
|
+
#define GGML_COMMON_DECL_C
|
|
15
|
+
#include "ggml-common.h"
|
|
16
|
+
#include "htp-ctx.h"
|
|
17
|
+
#include "htp-msg.h"
|
|
18
|
+
#include "htp-ops.h"
|
|
19
|
+
|
|
20
|
+
// Dot product of two F16 vectors, accumulating to float
|
|
21
|
+
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
|
22
|
+
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
|
23
|
+
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
|
24
|
+
|
|
25
|
+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
26
|
+
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
27
|
+
|
|
28
|
+
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
|
29
|
+
|
|
30
|
+
uint32_t i = 0;
|
|
31
|
+
|
|
32
|
+
#pragma unroll(4)
|
|
33
|
+
for (i = 0; i < nvec; i++) {
|
|
34
|
+
HVX_Vector y_hf = vy[i];
|
|
35
|
+
HVX_Vector x_hf = vx[i];
|
|
36
|
+
|
|
37
|
+
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
|
38
|
+
|
|
39
|
+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if (nloe) {
|
|
43
|
+
// Load x (fp16) and zero-out unused elements
|
|
44
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
45
|
+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
|
46
|
+
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
|
47
|
+
|
|
48
|
+
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
|
49
|
+
|
|
50
|
+
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
|
54
|
+
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
|
58
|
+
const void * restrict y,
|
|
59
|
+
const void * restrict x0,
|
|
60
|
+
const void * restrict x1,
|
|
61
|
+
unsigned int n,
|
|
62
|
+
float s) {
|
|
63
|
+
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
|
|
64
|
+
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
|
65
|
+
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
|
66
|
+
|
|
67
|
+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
68
|
+
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
69
|
+
|
|
70
|
+
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
|
71
|
+
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
|
72
|
+
|
|
73
|
+
uint32_t i = 0;
|
|
74
|
+
|
|
75
|
+
#pragma unroll(4)
|
|
76
|
+
for (i = 0; i < nvec; i++) {
|
|
77
|
+
HVX_Vector y_hf = vy[i];
|
|
78
|
+
HVX_Vector x0_hf = vx0[i];
|
|
79
|
+
HVX_Vector x1_hf = vx1[i];
|
|
80
|
+
|
|
81
|
+
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
|
82
|
+
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
|
83
|
+
|
|
84
|
+
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
|
85
|
+
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
if (nloe) {
|
|
89
|
+
// Load x (fp16) and zero-out unused elements
|
|
90
|
+
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
|
91
|
+
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
|
92
|
+
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
|
93
|
+
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
|
94
|
+
|
|
95
|
+
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
|
96
|
+
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
|
97
|
+
|
|
98
|
+
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
|
99
|
+
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
|
|
103
|
+
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// MAD: y (F32) += x (F16) * s (F32)
|
|
107
|
+
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
|
108
|
+
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
|
109
|
+
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
|
110
|
+
|
|
111
|
+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
112
|
+
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
113
|
+
|
|
114
|
+
HVX_Vector S = hvx_vec_splat_f16(s);
|
|
115
|
+
|
|
116
|
+
uint32_t i = 0;
|
|
117
|
+
#pragma unroll(4)
|
|
118
|
+
for (i = 0; i < nvec; ++i) {
|
|
119
|
+
// Multiply x * s -> pair of F32 vectors
|
|
120
|
+
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
|
121
|
+
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
|
|
122
|
+
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
if (nloe) {
|
|
126
|
+
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
|
|
127
|
+
|
|
128
|
+
HVX_Vector xs = Q6_V_lo_W(xs_p);
|
|
129
|
+
i = 2 * i; // index for ptr_y
|
|
130
|
+
|
|
131
|
+
if (nloe >= 32) {
|
|
132
|
+
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
|
133
|
+
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
if (nloe) {
|
|
137
|
+
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
|
138
|
+
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
|
|
144
|
+
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
|
|
145
|
+
const void * restrict x0,
|
|
146
|
+
const void * restrict x1,
|
|
147
|
+
float s0,
|
|
148
|
+
float s1,
|
|
149
|
+
int n) {
|
|
150
|
+
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
|
|
151
|
+
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
|
|
152
|
+
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
|
153
|
+
|
|
154
|
+
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
|
155
|
+
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
|
156
|
+
|
|
157
|
+
HVX_Vector S0 = hvx_vec_splat_f16(s0);
|
|
158
|
+
HVX_Vector S1 = hvx_vec_splat_f16(s1);
|
|
159
|
+
|
|
160
|
+
uint32_t i = 0;
|
|
161
|
+
#pragma unroll(2)
|
|
162
|
+
for (i = 0; i < nvec; ++i) {
|
|
163
|
+
// Multiply x * s -> pair of F32 vectors
|
|
164
|
+
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
|
165
|
+
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
|
166
|
+
|
|
167
|
+
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
|
168
|
+
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
|
169
|
+
|
|
170
|
+
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
|
|
171
|
+
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
if (nloe) {
|
|
175
|
+
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
|
176
|
+
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
|
177
|
+
|
|
178
|
+
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
|
179
|
+
HVX_Vector xs = xs_p_lo;
|
|
180
|
+
i = 2 * i; // index for ptr_y
|
|
181
|
+
|
|
182
|
+
if (nloe >= 32) {
|
|
183
|
+
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
|
184
|
+
nloe -= 32; ++i;
|
|
185
|
+
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if (nloe) {
|
|
189
|
+
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
|
190
|
+
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
#define FLASH_ATTN_BLOCK_SIZE 128
|
|
196
|
+
|
|
197
|
+
struct htp_fa_context {
|
|
198
|
+
const struct htp_ops_context * octx;
|
|
199
|
+
|
|
200
|
+
struct fastdiv_values src0_div21;
|
|
201
|
+
struct fastdiv_values src0_div1;
|
|
202
|
+
|
|
203
|
+
struct fastdiv_values broadcast_rk2;
|
|
204
|
+
struct fastdiv_values broadcast_rk3;
|
|
205
|
+
struct fastdiv_values broadcast_rv2;
|
|
206
|
+
struct fastdiv_values broadcast_rv3;
|
|
207
|
+
|
|
208
|
+
struct fastdiv_values src3_div2;
|
|
209
|
+
struct fastdiv_values src3_div3;
|
|
210
|
+
|
|
211
|
+
float scale;
|
|
212
|
+
float max_bias;
|
|
213
|
+
float logit_softcap;
|
|
214
|
+
|
|
215
|
+
uint32_t n_head_log2;
|
|
216
|
+
float m0;
|
|
217
|
+
float m1;
|
|
218
|
+
|
|
219
|
+
uint32_t n_blocks;
|
|
220
|
+
|
|
221
|
+
size_t size_q_row_padded;
|
|
222
|
+
size_t size_k_row_padded;
|
|
223
|
+
size_t size_v_row_padded;
|
|
224
|
+
|
|
225
|
+
size_t size_k_block;
|
|
226
|
+
size_t size_v_block;
|
|
227
|
+
size_t size_m_block;
|
|
228
|
+
|
|
229
|
+
bool is_q_fp32;
|
|
230
|
+
};
|
|
231
|
+
|
|
232
|
+
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
|
|
233
|
+
assert((size_t) dst % 128 == 0);
|
|
234
|
+
assert((size_t) src % 128 == 0);
|
|
235
|
+
|
|
236
|
+
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
|
237
|
+
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
|
238
|
+
|
|
239
|
+
const uint32_t nvec = n / VLEN_FP32;
|
|
240
|
+
const uint32_t nloe = n % VLEN_FP32;
|
|
241
|
+
|
|
242
|
+
uint32_t i = 0;
|
|
243
|
+
#pragma unroll(4)
|
|
244
|
+
for (; i < nvec; ++i) {
|
|
245
|
+
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
|
246
|
+
}
|
|
247
|
+
if (nloe) {
|
|
248
|
+
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
|
249
|
+
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
|
254
|
+
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
|
255
|
+
const struct htp_ops_context * octx = factx->octx;
|
|
256
|
+
const struct htp_tensor * q = &octx->src0;
|
|
257
|
+
const struct htp_tensor * k = &octx->src1;
|
|
258
|
+
const struct htp_tensor * v = &octx->src2;
|
|
259
|
+
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
|
260
|
+
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
|
261
|
+
const struct htp_tensor * dst = &octx->dst;
|
|
262
|
+
|
|
263
|
+
const uint32_t neq0 = q->ne[0];
|
|
264
|
+
const uint32_t neq1 = q->ne[1];
|
|
265
|
+
const uint32_t neq2 = q->ne[2];
|
|
266
|
+
const uint32_t neq3 = q->ne[3];
|
|
267
|
+
|
|
268
|
+
const uint32_t nek0 = k->ne[0];
|
|
269
|
+
const uint32_t nek1 = k->ne[1];
|
|
270
|
+
const uint32_t nek2 = k->ne[2];
|
|
271
|
+
const uint32_t nek3 = k->ne[3];
|
|
272
|
+
|
|
273
|
+
const uint32_t nev0 = v->ne[0];
|
|
274
|
+
const uint32_t nev1 = v->ne[1];
|
|
275
|
+
const uint32_t nev2 = v->ne[2];
|
|
276
|
+
const uint32_t nev3 = v->ne[3];
|
|
277
|
+
|
|
278
|
+
const uint32_t nbq1 = q->nb[1];
|
|
279
|
+
const uint32_t nbq2 = q->nb[2];
|
|
280
|
+
const uint32_t nbq3 = q->nb[3];
|
|
281
|
+
|
|
282
|
+
const uint32_t nbk1 = k->nb[1];
|
|
283
|
+
const uint32_t nbk2 = k->nb[2];
|
|
284
|
+
const uint32_t nbk3 = k->nb[3];
|
|
285
|
+
|
|
286
|
+
const uint32_t nbv1 = v->nb[1];
|
|
287
|
+
const uint32_t nbv2 = v->nb[2];
|
|
288
|
+
const uint32_t nbv3 = v->nb[3];
|
|
289
|
+
|
|
290
|
+
const uint32_t ne1 = dst->ne[1];
|
|
291
|
+
const uint32_t ne2 = dst->ne[2];
|
|
292
|
+
const uint32_t ne3 = dst->ne[3];
|
|
293
|
+
|
|
294
|
+
const uint32_t nb1 = dst->nb[1];
|
|
295
|
+
const uint32_t nb2 = dst->nb[2];
|
|
296
|
+
const uint32_t nb3 = dst->nb[3];
|
|
297
|
+
|
|
298
|
+
// total rows in q
|
|
299
|
+
const uint32_t nr = neq1*neq2*neq3;
|
|
300
|
+
|
|
301
|
+
const uint32_t dr = (nr + nth - 1) / nth;
|
|
302
|
+
const uint32_t ir0 = dr * ith;
|
|
303
|
+
const uint32_t ir1 = MIN(ir0 + dr, nr);
|
|
304
|
+
|
|
305
|
+
if (ir0 >= ir1) return;
|
|
306
|
+
|
|
307
|
+
dma_queue * dma = octx->ctx->dma[ith];
|
|
308
|
+
|
|
309
|
+
const uint32_t DK = nek0;
|
|
310
|
+
const uint32_t DV = nev0;
|
|
311
|
+
|
|
312
|
+
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
|
313
|
+
const size_t size_k_row = DK * sizeof(__fp16);
|
|
314
|
+
const size_t size_v_row = DV * sizeof(__fp16);
|
|
315
|
+
|
|
316
|
+
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
|
317
|
+
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
|
318
|
+
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
|
|
319
|
+
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
|
|
320
|
+
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
|
321
|
+
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
|
322
|
+
|
|
323
|
+
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
|
|
324
|
+
|
|
325
|
+
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
|
326
|
+
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
|
|
327
|
+
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
|
|
328
|
+
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
|
329
|
+
|
|
330
|
+
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
|
|
331
|
+
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
|
|
332
|
+
|
|
333
|
+
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
|
|
334
|
+
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
|
|
335
|
+
|
|
336
|
+
// Fetch Q row
|
|
337
|
+
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
|
338
|
+
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
|
|
339
|
+
|
|
340
|
+
const uint32_t h = iq2; // head index
|
|
341
|
+
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
|
|
342
|
+
|
|
343
|
+
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
|
|
344
|
+
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
|
|
345
|
+
|
|
346
|
+
// Clear accumulator
|
|
347
|
+
hvx_splat_f32_a(spad_a, 0, DV);
|
|
348
|
+
float * VKQ32 = (float *) spad_a;
|
|
349
|
+
|
|
350
|
+
const __fp16 * mp_base = NULL;
|
|
351
|
+
if (mask) {
|
|
352
|
+
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
|
|
353
|
+
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
|
|
354
|
+
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
// Prefetch first two blocks
|
|
358
|
+
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
|
|
359
|
+
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
|
360
|
+
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
|
361
|
+
|
|
362
|
+
// K
|
|
363
|
+
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
|
364
|
+
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
|
|
365
|
+
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
|
|
366
|
+
|
|
367
|
+
// V
|
|
368
|
+
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
|
369
|
+
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
|
|
370
|
+
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
|
|
371
|
+
|
|
372
|
+
// Mask
|
|
373
|
+
if (mask) {
|
|
374
|
+
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
|
375
|
+
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
|
|
376
|
+
// Mask is 1D contiguous for this row
|
|
377
|
+
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
|
382
|
+
if (factx->is_q_fp32) {
|
|
383
|
+
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
|
|
387
|
+
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
|
|
388
|
+
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
|
389
|
+
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
|
390
|
+
|
|
391
|
+
// Wait for DMA
|
|
392
|
+
uint8_t * k_base = dma_queue_pop(dma).dst; // K
|
|
393
|
+
uint8_t * v_base = dma_queue_pop(dma).dst; // V
|
|
394
|
+
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
|
|
395
|
+
|
|
396
|
+
// Inner loop processing the block from VTCM
|
|
397
|
+
uint32_t ic = 0;
|
|
398
|
+
|
|
399
|
+
// Process in blocks of 32 (VLEN_FP32)
|
|
400
|
+
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
|
401
|
+
HVX_Vector_x4 scores_x4;
|
|
402
|
+
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
|
403
|
+
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
|
404
|
+
// 1. Compute scores
|
|
405
|
+
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
|
406
|
+
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
|
407
|
+
const uint32_t cur_ic = ic + j;
|
|
408
|
+
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
|
|
409
|
+
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
|
413
|
+
|
|
414
|
+
// 2. Softcap
|
|
415
|
+
if (factx->logit_softcap != 0.0f) {
|
|
416
|
+
scores = hvx_vec_tanh_f32(scores);
|
|
417
|
+
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
|
418
|
+
scores = Q6_Vsf_equals_Vqf32(scores);
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
// 3. Mask
|
|
422
|
+
if (mask) {
|
|
423
|
+
const __fp16 * mp = m_base + ic;
|
|
424
|
+
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
|
425
|
+
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
|
426
|
+
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
|
427
|
+
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
|
428
|
+
scores = Q6_Vsf_equals_Vqf32(scores);
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
scores_x4.v[iv] = scores;
|
|
432
|
+
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
{
|
|
436
|
+
// 4. Online Softmax Update
|
|
437
|
+
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
|
438
|
+
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
|
|
439
|
+
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
|
|
440
|
+
M_vec = M_new_vec;
|
|
441
|
+
|
|
442
|
+
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
|
443
|
+
|
|
444
|
+
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
|
445
|
+
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
|
446
|
+
HVX_Vector scores = scores_x4.v[iv];
|
|
447
|
+
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
|
448
|
+
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
|
449
|
+
|
|
450
|
+
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
|
451
|
+
|
|
452
|
+
// 5. Accumulate V
|
|
453
|
+
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
|
454
|
+
*(HVX_Vector *) p_arr = P;
|
|
455
|
+
|
|
456
|
+
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
|
457
|
+
const uint32_t cur_ic = ic2 + j;
|
|
458
|
+
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
|
459
|
+
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
|
|
460
|
+
}
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
|
464
|
+
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// Sync scalars for leftover/next block if needed
|
|
468
|
+
float M = hvx_vec_get_f32(M_vec);
|
|
469
|
+
float S = hvx_vec_get_f32(S_vec);
|
|
470
|
+
|
|
471
|
+
// Leftover
|
|
472
|
+
for (; ic < current_block_size; ++ic) {
|
|
473
|
+
float s_val;
|
|
474
|
+
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
|
475
|
+
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
|
476
|
+
if (factx->logit_softcap != 0.0f) {
|
|
477
|
+
s_val = factx->logit_softcap * tanhf(s_val);
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
if (mask) {
|
|
481
|
+
const float m_val = m_base[ic];
|
|
482
|
+
s_val += slope * m_val;
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
const float Mold = M;
|
|
486
|
+
float vs = 1.0f;
|
|
487
|
+
|
|
488
|
+
if (s_val > M) {
|
|
489
|
+
M = s_val;
|
|
490
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
|
491
|
+
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
|
492
|
+
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
|
493
|
+
|
|
494
|
+
float ms = hvx_vec_get_f32(ms_vec);
|
|
495
|
+
S = S * ms + vs;
|
|
496
|
+
} else {
|
|
497
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
|
498
|
+
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
|
499
|
+
S += vs;
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
|
503
|
+
|
|
504
|
+
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
|
505
|
+
}
|
|
506
|
+
M_vec = hvx_vec_splat_f32(M);
|
|
507
|
+
S_vec = hvx_vec_splat_f32(S);
|
|
508
|
+
|
|
509
|
+
// Issue DMA for next+1 block (if exists)
|
|
510
|
+
if (ib + 2 < factx->n_blocks) {
|
|
511
|
+
const uint32_t next_ib = ib + 2;
|
|
512
|
+
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
|
513
|
+
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
|
514
|
+
|
|
515
|
+
// K
|
|
516
|
+
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
|
517
|
+
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
|
|
518
|
+
|
|
519
|
+
// V
|
|
520
|
+
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
|
521
|
+
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
|
|
522
|
+
|
|
523
|
+
// Mask
|
|
524
|
+
if (mask) {
|
|
525
|
+
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
|
|
526
|
+
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
// sinks
|
|
532
|
+
float M = hvx_vec_get_f32(M_vec);
|
|
533
|
+
float S = hvx_vec_get_f32(S_vec);
|
|
534
|
+
|
|
535
|
+
if (sinks) {
|
|
536
|
+
const float s = ((float *)((char *) sinks->data))[h];
|
|
537
|
+
|
|
538
|
+
float vs = 1.0f;
|
|
539
|
+
|
|
540
|
+
if (s > M) {
|
|
541
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
|
|
542
|
+
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
|
543
|
+
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
|
544
|
+
|
|
545
|
+
float ms = hvx_vec_get_f32(ms_vec);
|
|
546
|
+
S = S * ms + vs;
|
|
547
|
+
} else {
|
|
548
|
+
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
|
|
549
|
+
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
|
550
|
+
S += vs;
|
|
551
|
+
}
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
|
555
|
+
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
|
|
556
|
+
|
|
557
|
+
// Store result
|
|
558
|
+
// dst indices
|
|
559
|
+
const int i1 = iq1;
|
|
560
|
+
const int i2 = iq2;
|
|
561
|
+
const int i3 = iq3;
|
|
562
|
+
|
|
563
|
+
// dst is permuted
|
|
564
|
+
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
|
|
565
|
+
|
|
566
|
+
if (dst->type == HTP_TYPE_F32) {
|
|
567
|
+
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
|
568
|
+
} else if (dst->type == HTP_TYPE_F16) {
|
|
569
|
+
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
|
575
|
+
const struct htp_tensor * q = &octx->src0;
|
|
576
|
+
const struct htp_tensor * k = &octx->src1;
|
|
577
|
+
const struct htp_tensor * v = &octx->src2;
|
|
578
|
+
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
|
579
|
+
const struct htp_tensor * dst = &octx->dst;
|
|
580
|
+
|
|
581
|
+
// Check support
|
|
582
|
+
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
|
583
|
+
return HTP_STATUS_NO_SUPPORT;
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
struct htp_fa_context factx;
|
|
587
|
+
factx.octx = octx;
|
|
588
|
+
|
|
589
|
+
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
|
590
|
+
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
|
591
|
+
|
|
592
|
+
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
|
593
|
+
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
|
594
|
+
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
|
595
|
+
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
|
596
|
+
|
|
597
|
+
if (mask) {
|
|
598
|
+
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
|
|
599
|
+
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
|
603
|
+
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
|
|
604
|
+
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
|
605
|
+
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
|
606
|
+
|
|
607
|
+
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
|
|
608
|
+
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
609
|
+
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
|
610
|
+
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
|
611
|
+
|
|
612
|
+
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
|
613
|
+
|
|
614
|
+
float scale = 1.0f;
|
|
615
|
+
float max_bias = 0.0f;
|
|
616
|
+
float logit_softcap = 0.0f;
|
|
617
|
+
|
|
618
|
+
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
|
619
|
+
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
|
620
|
+
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
|
621
|
+
|
|
622
|
+
if (logit_softcap != 0.0f) {
|
|
623
|
+
scale /= logit_softcap;
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
factx.scale = scale;
|
|
627
|
+
factx.max_bias = max_bias;
|
|
628
|
+
factx.logit_softcap = logit_softcap;
|
|
629
|
+
|
|
630
|
+
uint32_t n_head = q->ne[2];
|
|
631
|
+
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
|
632
|
+
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
|
|
633
|
+
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
|
|
634
|
+
|
|
635
|
+
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
|
636
|
+
|
|
637
|
+
octx->src0_spad.size_per_thread = size_q_block * 1;
|
|
638
|
+
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
|
|
639
|
+
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
|
|
640
|
+
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
|
|
641
|
+
octx->dst_spad.size_per_thread = size_vkq_acc;
|
|
642
|
+
|
|
643
|
+
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
|
644
|
+
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
|
|
645
|
+
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
|
|
646
|
+
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
|
|
647
|
+
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
|
|
648
|
+
|
|
649
|
+
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
|
|
650
|
+
|
|
651
|
+
if (octx->ctx->vtcm_size < total_spad) {
|
|
652
|
+
return HTP_STATUS_VTCM_TOO_SMALL;
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
octx->src0_spad.data = octx->ctx->vtcm_base;
|
|
656
|
+
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
|
657
|
+
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
|
658
|
+
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
|
|
659
|
+
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
|
660
|
+
|
|
661
|
+
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
|
662
|
+
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
return HTP_STATUS_OK;
|
|
666
|
+
}
|