local-llm-rn 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/CMakeLists.txt +285 -0
- package/cpp/common/CMakeLists.txt +149 -0
- package/cpp/common/arg.cpp +3799 -0
- package/cpp/common/arg.h +131 -0
- package/cpp/common/base64.hpp +392 -0
- package/cpp/common/build-info.cpp.in +4 -0
- package/cpp/common/chat-parser-xml-toolcall.cpp +879 -0
- package/cpp/common/chat-parser-xml-toolcall.h +45 -0
- package/cpp/common/chat-parser.cpp +1649 -0
- package/cpp/common/chat-parser.h +133 -0
- package/cpp/common/chat-peg-parser.cpp +124 -0
- package/cpp/common/chat-peg-parser.h +105 -0
- package/cpp/common/chat.cpp +3355 -0
- package/cpp/common/chat.h +252 -0
- package/cpp/common/common.cpp +1824 -0
- package/cpp/common/common.h +930 -0
- package/cpp/common/console.cpp +1137 -0
- package/cpp/common/console.h +41 -0
- package/cpp/common/debug.cpp +167 -0
- package/cpp/common/debug.h +43 -0
- package/cpp/common/download.cpp +792 -0
- package/cpp/common/download.h +84 -0
- package/cpp/common/http.h +84 -0
- package/cpp/common/jinja/README.md +88 -0
- package/cpp/common/jinja/caps.cpp +285 -0
- package/cpp/common/jinja/caps.h +30 -0
- package/cpp/common/jinja/lexer.cpp +341 -0
- package/cpp/common/jinja/lexer.h +157 -0
- package/cpp/common/jinja/parser.cpp +591 -0
- package/cpp/common/jinja/parser.h +21 -0
- package/cpp/common/jinja/runtime.cpp +867 -0
- package/cpp/common/jinja/runtime.h +638 -0
- package/cpp/common/jinja/string.cpp +213 -0
- package/cpp/common/jinja/string.h +61 -0
- package/cpp/common/jinja/utils.h +149 -0
- package/cpp/common/jinja/value.cpp +1393 -0
- package/cpp/common/jinja/value.h +756 -0
- package/cpp/common/json-partial.cpp +324 -0
- package/cpp/common/json-partial.h +39 -0
- package/cpp/common/json-schema-to-grammar.cpp +1153 -0
- package/cpp/common/json-schema-to-grammar.h +43 -0
- package/cpp/common/llguidance.cpp +258 -0
- package/cpp/common/log.cpp +446 -0
- package/cpp/common/log.h +119 -0
- package/cpp/common/ngram-cache.cpp +285 -0
- package/cpp/common/ngram-cache.h +101 -0
- package/cpp/common/ngram-map.cpp +530 -0
- package/cpp/common/ngram-map.h +115 -0
- package/cpp/common/ngram-mod.cpp +60 -0
- package/cpp/common/ngram-mod.h +38 -0
- package/cpp/common/peg-parser.cpp +1712 -0
- package/cpp/common/peg-parser.h +459 -0
- package/cpp/common/preset.cpp +483 -0
- package/cpp/common/preset.h +83 -0
- package/cpp/common/regex-partial.cpp +204 -0
- package/cpp/common/regex-partial.h +56 -0
- package/cpp/common/sampling.cpp +745 -0
- package/cpp/common/sampling.h +119 -0
- package/cpp/common/speculative.cpp +1074 -0
- package/cpp/common/speculative.h +41 -0
- package/cpp/common/unicode.cpp +64 -0
- package/cpp/common/unicode.h +22 -0
- package/cpp/ggml/CMakeLists.txt +494 -0
- package/cpp/ggml/cmake/GitVars.cmake +22 -0
- package/cpp/ggml/cmake/common.cmake +50 -0
- package/cpp/ggml/cmake/ggml-config.cmake.in +191 -0
- package/cpp/ggml/include/ggml-alloc.h +85 -0
- package/cpp/ggml/include/ggml-backend.h +373 -0
- package/cpp/ggml/include/ggml-blas.h +25 -0
- package/cpp/ggml/include/ggml-cann.h +123 -0
- package/cpp/ggml/include/ggml-cpp.h +39 -0
- package/cpp/ggml/include/ggml-cpu.h +151 -0
- package/cpp/ggml/include/ggml-cuda.h +47 -0
- package/cpp/ggml/include/ggml-hexagon.h +19 -0
- package/cpp/ggml/include/ggml-metal.h +61 -0
- package/cpp/ggml/include/ggml-opencl.h +26 -0
- package/cpp/ggml/include/ggml-opt.h +256 -0
- package/cpp/ggml/include/ggml-rpc.h +30 -0
- package/cpp/ggml/include/ggml-sycl.h +49 -0
- package/cpp/ggml/include/ggml-virtgpu.h +14 -0
- package/cpp/ggml/include/ggml-vulkan.h +29 -0
- package/cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/ggml/include/ggml-zdnn.h +17 -0
- package/cpp/ggml/include/ggml-zendnn.h +22 -0
- package/cpp/ggml/include/ggml.h +2753 -0
- package/cpp/ggml/include/gguf.h +204 -0
- package/cpp/ggml/src/CMakeLists.txt +492 -0
- package/cpp/ggml/src/ggml-alloc.c +1244 -0
- package/cpp/ggml/src/ggml-backend-dl.cpp +48 -0
- package/cpp/ggml/src/ggml-backend-dl.h +45 -0
- package/cpp/ggml/src/ggml-backend-impl.h +255 -0
- package/cpp/ggml/src/ggml-backend-reg.cpp +566 -0
- package/cpp/ggml/src/ggml-backend.cpp +2270 -0
- package/cpp/ggml/src/ggml-blas/CMakeLists.txt +101 -0
- package/cpp/ggml/src/ggml-blas/ggml-blas.cpp +518 -0
- package/cpp/ggml/src/ggml-common.h +1878 -0
- package/cpp/ggml/src/ggml-cpu/CMakeLists.txt +691 -0
- package/cpp/ggml/src/ggml-cpu/amx/amx.cpp +247 -0
- package/cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2512 -0
- package/cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +98 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4052 -0
- package/cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +4935 -0
- package/cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2159 -0
- package/cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2726 -0
- package/cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
- package/cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
- package/cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
- package/cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
- package/cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
- package/cpp/ggml/src/ggml-cpu/arch-fallback.h +313 -0
- package/cpp/ggml/src/ggml-cpu/binary-ops.cpp +154 -0
- package/cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/cpp/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- package/cpp/ggml/src/ggml-cpu/common.h +95 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +529 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu.c +3734 -0
- package/cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +701 -0
- package/cpp/ggml/src/ggml-cpu/hbm.cpp +55 -0
- package/cpp/ggml/src/ggml-cpu/hbm.h +8 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +938 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +90 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +798 -0
- package/cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +4033 -0
- package/cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +25 -0
- package/cpp/ggml/src/ggml-cpu/ops.cpp +10978 -0
- package/cpp/ggml/src/ggml-cpu/ops.h +116 -0
- package/cpp/ggml/src/ggml-cpu/quants.c +1193 -0
- package/cpp/ggml/src/ggml-cpu/quants.h +97 -0
- package/cpp/ggml/src/ggml-cpu/repack.cpp +3316 -0
- package/cpp/ggml/src/ggml-cpu/repack.h +173 -0
- package/cpp/ggml/src/ggml-cpu/simd-gemm.h +136 -0
- package/cpp/ggml/src/ggml-cpu/simd-mappings.h +1279 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- package/cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- package/cpp/ggml/src/ggml-cpu/traits.cpp +36 -0
- package/cpp/ggml/src/ggml-cpu/traits.h +38 -0
- package/cpp/ggml/src/ggml-cpu/unary-ops.cpp +337 -0
- package/cpp/ggml/src/ggml-cpu/unary-ops.h +35 -0
- package/cpp/ggml/src/ggml-cpu/vec.cpp +629 -0
- package/cpp/ggml/src/ggml-cpu/vec.h +1585 -0
- package/cpp/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
- package/cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3232 -0
- package/cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -0
- package/cpp/ggml/src/ggml-hexagon/htp/act-ops.c +815 -0
- package/cpp/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
- package/cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +827 -0
- package/cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
- package/cpp/ggml/src/ggml-hexagon/htp/cpy-ops.c +251 -0
- package/cpp/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +666 -0
- package/cpp/ggml/src/ggml-hexagon/htp/get-rows-ops.c +111 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +154 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +65 -0
- package/cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-arith.h +470 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-base.h +173 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-div.h +116 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.h +176 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-reduce.h +266 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
- package/cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -0
- package/cpp/ggml/src/ggml-hexagon/htp/main.c +1150 -0
- package/cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2595 -0
- package/cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +498 -0
- package/cpp/ggml/src/ggml-hexagon/htp/set-rows-ops.c +167 -0
- package/cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +421 -0
- package/cpp/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +130 -0
- package/cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +384 -0
- package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
- package/cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
- package/cpp/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
- package/cpp/ggml/src/ggml-hexagon/htp-drv.h +121 -0
- package/cpp/ggml/src/ggml-hexagon/libdl.h +79 -0
- package/cpp/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
- package/cpp/ggml/src/ggml-hexagon/op-desc.h +153 -0
- package/cpp/ggml/src/ggml-impl.h +724 -0
- package/cpp/ggml/src/ggml-metal/CMakeLists.txt +124 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +457 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-context.h +41 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-context.m +702 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1890 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.h +290 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-device.m +1749 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-impl.h +1054 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +4370 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal.cpp +937 -0
- package/cpp/ggml/src/ggml-metal/ggml-metal.metal +9819 -0
- package/cpp/ggml/src/ggml-musa/CMakeLists.txt +125 -0
- package/cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
- package/cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
- package/cpp/ggml/src/ggml-opencl/CMakeLists.txt +150 -0
- package/cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +11553 -0
- package/cpp/ggml/src/ggml-opencl/kernels/add.cl +190 -0
- package/cpp/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
- package/cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- package/cpp/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
- package/cpp/ggml/src/ggml-opencl/kernels/concat.cl +51 -0
- package/cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
- package/cpp/ggml/src/ggml-opencl/kernels/cvt.cl +417 -0
- package/cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
- package/cpp/ggml/src/ggml-opencl/kernels/div.cl +138 -0
- package/cpp/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
- package/cpp/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
- package/cpp/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
- package/cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gelu.cl +89 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
- package/cpp/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
- package/cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +187 -0
- package/cpp/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
- package/cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
- package/cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
- package/cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul.cl +152 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl +194 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
- package/cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
- package/cpp/ggml/src/ggml-opencl/kernels/norm.cl +161 -0
- package/cpp/ggml/src/ggml-opencl/kernels/pad.cl +39 -0
- package/cpp/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
- package/cpp/ggml/src/ggml-opencl/kernels/repeat.cl +38 -0
- package/cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +190 -0
- package/cpp/ggml/src/ggml-opencl/kernels/rope.cl +747 -0
- package/cpp/ggml/src/ggml-opencl/kernels/scale.cl +27 -0
- package/cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- package/cpp/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +108 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +108 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +107 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +107 -0
- package/cpp/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
- package/cpp/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
- package/cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
- package/cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +140 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tanh.cl +109 -0
- package/cpp/ggml/src/ggml-opencl/kernels/transpose.cl +117 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
- package/cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
- package/cpp/ggml/src/ggml-opt.cpp +1093 -0
- package/cpp/ggml/src/ggml-quants.c +5325 -0
- package/cpp/ggml/src/ggml-quants.h +106 -0
- package/cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2118 -0
- package/cpp/ggml/src/ggml-threading.cpp +12 -0
- package/cpp/ggml/src/ggml-threading.h +14 -0
- package/cpp/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
- package/cpp/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
- package/cpp/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
- package/cpp/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
- package/cpp/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
- package/cpp/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
- package/cpp/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
- package/cpp/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
- package/cpp/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
- package/cpp/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1231 -0
- package/cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3150 -0
- package/cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +107 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +923 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +182 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +668 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +713 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +103 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +138 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +188 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +194 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +109 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
- package/cpp/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
- package/cpp/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
- package/cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
- package/cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +633 -0
- package/cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
- package/cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
- package/cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
- package/cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
- package/cpp/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
- package/cpp/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
- package/cpp/ggml/src/ggml.c +7669 -0
- package/cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/ggml/src/gguf.cpp +1699 -0
- package/cpp/include/llama-cpp.h +32 -0
- package/cpp/include/llama.h +1568 -0
- package/cpp/mtmd/CMakeLists.txt +98 -0
- package/cpp/mtmd/README.md +63 -0
- package/cpp/mtmd/clip-graph.h +117 -0
- package/cpp/mtmd/clip-impl.h +586 -0
- package/cpp/mtmd/clip-model.h +390 -0
- package/cpp/mtmd/clip.cpp +4154 -0
- package/cpp/mtmd/clip.h +121 -0
- package/cpp/mtmd/deprecation-warning.cpp +22 -0
- package/cpp/mtmd/legacy-models/convert_image_encoder_to_gguf.py +412 -0
- package/cpp/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +280 -0
- package/cpp/mtmd/legacy-models/glmedge-surgery.py +33 -0
- package/cpp/mtmd/legacy-models/llava_surgery.py +38 -0
- package/cpp/mtmd/legacy-models/llava_surgery_v2.py +180 -0
- package/cpp/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +892 -0
- package/cpp/mtmd/legacy-models/minicpmv-surgery.py +47 -0
- package/cpp/mtmd/models/cogvlm.cpp +98 -0
- package/cpp/mtmd/models/conformer.cpp +216 -0
- package/cpp/mtmd/models/glm4v.cpp +122 -0
- package/cpp/mtmd/models/internvl.cpp +69 -0
- package/cpp/mtmd/models/kimik25.cpp +101 -0
- package/cpp/mtmd/models/kimivl.cpp +63 -0
- package/cpp/mtmd/models/llama4.cpp +96 -0
- package/cpp/mtmd/models/llava.cpp +374 -0
- package/cpp/mtmd/models/minicpmv.cpp +114 -0
- package/cpp/mtmd/models/mobilenetv5.cpp +451 -0
- package/cpp/mtmd/models/models.h +128 -0
- package/cpp/mtmd/models/nemotron-v2-vl.cpp +35 -0
- package/cpp/mtmd/models/paddleocr.cpp +52 -0
- package/cpp/mtmd/models/pixtral.cpp +86 -0
- package/cpp/mtmd/models/qwen2vl.cpp +183 -0
- package/cpp/mtmd/models/qwen3vl.cpp +193 -0
- package/cpp/mtmd/models/siglip.cpp +86 -0
- package/cpp/mtmd/models/whisper-enc.cpp +115 -0
- package/cpp/mtmd/models/youtuvl.cpp +179 -0
- package/cpp/mtmd/mtmd-audio.cpp +730 -0
- package/cpp/mtmd/mtmd-audio.h +113 -0
- package/cpp/mtmd/mtmd-cli.cpp +437 -0
- package/cpp/mtmd/mtmd-helper.cpp +521 -0
- package/cpp/mtmd/mtmd-helper.h +96 -0
- package/cpp/mtmd/mtmd.cpp +1156 -0
- package/cpp/mtmd/mtmd.h +319 -0
- package/cpp/mtmd/requirements.txt +5 -0
- package/cpp/mtmd/test-1.jpeg +0 -0
- package/cpp/mtmd/test-2.mp3 +0 -0
- package/cpp/mtmd/tests.sh +192 -0
- package/cpp/src/CMakeLists.txt +169 -0
- package/cpp/src/llama-adapter.cpp +488 -0
- package/cpp/src/llama-adapter.h +89 -0
- package/cpp/src/llama-arch.cpp +2855 -0
- package/cpp/src/llama-arch.h +619 -0
- package/cpp/src/llama-batch.cpp +917 -0
- package/cpp/src/llama-batch.h +173 -0
- package/cpp/src/llama-chat.cpp +896 -0
- package/cpp/src/llama-chat.h +71 -0
- package/cpp/src/llama-context.cpp +3512 -0
- package/cpp/src/llama-context.h +359 -0
- package/cpp/src/llama-cparams.cpp +5 -0
- package/cpp/src/llama-cparams.h +44 -0
- package/cpp/src/llama-grammar.cpp +1464 -0
- package/cpp/src/llama-grammar.h +194 -0
- package/cpp/src/llama-graph.cpp +2685 -0
- package/cpp/src/llama-graph.h +1026 -0
- package/cpp/src/llama-hparams.cpp +234 -0
- package/cpp/src/llama-hparams.h +339 -0
- package/cpp/src/llama-impl.cpp +171 -0
- package/cpp/src/llama-impl.h +73 -0
- package/cpp/src/llama-io.cpp +15 -0
- package/cpp/src/llama-io.h +35 -0
- package/cpp/src/llama-kv-cache-iswa.cpp +330 -0
- package/cpp/src/llama-kv-cache-iswa.h +137 -0
- package/cpp/src/llama-kv-cache.cpp +2271 -0
- package/cpp/src/llama-kv-cache.h +388 -0
- package/cpp/src/llama-kv-cells.h +533 -0
- package/cpp/src/llama-memory-hybrid-iswa.cpp +275 -0
- package/cpp/src/llama-memory-hybrid-iswa.h +140 -0
- package/cpp/src/llama-memory-hybrid.cpp +268 -0
- package/cpp/src/llama-memory-hybrid.h +139 -0
- package/cpp/src/llama-memory-recurrent.cpp +1165 -0
- package/cpp/src/llama-memory-recurrent.h +182 -0
- package/cpp/src/llama-memory.cpp +59 -0
- package/cpp/src/llama-memory.h +122 -0
- package/cpp/src/llama-mmap.cpp +785 -0
- package/cpp/src/llama-mmap.h +92 -0
- package/cpp/src/llama-model-loader.cpp +1414 -0
- package/cpp/src/llama-model-loader.h +203 -0
- package/cpp/src/llama-model-saver.cpp +286 -0
- package/cpp/src/llama-model-saver.h +37 -0
- package/cpp/src/llama-model.cpp +9253 -0
- package/cpp/src/llama-model.h +576 -0
- package/cpp/src/llama-quant.cpp +1119 -0
- package/cpp/src/llama-quant.h +1 -0
- package/cpp/src/llama-sampler.cpp +3885 -0
- package/cpp/src/llama-sampler.h +42 -0
- package/cpp/src/llama-vocab.cpp +3970 -0
- package/cpp/src/llama-vocab.h +187 -0
- package/cpp/src/llama.cpp +1313 -0
- package/cpp/src/models/afmoe.cpp +191 -0
- package/cpp/src/models/apertus.cpp +125 -0
- package/cpp/src/models/arcee.cpp +135 -0
- package/cpp/src/models/arctic.cpp +138 -0
- package/cpp/src/models/arwkv7.cpp +86 -0
- package/cpp/src/models/baichuan.cpp +122 -0
- package/cpp/src/models/bailingmoe.cpp +144 -0
- package/cpp/src/models/bailingmoe2.cpp +135 -0
- package/cpp/src/models/bert.cpp +178 -0
- package/cpp/src/models/bitnet.cpp +160 -0
- package/cpp/src/models/bloom.cpp +101 -0
- package/cpp/src/models/chameleon.cpp +178 -0
- package/cpp/src/models/chatglm.cpp +132 -0
- package/cpp/src/models/codeshell.cpp +111 -0
- package/cpp/src/models/cogvlm.cpp +102 -0
- package/cpp/src/models/cohere2-iswa.cpp +134 -0
- package/cpp/src/models/command-r.cpp +122 -0
- package/cpp/src/models/dbrx.cpp +123 -0
- package/cpp/src/models/deci.cpp +135 -0
- package/cpp/src/models/deepseek.cpp +144 -0
- package/cpp/src/models/deepseek2.cpp +262 -0
- package/cpp/src/models/delta-net-base.cpp +376 -0
- package/cpp/src/models/dots1.cpp +134 -0
- package/cpp/src/models/dream.cpp +105 -0
- package/cpp/src/models/ernie4-5-moe.cpp +150 -0
- package/cpp/src/models/ernie4-5.cpp +110 -0
- package/cpp/src/models/eurobert.cpp +97 -0
- package/cpp/src/models/exaone-moe.cpp +146 -0
- package/cpp/src/models/exaone.cpp +114 -0
- package/cpp/src/models/exaone4.cpp +123 -0
- package/cpp/src/models/falcon-h1.cpp +111 -0
- package/cpp/src/models/falcon.cpp +120 -0
- package/cpp/src/models/gemma-embedding.cpp +116 -0
- package/cpp/src/models/gemma.cpp +112 -0
- package/cpp/src/models/gemma2-iswa.cpp +128 -0
- package/cpp/src/models/gemma3.cpp +155 -0
- package/cpp/src/models/gemma3n-iswa.cpp +384 -0
- package/cpp/src/models/glm4-moe.cpp +170 -0
- package/cpp/src/models/glm4.cpp +157 -0
- package/cpp/src/models/gpt2.cpp +105 -0
- package/cpp/src/models/gptneox.cpp +144 -0
- package/cpp/src/models/granite-hybrid.cpp +196 -0
- package/cpp/src/models/granite.cpp +211 -0
- package/cpp/src/models/grok.cpp +159 -0
- package/cpp/src/models/grovemoe.cpp +141 -0
- package/cpp/src/models/hunyuan-dense.cpp +132 -0
- package/cpp/src/models/hunyuan-moe.cpp +154 -0
- package/cpp/src/models/internlm2.cpp +120 -0
- package/cpp/src/models/jais.cpp +86 -0
- package/cpp/src/models/jais2.cpp +123 -0
- package/cpp/src/models/jamba.cpp +106 -0
- package/cpp/src/models/kimi-linear.cpp +392 -0
- package/cpp/src/models/lfm2.cpp +190 -0
- package/cpp/src/models/llada-moe.cpp +122 -0
- package/cpp/src/models/llada.cpp +99 -0
- package/cpp/src/models/llama-iswa.cpp +178 -0
- package/cpp/src/models/llama.cpp +168 -0
- package/cpp/src/models/maincoder.cpp +117 -0
- package/cpp/src/models/mamba-base.cpp +285 -0
- package/cpp/src/models/mamba.cpp +54 -0
- package/cpp/src/models/mimo2-iswa.cpp +123 -0
- package/cpp/src/models/minicpm3.cpp +200 -0
- package/cpp/src/models/minimax-m2.cpp +124 -0
- package/cpp/src/models/mistral3.cpp +160 -0
- package/cpp/src/models/models.h +684 -0
- package/cpp/src/models/modern-bert.cpp +109 -0
- package/cpp/src/models/mpt.cpp +126 -0
- package/cpp/src/models/nemotron-h.cpp +148 -0
- package/cpp/src/models/nemotron.cpp +122 -0
- package/cpp/src/models/neo-bert.cpp +104 -0
- package/cpp/src/models/olmo.cpp +121 -0
- package/cpp/src/models/olmo2.cpp +150 -0
- package/cpp/src/models/olmoe.cpp +124 -0
- package/cpp/src/models/openai-moe-iswa.cpp +127 -0
- package/cpp/src/models/openelm.cpp +124 -0
- package/cpp/src/models/orion.cpp +123 -0
- package/cpp/src/models/paddleocr.cpp +122 -0
- package/cpp/src/models/pangu-embedded.cpp +121 -0
- package/cpp/src/models/phi2.cpp +121 -0
- package/cpp/src/models/phi3.cpp +152 -0
- package/cpp/src/models/plamo.cpp +110 -0
- package/cpp/src/models/plamo2.cpp +318 -0
- package/cpp/src/models/plamo3.cpp +128 -0
- package/cpp/src/models/plm.cpp +169 -0
- package/cpp/src/models/qwen.cpp +108 -0
- package/cpp/src/models/qwen2.cpp +126 -0
- package/cpp/src/models/qwen2moe.cpp +151 -0
- package/cpp/src/models/qwen2vl.cpp +117 -0
- package/cpp/src/models/qwen3.cpp +117 -0
- package/cpp/src/models/qwen35.cpp +386 -0
- package/cpp/src/models/qwen35moe.cpp +420 -0
- package/cpp/src/models/qwen3moe.cpp +124 -0
- package/cpp/src/models/qwen3next.cpp +525 -0
- package/cpp/src/models/qwen3vl-moe.cpp +140 -0
- package/cpp/src/models/qwen3vl.cpp +132 -0
- package/cpp/src/models/refact.cpp +94 -0
- package/cpp/src/models/rnd1.cpp +126 -0
- package/cpp/src/models/rwkv6-base.cpp +164 -0
- package/cpp/src/models/rwkv6.cpp +94 -0
- package/cpp/src/models/rwkv6qwen2.cpp +86 -0
- package/cpp/src/models/rwkv7-base.cpp +137 -0
- package/cpp/src/models/rwkv7.cpp +90 -0
- package/cpp/src/models/seed-oss.cpp +124 -0
- package/cpp/src/models/smallthinker.cpp +126 -0
- package/cpp/src/models/smollm3.cpp +128 -0
- package/cpp/src/models/stablelm.cpp +146 -0
- package/cpp/src/models/starcoder.cpp +100 -0
- package/cpp/src/models/starcoder2.cpp +121 -0
- package/cpp/src/models/step35-iswa.cpp +168 -0
- package/cpp/src/models/t5-dec.cpp +166 -0
- package/cpp/src/models/t5-enc.cpp +96 -0
- package/cpp/src/models/wavtokenizer-dec.cpp +149 -0
- package/cpp/src/models/xverse.cpp +108 -0
- package/cpp/src/unicode-data.cpp +7034 -0
- package/cpp/src/unicode-data.h +20 -0
- package/cpp/src/unicode.cpp +1103 -0
- package/cpp/src/unicode.h +111 -0
- package/cpp/vendor/nlohmann/json.hpp +25526 -0
- package/cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/vendor/stb/stb_image.h +7988 -0
- package/ios/LocalLLM-Bridging-Header.h +2 -0
- package/ios/LocalLLM.h +5 -0
- package/ios/LocalLLM.mm +1267 -0
- package/local-llm-rn.podspec +60 -0
- package/package.json +35 -0
- package/src/NativeLocalLLM.ts +73 -0
- package/src/device.ts +50 -0
- package/src/download-adapter.ts +17 -0
- package/src/index.ts +21 -0
- package/src/native-bridge.ts +142 -0
- package/src/rn-downloader.ts +37 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
#define ACC_TYPE float
|
|
4
|
+
#define ACC_TYPE4 float4
|
|
5
|
+
#define Q_DATA_TYPE4 float4
|
|
6
|
+
#define KV_DATA_TYPE4 half4
|
|
7
|
+
#define O_DATA_TYPE4 float4
|
|
8
|
+
#define MASK_DATA_TYPE half
|
|
9
|
+
#define CONVERT_Q_ACC4(x) (x)
|
|
10
|
+
#define CONVERT_KV_ACC4(x) convert_float4(x)
|
|
11
|
+
#define CONVERT_O_DATA4(x) (x)
|
|
12
|
+
|
|
13
|
+
#define DK_VEC (DK/4)
|
|
14
|
+
#define DV_VEC (DV/4)
|
|
15
|
+
#define WG_SIZE (BLOCK_M)
|
|
16
|
+
#define Q1_WG_SIZE 64
|
|
17
|
+
|
|
18
|
+
inline float get_alibi_slope(
|
|
19
|
+
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
|
20
|
+
) {
|
|
21
|
+
if (max_bias <= 0.0f) {
|
|
22
|
+
return 1.0f;
|
|
23
|
+
}
|
|
24
|
+
const float base = h < n_head_log2 ? m0 : m1;
|
|
25
|
+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
26
|
+
|
|
27
|
+
return pow(base, exph);
|
|
28
|
+
}
|
|
29
|
+
__kernel void flash_attn_f32_f16(
|
|
30
|
+
const global void * q_void, ulong q_offset,
|
|
31
|
+
const global void * k_void, ulong k_offset,
|
|
32
|
+
const global void * v_void, ulong v_offset,
|
|
33
|
+
global void * o_void, ulong o_offset,
|
|
34
|
+
const float scale,
|
|
35
|
+
const int n_q,
|
|
36
|
+
const int n_kv,
|
|
37
|
+
const int is_causal,
|
|
38
|
+
const int n_head,
|
|
39
|
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
|
40
|
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
|
41
|
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
|
42
|
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
|
43
|
+
const float max_bias,
|
|
44
|
+
const float m0,
|
|
45
|
+
const float m1,
|
|
46
|
+
const int n_head_log2,
|
|
47
|
+
const float logit_softcap,
|
|
48
|
+
const int n_head_kv,
|
|
49
|
+
const global void* mask_void,
|
|
50
|
+
const ulong mask_offset,
|
|
51
|
+
const ulong mask_nb1,
|
|
52
|
+
const ulong mask_nb2,
|
|
53
|
+
const ulong mask_nb3,
|
|
54
|
+
const int mask_ne2,
|
|
55
|
+
const int mask_ne3,
|
|
56
|
+
const global void* sinks_void,
|
|
57
|
+
const ulong sinks_offset
|
|
58
|
+
) {
|
|
59
|
+
const int tid = get_local_id(0);
|
|
60
|
+
const int block_q_idx = get_group_id(0);
|
|
61
|
+
const int head_batch_idx = get_global_id(1);
|
|
62
|
+
|
|
63
|
+
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
|
64
|
+
|
|
65
|
+
const int batch_idx = head_batch_idx / n_head;
|
|
66
|
+
const int head_idx = head_batch_idx % n_head;
|
|
67
|
+
|
|
68
|
+
const int gqa_ratio = n_head / n_head_kv;
|
|
69
|
+
const int head_kv_idx = head_idx / gqa_ratio;
|
|
70
|
+
|
|
71
|
+
const global char* q_base = (const global char*)q_void + q_offset;
|
|
72
|
+
const global char* k_base = (const global char*)k_void + k_offset;
|
|
73
|
+
const global char* v_base = (const global char*)v_void + v_offset;
|
|
74
|
+
global char* o_base = (global char*)o_void + o_offset;
|
|
75
|
+
|
|
76
|
+
const global char* mask_base = NULL;
|
|
77
|
+
if (mask_void != NULL) {
|
|
78
|
+
const int mask_head_idx = head_idx % mask_ne2;
|
|
79
|
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
|
80
|
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
ACC_TYPE4 q_priv[DK_VEC];
|
|
84
|
+
if (my_query_row < n_q) {
|
|
85
|
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
|
86
|
+
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
|
87
|
+
#pragma unroll
|
|
88
|
+
for (int i = 0; i < DK_VEC; ++i) {
|
|
89
|
+
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
ACC_TYPE4 o_acc[DV_VEC];
|
|
94
|
+
#pragma unroll
|
|
95
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
96
|
+
o_acc[i] = (ACC_TYPE4)(0.0f);
|
|
97
|
+
}
|
|
98
|
+
ACC_TYPE m_i = -INFINITY;
|
|
99
|
+
ACC_TYPE l_i = 0.0f;
|
|
100
|
+
|
|
101
|
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
|
102
|
+
|
|
103
|
+
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
|
104
|
+
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
|
105
|
+
|
|
106
|
+
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
|
107
|
+
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
|
108
|
+
const int row = i / DK_VEC;
|
|
109
|
+
const int col = i % DK_VEC;
|
|
110
|
+
const int k_row_idx = k_start + row;
|
|
111
|
+
if (k_row_idx < n_kv) {
|
|
112
|
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
|
113
|
+
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
|
117
|
+
const int row = i / DV_VEC;
|
|
118
|
+
const int col = i % DV_VEC;
|
|
119
|
+
const int v_row_idx = k_start + row;
|
|
120
|
+
if (v_row_idx < n_kv) {
|
|
121
|
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
|
122
|
+
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
126
|
+
|
|
127
|
+
if (my_query_row >= n_q) {
|
|
128
|
+
continue;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
for (int j = 0; j < BLOCK_N; j += 2) {
|
|
132
|
+
const int k_row0 = k_start + j;
|
|
133
|
+
const int k_row1 = k_start + j + 1;
|
|
134
|
+
|
|
135
|
+
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
|
136
|
+
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
|
137
|
+
#pragma unroll
|
|
138
|
+
for (int k = 0; k < DK_VEC; k++) {
|
|
139
|
+
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
|
140
|
+
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
|
141
|
+
}
|
|
142
|
+
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
|
143
|
+
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
|
144
|
+
|
|
145
|
+
if (is_causal) {
|
|
146
|
+
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
|
147
|
+
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if (k_row0 >= n_kv) score0 = -INFINITY;
|
|
151
|
+
if (k_row1 >= n_kv) score1 = -INFINITY;
|
|
152
|
+
|
|
153
|
+
if (mask_base != NULL) {
|
|
154
|
+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
|
155
|
+
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
|
156
|
+
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if (logit_softcap > 0.0f) {
|
|
160
|
+
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
|
161
|
+
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
|
165
|
+
const ACC_TYPE p0 = exp(score0 - m_new);
|
|
166
|
+
const ACC_TYPE p1 = exp(score1 - m_new);
|
|
167
|
+
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
|
168
|
+
|
|
169
|
+
#pragma unroll
|
|
170
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
171
|
+
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
|
|
172
|
+
}
|
|
173
|
+
l_i = l_i * scale_prev + p0 + p1;
|
|
174
|
+
m_i = m_new;
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
if (my_query_row < n_q) {
|
|
179
|
+
if (sinks_void != NULL) {
|
|
180
|
+
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
|
181
|
+
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
|
182
|
+
const ACC_TYPE m_final = max(m_i, m_sink);
|
|
183
|
+
|
|
184
|
+
const ACC_TYPE scale_o = exp(m_i - m_final);
|
|
185
|
+
#pragma unroll
|
|
186
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
187
|
+
o_acc[i] *= scale_o;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
|
194
|
+
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
|
195
|
+
if (l_i > 0.0f) {
|
|
196
|
+
const ACC_TYPE l_inv = 1.0f / l_i;
|
|
197
|
+
#pragma unroll
|
|
198
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
199
|
+
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
|
|
200
|
+
}
|
|
201
|
+
} else {
|
|
202
|
+
#pragma unroll
|
|
203
|
+
for (int i = 0; i < DV_VEC; ++i) {
|
|
204
|
+
o_row[i] = (O_DATA_TYPE4)(0.0f);
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
__kernel void flash_attn_f32_f16_q1(
|
|
211
|
+
const global void * q_void, ulong q_offset,
|
|
212
|
+
const global void * k_void, ulong k_offset,
|
|
213
|
+
const global void * v_void, ulong v_offset,
|
|
214
|
+
global void * o_void, ulong o_offset,
|
|
215
|
+
const float scale,
|
|
216
|
+
const int n_q,
|
|
217
|
+
const int n_kv,
|
|
218
|
+
const int is_causal,
|
|
219
|
+
const int n_head,
|
|
220
|
+
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
|
221
|
+
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
|
222
|
+
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
|
223
|
+
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
|
224
|
+
const float max_bias,
|
|
225
|
+
const float m0,
|
|
226
|
+
const float m1,
|
|
227
|
+
const int n_head_log2,
|
|
228
|
+
const float logit_softcap,
|
|
229
|
+
const int n_head_kv,
|
|
230
|
+
const global void* mask_void,
|
|
231
|
+
const ulong mask_offset,
|
|
232
|
+
const ulong mask_nb1,
|
|
233
|
+
const ulong mask_nb2,
|
|
234
|
+
const ulong mask_nb3,
|
|
235
|
+
const int mask_ne2,
|
|
236
|
+
const int mask_ne3,
|
|
237
|
+
const global void* sinks_void,
|
|
238
|
+
const ulong sinks_offset
|
|
239
|
+
) {
|
|
240
|
+
const int tid = get_local_id(0);
|
|
241
|
+
const int head_batch_idx = get_global_id(1);
|
|
242
|
+
|
|
243
|
+
const int batch_idx = head_batch_idx / n_head;
|
|
244
|
+
const int head_idx = head_batch_idx % n_head;
|
|
245
|
+
|
|
246
|
+
const int gqa_ratio = n_head / n_head_kv;
|
|
247
|
+
const int head_kv_idx = head_idx / gqa_ratio;
|
|
248
|
+
|
|
249
|
+
const global char* q_base = (const global char*)q_void + q_offset;
|
|
250
|
+
const global char* k_base = (const global char*)k_void + k_offset;
|
|
251
|
+
const global char* v_base = (const global char*)v_void + v_offset;
|
|
252
|
+
global char* o_base = (global char*)o_void + o_offset;
|
|
253
|
+
|
|
254
|
+
const global char* mask_base = NULL;
|
|
255
|
+
if (mask_void != NULL) {
|
|
256
|
+
const int mask_head_idx = head_idx % mask_ne2;
|
|
257
|
+
const int mask_batch_idx = batch_idx % mask_ne3;
|
|
258
|
+
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
ACC_TYPE4 q_priv[DK_VEC];
|
|
262
|
+
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
|
263
|
+
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
|
264
|
+
#pragma unroll
|
|
265
|
+
for (int i = 0; i < DK_VEC; ++i) {
|
|
266
|
+
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
|
270
|
+
|
|
271
|
+
const global ACC_TYPE* sinks_ptr = NULL;
|
|
272
|
+
if (sinks_void != NULL) {
|
|
273
|
+
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
|
277
|
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
|
278
|
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
|
279
|
+
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
|
280
|
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
|
281
|
+
#pragma unroll
|
|
282
|
+
for (int k = 0; k < DK_VEC; k++) {
|
|
283
|
+
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
|
284
|
+
}
|
|
285
|
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
|
286
|
+
if (mask_base != NULL) {
|
|
287
|
+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
|
288
|
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
|
289
|
+
}
|
|
290
|
+
if (logit_softcap > 0.0f) {
|
|
291
|
+
score = logit_softcap * tanh(score / logit_softcap);
|
|
292
|
+
}
|
|
293
|
+
m_i = max(m_i, score);
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
|
297
|
+
local_m[tid] = m_i;
|
|
298
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
299
|
+
#pragma unroll
|
|
300
|
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
|
301
|
+
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
|
302
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
303
|
+
}
|
|
304
|
+
const ACC_TYPE m_final = local_m[0];
|
|
305
|
+
|
|
306
|
+
ACC_TYPE4 o_acc[DV_VEC];
|
|
307
|
+
#pragma unroll
|
|
308
|
+
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
|
309
|
+
ACC_TYPE l_i = 0.0f;
|
|
310
|
+
|
|
311
|
+
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
|
312
|
+
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
|
313
|
+
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
|
314
|
+
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
|
315
|
+
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
|
|
316
|
+
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
|
317
|
+
#pragma unroll
|
|
318
|
+
for (int k = 0; k < DK_VEC; k++) {
|
|
319
|
+
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
|
320
|
+
}
|
|
321
|
+
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
|
322
|
+
if (mask_base != NULL) {
|
|
323
|
+
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
|
|
324
|
+
score += slope * (ACC_TYPE)mask_ptr[k_idx];
|
|
325
|
+
}
|
|
326
|
+
if (logit_softcap > 0.0f) {
|
|
327
|
+
score = logit_softcap * tanh(score / logit_softcap);
|
|
328
|
+
}
|
|
329
|
+
const ACC_TYPE p = exp(score - m_final);
|
|
330
|
+
l_i += p;
|
|
331
|
+
#pragma unroll
|
|
332
|
+
for (int i = 0; i < DV_VEC; i++) {
|
|
333
|
+
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
|
338
|
+
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
|
339
|
+
local_l[tid] = l_i;
|
|
340
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
341
|
+
#pragma unroll
|
|
342
|
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
|
343
|
+
if (tid < s) local_l[tid] += local_l[tid + s];
|
|
344
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
|
|
348
|
+
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
|
349
|
+
ACC_TYPE l_final = local_l[0];
|
|
350
|
+
|
|
351
|
+
if (sinks_ptr != NULL) {
|
|
352
|
+
l_final += exp(sinks_ptr[head_idx] - m_final);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
if (l_final > 0.0f) {
|
|
356
|
+
const ACC_TYPE l_inv = 1.0f / l_final;
|
|
357
|
+
for (int i = 0; i < DV_VEC; i++) {
|
|
358
|
+
local_o_comp[tid] = o_acc[i];
|
|
359
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
360
|
+
#pragma unroll
|
|
361
|
+
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
|
362
|
+
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
|
363
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
364
|
+
}
|
|
365
|
+
if (tid == 0) {
|
|
366
|
+
o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv);
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
} else if (tid == 0) {
|
|
370
|
+
#pragma unroll
|
|
371
|
+
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
|
|
372
|
+
}
|
|
373
|
+
}
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
|
|
3
|
+
//------------------------------------------------------------------------------
|
|
4
|
+
// gelu
|
|
5
|
+
//------------------------------------------------------------------------------
|
|
6
|
+
#define GELU_COEF_A 0.044715f
|
|
7
|
+
#define GELU_QUICK_COEF -1.702f
|
|
8
|
+
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
|
9
|
+
#define SQRT_2_INV 0.70710678118654752440084436210484f
|
|
10
|
+
|
|
11
|
+
kernel void kernel_gelu(
|
|
12
|
+
global float * src0,
|
|
13
|
+
ulong offset0,
|
|
14
|
+
global float * dst,
|
|
15
|
+
ulong offsetd
|
|
16
|
+
) {
|
|
17
|
+
src0 = (global float*)((global char*)src0 + offset0);
|
|
18
|
+
dst = (global float*)((global char*)dst + offsetd);
|
|
19
|
+
|
|
20
|
+
float x = src0[get_global_id(0)];
|
|
21
|
+
|
|
22
|
+
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
kernel void kernel_gelu_4(
|
|
26
|
+
global float4 * src0,
|
|
27
|
+
ulong offset0,
|
|
28
|
+
global float4 * dst,
|
|
29
|
+
ulong offsetd
|
|
30
|
+
) {
|
|
31
|
+
src0 = (global float4*)((global char*)src0 + offset0);
|
|
32
|
+
dst = (global float4*)((global char*)dst + offsetd);
|
|
33
|
+
|
|
34
|
+
float4 x = src0[get_global_id(0)];
|
|
35
|
+
|
|
36
|
+
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
kernel void kernel_gelu_erf(
|
|
40
|
+
global float * src0,
|
|
41
|
+
ulong offset0,
|
|
42
|
+
global float * dst,
|
|
43
|
+
ulong offsetd
|
|
44
|
+
) {
|
|
45
|
+
src0 = (global float*)((global char*)src0 + offset0);
|
|
46
|
+
dst = (global float*)((global char*)dst + offsetd);
|
|
47
|
+
|
|
48
|
+
float x = src0[get_global_id(0)];
|
|
49
|
+
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
kernel void kernel_gelu_erf_4(
|
|
53
|
+
global float4 * src0,
|
|
54
|
+
ulong offset0,
|
|
55
|
+
global float4 * dst,
|
|
56
|
+
ulong offsetd
|
|
57
|
+
) {
|
|
58
|
+
src0 = (global float4*)((global char*)src0 + offset0);
|
|
59
|
+
dst = (global float4*)((global char*)dst + offsetd);
|
|
60
|
+
|
|
61
|
+
float4 x = src0[get_global_id(0)];
|
|
62
|
+
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
kernel void kernel_gelu_quick(
|
|
66
|
+
global float * src0,
|
|
67
|
+
ulong offset0,
|
|
68
|
+
global float * dst,
|
|
69
|
+
ulong offsetd
|
|
70
|
+
) {
|
|
71
|
+
src0 = (global float*)((global char*)src0 + offset0);
|
|
72
|
+
dst = (global float*)((global char*)dst + offsetd);
|
|
73
|
+
|
|
74
|
+
float x = src0[get_global_id(0)];
|
|
75
|
+
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
kernel void kernel_gelu_quick_4(
|
|
79
|
+
global float4 * src0,
|
|
80
|
+
ulong offset0,
|
|
81
|
+
global float4 * dst,
|
|
82
|
+
ulong offsetd
|
|
83
|
+
) {
|
|
84
|
+
src0 = (global float4*)((global char*)src0 + offset0);
|
|
85
|
+
dst = (global float4*)((global char*)dst + offsetd);
|
|
86
|
+
|
|
87
|
+
float4 x = src0[get_global_id(0)];
|
|
88
|
+
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
|
89
|
+
}
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
2
|
+
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
|
3
|
+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
|
4
|
+
|
|
5
|
+
#define QK_MXFP4 32
|
|
6
|
+
#define N_SIMDGROUP 2
|
|
7
|
+
#define SIMDGROUP_WIDTH 64
|
|
8
|
+
|
|
9
|
+
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
|
|
10
|
+
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
|
11
|
+
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
|
12
|
+
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
|
13
|
+
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
|
14
|
+
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
|
15
|
+
|
|
16
|
+
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
|
17
|
+
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
|
18
|
+
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
|
19
|
+
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
|
20
|
+
|
|
21
|
+
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
|
22
|
+
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
|
23
|
+
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
|
24
|
+
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
|
25
|
+
|
|
26
|
+
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
|
27
|
+
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
|
28
|
+
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
|
29
|
+
sign_b.hi = fp4x8.s0 & 0x8000;
|
|
30
|
+
|
|
31
|
+
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
|
32
|
+
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
|
33
|
+
|
|
34
|
+
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
|
35
|
+
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
|
36
|
+
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
|
37
|
+
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
|
38
|
+
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
|
39
|
+
|
|
40
|
+
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
|
41
|
+
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
|
42
|
+
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
|
43
|
+
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
|
44
|
+
|
|
45
|
+
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
|
46
|
+
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
|
47
|
+
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
|
48
|
+
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
|
49
|
+
|
|
50
|
+
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
|
51
|
+
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
|
52
|
+
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
|
53
|
+
sign_b.hi = fp4x8.s1 & 0x8000;
|
|
54
|
+
|
|
55
|
+
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
|
56
|
+
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
|
57
|
+
|
|
58
|
+
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
static inline float e8m0_to_fp32(uchar x) {
|
|
62
|
+
int bits;
|
|
63
|
+
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
|
64
|
+
return as_float(bits);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
__attribute__((qcom_reqd_sub_group_size("half")))
|
|
69
|
+
__kernel void kernel_gemm_moe_mxfp4_f32(
|
|
70
|
+
__global uint4 * src0_q,
|
|
71
|
+
__global uchar * src0_e,
|
|
72
|
+
__read_only image1d_buffer_t src1,
|
|
73
|
+
__global ushort4 * src2,
|
|
74
|
+
__global float * dst,
|
|
75
|
+
ulong offsetd,
|
|
76
|
+
int ne00,
|
|
77
|
+
int ne01,
|
|
78
|
+
int tile_size
|
|
79
|
+
) {
|
|
80
|
+
uint i01 = get_global_id(0);
|
|
81
|
+
uint i20 = get_global_id(2);
|
|
82
|
+
uint sgid = get_local_id(1);
|
|
83
|
+
uint slid = get_sub_group_local_id();
|
|
84
|
+
|
|
85
|
+
ushort4 router = src2[i20];
|
|
86
|
+
ushort expert_id = router.x;
|
|
87
|
+
ushort i11 = router.y;
|
|
88
|
+
ushort i1 = router.z;
|
|
89
|
+
ushort tile_id = router.w;
|
|
90
|
+
|
|
91
|
+
if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
|
|
92
|
+
return;
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
|
96
|
+
uint tile_offset = expert_offset + tile_id * tile_size + i01;
|
|
97
|
+
|
|
98
|
+
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
|
99
|
+
|
|
100
|
+
// loop along ne00 in block granularity, skip 4 blocks every iter
|
|
101
|
+
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
|
102
|
+
// load one block of q
|
|
103
|
+
uint4 regQ = src0_q[tile_offset + ib00 * ne01];
|
|
104
|
+
// convert 8 fp4 to fp16
|
|
105
|
+
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
|
106
|
+
|
|
107
|
+
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
|
108
|
+
float4 shared_y4;
|
|
109
|
+
shared_y4 = read_imagef(src1, (offset + 0));
|
|
110
|
+
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
|
111
|
+
|
|
112
|
+
shared_y4 = read_imagef(src1, (offset + 4));
|
|
113
|
+
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
|
117
|
+
|
|
118
|
+
shared_y4 = read_imagef(src1, (offset + 1));
|
|
119
|
+
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
|
120
|
+
|
|
121
|
+
shared_y4 = read_imagef(src1, (offset + 5));
|
|
122
|
+
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
|
126
|
+
|
|
127
|
+
shared_y4 = read_imagef(src1, (offset + 2));
|
|
128
|
+
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
|
129
|
+
|
|
130
|
+
shared_y4 = read_imagef(src1, (offset + 6));
|
|
131
|
+
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
|
135
|
+
|
|
136
|
+
shared_y4 = read_imagef(src1, (offset + 3));
|
|
137
|
+
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
|
138
|
+
|
|
139
|
+
shared_y4 = read_imagef(src1, (offset + 7));
|
|
140
|
+
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
|
141
|
+
|
|
142
|
+
uchar regE = src0_e[tile_offset + ib00 * ne01];
|
|
143
|
+
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
// reduction in local memory, assumes #subgroups=4
|
|
147
|
+
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
|
148
|
+
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
|
149
|
+
// if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
|
150
|
+
// if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
|
151
|
+
barrier(CLK_LOCAL_MEM_FENCE);
|
|
152
|
+
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
|
153
|
+
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
|
154
|
+
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
|
155
|
+
|
|
156
|
+
// 1 outputs per thread in subgroup 0
|
|
157
|
+
if (sgid == 0) {
|
|
158
|
+
dst = dst + (offsetd >> 2);
|
|
159
|
+
dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
}
|