whispercpp 1.3.2 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +6 -3
- data/README.md +71 -14
- data/Rakefile +20 -7
- data/ext/.gitignore +4 -6
- data/ext/dependencies.rb +36 -24
- data/ext/extconf.rb +1 -1
- data/ext/options.rb +48 -184
- data/ext/ruby_whisper.c +18 -0
- data/ext/ruby_whisper_context.c +43 -12
- data/ext/ruby_whisper_model.c +1 -1
- data/ext/ruby_whisper_params.c +4 -2
- data/ext/ruby_whisper_segment.c +81 -4
- data/ext/ruby_whisper_transcribe.cpp +13 -7
- data/ext/ruby_whisper_vad_params.c +1 -1
- data/ext/sources/CMakeLists.txt +5 -1
- data/ext/sources/bindings/javascript/package.json +1 -1
- data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
- data/ext/sources/examples/addon.node/addon.cpp +150 -31
- data/ext/sources/examples/addon.node/index.js +3 -0
- data/ext/sources/examples/addon.node/vad-example.js +132 -0
- data/ext/sources/examples/bench/bench.cpp +3 -2
- data/ext/sources/examples/cli/cli.cpp +3 -2
- data/ext/sources/examples/command/command.cpp +32 -8
- data/ext/sources/examples/common-whisper.cpp +14 -7
- data/ext/sources/examples/lsp/lsp.cpp +2 -0
- data/ext/sources/examples/quantize/quantize.cpp +3 -0
- data/ext/sources/examples/server/CMakeLists.txt +3 -0
- data/ext/sources/examples/server/server.cpp +169 -22
- data/ext/sources/examples/stream/stream.cpp +6 -0
- data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
- data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
- data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
- data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
- data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
- data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
- data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
- data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
- data/ext/sources/examples/talk-llama/llama-context.h +38 -17
- data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
- data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
- data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
- data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
- data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
- data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
- data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
- data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
- data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
- data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
- data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
- data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
- data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
- data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
- data/ext/sources/examples/talk-llama/llama-model.h +27 -0
- data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
- data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
- data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
- data/ext/sources/examples/talk-llama/llama.cpp +11 -7
- data/ext/sources/examples/talk-llama/llama.h +147 -40
- data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
- data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
- data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
- data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
- data/ext/sources/ggml/CMakeLists.txt +48 -3
- data/ext/sources/ggml/cmake/common.cmake +24 -0
- data/ext/sources/ggml/include/ggml-backend.h +1 -1
- data/ext/sources/ggml/include/ggml-cpu.h +2 -0
- data/ext/sources/ggml/include/ggml.h +144 -5
- data/ext/sources/ggml/src/CMakeLists.txt +82 -24
- data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
- data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
- data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
- data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
- data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- data/ext/sources/ggml/src/ggml-common.h +4 -0
- data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
- data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
- data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
- data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
- data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
- data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
- data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
- data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
- data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
- data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
- data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
- data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
- data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
- data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
- data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
- data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
- data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
- data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
- data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- data/ext/sources/ggml/src/ggml-impl.h +127 -183
- data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
- data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
- data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
- data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
- data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
- data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- data/ext/sources/ggml/src/ggml-quants.c +6 -8
- data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
- data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
- data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
- data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
- data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
- data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
- data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
- data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
- data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
- data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
- data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
- data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
- data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
- data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
- data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
- data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
- data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
- data/ext/sources/ggml/src/ggml.c +328 -48
- data/ext/sources/ggml/src/ggml.cpp +26 -0
- data/ext/sources/ggml/src/gguf.cpp +24 -3
- data/ext/sources/include/whisper.h +2 -0
- data/ext/sources/src/CMakeLists.txt +2 -0
- data/ext/sources/src/coreml/whisper-compat.h +10 -0
- data/ext/sources/src/coreml/whisper-compat.m +35 -0
- data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
- data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
- data/ext/sources/src/whisper.cpp +218 -169
- data/extsources.rb +15 -9
- data/lib/whisper/context.rb +15 -0
- data/lib/whisper/model/uri.rb +56 -1
- data/lib/whisper/segment.rb +58 -0
- data/sig/whisper.rbs +68 -38
- data/{tests → test}/helper.rb +1 -12
- data/{tests → test}/test_model.rb +9 -0
- data/test/test_package.rb +51 -0
- data/test/test_segment.rb +146 -0
- data/{tests → test}/test_whisper.rb +70 -0
- data/whispercpp.gemspec +2 -3
- metadata +91 -43
- data/ext/sources/.dockerignore +0 -3
- data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
- data/ext/sources/ci/run.sh +0 -336
- data/ext/sources/close-issue.yml +0 -28
- data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
- data/tests/test_package.rb +0 -46
- data/tests/test_segment.rb +0 -74
- /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /data/{tests → test}/jfk_reader/.gitignore +0 -0
- /data/{tests → test}/jfk_reader/extconf.rb +0 -0
- /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
- /data/{tests → test}/test_callback.rb +0 -0
- /data/{tests → test}/test_error.rb +0 -0
- /data/{tests → test}/test_params.rb +0 -0
- /data/{tests → test}/test_vad.rb +0 -0
- /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,1481 @@
|
|
1
|
+
#define GGML_COMMON_IMPL_C
|
2
|
+
#include "ggml-common.h"
|
3
|
+
#include "ggml-quants.h"
|
4
|
+
#include "ggml-impl.h"
|
5
|
+
#include "ggml-cpu.h"
|
6
|
+
#include "simd-mappings.h"
|
7
|
+
|
8
|
+
#include "../../quants.h"
|
9
|
+
#include "../../ggml-cpu-impl.h"
|
10
|
+
|
11
|
+
#include <math.h>
|
12
|
+
#include <string.h>
|
13
|
+
#include <assert.h>
|
14
|
+
#include <float.h>
|
15
|
+
#include <stdlib.h> // for qsort
|
16
|
+
#include <stdio.h> // for GGML_ASSERT
|
17
|
+
|
18
|
+
#define GROUP_MAX_EPS 1e-15f
|
19
|
+
#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
|
20
|
+
#define GROUP_MAX_EPS_IQ2_S 1e-8f
|
21
|
+
#define GROUP_MAX_EPS_IQ1_M 1e-7f
|
22
|
+
#define GROUP_MAX_EPS_IQ1_S 1e-12f
|
23
|
+
|
24
|
+
#define UNUSED GGML_UNUSED
|
25
|
+
|
26
|
+
#if defined(__wasm_simd128__)
|
27
|
+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
28
|
+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
29
|
+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
30
|
+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
31
|
+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
32
|
+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
33
|
+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
34
|
+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
35
|
+
|
36
|
+
// precomputed tables for expanding 8bits to 8 bytes:
|
37
|
+
static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
|
38
|
+
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
39
|
+
#endif
|
40
|
+
|
41
|
+
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
42
|
+
assert(QK8_0 == 32);
|
43
|
+
assert(k % QK8_0 == 0);
|
44
|
+
const int nb = k / QK8_0;
|
45
|
+
|
46
|
+
block_q8_0 * GGML_RESTRICT y = vy;
|
47
|
+
|
48
|
+
#if defined __wasm_simd128__
|
49
|
+
for (int i = 0; i < nb; i++) {
|
50
|
+
v128_t srcv [8];
|
51
|
+
v128_t asrcv[8];
|
52
|
+
v128_t amaxv[8];
|
53
|
+
|
54
|
+
for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
|
55
|
+
for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
|
56
|
+
|
57
|
+
for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
|
58
|
+
for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
|
59
|
+
for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
|
60
|
+
|
61
|
+
const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
|
62
|
+
wasm_f32x4_extract_lane(amaxv[0], 1)),
|
63
|
+
MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
|
64
|
+
wasm_f32x4_extract_lane(amaxv[0], 3)));
|
65
|
+
|
66
|
+
const float d = amax / ((1 << 7) - 1);
|
67
|
+
const float id = d ? 1.0f/d : 0.0f;
|
68
|
+
|
69
|
+
y[i].d = GGML_CPU_FP32_TO_FP16(d);
|
70
|
+
|
71
|
+
for (int j = 0; j < 8; j++) {
|
72
|
+
const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
|
73
|
+
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
|
74
|
+
|
75
|
+
y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
|
76
|
+
y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
|
77
|
+
y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
|
78
|
+
y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
|
79
|
+
}
|
80
|
+
}
|
81
|
+
#else
|
82
|
+
GGML_UNUSED(nb);
|
83
|
+
// scalar
|
84
|
+
quantize_row_q8_0_ref(x, y, k);
|
85
|
+
#endif
|
86
|
+
}
|
87
|
+
|
88
|
+
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
89
|
+
assert(k % QK8_1 == 0);
|
90
|
+
const int nb = k / QK8_1;
|
91
|
+
|
92
|
+
block_q8_1 * GGML_RESTRICT y = vy;
|
93
|
+
#if defined __wasm_simd128__
|
94
|
+
for (int i = 0; i < nb; i++) {
|
95
|
+
v128_t srcv [8];
|
96
|
+
v128_t asrcv[8];
|
97
|
+
v128_t amaxv[8];
|
98
|
+
|
99
|
+
for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
|
100
|
+
for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
|
101
|
+
|
102
|
+
for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
|
103
|
+
for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
|
104
|
+
for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
|
105
|
+
|
106
|
+
const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
|
107
|
+
wasm_f32x4_extract_lane(amaxv[0], 1)),
|
108
|
+
MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
|
109
|
+
wasm_f32x4_extract_lane(amaxv[0], 3)));
|
110
|
+
|
111
|
+
const float d = amax / ((1 << 7) - 1);
|
112
|
+
const float id = d ? 1.0f/d : 0.0f;
|
113
|
+
|
114
|
+
y[i].d = GGML_CPU_FP32_TO_FP16(d);
|
115
|
+
|
116
|
+
v128_t accv = wasm_i32x4_splat(0);
|
117
|
+
|
118
|
+
for (int j = 0; j < 8; j++) {
|
119
|
+
const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
|
120
|
+
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
|
121
|
+
|
122
|
+
y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
|
123
|
+
y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
|
124
|
+
y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
|
125
|
+
y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
|
126
|
+
|
127
|
+
accv = wasm_i32x4_add(accv, vi);
|
128
|
+
}
|
129
|
+
|
130
|
+
y[i].s = GGML_CPU_FP32_TO_FP16(
|
131
|
+
d * (wasm_i32x4_extract_lane(accv, 0) +
|
132
|
+
wasm_i32x4_extract_lane(accv, 1) +
|
133
|
+
wasm_i32x4_extract_lane(accv, 2) +
|
134
|
+
wasm_i32x4_extract_lane(accv, 3)));
|
135
|
+
}
|
136
|
+
#else
|
137
|
+
GGML_UNUSED(nb);
|
138
|
+
// scalar
|
139
|
+
quantize_row_q8_1_ref(x, y, k);
|
140
|
+
#endif
|
141
|
+
}
|
142
|
+
|
143
|
+
//===================================== Q8_K ==============================================
|
144
|
+
|
145
|
+
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
146
|
+
#ifdef __wasm_simd128__
|
147
|
+
assert(k % QK_K == 0);
|
148
|
+
const int64_t nb = k / QK_K;
|
149
|
+
block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type
|
150
|
+
|
151
|
+
for (int i = 0; i < nb; i++) {
|
152
|
+
const float * x_block = x + i * QK_K;
|
153
|
+
|
154
|
+
v128_t min_vec = wasm_v128_load(x_block);
|
155
|
+
v128_t max_vec = min_vec;
|
156
|
+
|
157
|
+
for (int j = 4; j < QK_K; j += 4) {
|
158
|
+
v128_t x_vec = wasm_v128_load(x_block + j);
|
159
|
+
max_vec = wasm_f32x4_pmax(max_vec, x_vec);
|
160
|
+
min_vec = wasm_f32x4_pmin(min_vec, x_vec);
|
161
|
+
}
|
162
|
+
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
|
163
|
+
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
|
164
|
+
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
|
165
|
+
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
|
166
|
+
float max = wasm_f32x4_extract_lane(max_vec, 0);
|
167
|
+
float min = wasm_f32x4_extract_lane(min_vec, 0);
|
168
|
+
float amax = -min > max ? min : max;
|
169
|
+
|
170
|
+
if (amax == 0.0f) {
|
171
|
+
yc[i].d = 0.0f;
|
172
|
+
const v128_t zero = wasm_i8x16_splat(0);
|
173
|
+
for (int j = 0; j < QK_K; j += 16) {
|
174
|
+
wasm_v128_store(yc[i].qs + j, zero);
|
175
|
+
}
|
176
|
+
continue;
|
177
|
+
}
|
178
|
+
|
179
|
+
const float iscale = -127.0f / amax;
|
180
|
+
const v128_t scale_vec = wasm_f32x4_splat(iscale);
|
181
|
+
|
182
|
+
// Process 16 elements per iteration
|
183
|
+
for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
|
184
|
+
// Load and quantize 16 floats
|
185
|
+
v128_t x0 = wasm_v128_load(x_block + j);
|
186
|
+
v128_t x1 = wasm_v128_load(x_block + j + 4);
|
187
|
+
v128_t x2 = wasm_v128_load(x_block + j + 8);
|
188
|
+
v128_t x3 = wasm_v128_load(x_block + j + 12);
|
189
|
+
|
190
|
+
v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
|
191
|
+
v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
|
192
|
+
v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
|
193
|
+
v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
|
194
|
+
|
195
|
+
// Convert to i32 with saturation
|
196
|
+
v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
|
197
|
+
v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
|
198
|
+
v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
|
199
|
+
v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
|
200
|
+
|
201
|
+
// Pack into 16 i8 values
|
202
|
+
v128_t i8 = wasm_i8x16_narrow_i16x8(
|
203
|
+
wasm_i16x8_narrow_i32x4(i0, i1),
|
204
|
+
wasm_i16x8_narrow_i32x4(i2, i3)
|
205
|
+
);
|
206
|
+
wasm_v128_store(yc[i].qs + j, i8);
|
207
|
+
|
208
|
+
// Calculate bsums using SIMD
|
209
|
+
v128_t sum16 = wasm_i16x8_add(
|
210
|
+
wasm_i16x8_extend_low_i8x16(i8),
|
211
|
+
wasm_i16x8_extend_high_i8x16(i8)
|
212
|
+
);
|
213
|
+
v128_t sum32 = wasm_i32x4_add(
|
214
|
+
wasm_i32x4_extend_low_i16x8(sum16),
|
215
|
+
wasm_i32x4_extend_high_i16x8(sum16)
|
216
|
+
);
|
217
|
+
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
|
218
|
+
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
|
219
|
+
yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
|
220
|
+
}
|
221
|
+
|
222
|
+
yc[i].d = 1.0f / iscale;
|
223
|
+
}
|
224
|
+
#else
|
225
|
+
quantize_row_q8_K_ref(x, y, k);
|
226
|
+
#endif
|
227
|
+
}
|
228
|
+
|
229
|
+
|
230
|
+
//===================================== Dot products =================================
|
231
|
+
|
232
|
+
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
233
|
+
const int qk = QK8_0;
|
234
|
+
const int nb = n / qk;
|
235
|
+
|
236
|
+
assert(n % qk == 0);
|
237
|
+
assert(nrc == 1);
|
238
|
+
UNUSED(nrc);
|
239
|
+
UNUSED(bx);
|
240
|
+
UNUSED(by);
|
241
|
+
UNUSED(bs);
|
242
|
+
|
243
|
+
const block_q4_0 * GGML_RESTRICT x = vx;
|
244
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
245
|
+
|
246
|
+
int ib = 0;
|
247
|
+
float sumf = 0;
|
248
|
+
|
249
|
+
#if defined __wasm_simd128__
|
250
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
251
|
+
|
252
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
253
|
+
const v128_t s8b = wasm_i8x16_splat(0x8);
|
254
|
+
|
255
|
+
for (; ib + 1 < nb; ib += 2) {
|
256
|
+
const block_q4_0 * GGML_RESTRICT x0 = &x[ib];
|
257
|
+
const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
|
258
|
+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
|
259
|
+
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
260
|
+
|
261
|
+
// Load and process x0
|
262
|
+
v128_t v0_0 = wasm_v128_load(x0->qs);
|
263
|
+
v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
264
|
+
v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
265
|
+
v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
266
|
+
v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
267
|
+
|
268
|
+
// Load y0 vectors
|
269
|
+
v128_t y0_l = wasm_v128_load(y0->qs);
|
270
|
+
v128_t y0_h = wasm_v128_load(y0->qs + 16);
|
271
|
+
|
272
|
+
// Extend to i16x8 and compute dot products
|
273
|
+
v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
|
274
|
+
v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
|
275
|
+
v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
|
276
|
+
v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
|
277
|
+
|
278
|
+
v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
|
279
|
+
v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
|
280
|
+
v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
|
281
|
+
v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
|
282
|
+
|
283
|
+
v128_t dp0 = wasm_i32x4_add(
|
284
|
+
wasm_i32x4_add(
|
285
|
+
wasm_i32x4_dot_i16x8(dx0l, dy0ll),
|
286
|
+
wasm_i32x4_dot_i16x8(dx0h, dy0lh)
|
287
|
+
),
|
288
|
+
wasm_i32x4_add(
|
289
|
+
wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
|
290
|
+
wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
|
291
|
+
)
|
292
|
+
);
|
293
|
+
|
294
|
+
// Load and process x1
|
295
|
+
v128_t v0_1 = wasm_v128_load(x1->qs);
|
296
|
+
v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
297
|
+
v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
298
|
+
v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
299
|
+
v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
300
|
+
|
301
|
+
// Load y1 vectors
|
302
|
+
v128_t y1_l = wasm_v128_load(y1->qs);
|
303
|
+
v128_t y1_h = wasm_v128_load(y1->qs + 16);
|
304
|
+
|
305
|
+
// Extend to i16x8 and compute dot products
|
306
|
+
v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
|
307
|
+
v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
|
308
|
+
v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
|
309
|
+
v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
|
310
|
+
|
311
|
+
v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
|
312
|
+
v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
|
313
|
+
v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
|
314
|
+
v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
|
315
|
+
|
316
|
+
v128_t dp1 = wasm_i32x4_add(
|
317
|
+
wasm_i32x4_add(
|
318
|
+
wasm_i32x4_dot_i16x8(dx1l, dy1ll),
|
319
|
+
wasm_i32x4_dot_i16x8(dx1h, dy1lh)
|
320
|
+
),
|
321
|
+
wasm_i32x4_add(
|
322
|
+
wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
|
323
|
+
wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
|
324
|
+
)
|
325
|
+
);
|
326
|
+
|
327
|
+
// Accumulate results with scaling
|
328
|
+
float scale0 = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);
|
329
|
+
float scale1 = GGML_CPU_FP16_TO_FP32(x1->d) * GGML_CPU_FP16_TO_FP32(y1->d);
|
330
|
+
|
331
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
|
332
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
|
333
|
+
}
|
334
|
+
|
335
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
336
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
337
|
+
|
338
|
+
#endif
|
339
|
+
for (; ib < nb; ++ib) {
|
340
|
+
int sumi0 = 0;
|
341
|
+
int sumi1 = 0;
|
342
|
+
|
343
|
+
for (int j = 0; j < qk/2; ++j) {
|
344
|
+
const int v0 = (x[ib].qs[j] & 0x0F) - 8;
|
345
|
+
const int v1 = (x[ib].qs[j] >> 4) - 8;
|
346
|
+
|
347
|
+
sumi0 += (v0 * y[ib].qs[j]);
|
348
|
+
sumi1 += (v1 * y[ib].qs[j + qk/2]);
|
349
|
+
}
|
350
|
+
|
351
|
+
int sumi = sumi0 + sumi1;
|
352
|
+
sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
|
353
|
+
}
|
354
|
+
|
355
|
+
*s = sumf;
|
356
|
+
}
|
357
|
+
|
358
|
+
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
359
|
+
const int qk = QK8_0;
|
360
|
+
const int nb = n / qk;
|
361
|
+
|
362
|
+
int ib = 0;
|
363
|
+
float sumf = 0;
|
364
|
+
|
365
|
+
assert(n % qk == 0);
|
366
|
+
assert(qk == QK5_0);
|
367
|
+
assert(nrc == 1);
|
368
|
+
UNUSED(nrc);
|
369
|
+
UNUSED(bx);
|
370
|
+
UNUSED(by);
|
371
|
+
UNUSED(bs);
|
372
|
+
|
373
|
+
const block_q5_0 * GGML_RESTRICT x = vx;
|
374
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
375
|
+
|
376
|
+
#if defined __wasm_simd128__
|
377
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
378
|
+
|
379
|
+
uint32_t qh_;
|
380
|
+
uint64_t tmp[4];
|
381
|
+
|
382
|
+
// TODO: check if unrolling this is better
|
383
|
+
for (; ib < nb; ++ib) {
|
384
|
+
const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
|
385
|
+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
|
386
|
+
|
387
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
388
|
+
|
389
|
+
// extract the 5th bit
|
390
|
+
memcpy(&qh_, x0->qh, sizeof(qh_));
|
391
|
+
|
392
|
+
tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
|
393
|
+
tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
|
394
|
+
tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
|
395
|
+
tmp[3] = table_b2b_1[(qh_ >> 24) ];
|
396
|
+
|
397
|
+
const v128_t qhl = wasm_v128_load(tmp + 0);
|
398
|
+
const v128_t qhh = wasm_v128_load(tmp + 2);
|
399
|
+
|
400
|
+
const v128_t v0 = wasm_v128_load(x0->qs);
|
401
|
+
|
402
|
+
// 4-bit -> 8-bit
|
403
|
+
const v128_t v0l = wasm_v128_and (v0, m4b);
|
404
|
+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
|
405
|
+
|
406
|
+
// add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
|
407
|
+
const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
|
408
|
+
const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
|
409
|
+
|
410
|
+
// load y
|
411
|
+
const v128_t v1l = wasm_v128_load(y0->qs);
|
412
|
+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
|
413
|
+
|
414
|
+
// int8x16 -> int16x8
|
415
|
+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
|
416
|
+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
|
417
|
+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
|
418
|
+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
|
419
|
+
|
420
|
+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
|
421
|
+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
|
422
|
+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
|
423
|
+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
|
424
|
+
|
425
|
+
// dot product
|
426
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
|
427
|
+
wasm_i32x4_add(
|
428
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
|
429
|
+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
430
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
431
|
+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
|
432
|
+
wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
|
433
|
+
}
|
434
|
+
|
435
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
436
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
437
|
+
|
438
|
+
#endif
|
439
|
+
for (; ib < nb; ++ib) {
|
440
|
+
uint32_t qh;
|
441
|
+
memcpy(&qh, x[ib].qh, sizeof(qh));
|
442
|
+
|
443
|
+
int sumi0 = 0;
|
444
|
+
int sumi1 = 0;
|
445
|
+
|
446
|
+
for (int j = 0; j < qk/2; ++j) {
|
447
|
+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
448
|
+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
|
449
|
+
|
450
|
+
const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
|
451
|
+
const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
|
452
|
+
|
453
|
+
sumi0 += (x0 * y[ib].qs[j]);
|
454
|
+
sumi1 += (x1 * y[ib].qs[j + qk/2]);
|
455
|
+
}
|
456
|
+
|
457
|
+
int sumi = sumi0 + sumi1;
|
458
|
+
sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
|
459
|
+
}
|
460
|
+
|
461
|
+
*s = sumf;
|
462
|
+
}
|
463
|
+
|
464
|
+
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
465
|
+
const int qk = QK8_1;
|
466
|
+
const int nb = n / qk;
|
467
|
+
|
468
|
+
int ib = 0;
|
469
|
+
float sumf = 0;
|
470
|
+
|
471
|
+
assert(n % qk == 0);
|
472
|
+
assert(qk == QK5_1);
|
473
|
+
assert(nrc == 1);
|
474
|
+
UNUSED(nrc);
|
475
|
+
UNUSED(bx);
|
476
|
+
UNUSED(by);
|
477
|
+
UNUSED(bs);
|
478
|
+
|
479
|
+
const block_q5_1 * GGML_RESTRICT x = vx;
|
480
|
+
const block_q8_1 * GGML_RESTRICT y = vy;
|
481
|
+
|
482
|
+
#if defined __wasm_simd128__
|
483
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
484
|
+
|
485
|
+
float summs = 0.0f;
|
486
|
+
|
487
|
+
uint32_t qh_;
|
488
|
+
uint64_t tmp[4];
|
489
|
+
|
490
|
+
// TODO: check if unrolling this is better
|
491
|
+
for (; ib < nb; ++ib) {
|
492
|
+
const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
|
493
|
+
const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
|
494
|
+
|
495
|
+
summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s);
|
496
|
+
|
497
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
498
|
+
|
499
|
+
// extract the 5th bit
|
500
|
+
memcpy(&qh_, x0->qh, sizeof(qh_));
|
501
|
+
|
502
|
+
tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
|
503
|
+
tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
|
504
|
+
tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
|
505
|
+
tmp[3] = table_b2b_0[(qh_ >> 24) ];
|
506
|
+
|
507
|
+
const v128_t qhl = wasm_v128_load(tmp + 0);
|
508
|
+
const v128_t qhh = wasm_v128_load(tmp + 2);
|
509
|
+
|
510
|
+
const v128_t v0 = wasm_v128_load(x0->qs);
|
511
|
+
|
512
|
+
// 4-bit -> 8-bit
|
513
|
+
const v128_t v0l = wasm_v128_and (v0, m4b);
|
514
|
+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
|
515
|
+
|
516
|
+
// add high bit
|
517
|
+
const v128_t v0lf = wasm_v128_or(v0l, qhl);
|
518
|
+
const v128_t v0hf = wasm_v128_or(v0h, qhh);
|
519
|
+
|
520
|
+
// load y
|
521
|
+
const v128_t v1l = wasm_v128_load(y0->qs);
|
522
|
+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
|
523
|
+
|
524
|
+
// int8x16 -> int16x8
|
525
|
+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
|
526
|
+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
|
527
|
+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
|
528
|
+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
|
529
|
+
|
530
|
+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
|
531
|
+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
|
532
|
+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
|
533
|
+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
|
534
|
+
|
535
|
+
// dot product
|
536
|
+
sumv = wasm_f32x4_add(sumv,
|
537
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
|
538
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
|
539
|
+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
|
540
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
|
541
|
+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
|
542
|
+
wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d))));
|
543
|
+
}
|
544
|
+
|
545
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
546
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
|
547
|
+
|
548
|
+
#endif
|
549
|
+
for (; ib < nb; ++ib) {
|
550
|
+
uint32_t qh;
|
551
|
+
memcpy(&qh, x[ib].qh, sizeof(qh));
|
552
|
+
|
553
|
+
int sumi0 = 0;
|
554
|
+
int sumi1 = 0;
|
555
|
+
|
556
|
+
for (int j = 0; j < qk/2; ++j) {
|
557
|
+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
558
|
+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
559
|
+
|
560
|
+
const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
|
561
|
+
const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
|
562
|
+
|
563
|
+
sumi0 += (x0 * y[ib].qs[j]);
|
564
|
+
sumi1 += (x1 * y[ib].qs[j + qk/2]);
|
565
|
+
}
|
566
|
+
|
567
|
+
int sumi = sumi0 + sumi1;
|
568
|
+
sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
|
569
|
+
}
|
570
|
+
|
571
|
+
*s = sumf;
|
572
|
+
}
|
573
|
+
|
574
|
+
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
575
|
+
const int qk = QK8_0;
|
576
|
+
const int nb = n / qk;
|
577
|
+
|
578
|
+
assert(n % qk == 0);
|
579
|
+
assert(nrc == 1);
|
580
|
+
UNUSED(nrc);
|
581
|
+
UNUSED(bx);
|
582
|
+
UNUSED(by);
|
583
|
+
UNUSED(bs);
|
584
|
+
|
585
|
+
const block_q8_0 * GGML_RESTRICT x = vx;
|
586
|
+
const block_q8_0 * GGML_RESTRICT y = vy;
|
587
|
+
|
588
|
+
int ib = 0;
|
589
|
+
float sumf = 0;
|
590
|
+
|
591
|
+
#if defined __wasm_simd128__
|
592
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
593
|
+
|
594
|
+
for (; ib < nb; ++ib) {
|
595
|
+
const block_q8_0 * GGML_RESTRICT x0 = &x[ib];
|
596
|
+
const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
|
597
|
+
|
598
|
+
const v128_t x0_0 = wasm_v128_load(x0->qs);
|
599
|
+
const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
|
600
|
+
const v128_t y0_0 = wasm_v128_load(y0->qs);
|
601
|
+
const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
|
602
|
+
|
603
|
+
// Extend 8-bit to 16-bit
|
604
|
+
const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
|
605
|
+
const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
|
606
|
+
const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
|
607
|
+
const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
|
608
|
+
|
609
|
+
const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
|
610
|
+
const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
|
611
|
+
const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
|
612
|
+
const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
|
613
|
+
|
614
|
+
// Compute dot products
|
615
|
+
const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
|
616
|
+
const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
|
617
|
+
const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
|
618
|
+
const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
|
619
|
+
|
620
|
+
// Sum all dot products
|
621
|
+
const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
|
622
|
+
|
623
|
+
// Convert to float and accumulate
|
624
|
+
const float scale = GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d);
|
625
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
|
626
|
+
}
|
627
|
+
|
628
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
629
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
630
|
+
|
631
|
+
#endif
|
632
|
+
for (; ib < nb; ++ib) {
|
633
|
+
int sumi = 0;
|
634
|
+
|
635
|
+
for (int j = 0; j < qk; j++) {
|
636
|
+
sumi += x[ib].qs[j]*y[ib].qs[j];
|
637
|
+
}
|
638
|
+
|
639
|
+
sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
|
640
|
+
}
|
641
|
+
|
642
|
+
*s = sumf;
|
643
|
+
}
|
644
|
+
|
645
|
+
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
646
|
+
assert(nrc == 1);
|
647
|
+
UNUSED(nrc);
|
648
|
+
UNUSED(bx);
|
649
|
+
UNUSED(by);
|
650
|
+
UNUSED(bs);
|
651
|
+
|
652
|
+
const block_q2_K * GGML_RESTRICT x = vx;
|
653
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
654
|
+
|
655
|
+
const int nb = n / QK_K;
|
656
|
+
|
657
|
+
#if defined __wasm_simd128__
|
658
|
+
float sumf = 0;
|
659
|
+
|
660
|
+
for (int i = 0; i < nb; ++i) {
|
661
|
+
const uint8_t * q2 = x[i].qs;
|
662
|
+
const int8_t * q8 = y[i].qs;
|
663
|
+
const uint8_t * sc = x[i].scales;
|
664
|
+
|
665
|
+
// Vectorized summs calculation
|
666
|
+
v128_t summs_vec = wasm_i32x4_splat(0);
|
667
|
+
{
|
668
|
+
v128_t sc_vec = wasm_v128_load(sc);
|
669
|
+
v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
|
670
|
+
|
671
|
+
v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
|
672
|
+
v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
|
673
|
+
|
674
|
+
v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
|
675
|
+
v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
|
676
|
+
|
677
|
+
summs_vec = wasm_i32x4_add(
|
678
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
|
679
|
+
wasm_i32x4_dot_i16x8(sc_high, bsums2)),
|
680
|
+
summs_vec
|
681
|
+
);
|
682
|
+
|
683
|
+
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
|
684
|
+
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
|
685
|
+
}
|
686
|
+
int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
|
687
|
+
|
688
|
+
// Vectorized isum calculation
|
689
|
+
int32_t isum = 0;
|
690
|
+
const uint8_t * sc_ptr = sc;
|
691
|
+
const int k_iters = QK_K/128;
|
692
|
+
|
693
|
+
for (int k = 0; k < k_iters; ++k) {
|
694
|
+
v128_t isum_vec = wasm_i32x4_splat(0);
|
695
|
+
int shift = 0;
|
696
|
+
|
697
|
+
for (int j = 0; j < 4; ++j) {
|
698
|
+
const int d0 = (sc_ptr[0] & 0xF);
|
699
|
+
const int d1 = (sc_ptr[1] & 0xF);
|
700
|
+
sc_ptr += 2;
|
701
|
+
|
702
|
+
// Process first 16 elements
|
703
|
+
v128_t q2_0 = wasm_v128_load(q2);
|
704
|
+
v128_t q8_0 = wasm_v128_load(q8);
|
705
|
+
v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
|
706
|
+
v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
|
707
|
+
|
708
|
+
// Process next 16 elements
|
709
|
+
v128_t q2_1 = wasm_v128_load(q2 + 16);
|
710
|
+
v128_t q8_1 = wasm_v128_load(q8 + 16);
|
711
|
+
v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
|
712
|
+
v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
|
713
|
+
|
714
|
+
// Calculate dot products
|
715
|
+
v128_t p0 = wasm_i32x4_dot_i16x8(
|
716
|
+
wasm_i16x8_extend_low_i8x16(q8_0),
|
717
|
+
wasm_i16x8_extend_low_i8x16(q2_bits_0)
|
718
|
+
);
|
719
|
+
v128_t p1 = wasm_i32x4_dot_i16x8(
|
720
|
+
wasm_i16x8_extend_high_i8x16(q8_0),
|
721
|
+
wasm_i16x8_extend_high_i8x16(q2_bits_0)
|
722
|
+
);
|
723
|
+
v128_t p2 = wasm_i32x4_dot_i16x8(
|
724
|
+
wasm_i16x8_extend_low_i8x16(q8_1),
|
725
|
+
wasm_i16x8_extend_low_i8x16(q2_bits_1)
|
726
|
+
);
|
727
|
+
v128_t p3 = wasm_i32x4_dot_i16x8(
|
728
|
+
wasm_i16x8_extend_high_i8x16(q8_1),
|
729
|
+
wasm_i16x8_extend_high_i8x16(q2_bits_1)
|
730
|
+
);
|
731
|
+
|
732
|
+
// Accumulate scaled results
|
733
|
+
v128_t scaled = wasm_i32x4_add(
|
734
|
+
wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
|
735
|
+
wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
|
736
|
+
);
|
737
|
+
|
738
|
+
isum_vec = wasm_i32x4_add(isum_vec, scaled);
|
739
|
+
q8 += 32;
|
740
|
+
shift += 2;
|
741
|
+
}
|
742
|
+
q2 += 32;
|
743
|
+
|
744
|
+
// Horizontal sum of isum_vec
|
745
|
+
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
|
746
|
+
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
|
747
|
+
isum += wasm_i32x4_extract_lane(isum_vec, 0);
|
748
|
+
}
|
749
|
+
|
750
|
+
const float dall = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
751
|
+
const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
752
|
+
sumf += dall * isum - dmin * summs;
|
753
|
+
}
|
754
|
+
|
755
|
+
*s = sumf;
|
756
|
+
|
757
|
+
#else
|
758
|
+
|
759
|
+
float sumf = 0;
|
760
|
+
|
761
|
+
for (int i = 0; i < nb; ++i) {
|
762
|
+
|
763
|
+
const uint8_t * q2 = x[i].qs;
|
764
|
+
const int8_t * q8 = y[i].qs;
|
765
|
+
const uint8_t * sc = x[i].scales;
|
766
|
+
|
767
|
+
int summs = 0;
|
768
|
+
for (int j = 0; j < 16; ++j) {
|
769
|
+
summs += y[i].bsums[j] * (sc[j] >> 4);
|
770
|
+
}
|
771
|
+
|
772
|
+
const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
773
|
+
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
774
|
+
|
775
|
+
int isum = 0;
|
776
|
+
int is = 0;
|
777
|
+
int d;
|
778
|
+
for (int k = 0; k < QK_K/128; ++k) {
|
779
|
+
int shift = 0;
|
780
|
+
for (int j = 0; j < 4; ++j) {
|
781
|
+
d = sc[is++] & 0xF;
|
782
|
+
int isuml = 0;
|
783
|
+
for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
784
|
+
isum += d * isuml;
|
785
|
+
d = sc[is++] & 0xF;
|
786
|
+
isuml = 0;
|
787
|
+
for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
788
|
+
isum += d * isuml;
|
789
|
+
shift += 2;
|
790
|
+
q8 += 32;
|
791
|
+
}
|
792
|
+
q2 += 32;
|
793
|
+
}
|
794
|
+
sumf += dall * isum - dmin * summs;
|
795
|
+
}
|
796
|
+
*s = sumf;
|
797
|
+
#endif
|
798
|
+
}
|
799
|
+
|
800
|
+
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
801
|
+
assert(n % QK_K == 0);
|
802
|
+
assert(nrc == 1);
|
803
|
+
UNUSED(nrc);
|
804
|
+
UNUSED(bx);
|
805
|
+
UNUSED(by);
|
806
|
+
UNUSED(bs);
|
807
|
+
|
808
|
+
const uint32_t kmask1 = 0x03030303;
|
809
|
+
const uint32_t kmask2 = 0x0f0f0f0f;
|
810
|
+
|
811
|
+
const block_q3_K * GGML_RESTRICT x = vx;
|
812
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
813
|
+
|
814
|
+
const int nb = n / QK_K;
|
815
|
+
|
816
|
+
#if defined __wasm_simd128__
|
817
|
+
int8_t aux8[QK_K];
|
818
|
+
float sums[8] = {0};
|
819
|
+
uint32_t auxs[4];
|
820
|
+
|
821
|
+
float sumf = 0;
|
822
|
+
for (int i = 0; i < nb; ++i) {
|
823
|
+
const uint8_t * GGML_RESTRICT q3 = x[i].qs;
|
824
|
+
const uint8_t * GGML_RESTRICT hm = x[i].hmask;
|
825
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
826
|
+
|
827
|
+
// Process blocks with SIMD
|
828
|
+
int8_t * a = aux8;
|
829
|
+
uint8_t m = 1;
|
830
|
+
for (int j = 0; j < QK_K; j += 128) {
|
831
|
+
for (int shift = 0; shift <= 6; shift += 2) {
|
832
|
+
v128_t v_m = wasm_i8x16_splat(m);
|
833
|
+
for (int l = 0; l < 32; l += 16) {
|
834
|
+
v128_t v_q3 = wasm_v128_load(q3 + l);
|
835
|
+
v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
|
836
|
+
v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
|
837
|
+
|
838
|
+
v128_t v_hm = wasm_v128_load(hm + l);
|
839
|
+
v128_t v_mask = wasm_v128_and(v_hm, v_m);
|
840
|
+
v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
|
841
|
+
|
842
|
+
v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
|
843
|
+
wasm_v128_store(a + l, v_low2);
|
844
|
+
}
|
845
|
+
a += 32;
|
846
|
+
m <<= 1;
|
847
|
+
}
|
848
|
+
q3 += 32;
|
849
|
+
}
|
850
|
+
|
851
|
+
// Extract scales
|
852
|
+
memcpy(auxs, x[i].scales, 12);
|
853
|
+
uint32_t tmp = auxs[2];
|
854
|
+
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
855
|
+
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
856
|
+
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
857
|
+
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
858
|
+
const int8_t * scales = (const int8_t *)auxs;
|
859
|
+
|
860
|
+
// SIMD dot product with register accumulators
|
861
|
+
v128_t v_acc0 = wasm_i32x4_splat(0);
|
862
|
+
v128_t v_acc1 = wasm_i32x4_splat(0);
|
863
|
+
a = aux8;
|
864
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
865
|
+
const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
|
866
|
+
|
867
|
+
// Process 16 elements per iteration
|
868
|
+
for (int k = 0; k < 2; ++k) {
|
869
|
+
const v128_t v_q8 = wasm_i16x8_load8x8(q8);
|
870
|
+
const v128_t v_a = wasm_i16x8_load8x8(a);
|
871
|
+
|
872
|
+
v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
|
873
|
+
v_prod = wasm_i16x8_mul(v_prod, v_scale);
|
874
|
+
|
875
|
+
v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
|
876
|
+
v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
|
877
|
+
|
878
|
+
q8 += 8;
|
879
|
+
a += 8;
|
880
|
+
}
|
881
|
+
}
|
882
|
+
|
883
|
+
// Accumulate results
|
884
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
885
|
+
const v128_t v_d = wasm_f32x4_splat(d);
|
886
|
+
v128_t v_sum = wasm_f32x4_add(
|
887
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
|
888
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
|
889
|
+
);
|
890
|
+
|
891
|
+
// Accumulate into sums vector
|
892
|
+
wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
|
893
|
+
}
|
894
|
+
|
895
|
+
// Horizontal sum
|
896
|
+
v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
|
897
|
+
sumf = wasm_f32x4_extract_lane(v_sum, 0) +
|
898
|
+
wasm_f32x4_extract_lane(v_sum, 1) +
|
899
|
+
wasm_f32x4_extract_lane(v_sum, 2) +
|
900
|
+
wasm_f32x4_extract_lane(v_sum, 3);
|
901
|
+
|
902
|
+
*s = sumf;
|
903
|
+
|
904
|
+
#else
|
905
|
+
// scalar version
|
906
|
+
// This function is written like this so the compiler can manage to vectorize most of it
|
907
|
+
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
908
|
+
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
909
|
+
// The ideal situation would be if we could just write the code once, and the compiler would
|
910
|
+
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
911
|
+
// write vectorized versions for AVX, ARM_NEON, etc.
|
912
|
+
|
913
|
+
int8_t aux8[QK_K];
|
914
|
+
int16_t aux16[8];
|
915
|
+
float sums [8];
|
916
|
+
int32_t aux32[8];
|
917
|
+
memset(sums, 0, 8*sizeof(float));
|
918
|
+
|
919
|
+
uint32_t auxs[4];
|
920
|
+
const int8_t * scales = (const int8_t*)auxs;
|
921
|
+
|
922
|
+
float sumf = 0;
|
923
|
+
for (int i = 0; i < nb; ++i) {
|
924
|
+
const uint8_t * GGML_RESTRICT q3 = x[i].qs;
|
925
|
+
const uint8_t * GGML_RESTRICT hm = x[i].hmask;
|
926
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
927
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
928
|
+
int8_t * GGML_RESTRICT a = aux8;
|
929
|
+
uint8_t m = 1;
|
930
|
+
for (int j = 0; j < QK_K; j += 128) {
|
931
|
+
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
|
932
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
933
|
+
a += 32; m <<= 1;
|
934
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
|
935
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
936
|
+
a += 32; m <<= 1;
|
937
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
|
938
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
939
|
+
a += 32; m <<= 1;
|
940
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
|
941
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
942
|
+
a += 32; m <<= 1;
|
943
|
+
q3 += 32;
|
944
|
+
}
|
945
|
+
a = aux8;
|
946
|
+
|
947
|
+
memcpy(auxs, x[i].scales, 12);
|
948
|
+
uint32_t tmp = auxs[2];
|
949
|
+
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
950
|
+
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
951
|
+
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
952
|
+
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
953
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
954
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
955
|
+
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
956
|
+
q8 += 8; a += 8;
|
957
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
958
|
+
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
959
|
+
q8 += 8; a += 8;
|
960
|
+
}
|
961
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
962
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
963
|
+
}
|
964
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
965
|
+
*s = sumf;
|
966
|
+
|
967
|
+
#endif
|
968
|
+
|
969
|
+
}
|
970
|
+
|
971
|
+
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
972
|
+
assert(n % QK_K == 0);
|
973
|
+
assert(nrc == 1);
|
974
|
+
UNUSED(nrc);
|
975
|
+
UNUSED(bx);
|
976
|
+
UNUSED(by);
|
977
|
+
UNUSED(bs);
|
978
|
+
|
979
|
+
const block_q4_K * GGML_RESTRICT x = vx;
|
980
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
981
|
+
|
982
|
+
const int nb = n / QK_K;
|
983
|
+
|
984
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
985
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
986
|
+
static const uint32_t kmask3 = 0x03030303;
|
987
|
+
|
988
|
+
uint32_t utmp[4];
|
989
|
+
|
990
|
+
#if defined __wasm_simd128__
|
991
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
992
|
+
float sumf = 0;
|
993
|
+
|
994
|
+
for (int i = 0; i < nb; ++i) {
|
995
|
+
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
996
|
+
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Corrected sign
|
997
|
+
|
998
|
+
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
999
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
1000
|
+
|
1001
|
+
// Process scales and mins
|
1002
|
+
memcpy(utmp, x[i].scales, 12);
|
1003
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1004
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1005
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1006
|
+
utmp[2] = uaux;
|
1007
|
+
utmp[0] &= kmask1;
|
1008
|
+
|
1009
|
+
// Sum mins * q8sums
|
1010
|
+
int32_t sumi = 0;
|
1011
|
+
const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
|
1012
|
+
const uint8_t * m = (const uint8_t *)&utmp[2];
|
1013
|
+
for (int j = 0; j < 16; j += 2) {
|
1014
|
+
sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
|
1015
|
+
}
|
1016
|
+
sumf -= dmin * sumi;
|
1017
|
+
|
1018
|
+
int32_t sumi1 = 0;
|
1019
|
+
int32_t sumi2 = 0;
|
1020
|
+
|
1021
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1022
|
+
// Load 64 4-bit weights (32 bytes)
|
1023
|
+
const v128_t q4x0 = wasm_v128_load(q4);
|
1024
|
+
const v128_t q4x1 = wasm_v128_load(q4 + 16);
|
1025
|
+
q4 += 32;
|
1026
|
+
|
1027
|
+
// Split into low/high nibbles
|
1028
|
+
const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
|
1029
|
+
const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
|
1030
|
+
const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
|
1031
|
+
const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
|
1032
|
+
|
1033
|
+
// Load 64 8-bit values (64 bytes)
|
1034
|
+
const v128_t q8x0 = wasm_v128_load(q8);
|
1035
|
+
const v128_t q8x1 = wasm_v128_load(q8 + 16);
|
1036
|
+
const v128_t q8x2 = wasm_v128_load(q8 + 32);
|
1037
|
+
const v128_t q8x3 = wasm_v128_load(q8 + 48);
|
1038
|
+
q8 += 64;
|
1039
|
+
|
1040
|
+
// Low nibble products
|
1041
|
+
v128_t vacc1 = wasm_i32x4_dot_i16x8(
|
1042
|
+
wasm_i16x8_extend_low_i8x16(q4l0),
|
1043
|
+
wasm_i16x8_extend_low_i8x16(q8x0)
|
1044
|
+
);
|
1045
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
1046
|
+
wasm_i16x8_extend_high_i8x16(q4l0),
|
1047
|
+
wasm_i16x8_extend_high_i8x16(q8x0)
|
1048
|
+
));
|
1049
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
1050
|
+
wasm_i16x8_extend_low_i8x16(q4l1),
|
1051
|
+
wasm_i16x8_extend_low_i8x16(q8x1)
|
1052
|
+
));
|
1053
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
1054
|
+
wasm_i16x8_extend_high_i8x16(q4l1),
|
1055
|
+
wasm_i16x8_extend_high_i8x16(q8x1)
|
1056
|
+
));
|
1057
|
+
|
1058
|
+
// High nibble products
|
1059
|
+
v128_t vacc2 = wasm_i32x4_dot_i16x8(
|
1060
|
+
wasm_i16x8_extend_low_i8x16(q4h0),
|
1061
|
+
wasm_i16x8_extend_low_i8x16(q8x2)
|
1062
|
+
);
|
1063
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
1064
|
+
wasm_i16x8_extend_high_i8x16(q4h0),
|
1065
|
+
wasm_i16x8_extend_high_i8x16(q8x2)
|
1066
|
+
));
|
1067
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
1068
|
+
wasm_i16x8_extend_low_i8x16(q4h1),
|
1069
|
+
wasm_i16x8_extend_low_i8x16(q8x3)
|
1070
|
+
));
|
1071
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
1072
|
+
wasm_i16x8_extend_high_i8x16(q4h1),
|
1073
|
+
wasm_i16x8_extend_high_i8x16(q8x3)
|
1074
|
+
));
|
1075
|
+
|
1076
|
+
// Accumulate scaled results
|
1077
|
+
int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
|
1078
|
+
wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
|
1079
|
+
sumi1 += vacc1_sum * scales[2*j];
|
1080
|
+
|
1081
|
+
int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
|
1082
|
+
wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
|
1083
|
+
sumi2 += vacc2_sum * scales[2*j+1];
|
1084
|
+
}
|
1085
|
+
|
1086
|
+
sumf += d * (sumi1 + sumi2);
|
1087
|
+
}
|
1088
|
+
|
1089
|
+
*s = sumf;
|
1090
|
+
|
1091
|
+
#else
|
1092
|
+
|
1093
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
1094
|
+
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
1095
|
+
|
1096
|
+
int8_t aux8[QK_K];
|
1097
|
+
int16_t aux16[8];
|
1098
|
+
float sums [8];
|
1099
|
+
int32_t aux32[8];
|
1100
|
+
memset(sums, 0, 8*sizeof(float));
|
1101
|
+
|
1102
|
+
float sumf = 0;
|
1103
|
+
for (int i = 0; i < nb; ++i) {
|
1104
|
+
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
1105
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
1106
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
1107
|
+
int8_t * GGML_RESTRICT a = aux8;
|
1108
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1109
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
1110
|
+
a += 32;
|
1111
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
1112
|
+
a += 32; q4 += 32;
|
1113
|
+
}
|
1114
|
+
memcpy(utmp, x[i].scales, 12);
|
1115
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1116
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1117
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1118
|
+
utmp[2] = uaux;
|
1119
|
+
utmp[0] &= kmask1;
|
1120
|
+
|
1121
|
+
int sumi = 0;
|
1122
|
+
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
1123
|
+
a = aux8;
|
1124
|
+
int is = 0;
|
1125
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
1126
|
+
int32_t scale = scales[is++];
|
1127
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1128
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1129
|
+
q8 += 8; a += 8;
|
1130
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1131
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1132
|
+
q8 += 8; a += 8;
|
1133
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1134
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1135
|
+
q8 += 8; a += 8;
|
1136
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1137
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1138
|
+
q8 += 8; a += 8;
|
1139
|
+
}
|
1140
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
1141
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
1142
|
+
const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
1143
|
+
sumf -= dmin * sumi;
|
1144
|
+
}
|
1145
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
1146
|
+
*s = sumf;
|
1147
|
+
#endif
|
1148
|
+
}
|
1149
|
+
|
1150
|
+
void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
1151
|
+
assert(n % QK_K == 0);
|
1152
|
+
assert(nrc == 1);
|
1153
|
+
UNUSED(nrc);
|
1154
|
+
UNUSED(bx);
|
1155
|
+
UNUSED(by);
|
1156
|
+
UNUSED(bs);
|
1157
|
+
|
1158
|
+
const block_q5_K * GGML_RESTRICT x = vx;
|
1159
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
1160
|
+
|
1161
|
+
const int nb = n / QK_K;
|
1162
|
+
|
1163
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
1164
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
1165
|
+
static const uint32_t kmask3 = 0x03030303;
|
1166
|
+
|
1167
|
+
uint32_t utmp[4];
|
1168
|
+
|
1169
|
+
#if defined __wasm_simd128__
|
1170
|
+
//const uint8_t * scales = (const uint8_t*)&utmp[0];
|
1171
|
+
float sumf = 0;
|
1172
|
+
|
1173
|
+
for (int i = 0; i < nb; ++i) {
|
1174
|
+
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
1175
|
+
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); // Fixed sign
|
1176
|
+
|
1177
|
+
const uint8_t * GGML_RESTRICT q5 = x[i].qs;
|
1178
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
1179
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
1180
|
+
|
1181
|
+
// Process scales and mins
|
1182
|
+
memcpy(utmp, x[i].scales, 12);
|
1183
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1184
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1185
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1186
|
+
utmp[2] = uaux;
|
1187
|
+
utmp[0] &= kmask1;
|
1188
|
+
|
1189
|
+
// Sum mins * q8sums
|
1190
|
+
int32_t sumi_mins = 0;
|
1191
|
+
const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
|
1192
|
+
const uint8_t * m = (const uint8_t *)&utmp[2];
|
1193
|
+
for (int j = 0; j < 16; j += 2) {
|
1194
|
+
sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
|
1195
|
+
}
|
1196
|
+
sumf -= dmin * sumi_mins; // Correct subtraction
|
1197
|
+
|
1198
|
+
v128_t qh0 = wasm_v128_load(qh);
|
1199
|
+
v128_t qh1 = wasm_v128_load(qh + 16);
|
1200
|
+
const uint8_t * sc = (const uint8_t *)utmp;
|
1201
|
+
|
1202
|
+
int32_t sumi = 0;
|
1203
|
+
|
1204
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1205
|
+
const int shift = j * 2;
|
1206
|
+
v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
|
1207
|
+
v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
|
1208
|
+
|
1209
|
+
v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
|
1210
|
+
v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
|
1211
|
+
v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
|
1212
|
+
v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
|
1213
|
+
|
1214
|
+
v128_t q5_0 = wasm_v128_load(q5);
|
1215
|
+
v128_t q5_1 = wasm_v128_load(q5 + 16);
|
1216
|
+
q5 += 32;
|
1217
|
+
|
1218
|
+
v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
|
1219
|
+
v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
|
1220
|
+
v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
|
1221
|
+
v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
|
1222
|
+
|
1223
|
+
v128_t q8_0 = wasm_v128_load(q8);
|
1224
|
+
v128_t q8_1 = wasm_v128_load(q8 + 16);
|
1225
|
+
v128_t q8_2 = wasm_v128_load(q8 + 32);
|
1226
|
+
v128_t q8_3 = wasm_v128_load(q8 + 48);
|
1227
|
+
q8 += 64;
|
1228
|
+
|
1229
|
+
// Process low quants
|
1230
|
+
v128_t pl0 = wasm_i32x4_dot_i16x8(
|
1231
|
+
wasm_i16x8_extend_low_i8x16(q5l_0),
|
1232
|
+
wasm_i16x8_extend_low_i8x16(q8_0)
|
1233
|
+
);
|
1234
|
+
pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
|
1235
|
+
wasm_i16x8_extend_high_i8x16(q5l_0),
|
1236
|
+
wasm_i16x8_extend_high_i8x16(q8_0)
|
1237
|
+
));
|
1238
|
+
v128_t pl1 = wasm_i32x4_dot_i16x8(
|
1239
|
+
wasm_i16x8_extend_low_i8x16(q5l_1),
|
1240
|
+
wasm_i16x8_extend_low_i8x16(q8_1)
|
1241
|
+
);
|
1242
|
+
pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
|
1243
|
+
wasm_i16x8_extend_high_i8x16(q5l_1),
|
1244
|
+
wasm_i16x8_extend_high_i8x16(q8_1)
|
1245
|
+
));
|
1246
|
+
v128_t sum_low = wasm_i32x4_add(pl0, pl1);
|
1247
|
+
|
1248
|
+
// Process high quants
|
1249
|
+
v128_t ph0 = wasm_i32x4_dot_i16x8(
|
1250
|
+
wasm_i16x8_extend_low_i8x16(q5h_0),
|
1251
|
+
wasm_i16x8_extend_low_i8x16(q8_2)
|
1252
|
+
);
|
1253
|
+
ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
|
1254
|
+
wasm_i16x8_extend_high_i8x16(q5h_0),
|
1255
|
+
wasm_i16x8_extend_high_i8x16(q8_2)
|
1256
|
+
));
|
1257
|
+
v128_t ph1 = wasm_i32x4_dot_i16x8(
|
1258
|
+
wasm_i16x8_extend_low_i8x16(q5h_1),
|
1259
|
+
wasm_i16x8_extend_low_i8x16(q8_3)
|
1260
|
+
);
|
1261
|
+
ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
|
1262
|
+
wasm_i16x8_extend_high_i8x16(q5h_1),
|
1263
|
+
wasm_i16x8_extend_high_i8x16(q8_3)
|
1264
|
+
));
|
1265
|
+
v128_t sum_high = wasm_i32x4_add(ph0, ph1);
|
1266
|
+
|
1267
|
+
// Accumulate with scale factors
|
1268
|
+
int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
|
1269
|
+
wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
|
1270
|
+
int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
|
1271
|
+
wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
|
1272
|
+
|
1273
|
+
sumi += sl * sc[2*j] + sh * sc[2*j+1];
|
1274
|
+
}
|
1275
|
+
|
1276
|
+
sumf += d * sumi;
|
1277
|
+
}
|
1278
|
+
|
1279
|
+
*s = sumf;
|
1280
|
+
|
1281
|
+
#else
|
1282
|
+
|
1283
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
1284
|
+
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
1285
|
+
|
1286
|
+
int8_t aux8[QK_K];
|
1287
|
+
int16_t aux16[8];
|
1288
|
+
float sums [8];
|
1289
|
+
int32_t aux32[8];
|
1290
|
+
memset(sums, 0, 8*sizeof(float));
|
1291
|
+
|
1292
|
+
float sumf = 0;
|
1293
|
+
for (int i = 0; i < nb; ++i) {
|
1294
|
+
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
1295
|
+
const uint8_t * GGML_RESTRICT hm = x[i].qh;
|
1296
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
1297
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
1298
|
+
int8_t * GGML_RESTRICT a = aux8;
|
1299
|
+
uint8_t m = 1;
|
1300
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
1301
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
1302
|
+
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
1303
|
+
a += 32; m <<= 1;
|
1304
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
1305
|
+
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
1306
|
+
a += 32; m <<= 1;
|
1307
|
+
q4 += 32;
|
1308
|
+
}
|
1309
|
+
memcpy(utmp, x[i].scales, 12);
|
1310
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
1311
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
1312
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
1313
|
+
utmp[2] = uaux;
|
1314
|
+
utmp[0] &= kmask1;
|
1315
|
+
|
1316
|
+
int sumi = 0;
|
1317
|
+
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
1318
|
+
a = aux8;
|
1319
|
+
int is = 0;
|
1320
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
1321
|
+
int32_t scale = scales[is++];
|
1322
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1323
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1324
|
+
q8 += 8; a += 8;
|
1325
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1326
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1327
|
+
q8 += 8; a += 8;
|
1328
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1329
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1330
|
+
q8 += 8; a += 8;
|
1331
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1332
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1333
|
+
q8 += 8; a += 8;
|
1334
|
+
}
|
1335
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
1336
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
1337
|
+
const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
1338
|
+
sumf -= dmin * sumi;
|
1339
|
+
}
|
1340
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
1341
|
+
*s = sumf;
|
1342
|
+
#endif
|
1343
|
+
}
|
1344
|
+
|
1345
|
+
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
1346
|
+
assert(n % QK_K == 0);
|
1347
|
+
assert(nrc == 1);
|
1348
|
+
UNUSED(nrc);
|
1349
|
+
UNUSED(bx);
|
1350
|
+
UNUSED(by);
|
1351
|
+
UNUSED(bs);
|
1352
|
+
|
1353
|
+
const block_q6_K * GGML_RESTRICT x = vx;
|
1354
|
+
const block_q8_K * GGML_RESTRICT y = vy;
|
1355
|
+
|
1356
|
+
const int nb = n / QK_K;
|
1357
|
+
|
1358
|
+
#if defined __wasm_simd128__
|
1359
|
+
int8_t aux8[QK_K] __attribute__((aligned(16)));
|
1360
|
+
int32_t aux32[8] __attribute__((aligned(16))) = {0};
|
1361
|
+
float sums[8] __attribute__((aligned(16))) = {0};
|
1362
|
+
|
1363
|
+
for (int i = 0; i < nb; ++i) {
|
1364
|
+
// Unpack 6-bit quantized data into aux8 (unchanged)
|
1365
|
+
const uint8_t * GGML_RESTRICT q4 = x[i].ql;
|
1366
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
1367
|
+
int8_t * a = aux8;
|
1368
|
+
for (int j = 0; j < QK_K; j += 128) {
|
1369
|
+
for (int l = 0; l < 32; ++l) {
|
1370
|
+
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1371
|
+
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1372
|
+
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1373
|
+
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1374
|
+
}
|
1375
|
+
a += 128;
|
1376
|
+
q4 += 64;
|
1377
|
+
qh += 32;
|
1378
|
+
}
|
1379
|
+
|
1380
|
+
const int8_t * GGML_RESTRICT a_ptr = aux8;
|
1381
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
1382
|
+
v128_t acc0 = wasm_i32x4_splat(0);
|
1383
|
+
v128_t acc1 = wasm_i32x4_splat(0);
|
1384
|
+
|
1385
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
1386
|
+
const int scale = x[i].scales[j];
|
1387
|
+
const v128_t vscale = wasm_i32x4_splat(scale);
|
1388
|
+
|
1389
|
+
// Load 16 elements from a and q8
|
1390
|
+
const v128_t a_vec = wasm_v128_load(a_ptr);
|
1391
|
+
const v128_t q8_vec = wasm_v128_load(q8);
|
1392
|
+
|
1393
|
+
// Process low 8 elements
|
1394
|
+
v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
|
1395
|
+
v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
|
1396
|
+
v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
|
1397
|
+
v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
|
1398
|
+
v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
|
1399
|
+
|
1400
|
+
// Process high 8 elements
|
1401
|
+
v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
|
1402
|
+
v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
|
1403
|
+
v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
|
1404
|
+
v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
|
1405
|
+
v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
|
1406
|
+
|
1407
|
+
// Scale and accumulate
|
1408
|
+
prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
|
1409
|
+
prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
|
1410
|
+
prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
|
1411
|
+
prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
|
1412
|
+
|
1413
|
+
acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
|
1414
|
+
acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
|
1415
|
+
|
1416
|
+
a_ptr += 16;
|
1417
|
+
q8 += 16;
|
1418
|
+
}
|
1419
|
+
|
1420
|
+
// Store accumulated results
|
1421
|
+
wasm_v128_store(&aux32[0], acc0);
|
1422
|
+
wasm_v128_store(&aux32[4], acc1);
|
1423
|
+
|
1424
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
1425
|
+
for (int l = 0; l < 8; ++l) {
|
1426
|
+
sums[l] += d * aux32[l];
|
1427
|
+
}
|
1428
|
+
}
|
1429
|
+
|
1430
|
+
// Sum final results
|
1431
|
+
float sumf = 0;
|
1432
|
+
for (int l = 0; l < 8; ++l) {
|
1433
|
+
sumf += sums[l];
|
1434
|
+
}
|
1435
|
+
*s = sumf;
|
1436
|
+
|
1437
|
+
#else
|
1438
|
+
|
1439
|
+
int8_t aux8[QK_K];
|
1440
|
+
int16_t aux16[8];
|
1441
|
+
float sums [8];
|
1442
|
+
int32_t aux32[8];
|
1443
|
+
memset(sums, 0, 8*sizeof(float));
|
1444
|
+
|
1445
|
+
float sumf = 0;
|
1446
|
+
for (int i = 0; i < nb; ++i) {
|
1447
|
+
const uint8_t * GGML_RESTRICT q4 = x[i].ql;
|
1448
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
1449
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
1450
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
1451
|
+
int8_t * GGML_RESTRICT a = aux8;
|
1452
|
+
for (int j = 0; j < QK_K; j += 128) {
|
1453
|
+
for (int l = 0; l < 32; ++l) {
|
1454
|
+
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
1455
|
+
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
1456
|
+
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
1457
|
+
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
1458
|
+
}
|
1459
|
+
a += 128;
|
1460
|
+
q4 += 64;
|
1461
|
+
qh += 32;
|
1462
|
+
}
|
1463
|
+
a = aux8;
|
1464
|
+
int is = 0;
|
1465
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
1466
|
+
int scale = x[i].scales[is++];
|
1467
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1468
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1469
|
+
q8 += 8; a += 8;
|
1470
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
1471
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
1472
|
+
q8 += 8; a += 8;
|
1473
|
+
}
|
1474
|
+
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
1475
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
1476
|
+
}
|
1477
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
1478
|
+
*s = sumf;
|
1479
|
+
#endif
|
1480
|
+
}
|
1481
|
+
|