@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
#include "ggml-quants.h"
|
|
5
5
|
#include "ggml-impl.h"
|
|
6
|
+
#include "ggml-cpu-impl.h"
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
#include <math.h>
|
|
@@ -230,6 +231,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
|
|
230
231
|
|
|
231
232
|
return _mm_packus_epi16( bytes1, bytes2);
|
|
232
233
|
}
|
|
234
|
+
|
|
235
|
+
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
|
|
236
|
+
const __m128i ax = _mm_sign_epi8(x, x);
|
|
237
|
+
const __m128i sy = _mm_sign_epi8(y, x);
|
|
238
|
+
return _mm_maddubs_epi16(ax, sy);
|
|
239
|
+
}
|
|
233
240
|
#endif
|
|
234
241
|
#elif defined(__SSSE3__)
|
|
235
242
|
// horizontally add 4x4 floats
|
|
@@ -1630,7 +1637,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
|
|
|
1630
1637
|
// ===================== Helper functions
|
|
1631
1638
|
//
|
|
1632
1639
|
static inline int nearest_int(float fval) {
|
|
1633
|
-
assert(fval <= 4194303.f);
|
|
1640
|
+
assert(fabsf(fval) <= 4194303.f);
|
|
1634
1641
|
float val = fval + 12582912.f;
|
|
1635
1642
|
int i; memcpy(&i, &val, sizeof(int));
|
|
1636
1643
|
return (i & 0x007fffff) - 0x00400000;
|
|
@@ -3306,6 +3313,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
|
|
3306
3313
|
return nrow * row_size;
|
|
3307
3314
|
}
|
|
3308
3315
|
|
|
3316
|
+
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
|
3317
|
+
|
|
3318
|
+
void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
|
|
3319
|
+
assert(k % QK_K == 0);
|
|
3320
|
+
const int64_t nb = k / QK_K;
|
|
3321
|
+
|
|
3322
|
+
for (int64_t i = 0; i < nb; i++) {
|
|
3323
|
+
float amax = 0.0f; // absolute max
|
|
3324
|
+
|
|
3325
|
+
for (int j = 0; j < QK_K; j++) {
|
|
3326
|
+
const float v = x[j];
|
|
3327
|
+
amax = MAX(amax, fabsf(v));
|
|
3328
|
+
}
|
|
3329
|
+
|
|
3330
|
+
const float d = amax;
|
|
3331
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
3332
|
+
|
|
3333
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
|
3334
|
+
|
|
3335
|
+
// 5 elements per byte, along 32 bytes
|
|
3336
|
+
for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
|
|
3337
|
+
for (size_t m = 0; m < 32; ++m) {
|
|
3338
|
+
uint8_t q = 0;
|
|
3339
|
+
for (size_t n = 0; n < 5; ++n) {
|
|
3340
|
+
int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
|
|
3341
|
+
q *= 3;
|
|
3342
|
+
q += xi;
|
|
3343
|
+
}
|
|
3344
|
+
// ceiling division (243 == pow(3, 5))
|
|
3345
|
+
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
|
3346
|
+
y[i].qs[j + m] = q;
|
|
3347
|
+
}
|
|
3348
|
+
x += 5*32;
|
|
3349
|
+
}
|
|
3350
|
+
// along 16 bytes
|
|
3351
|
+
for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
|
|
3352
|
+
for (size_t m = 0; m < 16; ++m) {
|
|
3353
|
+
uint8_t q = 0;
|
|
3354
|
+
for (size_t n = 0; n < 5; ++n) {
|
|
3355
|
+
int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
|
|
3356
|
+
q *= 3;
|
|
3357
|
+
q += xi;
|
|
3358
|
+
}
|
|
3359
|
+
// ceiling division (243 == pow(3, 5))
|
|
3360
|
+
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
|
3361
|
+
y[i].qs[j + m] = q;
|
|
3362
|
+
}
|
|
3363
|
+
x += 5*16;
|
|
3364
|
+
}
|
|
3365
|
+
// 4 elements per byte
|
|
3366
|
+
for (size_t j = 0; j < sizeof(y->qh); ++j) {
|
|
3367
|
+
uint8_t q = 0;
|
|
3368
|
+
for (size_t m = 0; m < 4; ++m) {
|
|
3369
|
+
// -1, 0, 1 -> 0, 1, 2
|
|
3370
|
+
int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
|
|
3371
|
+
q *= 3;
|
|
3372
|
+
q += xi;
|
|
3373
|
+
}
|
|
3374
|
+
// shift the first value to the most significant trit
|
|
3375
|
+
q *= 3;
|
|
3376
|
+
// ceiling division (243 == pow(3, 5))
|
|
3377
|
+
q = ((uint16_t)q * 256 + (243 - 1)) / 243;
|
|
3378
|
+
y[i].qh[j] = q;
|
|
3379
|
+
}
|
|
3380
|
+
x += 4*sizeof(y->qh);
|
|
3381
|
+
}
|
|
3382
|
+
}
|
|
3383
|
+
|
|
3384
|
+
void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
|
|
3385
|
+
assert(k % QK_K == 0);
|
|
3386
|
+
const int64_t nb = k / QK_K;
|
|
3387
|
+
|
|
3388
|
+
for (int64_t i = 0; i < nb; i++) {
|
|
3389
|
+
float amax = 0.0f; // absolute max
|
|
3390
|
+
|
|
3391
|
+
for (int j = 0; j < QK_K; j++) {
|
|
3392
|
+
const float v = x[j];
|
|
3393
|
+
amax = MAX(amax, fabsf(v));
|
|
3394
|
+
}
|
|
3395
|
+
|
|
3396
|
+
const float d = amax;
|
|
3397
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
3398
|
+
|
|
3399
|
+
y[i].d = GGML_FP32_TO_FP16(d);
|
|
3400
|
+
|
|
3401
|
+
for (size_t j = 0; j < sizeof(y->qs); j += 32) {
|
|
3402
|
+
for (size_t m = 0; m < 32; ++m) {
|
|
3403
|
+
uint8_t q = 0;
|
|
3404
|
+
for (size_t n = 0; n < 4; ++n) {
|
|
3405
|
+
// -1, 0, 1 -> 0, 1, 2
|
|
3406
|
+
int xi = lroundf(x[m + n*32] * id) + 1;
|
|
3407
|
+
q += (xi & 3) << (2*n);
|
|
3408
|
+
}
|
|
3409
|
+
y[i].qs[j + m] = q;
|
|
3410
|
+
}
|
|
3411
|
+
x += 4*32;
|
|
3412
|
+
}
|
|
3413
|
+
}
|
|
3414
|
+
}
|
|
3415
|
+
|
|
3416
|
+
void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
|
|
3417
|
+
assert(k % QK_K == 0);
|
|
3418
|
+
block_tq1_0 * restrict y = vy;
|
|
3419
|
+
quantize_row_tq1_0_ref(x, y, k);
|
|
3420
|
+
}
|
|
3421
|
+
|
|
3422
|
+
void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
|
|
3423
|
+
assert(k % QK_K == 0);
|
|
3424
|
+
block_tq2_0 * restrict y = vy;
|
|
3425
|
+
quantize_row_tq2_0_ref(x, y, k);
|
|
3426
|
+
}
|
|
3427
|
+
|
|
3428
|
+
size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
3429
|
+
(void)quant_weights; // not used
|
|
3430
|
+
const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
|
|
3431
|
+
quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row);
|
|
3432
|
+
return nrow * row_size;
|
|
3433
|
+
}
|
|
3434
|
+
|
|
3435
|
+
size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
3436
|
+
(void)quant_weights; // not used
|
|
3437
|
+
const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
|
|
3438
|
+
quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row);
|
|
3439
|
+
return nrow * row_size;
|
|
3440
|
+
}
|
|
3441
|
+
|
|
3442
|
+
|
|
3443
|
+
void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
|
|
3444
|
+
assert(k % QK_K == 0);
|
|
3445
|
+
const int64_t nb = k / QK_K;
|
|
3446
|
+
|
|
3447
|
+
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
|
3448
|
+
|
|
3449
|
+
for (int64_t i = 0; i < nb; ++i) {
|
|
3450
|
+
|
|
3451
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
|
3452
|
+
|
|
3453
|
+
for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
|
3454
|
+
for (size_t n = 0; n < 5; ++n) {
|
|
3455
|
+
for (size_t m = 0; m < 32; ++m) {
|
|
3456
|
+
uint8_t q = x[i].qs[j + m] * pow3[n];
|
|
3457
|
+
int16_t xi = ((uint16_t) q * 3) >> 8;
|
|
3458
|
+
*y++ = (float) (xi - 1) * d;
|
|
3459
|
+
}
|
|
3460
|
+
}
|
|
3461
|
+
}
|
|
3462
|
+
for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
|
3463
|
+
for (size_t n = 0; n < 5; ++n) {
|
|
3464
|
+
for (size_t m = 0; m < 16; ++m) {
|
|
3465
|
+
uint8_t q = x[i].qs[j + m] * pow3[n];
|
|
3466
|
+
int16_t xi = ((uint16_t) q * 3) >> 8;
|
|
3467
|
+
*y++ = (float) (xi - 1) * d;
|
|
3468
|
+
}
|
|
3469
|
+
}
|
|
3470
|
+
}
|
|
3471
|
+
|
|
3472
|
+
for (size_t n = 0; n < 4; ++n) {
|
|
3473
|
+
for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
|
3474
|
+
uint8_t q = x[i].qh[j] * pow3[n];
|
|
3475
|
+
int16_t xi = ((uint16_t) q * 3) >> 8;
|
|
3476
|
+
*y++ = (float) (xi - 1) * d;
|
|
3477
|
+
}
|
|
3478
|
+
}
|
|
3479
|
+
}
|
|
3480
|
+
}
|
|
3481
|
+
|
|
3482
|
+
void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
|
|
3483
|
+
assert(k % QK_K == 0);
|
|
3484
|
+
const int64_t nb = k / QK_K;
|
|
3485
|
+
|
|
3486
|
+
for (int64_t i = 0; i < nb; ++i) {
|
|
3487
|
+
|
|
3488
|
+
const float d = GGML_FP16_TO_FP32(x[i].d);
|
|
3489
|
+
|
|
3490
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
3491
|
+
for (size_t l = 0; l < 4; ++l) {
|
|
3492
|
+
for (size_t m = 0; m < 32; ++m) {
|
|
3493
|
+
int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
|
|
3494
|
+
*y++ = (float) (q - 1) * d;
|
|
3495
|
+
}
|
|
3496
|
+
}
|
|
3497
|
+
}
|
|
3498
|
+
}
|
|
3499
|
+
}
|
|
3500
|
+
|
|
3309
3501
|
// ====================== "True" 2-bit (de)-quantization
|
|
3310
3502
|
|
|
3311
3503
|
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
|
@@ -3644,7 +3836,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
|
|
3644
3836
|
quantize_row_q8_K_ref(x, y, k);
|
|
3645
3837
|
}
|
|
3646
3838
|
|
|
3647
|
-
//===================================== Dot
|
|
3839
|
+
//===================================== Dot products =================================
|
|
3648
3840
|
|
|
3649
3841
|
//
|
|
3650
3842
|
// Helper functions
|
|
@@ -3818,42 +4010,141 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
|
3818
4010
|
float sumf = 0;
|
|
3819
4011
|
|
|
3820
4012
|
#if defined(__ARM_FEATURE_SVE)
|
|
3821
|
-
|
|
3822
|
-
|
|
3823
|
-
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
|
3824
|
-
|
|
3825
|
-
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
|
3826
|
-
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
|
3827
|
-
|
|
3828
|
-
for (; ib + 1 < nb; ib += 2) {
|
|
3829
|
-
const block_q4_0 * restrict x0 = &x[ib + 0];
|
|
3830
|
-
const block_q4_0 * restrict x1 = &x[ib + 1];
|
|
3831
|
-
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
3832
|
-
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
3833
|
-
|
|
3834
|
-
// load x
|
|
3835
|
-
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
|
3836
|
-
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
|
4013
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
|
4014
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
|
3837
4015
|
|
|
3838
|
-
|
|
3839
|
-
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
|
3840
|
-
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
|
4016
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
3841
4017
|
|
|
3842
|
-
|
|
3843
|
-
|
|
3844
|
-
|
|
4018
|
+
// VLA Implementation using switch case
|
|
4019
|
+
switch (vector_length) {
|
|
4020
|
+
case 128:
|
|
4021
|
+
{
|
|
4022
|
+
// predicate for activating higher lanes for 4 float32 elements
|
|
4023
|
+
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
|
|
4024
|
+
|
|
4025
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
4026
|
+
const block_q4_0 * restrict x0 = &x[ib + 0];
|
|
4027
|
+
const block_q4_0 * restrict x1 = &x[ib + 1];
|
|
4028
|
+
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
4029
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
4030
|
+
|
|
4031
|
+
// load x
|
|
4032
|
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
|
4033
|
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
|
4034
|
+
|
|
4035
|
+
// 4-bit -> 8-bit
|
|
4036
|
+
const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
|
|
4037
|
+
const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
|
|
4038
|
+
const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
|
|
4039
|
+
const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
|
|
4040
|
+
|
|
4041
|
+
// sub 8
|
|
4042
|
+
const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
|
|
4043
|
+
const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
|
|
4044
|
+
const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
|
|
4045
|
+
const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
|
|
4046
|
+
|
|
4047
|
+
// load y
|
|
4048
|
+
const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
|
|
4049
|
+
const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
|
|
4050
|
+
const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
|
|
4051
|
+
const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
|
|
4052
|
+
|
|
4053
|
+
// dot product
|
|
4054
|
+
sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
|
|
4055
|
+
svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
|
|
4056
|
+
svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
|
4057
|
+
sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
|
|
4058
|
+
svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
|
|
4059
|
+
svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
4060
|
+
}
|
|
3845
4061
|
|
|
3846
|
-
|
|
3847
|
-
|
|
3848
|
-
|
|
4062
|
+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
|
4063
|
+
} break;
|
|
4064
|
+
case 256:
|
|
4065
|
+
{
|
|
4066
|
+
// predicate for activating higher lanes for 16 int8 elements
|
|
4067
|
+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
|
4068
|
+
// predicate for activating lower lanes for 16 int8 elements
|
|
4069
|
+
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
|
4070
|
+
|
|
4071
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
4072
|
+
const block_q4_0 * restrict x0 = &x[ib + 0];
|
|
4073
|
+
const block_q4_0 * restrict x1 = &x[ib + 1];
|
|
4074
|
+
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
4075
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
4076
|
+
|
|
4077
|
+
// load x
|
|
4078
|
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
|
4079
|
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
|
4080
|
+
|
|
4081
|
+
// 4-bit -> 8-bit
|
|
4082
|
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
|
|
4083
|
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
|
|
4084
|
+
|
|
4085
|
+
// sub 8
|
|
4086
|
+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
|
4087
|
+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
|
4088
|
+
|
|
4089
|
+
// load y
|
|
4090
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
|
4091
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
|
4092
|
+
|
|
4093
|
+
// dot product
|
|
4094
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
|
|
4095
|
+
svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
|
4096
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
|
|
4097
|
+
svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
4098
|
+
}
|
|
3849
4099
|
|
|
3850
|
-
|
|
3851
|
-
|
|
3852
|
-
|
|
3853
|
-
|
|
4100
|
+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
|
4101
|
+
} break;
|
|
4102
|
+
case 512:
|
|
4103
|
+
{
|
|
4104
|
+
// predicate for activating higher lanes for 32 int8 elements
|
|
4105
|
+
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
|
|
4106
|
+
|
|
4107
|
+
// predicate for activating higher lanes for 16 int8 elements
|
|
4108
|
+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
|
4109
|
+
// predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
|
|
4110
|
+
const svbool_t pl16 = svnot_b_z(ph32, ph16);
|
|
4111
|
+
|
|
4112
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
4113
|
+
const block_q4_0 * restrict x0 = &x[ib + 0];
|
|
4114
|
+
const block_q4_0 * restrict x1 = &x[ib + 1];
|
|
4115
|
+
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
4116
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
4117
|
+
|
|
4118
|
+
// load x
|
|
4119
|
+
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
|
|
4120
|
+
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
|
|
4121
|
+
|
|
4122
|
+
// 4-bit -> 8-bit
|
|
4123
|
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
|
|
4124
|
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
|
|
4125
|
+
|
|
4126
|
+
// sub 8
|
|
4127
|
+
const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
|
|
4128
|
+
const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
|
|
4129
|
+
|
|
4130
|
+
// load y
|
|
4131
|
+
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
|
|
4132
|
+
const svint8_t qy1 = svld1_s8(ph32, y1->qs);
|
|
4133
|
+
|
|
4134
|
+
// dot product
|
|
4135
|
+
sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
|
|
4136
|
+
svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
|
4137
|
+
sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
|
|
4138
|
+
svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
4139
|
+
}
|
|
3854
4140
|
|
|
3855
|
-
|
|
4141
|
+
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
|
|
4142
|
+
} break;
|
|
4143
|
+
default:
|
|
4144
|
+
assert(false && "Unsupported vector length");
|
|
4145
|
+
break;
|
|
3856
4146
|
}
|
|
4147
|
+
|
|
3857
4148
|
#elif defined(__ARM_NEON)
|
|
3858
4149
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
|
3859
4150
|
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
|
@@ -3922,37 +4213,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
|
3922
4213
|
|
|
3923
4214
|
sumf = hsum_float_8(acc);
|
|
3924
4215
|
#elif defined(__AVX__)
|
|
3925
|
-
|
|
3926
|
-
__m256 acc = _mm256_setzero_ps();
|
|
3927
|
-
|
|
3928
|
-
// Main loop
|
|
3929
|
-
for (; ib < nb; ++ib) {
|
|
3930
|
-
// Compute combined scale for the block
|
|
3931
|
-
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
|
|
3932
|
-
|
|
3933
|
-
const __m128i lowMask = _mm_set1_epi8(0xF);
|
|
3934
|
-
const __m128i off = _mm_set1_epi8(8);
|
|
3935
|
-
|
|
3936
|
-
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
|
|
3937
|
-
|
|
3938
|
-
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
|
|
3939
|
-
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
|
|
3940
|
-
bx_0 = _mm_sub_epi8(bx_0, off);
|
|
3941
|
-
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
|
|
3942
|
-
|
|
3943
|
-
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
|
|
3944
|
-
by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
|
|
3945
|
-
bx_0 = _mm_sub_epi8(bx_0, off);
|
|
3946
|
-
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
|
|
4216
|
+
const __m128i mone = _mm_set1_epi16(1);
|
|
3947
4217
|
|
|
3948
|
-
|
|
3949
|
-
|
|
4218
|
+
__m256 accum1 = _mm256_setzero_ps();
|
|
4219
|
+
__m256 accum2 = _mm256_setzero_ps();
|
|
4220
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
4221
|
+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
|
|
4222
|
+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
|
|
4223
|
+
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
|
|
4224
|
+
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
|
|
4225
|
+
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
|
|
4226
|
+
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
|
|
3950
4227
|
|
|
3951
|
-
|
|
3952
|
-
|
|
4228
|
+
const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
|
|
4229
|
+
const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
|
|
4230
|
+
const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
|
|
4231
|
+
const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
|
|
4232
|
+
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
|
|
4233
|
+
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
|
|
4234
|
+
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
|
|
4235
|
+
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
|
|
4236
|
+
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
|
|
4237
|
+
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
|
|
4238
|
+
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
|
|
4239
|
+
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
|
|
4240
|
+
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
|
|
4241
|
+
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
|
|
4242
|
+
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
|
|
4243
|
+
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
|
|
3953
4244
|
}
|
|
3954
4245
|
|
|
3955
|
-
sumf = hsum_float_8(
|
|
4246
|
+
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
|
3956
4247
|
#elif defined(__SSSE3__)
|
|
3957
4248
|
// set constants
|
|
3958
4249
|
const __m128i lowMask = _mm_set1_epi8(0xF);
|
|
@@ -5303,29 +5594,124 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
|
5303
5594
|
float sumf = 0;
|
|
5304
5595
|
|
|
5305
5596
|
#if defined(__ARM_FEATURE_SVE)
|
|
5306
|
-
|
|
5307
|
-
|
|
5308
|
-
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
|
5597
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
|
5598
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
|
5309
5599
|
|
|
5310
|
-
|
|
5311
|
-
const block_q8_0 * restrict x0 = &x[ib + 0];
|
|
5312
|
-
const block_q8_0 * restrict x1 = &x[ib + 1];
|
|
5313
|
-
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
5314
|
-
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
5600
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
5315
5601
|
|
|
5316
|
-
|
|
5317
|
-
|
|
5318
|
-
|
|
5602
|
+
//VLA Implemenation for SVE
|
|
5603
|
+
switch (vector_length) {
|
|
5604
|
+
case 128:
|
|
5605
|
+
{
|
|
5606
|
+
// predicate for activating lanes for 16 Int8 elements
|
|
5607
|
+
const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
|
|
5608
|
+
const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
|
|
5609
|
+
|
|
5610
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
5611
|
+
const block_q8_0 * restrict x0 = &x[ib + 0];
|
|
5612
|
+
const block_q8_0 * restrict x1 = &x[ib + 1];
|
|
5613
|
+
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
5614
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
5615
|
+
|
|
5616
|
+
// load x
|
|
5617
|
+
const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
|
|
5618
|
+
const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
|
|
5619
|
+
const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
|
|
5620
|
+
const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
|
|
5621
|
+
|
|
5622
|
+
// load y
|
|
5623
|
+
const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
|
|
5624
|
+
const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
|
|
5625
|
+
const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
|
|
5626
|
+
const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
|
|
5627
|
+
|
|
5628
|
+
sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
|
|
5629
|
+
svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
|
|
5630
|
+
svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
|
5631
|
+
sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
|
|
5632
|
+
svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
|
|
5633
|
+
svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
5634
|
+
}
|
|
5319
5635
|
|
|
5320
|
-
|
|
5321
|
-
|
|
5322
|
-
|
|
5636
|
+
sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
|
|
5637
|
+
} break;
|
|
5638
|
+
case 256:
|
|
5639
|
+
{
|
|
5640
|
+
//printf("sve256");
|
|
5641
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
5642
|
+
const block_q8_0 * restrict x0 = &x[ib + 0];
|
|
5643
|
+
const block_q8_0 * restrict x1 = &x[ib + 1];
|
|
5644
|
+
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
5645
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
5646
|
+
|
|
5647
|
+
// load x
|
|
5648
|
+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
|
5649
|
+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
|
5650
|
+
|
|
5651
|
+
// load y
|
|
5652
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
|
5653
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
|
5654
|
+
|
|
5655
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
|
|
5656
|
+
svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
|
5657
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
|
|
5658
|
+
svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
|
5659
|
+
}
|
|
5323
5660
|
|
|
5324
|
-
|
|
5325
|
-
|
|
5326
|
-
|
|
5661
|
+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
|
5662
|
+
} break;
|
|
5663
|
+
case 512:
|
|
5664
|
+
{
|
|
5665
|
+
// predicate for activating high 256 bit
|
|
5666
|
+
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
|
|
5667
|
+
// predicate for activating low 256 bit
|
|
5668
|
+
const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
|
|
5327
5669
|
|
|
5328
|
-
|
|
5670
|
+
// predicate for activating high lanes for 8 float32 elements
|
|
5671
|
+
const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
|
|
5672
|
+
// predicate for activating low lanes for 8 float32 elements
|
|
5673
|
+
const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
|
|
5674
|
+
|
|
5675
|
+
svfloat32_t sumv00 = svdup_n_f32(0.0f);
|
|
5676
|
+
|
|
5677
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
5678
|
+
const block_q8_0 * restrict x0 = &x[ib + 0];
|
|
5679
|
+
const block_q8_0 * restrict x1 = &x[ib + 1];
|
|
5680
|
+
const block_q8_0 * restrict y0 = &y[ib + 0];
|
|
5681
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
5682
|
+
|
|
5683
|
+
//load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
|
|
5684
|
+
// and add them to make one 64 element vector
|
|
5685
|
+
// load x
|
|
5686
|
+
const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
|
|
5687
|
+
svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
|
|
5688
|
+
|
|
5689
|
+
qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
|
|
5690
|
+
|
|
5691
|
+
// load y
|
|
5692
|
+
const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
|
|
5693
|
+
svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
|
|
5694
|
+
|
|
5695
|
+
qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
|
|
5696
|
+
|
|
5697
|
+
// scale creation
|
|
5698
|
+
const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
|
|
5699
|
+
const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
|
|
5700
|
+
|
|
5701
|
+
// duplicate deq1 in first half of vector and deq2 in second half of vector
|
|
5702
|
+
const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
|
|
5703
|
+
|
|
5704
|
+
const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
|
|
5705
|
+
|
|
5706
|
+
sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
|
|
5707
|
+
}
|
|
5708
|
+
|
|
5709
|
+
sumf = svaddv_f32(svptrue_b32(), sumv00);
|
|
5710
|
+
break;
|
|
5711
|
+
}
|
|
5712
|
+
default:
|
|
5713
|
+
assert(false && "Unsupported vector length");
|
|
5714
|
+
break;
|
|
5329
5715
|
}
|
|
5330
5716
|
#elif defined(__ARM_NEON)
|
|
5331
5717
|
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
|
@@ -5470,6 +5856,501 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|
|
5470
5856
|
*s = sumf;
|
|
5471
5857
|
}
|
|
5472
5858
|
|
|
5859
|
+
void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
|
5860
|
+
assert(nrc == 1);
|
|
5861
|
+
UNUSED(nrc);
|
|
5862
|
+
UNUSED(bx);
|
|
5863
|
+
UNUSED(by);
|
|
5864
|
+
UNUSED(bs);
|
|
5865
|
+
|
|
5866
|
+
const block_tq1_0 * restrict x = vx;
|
|
5867
|
+
const block_q8_K * restrict y = vy;
|
|
5868
|
+
|
|
5869
|
+
const int nb = n / QK_K;
|
|
5870
|
+
|
|
5871
|
+
#if defined(__ARM_NEON)
|
|
5872
|
+
float sumf = 0.0f;
|
|
5873
|
+
|
|
5874
|
+
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
|
5875
|
+
|
|
5876
|
+
const uint8x16_t shift = vld1q_u8(k_shift);
|
|
5877
|
+
|
|
5878
|
+
for (int i = 0; i < nb; ++i) {
|
|
5879
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
5880
|
+
int32x4_t sumi0 = vdupq_n_s32(0);
|
|
5881
|
+
int32x4_t sumi1 = vdupq_n_s32(0);
|
|
5882
|
+
#else
|
|
5883
|
+
int16x8_t sumi0 = vdupq_n_s16(0);
|
|
5884
|
+
int16x8_t sumi1 = vdupq_n_s16(0);
|
|
5885
|
+
#endif
|
|
5886
|
+
|
|
5887
|
+
// first 32 bytes of 5 elements
|
|
5888
|
+
{
|
|
5889
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
|
|
5890
|
+
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
|
|
5891
|
+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
|
|
5892
|
+
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
|
|
5893
|
+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
|
|
5894
|
+
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
|
|
5895
|
+
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
|
|
5896
|
+
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
|
|
5897
|
+
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
|
|
5898
|
+
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
|
|
5899
|
+
|
|
5900
|
+
// multiply by 3 and keep the 2 bits above 8 bits
|
|
5901
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
|
5902
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
|
5903
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
|
5904
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
|
5905
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
|
5906
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
|
5907
|
+
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
|
|
5908
|
+
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
|
|
5909
|
+
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
|
|
5910
|
+
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
|
|
5911
|
+
|
|
5912
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
|
|
5913
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
|
|
5914
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
|
|
5915
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
|
|
5916
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
|
|
5917
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
|
|
5918
|
+
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
|
|
5919
|
+
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
|
|
5920
|
+
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
|
5921
|
+
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
|
5922
|
+
|
|
5923
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
5924
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
5925
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
5926
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
5927
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
5928
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
5929
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
5930
|
+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
|
5931
|
+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
|
5932
|
+
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
|
5933
|
+
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
|
5934
|
+
#else
|
|
5935
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
|
5936
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
|
5937
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
|
5938
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
|
5939
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
|
5940
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
|
5941
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
|
5942
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
|
5943
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
|
5944
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
|
5945
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
|
5946
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
|
5947
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
|
5948
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
|
5949
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
|
5950
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
|
5951
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
|
|
5952
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
|
|
5953
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
|
|
5954
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
|
|
5955
|
+
#endif
|
|
5956
|
+
}
|
|
5957
|
+
|
|
5958
|
+
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
|
5959
|
+
{
|
|
5960
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
|
|
5961
|
+
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
|
|
5962
|
+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
|
|
5963
|
+
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
|
|
5964
|
+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
|
|
5965
|
+
uint32_t qh;
|
|
5966
|
+
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
|
5967
|
+
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
|
|
5968
|
+
qx5 = vmulq_u8(qx5, shift);
|
|
5969
|
+
|
|
5970
|
+
// multiply by 3 and keep the 2 bits above 8 bits
|
|
5971
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
|
5972
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
|
5973
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
|
5974
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
|
5975
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
|
5976
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
|
5977
|
+
|
|
5978
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
|
|
5979
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
|
|
5980
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
|
|
5981
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
|
|
5982
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
|
5983
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
|
5984
|
+
|
|
5985
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
5986
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
5987
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
5988
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
5989
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
5990
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
5991
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
5992
|
+
#else
|
|
5993
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
|
5994
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
|
5995
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
|
5996
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
|
5997
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
|
5998
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
|
5999
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
|
6000
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
|
6001
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
|
6002
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
|
6003
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
|
6004
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
|
6005
|
+
#endif
|
|
6006
|
+
}
|
|
6007
|
+
|
|
6008
|
+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
|
6009
|
+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
|
6010
|
+
|
|
6011
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
6012
|
+
|
|
6013
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
6014
|
+
sumi0 = vaddq_s32(sumi0, sumi1);
|
|
6015
|
+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
|
6016
|
+
|
|
6017
|
+
sumf += d * (float) vaddvq_s32(sumi0);
|
|
6018
|
+
#else
|
|
6019
|
+
sumi0 = vaddq_s16(sumi0, sumi1);
|
|
6020
|
+
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
|
6021
|
+
|
|
6022
|
+
sumf += d * (float) vaddlvq_s16(sumi0);
|
|
6023
|
+
#endif
|
|
6024
|
+
}
|
|
6025
|
+
|
|
6026
|
+
*s = sumf;
|
|
6027
|
+
|
|
6028
|
+
#elif defined(__AVX2__)
|
|
6029
|
+
__m256 sumf = _mm256_setzero_ps();
|
|
6030
|
+
|
|
6031
|
+
for (int i = 0; i < nb; ++i) {
|
|
6032
|
+
// 16-bit sums
|
|
6033
|
+
__m256i sumi0 = _mm256_setzero_si256();
|
|
6034
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
|
6035
|
+
__m256i sumi2 = _mm256_setzero_si256();
|
|
6036
|
+
|
|
6037
|
+
// first 32 bytes of 5 elements
|
|
6038
|
+
{
|
|
6039
|
+
__m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
|
|
6040
|
+
// 8-bit multiplies with shifts, masks and adds
|
|
6041
|
+
__m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
|
|
6042
|
+
__m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
|
|
6043
|
+
__m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
|
|
6044
|
+
__m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
|
|
6045
|
+
|
|
6046
|
+
// TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
|
|
6047
|
+
|
|
6048
|
+
// Cancel the +1 from avg so that it behaves like a halving add
|
|
6049
|
+
qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
|
|
6050
|
+
qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
|
|
6051
|
+
qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
|
|
6052
|
+
qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
|
|
6053
|
+
qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
|
|
6054
|
+
// Multiply by 3 and get the top 2 bits
|
|
6055
|
+
qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
|
|
6056
|
+
qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
|
|
6057
|
+
qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
|
|
6058
|
+
qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
|
|
6059
|
+
qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
|
|
6060
|
+
qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
|
|
6061
|
+
qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
|
|
6062
|
+
qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
|
|
6063
|
+
qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
|
|
6064
|
+
qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
|
|
6065
|
+
|
|
6066
|
+
const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0));
|
|
6067
|
+
const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32));
|
|
6068
|
+
const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64));
|
|
6069
|
+
const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96));
|
|
6070
|
+
const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
|
|
6071
|
+
|
|
6072
|
+
qx0 = _mm256_maddubs_epi16(qx0, qy0);
|
|
6073
|
+
qx1 = _mm256_maddubs_epi16(qx1, qy1);
|
|
6074
|
+
qx2 = _mm256_maddubs_epi16(qx2, qy2);
|
|
6075
|
+
qx3 = _mm256_maddubs_epi16(qx3, qy3);
|
|
6076
|
+
qx4 = _mm256_maddubs_epi16(qx4, qy4);
|
|
6077
|
+
|
|
6078
|
+
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
|
|
6079
|
+
sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
|
|
6080
|
+
sumi2 = _mm256_add_epi16(sumi2, qx4);
|
|
6081
|
+
}
|
|
6082
|
+
|
|
6083
|
+
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
|
6084
|
+
{
|
|
6085
|
+
__m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
|
|
6086
|
+
uint32_t qh;
|
|
6087
|
+
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
|
6088
|
+
__m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
|
|
6089
|
+
__m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
|
|
6090
|
+
__m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
|
|
6091
|
+
__m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
|
|
6092
|
+
__m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
|
|
6093
|
+
__m256i qx01 = MM256_SET_M128I(qx1, qx0);
|
|
6094
|
+
__m256i qx23 = MM256_SET_M128I(qx3, qx2);
|
|
6095
|
+
|
|
6096
|
+
// avx2 does not have 8-bit multiplies, so 16-bit it is.
|
|
6097
|
+
qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
|
|
6098
|
+
qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
|
|
6099
|
+
__m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
|
|
6100
|
+
|
|
6101
|
+
__m256i qx45 = MM256_SET_M128I(qx5, qx4);
|
|
6102
|
+
|
|
6103
|
+
// Cancel the +1 from avg so that it behaves like a halving add
|
|
6104
|
+
qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
|
|
6105
|
+
qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
|
|
6106
|
+
qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
|
|
6107
|
+
// Multiply by 3 and get the top 2 bits
|
|
6108
|
+
qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
|
|
6109
|
+
qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
|
|
6110
|
+
qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
|
|
6111
|
+
qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
|
|
6112
|
+
qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
|
|
6113
|
+
qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
|
|
6114
|
+
|
|
6115
|
+
const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
|
|
6116
|
+
const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
|
|
6117
|
+
const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
|
|
6118
|
+
|
|
6119
|
+
qx01 = _mm256_maddubs_epi16(qx01, qy01);
|
|
6120
|
+
qx23 = _mm256_maddubs_epi16(qx23, qy23);
|
|
6121
|
+
qx45 = _mm256_maddubs_epi16(qx45, qy45);
|
|
6122
|
+
|
|
6123
|
+
sumi0 = _mm256_add_epi16(sumi0, qx01);
|
|
6124
|
+
sumi1 = _mm256_add_epi16(sumi1, qx23);
|
|
6125
|
+
sumi2 = _mm256_add_epi16(sumi2, qx45);
|
|
6126
|
+
}
|
|
6127
|
+
|
|
6128
|
+
const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
|
|
6129
|
+
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
|
|
6130
|
+
|
|
6131
|
+
sumi0 = _mm256_sub_epi16(sumi0, ysum);
|
|
6132
|
+
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
|
|
6133
|
+
sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
|
|
6134
|
+
|
|
6135
|
+
sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
|
|
6136
|
+
}
|
|
6137
|
+
|
|
6138
|
+
*s = hsum_float_8(sumf);
|
|
6139
|
+
|
|
6140
|
+
#else
|
|
6141
|
+
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
|
6142
|
+
|
|
6143
|
+
float sumf = 0.0f;
|
|
6144
|
+
|
|
6145
|
+
for (int i = 0; i < nb; ++i) {
|
|
6146
|
+
int sum = 0;
|
|
6147
|
+
|
|
6148
|
+
for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
|
6149
|
+
for (size_t l = 0; l < 5; ++l) {
|
|
6150
|
+
for (size_t m = 0; m < 32; ++m) {
|
|
6151
|
+
uint8_t q = x[i].qs[j + m] * pow3[l];
|
|
6152
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
|
6153
|
+
sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
|
|
6154
|
+
}
|
|
6155
|
+
}
|
|
6156
|
+
}
|
|
6157
|
+
for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
|
6158
|
+
for (size_t l = 0; l < 5; ++l) {
|
|
6159
|
+
for (size_t m = 0; m < 16; ++m) {
|
|
6160
|
+
uint8_t q = x[i].qs[j + m] * pow3[l];
|
|
6161
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
|
6162
|
+
sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
|
|
6163
|
+
}
|
|
6164
|
+
}
|
|
6165
|
+
}
|
|
6166
|
+
|
|
6167
|
+
for (size_t l = 0; l < 4; ++l) {
|
|
6168
|
+
for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
|
6169
|
+
uint8_t q = x[i].qh[j] * pow3[l];
|
|
6170
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
|
6171
|
+
sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
|
|
6172
|
+
}
|
|
6173
|
+
}
|
|
6174
|
+
|
|
6175
|
+
sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
|
|
6176
|
+
}
|
|
6177
|
+
|
|
6178
|
+
*s = sumf;
|
|
6179
|
+
#endif
|
|
6180
|
+
}
|
|
6181
|
+
|
|
6182
|
+
void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
|
6183
|
+
assert(nrc == 1);
|
|
6184
|
+
UNUSED(nrc);
|
|
6185
|
+
UNUSED(bx);
|
|
6186
|
+
UNUSED(by);
|
|
6187
|
+
UNUSED(bs);
|
|
6188
|
+
|
|
6189
|
+
const block_tq2_0 * restrict x = vx;
|
|
6190
|
+
const block_q8_K * restrict y = vy;
|
|
6191
|
+
|
|
6192
|
+
const int nb = n / QK_K;
|
|
6193
|
+
|
|
6194
|
+
#if defined(__ARM_NEON)
|
|
6195
|
+
float sumf = 0.0f;
|
|
6196
|
+
|
|
6197
|
+
const uint8x16_t m3 = vdupq_n_u8(3);
|
|
6198
|
+
|
|
6199
|
+
for (int i = 0; i < nb; ++i) {
|
|
6200
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
6201
|
+
int32x4_t sumi0 = vdupq_n_s32(0);
|
|
6202
|
+
int32x4_t sumi1 = vdupq_n_s32(0);
|
|
6203
|
+
#else
|
|
6204
|
+
int16x8_t sumi0 = vdupq_n_s16(0);
|
|
6205
|
+
int16x8_t sumi1 = vdupq_n_s16(0);
|
|
6206
|
+
#endif
|
|
6207
|
+
|
|
6208
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
6209
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
|
6210
|
+
uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
|
|
6211
|
+
uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
|
|
6212
|
+
uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
|
|
6213
|
+
uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
|
|
6214
|
+
uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
|
|
6215
|
+
uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
|
|
6216
|
+
uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
|
|
6217
|
+
|
|
6218
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
|
|
6219
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
|
|
6220
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
|
|
6221
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
|
|
6222
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
|
|
6223
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
|
|
6224
|
+
int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
|
|
6225
|
+
int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
|
|
6226
|
+
|
|
6227
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
|
|
6228
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
|
|
6229
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
|
|
6230
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
|
|
6231
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
|
|
6232
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
|
|
6233
|
+
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
|
6234
|
+
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
|
6235
|
+
|
|
6236
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
6237
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
6238
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
6239
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
6240
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
6241
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
6242
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
6243
|
+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
|
6244
|
+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
|
6245
|
+
#else
|
|
6246
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
|
6247
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
|
6248
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
|
6249
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
|
6250
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
|
6251
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
|
6252
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
|
6253
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
|
6254
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
|
6255
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
|
6256
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
|
6257
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
|
6258
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
|
6259
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
|
6260
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
|
6261
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
|
6262
|
+
#endif
|
|
6263
|
+
}
|
|
6264
|
+
|
|
6265
|
+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
|
6266
|
+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
|
6267
|
+
|
|
6268
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
6269
|
+
|
|
6270
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
6271
|
+
sumi0 = vaddq_s32(sumi0, sumi1);
|
|
6272
|
+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
|
6273
|
+
|
|
6274
|
+
sumf += d * (float) vaddvq_s32(sumi0);
|
|
6275
|
+
#else
|
|
6276
|
+
sumi0 = vaddq_s16(sumi0, sumi1);
|
|
6277
|
+
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
|
6278
|
+
|
|
6279
|
+
sumf += d * (float) vaddlvq_s16(sumi0);
|
|
6280
|
+
#endif
|
|
6281
|
+
}
|
|
6282
|
+
|
|
6283
|
+
*s = sumf;
|
|
6284
|
+
|
|
6285
|
+
#elif defined(__AVX2__)
|
|
6286
|
+
__m256 sumf = _mm256_setzero_ps();
|
|
6287
|
+
|
|
6288
|
+
for (int i = 0; i < nb; ++i) {
|
|
6289
|
+
// 16-bit sums, because 256*127 still fits
|
|
6290
|
+
__m256i sumi0 = _mm256_setzero_si256();
|
|
6291
|
+
__m256i sumi1 = _mm256_setzero_si256();
|
|
6292
|
+
|
|
6293
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
6294
|
+
__m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
|
|
6295
|
+
__m256i qx1 = _mm256_srli_epi16(qx0, 2);
|
|
6296
|
+
__m256i qx2 = _mm256_srli_epi16(qx0, 4);
|
|
6297
|
+
__m256i qx3 = _mm256_srli_epi16(qx0, 6);
|
|
6298
|
+
|
|
6299
|
+
// 0, 1, 2 (should not be 3)
|
|
6300
|
+
qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
|
|
6301
|
+
qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
|
|
6302
|
+
qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
|
|
6303
|
+
qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
|
|
6304
|
+
|
|
6305
|
+
const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
|
|
6306
|
+
const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
|
|
6307
|
+
const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
|
|
6308
|
+
const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
|
|
6309
|
+
|
|
6310
|
+
qx0 = _mm256_maddubs_epi16(qx0, qy0);
|
|
6311
|
+
qx1 = _mm256_maddubs_epi16(qx1, qy1);
|
|
6312
|
+
qx2 = _mm256_maddubs_epi16(qx2, qy2);
|
|
6313
|
+
qx3 = _mm256_maddubs_epi16(qx3, qy3);
|
|
6314
|
+
|
|
6315
|
+
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
|
|
6316
|
+
sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
|
|
6317
|
+
}
|
|
6318
|
+
|
|
6319
|
+
const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
|
|
6320
|
+
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
|
|
6321
|
+
|
|
6322
|
+
sumi0 = _mm256_add_epi16(sumi0, sumi1);
|
|
6323
|
+
sumi0 = _mm256_sub_epi16(sumi0, ysum);
|
|
6324
|
+
sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
|
|
6325
|
+
|
|
6326
|
+
sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
|
|
6327
|
+
}
|
|
6328
|
+
|
|
6329
|
+
*s = hsum_float_8(sumf);
|
|
6330
|
+
|
|
6331
|
+
#else
|
|
6332
|
+
float sumf = 0.0f;
|
|
6333
|
+
|
|
6334
|
+
for (int i = 0; i < nb; ++i) {
|
|
6335
|
+
int32_t sumi = 0;
|
|
6336
|
+
|
|
6337
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
6338
|
+
for (size_t l = 0; l < 4; ++l) {
|
|
6339
|
+
for (size_t k = 0; k < 32; ++k) {
|
|
6340
|
+
sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
|
|
6341
|
+
}
|
|
6342
|
+
}
|
|
6343
|
+
}
|
|
6344
|
+
|
|
6345
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
6346
|
+
|
|
6347
|
+
sumf += (float) sumi * d;
|
|
6348
|
+
}
|
|
6349
|
+
|
|
6350
|
+
*s = sumf;
|
|
6351
|
+
#endif
|
|
6352
|
+
}
|
|
6353
|
+
|
|
5473
6354
|
void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
|
5474
6355
|
assert(nrc == 1);
|
|
5475
6356
|
UNUSED(nrc);
|
|
@@ -6449,22 +7330,22 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
|
6449
7330
|
// compute mask for subtraction
|
|
6450
7331
|
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6451
7332
|
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
|
|
6452
|
-
vint8m1_t q3_m0 =
|
|
7333
|
+
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
|
|
6453
7334
|
m <<= 1;
|
|
6454
7335
|
|
|
6455
7336
|
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6456
7337
|
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
|
|
6457
|
-
vint8m1_t q3_m1 =
|
|
7338
|
+
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
|
|
6458
7339
|
m <<= 1;
|
|
6459
7340
|
|
|
6460
7341
|
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6461
7342
|
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
|
|
6462
|
-
vint8m1_t q3_m2 =
|
|
7343
|
+
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
|
|
6463
7344
|
m <<= 1;
|
|
6464
7345
|
|
|
6465
7346
|
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6466
7347
|
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
|
|
6467
|
-
vint8m1_t q3_m3 =
|
|
7348
|
+
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
|
|
6468
7349
|
m <<= 1;
|
|
6469
7350
|
|
|
6470
7351
|
// load Q8 and take product with Q3
|
|
@@ -7720,13 +8601,13 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
|
7720
8601
|
vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
|
|
7721
8602
|
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
7722
8603
|
vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
|
|
7723
|
-
vint8m1_t q5_m1 =
|
|
8604
|
+
vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
|
|
7724
8605
|
m <<= 1;
|
|
7725
8606
|
|
|
7726
8607
|
vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
|
|
7727
8608
|
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
7728
8609
|
vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
|
|
7729
|
-
vint8m1_t q5_m2 =
|
|
8610
|
+
vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
|
|
7730
8611
|
m <<= 1;
|
|
7731
8612
|
|
|
7732
8613
|
vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
|
|
@@ -10945,15 +11826,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
|
|
|
10945
11826
|
#endif
|
|
10946
11827
|
}
|
|
10947
11828
|
|
|
10948
|
-
|
|
10949
|
-
#if defined(__AVX__)
|
|
10950
|
-
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
|
|
10951
|
-
const __m128i ax = _mm_sign_epi8(x, x);
|
|
10952
|
-
const __m128i sy = _mm_sign_epi8(y, x);
|
|
10953
|
-
return _mm_maddubs_epi16(ax, sy);
|
|
10954
|
-
}
|
|
10955
|
-
#endif
|
|
10956
|
-
|
|
10957
11829
|
#if defined(__AVX2__)
|
|
10958
11830
|
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
|
10959
11831
|
const __m256i ax = _mm256_sign_epi8(x, x);
|
|
@@ -14800,6 +15672,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|
|
14800
15672
|
}
|
|
14801
15673
|
}
|
|
14802
15674
|
} break;
|
|
15675
|
+
case GGML_TYPE_TQ1_0:
|
|
15676
|
+
{
|
|
15677
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
|
|
15678
|
+
} break;
|
|
15679
|
+
case GGML_TYPE_TQ2_0:
|
|
15680
|
+
{
|
|
15681
|
+
VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
|
|
15682
|
+
} break;
|
|
14803
15683
|
case GGML_TYPE_IQ1_S:
|
|
14804
15684
|
{
|
|
14805
15685
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|