whisper.rn 0.4.0-rc.10 → 0.4.0-rc.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +9 -3
- package/cpp/amx/amx.cpp +220 -0
- package/cpp/amx/amx.h +8 -0
- package/cpp/amx/common.h +91 -0
- package/cpp/amx/mmq.cpp +2511 -0
- package/cpp/amx/mmq.h +10 -0
- package/cpp/ggml-alloc.c +6 -14
- package/cpp/ggml-backend-impl.h +50 -11
- package/cpp/ggml-backend-reg.cpp +409 -31
- package/cpp/ggml-backend.cpp +9 -3
- package/cpp/ggml-backend.h +18 -0
- package/cpp/ggml-common.h +41 -43
- package/cpp/ggml-cpp.h +1 -0
- package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +941 -254
- package/cpp/ggml-cpu-aarch64.h +2 -24
- package/cpp/ggml-cpu-impl.h +171 -11
- package/cpp/ggml-cpu-quants.c +1812 -389
- package/cpp/ggml-cpu-traits.cpp +36 -0
- package/cpp/ggml-cpu-traits.h +38 -0
- package/cpp/ggml-cpu.c +1432 -610
- package/cpp/ggml-cpu.cpp +131 -141
- package/cpp/ggml-cpu.h +10 -50
- package/cpp/ggml-impl.h +27 -11
- package/cpp/ggml-metal-impl.h +39 -0
- package/cpp/ggml-metal.h +1 -1
- package/cpp/ggml-metal.m +1031 -359
- package/cpp/ggml-opt.cpp +854 -0
- package/cpp/ggml-opt.h +216 -0
- package/cpp/ggml-quants.c +0 -9
- package/cpp/ggml-threading.h +4 -2
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +501 -1537
- package/cpp/ggml.h +144 -171
- package/cpp/gguf.cpp +1329 -0
- package/cpp/gguf.h +202 -0
- package/cpp/whisper.cpp +254 -114
- package/cpp/whisper.h +6 -3
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +2 -1
- package/src/version.json +1 -1
- package/whisper-rn.podspec +2 -2
- package/cpp/README.md +0 -4
- package/cpp/ggml-aarch64.c +0 -129
- package/cpp/ggml-aarch64.h +0 -19
- package/cpp/ggml-backend.cpp.rej +0 -12
package/cpp/ggml-cpu-quants.c
CHANGED
|
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|
|
103
103
|
}
|
|
104
104
|
|
|
105
105
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
|
106
|
-
#if defined(
|
|
106
|
+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
107
107
|
const __m256i zero = _mm256_setzero_si256();
|
|
108
108
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
|
109
109
|
return _mm256_cvtepi32_ps(summed_pairs);
|
|
110
|
+
#elif defined(__AVXVNNI__)
|
|
111
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
112
|
+
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
|
113
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
|
110
114
|
#else
|
|
111
115
|
// Perform multiplication and create 16-bit values
|
|
112
116
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
|
@@ -293,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
|
|
|
293
297
|
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
|
294
298
|
#endif
|
|
295
299
|
|
|
300
|
+
#if defined(__loongarch_sx)
|
|
301
|
+
|
|
302
|
+
static __m128i lsx_packs_w(__m128i a, __m128i b) {
|
|
303
|
+
__m128i tmp, tmp1;
|
|
304
|
+
tmp = __lsx_vsat_w(a, 15);
|
|
305
|
+
tmp1 = __lsx_vsat_w(b, 15);
|
|
306
|
+
return __lsx_vpickev_h(tmp1, tmp);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
|
310
|
+
__m128i tmp, tmp1;
|
|
311
|
+
tmp = __lsx_vsat_h(a, 7);
|
|
312
|
+
tmp1 = __lsx_vsat_h(b, 7);
|
|
313
|
+
return __lsx_vpickev_b(tmp1, tmp);
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
|
317
|
+
__m128i tmp, tmp1;
|
|
318
|
+
tmp = __lsx_vsat_hu(a, 7);
|
|
319
|
+
tmp1 = __lsx_vsat_hu(b, 7);
|
|
320
|
+
return __lsx_vpickev_b(tmp1, tmp);
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
|
|
324
|
+
__m128i tmp1, tmp2;
|
|
325
|
+
tmp1 = __lsx_vmulwev_h_b(a, b);
|
|
326
|
+
tmp2 = __lsx_vmulwod_h_b(a, b);
|
|
327
|
+
return __lsx_vsadd_h(tmp1, tmp2);
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
static __m128i lsx_madd_h(__m128i a, __m128i b) {
|
|
331
|
+
__m128i tmp1, tmp2;
|
|
332
|
+
tmp1 = __lsx_vmulwev_w_h(a, b);
|
|
333
|
+
tmp2 = __lsx_vmulwod_w_h(a, b);
|
|
334
|
+
return __lsx_vadd_w(tmp1, tmp2);
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
|
338
|
+
v4i32 __ret = {d, c, b, a};
|
|
339
|
+
return (__m128i)__ret;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
|
343
|
+
__m128i mask_f, zero, tmp0, tmp2, mask;
|
|
344
|
+
int f = 0x8f;
|
|
345
|
+
mask_f = __lsx_vreplgr2vr_b(f);
|
|
346
|
+
zero = __lsx_vldi(0);
|
|
347
|
+
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
|
348
|
+
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
|
349
|
+
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
|
350
|
+
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
|
351
|
+
return __lsx_vshuf_b(a, zero, tmp2);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
|
355
|
+
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
|
356
|
+
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
|
357
|
+
return __lsx_vadd_h(tmp1, tmp2);
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
|
361
|
+
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
|
362
|
+
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
|
363
|
+
return __lsx_vadd_w(tmp1, tmp2);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
|
367
|
+
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
|
368
|
+
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
|
369
|
+
|
|
370
|
+
return __lsx_vfadd_s(tmp1, tmp2);
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
|
|
374
|
+
__m128 res_0 =lsx_hadd_s(a, b);
|
|
375
|
+
__m128 res_1 =lsx_hadd_s(c, d);
|
|
376
|
+
__m128 res =lsx_hadd_s(res_0, res_1);
|
|
377
|
+
res =lsx_hadd_s(res, res);
|
|
378
|
+
res =lsx_hadd_s(res, res);
|
|
379
|
+
|
|
380
|
+
return ((v4f32)res)[0];
|
|
381
|
+
}
|
|
382
|
+
#endif
|
|
383
|
+
|
|
296
384
|
#if defined(__loongarch_asx)
|
|
297
385
|
|
|
298
386
|
#ifdef __clang__
|
|
@@ -391,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
|
|
|
391
479
|
return (__m256i)__ret;
|
|
392
480
|
}
|
|
393
481
|
|
|
394
|
-
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
|
|
395
|
-
v4i32 __ret = {d, c, b, a};
|
|
396
|
-
return (__m128i)__ret;
|
|
397
|
-
}
|
|
398
|
-
|
|
399
482
|
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
|
|
400
483
|
v4i64 __ret = {d, c, b, a};
|
|
401
484
|
return (__m256i)__ret;
|
|
@@ -405,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
|
|
405
488
|
return lasx_set_q(x, y);
|
|
406
489
|
}
|
|
407
490
|
|
|
408
|
-
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
|
|
409
|
-
__m128i mask_f, zero, tmp0, tmp2, mask;
|
|
410
|
-
int f = 0x8f;
|
|
411
|
-
mask_f = __lsx_vreplgr2vr_b(f);
|
|
412
|
-
zero = __lsx_vldi(0);
|
|
413
|
-
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
|
|
414
|
-
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
|
|
415
|
-
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
|
|
416
|
-
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
|
|
417
|
-
return __lsx_vshuf_b(a, zero, tmp2);
|
|
418
|
-
}
|
|
419
|
-
|
|
420
491
|
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
|
421
492
|
__m256i mask_f, zero, tmp0, tmp2, mask;
|
|
422
493
|
int f = 0x8f;
|
|
@@ -430,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
|
|
|
430
501
|
}
|
|
431
502
|
|
|
432
503
|
static __m256i lasx_extu8_16(__m128i a) {
|
|
433
|
-
|
|
434
|
-
__m128i vlo = __lsx_vilvl_b(zero, a);
|
|
435
|
-
__m128i vhi = __lsx_vilvh_b(zero, a);
|
|
436
|
-
return lasx_set_q(vhi, vlo);
|
|
504
|
+
return __lasx_vext2xv_hu_bu(____m256i(a));
|
|
437
505
|
}
|
|
438
506
|
|
|
439
507
|
static __m256i lasx_ext8_16(__m128i a) {
|
|
440
|
-
|
|
441
|
-
__m128i vlo = __lsx_vilvl_b(sign, a);
|
|
442
|
-
__m128i vhi = __lsx_vilvh_b(sign, a);
|
|
443
|
-
return lasx_set_q(vhi, vlo);
|
|
508
|
+
return __lasx_vext2xv_h_b(____m256i(a));
|
|
444
509
|
}
|
|
445
510
|
|
|
446
511
|
static __m256i lasx_ext16_32(__m128i a) {
|
|
447
|
-
|
|
448
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
|
|
449
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
|
|
450
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
|
|
451
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
|
|
452
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
|
|
453
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
|
|
454
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
|
|
455
|
-
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
|
|
456
|
-
return tmp1;
|
|
512
|
+
return __lasx_vext2xv_w_h(____m256i(a));
|
|
457
513
|
}
|
|
458
514
|
|
|
459
515
|
static __m128i lasx_extracti128( __m256i a, int pos) {
|
|
@@ -478,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
|
|
|
478
534
|
return ret;
|
|
479
535
|
}
|
|
480
536
|
|
|
481
|
-
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
|
|
482
|
-
__m128i tmp1 = __lsx_vpickev_h(b, a);
|
|
483
|
-
__m128i tmp2 = __lsx_vpickod_h(b, a);
|
|
484
|
-
return __lsx_vadd_h(tmp1, tmp2);
|
|
485
|
-
}
|
|
486
|
-
|
|
487
|
-
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
|
|
488
|
-
__m128i tmp1 = __lsx_vpickev_w(b, a);
|
|
489
|
-
__m128i tmp2 = __lsx_vpickod_w(b, a);
|
|
490
|
-
return __lsx_vadd_w(tmp1, tmp2);
|
|
491
|
-
}
|
|
492
|
-
|
|
493
|
-
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
|
|
494
|
-
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
|
|
495
|
-
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
|
|
496
|
-
|
|
497
|
-
return __lsx_vfadd_s(tmp1, tmp2);
|
|
498
|
-
}
|
|
499
|
-
|
|
500
537
|
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
|
501
538
|
__m256i tmp1, tmp2;
|
|
502
539
|
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
|
@@ -525,40 +562,39 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
|
|
|
525
562
|
return __lasx_xvpickev_b(tmp1, tmp);
|
|
526
563
|
}
|
|
527
564
|
|
|
528
|
-
static
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
return
|
|
533
|
-
}
|
|
534
|
-
|
|
535
|
-
static __m128i lsx_packs_h(__m128i a, __m128i b) {
|
|
536
|
-
__m128i tmp, tmp1;
|
|
537
|
-
tmp = __lsx_vsat_h(a, 7);
|
|
538
|
-
tmp1 = __lsx_vsat_h(b, 7);
|
|
539
|
-
return __lsx_vpickev_b(tmp1, tmp);
|
|
540
|
-
}
|
|
541
|
-
|
|
542
|
-
static __m128i lsx_packus_h(__m128i a, __m128i b) {
|
|
543
|
-
__m128i tmp, tmp1;
|
|
544
|
-
tmp = __lsx_vsat_hu(a, 7);
|
|
545
|
-
tmp1 = __lsx_vsat_hu(b, 7);
|
|
546
|
-
return __lsx_vpickev_b(tmp1, tmp);
|
|
565
|
+
static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
|
|
566
|
+
__m256i tmp1, tmp2;
|
|
567
|
+
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
|
568
|
+
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
|
569
|
+
return __lasx_xvadd_h(tmp1, tmp2);
|
|
547
570
|
}
|
|
548
571
|
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
572
|
+
static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
|
|
573
|
+
switch (b) {
|
|
574
|
+
case 0: return __lasx_xvrepl128vei_h(a, 0);
|
|
575
|
+
case 1: return __lasx_xvrepl128vei_h(a, 1);
|
|
576
|
+
case 2: return __lasx_xvrepl128vei_h(a, 2);
|
|
577
|
+
case 3: return __lasx_xvrepl128vei_h(a, 3);
|
|
578
|
+
case 4: return __lasx_xvrepl128vei_h(a, 4);
|
|
579
|
+
case 5: return __lasx_xvrepl128vei_h(a, 5);
|
|
580
|
+
case 6: return __lasx_xvrepl128vei_h(a, 6);
|
|
581
|
+
case 7: return __lasx_xvrepl128vei_h(a, 7);
|
|
582
|
+
default: __builtin_unreachable();
|
|
583
|
+
}
|
|
555
584
|
}
|
|
556
585
|
|
|
557
|
-
static
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
586
|
+
static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
|
|
587
|
+
switch (b) {
|
|
588
|
+
case 0: return __lasx_xvandi_b(a, 1 << 0);
|
|
589
|
+
case 1: return __lasx_xvandi_b(a, 1 << 1);
|
|
590
|
+
case 2: return __lasx_xvandi_b(a, 1 << 2);
|
|
591
|
+
case 3: return __lasx_xvandi_b(a, 1 << 3);
|
|
592
|
+
case 4: return __lasx_xvandi_b(a, 1 << 4);
|
|
593
|
+
case 5: return __lasx_xvandi_b(a, 1 << 5);
|
|
594
|
+
case 6: return __lasx_xvandi_b(a, 1 << 6);
|
|
595
|
+
case 7: return __lasx_xvandi_b(a, 1 << 7);
|
|
596
|
+
default: __builtin_unreachable();
|
|
597
|
+
}
|
|
562
598
|
}
|
|
563
599
|
|
|
564
600
|
// multiply int8_t, add results pairwise twice
|
|
@@ -576,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
|
|
|
576
612
|
// horizontally add 8 floats
|
|
577
613
|
static inline float hsum_float_8(const __m256 x) {
|
|
578
614
|
__m128 res = lasx_extractf128(x, 1);
|
|
579
|
-
ft_union tmp;
|
|
580
615
|
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
|
581
616
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
|
582
617
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
|
583
|
-
|
|
584
|
-
return tmp.f;
|
|
618
|
+
return ((v4f32)res)[0];
|
|
585
619
|
}
|
|
586
620
|
|
|
587
621
|
// horizontally add 8 int32_t
|
|
@@ -657,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
|
|
|
657
691
|
|
|
658
692
|
// multiply int8_t, add results pairwise twice and return as float vector
|
|
659
693
|
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
|
663
|
-
// Sign the values of the y vectors
|
|
664
|
-
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
|
665
|
-
|
|
666
|
-
return mul_sum_us8_pairs_float(ax, sy);
|
|
694
|
+
const __m256i dot = lasx_madd_h_b(x, y);
|
|
695
|
+
return sum_i16_pairs_float(dot);
|
|
667
696
|
}
|
|
668
697
|
|
|
669
698
|
static inline __m128i packNibbles( __m256i bytes ) {
|
|
@@ -743,7 +772,7 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
|
|
|
743
772
|
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
|
|
744
773
|
}
|
|
745
774
|
}
|
|
746
|
-
#elif defined
|
|
775
|
+
#elif defined __wasm_simd128__
|
|
747
776
|
for (int i = 0; i < nb; i++) {
|
|
748
777
|
v128_t srcv [8];
|
|
749
778
|
v128_t asrcv[8];
|
|
@@ -923,7 +952,6 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
|
|
|
923
952
|
|
|
924
953
|
#elif defined(__loongarch_asx)
|
|
925
954
|
for (int i = 0; i < nb; i++) {
|
|
926
|
-
ft_union fi;
|
|
927
955
|
__m256 v0 = (__m256)__lasx_xvld( x , 0);
|
|
928
956
|
__m256 v1 = (__m256)__lasx_xvld( x , 32);
|
|
929
957
|
__m256 v2 = (__m256)__lasx_xvld( x , 64);
|
|
@@ -941,8 +969,7 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
|
|
|
941
969
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
|
|
942
970
|
__m128 tmp = max4;
|
|
943
971
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
|
|
944
|
-
|
|
945
|
-
const float max_scalar = fi.f;
|
|
972
|
+
const float max_scalar = ((v4f32)max4)[0];
|
|
946
973
|
|
|
947
974
|
// Quantize these floats
|
|
948
975
|
const float d = max_scalar / 127.f;
|
|
@@ -984,6 +1011,38 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
|
|
|
984
1011
|
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
|
|
985
1012
|
|
|
986
1013
|
}
|
|
1014
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
1015
|
+
for (int i = 0; i < nb; i++) {
|
|
1016
|
+
__vector float srcv [8];
|
|
1017
|
+
__vector float asrcv[8];
|
|
1018
|
+
__vector float amaxv[8];
|
|
1019
|
+
|
|
1020
|
+
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
|
|
1021
|
+
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
|
|
1022
|
+
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
|
|
1023
|
+
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
|
|
1024
|
+
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
|
|
1025
|
+
|
|
1026
|
+
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
|
|
1027
|
+
vec_extract(amaxv[0], 1)),
|
|
1028
|
+
MAX(vec_extract(amaxv[0], 2),
|
|
1029
|
+
vec_extract(amaxv[0], 3)));
|
|
1030
|
+
|
|
1031
|
+
const float d = amax / ((1 << 7) - 1);
|
|
1032
|
+
const float id = d ? 1.0f / d : 0.0f;
|
|
1033
|
+
|
|
1034
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(d);
|
|
1035
|
+
|
|
1036
|
+
for (int j = 0; j < 8; j++) {
|
|
1037
|
+
const __vector float v = vec_mul(srcv[j], vec_splats(id));
|
|
1038
|
+
const __vector int32_t vi = vec_signed(v);
|
|
1039
|
+
|
|
1040
|
+
y[i].qs[4*j + 0] = vec_extract(vi, 0);
|
|
1041
|
+
y[i].qs[4*j + 1] = vec_extract(vi, 1);
|
|
1042
|
+
y[i].qs[4*j + 2] = vec_extract(vi, 2);
|
|
1043
|
+
y[i].qs[4*j + 3] = vec_extract(vi, 3);
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
987
1046
|
#else
|
|
988
1047
|
WSP_GGML_UNUSED(nb);
|
|
989
1048
|
// scalar
|
|
@@ -1033,7 +1092,7 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
|
|
|
1033
1092
|
|
|
1034
1093
|
y[i].s = WSP_GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
|
|
1035
1094
|
}
|
|
1036
|
-
#elif defined
|
|
1095
|
+
#elif defined __wasm_simd128__
|
|
1037
1096
|
for (int i = 0; i < nb; i++) {
|
|
1038
1097
|
v128_t srcv [8];
|
|
1039
1098
|
v128_t asrcv[8];
|
|
@@ -1247,7 +1306,6 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
|
|
|
1247
1306
|
|
|
1248
1307
|
#elif defined(__loongarch_asx)
|
|
1249
1308
|
for (int i = 0; i < nb; i++) {
|
|
1250
|
-
ft_union ft;
|
|
1251
1309
|
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
|
|
1252
1310
|
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
|
|
1253
1311
|
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
|
|
@@ -1265,8 +1323,7 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
|
|
|
1265
1323
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
|
|
1266
1324
|
__m128 tmp = max4;
|
|
1267
1325
|
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
|
|
1268
|
-
|
|
1269
|
-
const float max_scalar = ft.f;
|
|
1326
|
+
const float max_scalar = ((v4f32)max4)[0];
|
|
1270
1327
|
|
|
1271
1328
|
// Quantize these floats
|
|
1272
1329
|
const float d = max_scalar / 127.f;
|
|
@@ -1312,6 +1369,44 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
|
|
|
1312
1369
|
__lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
|
|
1313
1370
|
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
|
|
1314
1371
|
}
|
|
1372
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
1373
|
+
for (int i = 0; i < nb; i++) {
|
|
1374
|
+
__vector float srcv [8];
|
|
1375
|
+
__vector float asrcv[8];
|
|
1376
|
+
__vector float amaxv[8];
|
|
1377
|
+
|
|
1378
|
+
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
|
|
1379
|
+
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
|
|
1380
|
+
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
|
|
1381
|
+
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
|
|
1382
|
+
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
|
|
1383
|
+
|
|
1384
|
+
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
|
|
1385
|
+
vec_extract(amaxv[0], 1)),
|
|
1386
|
+
MAX(vec_extract(amaxv[0], 2),
|
|
1387
|
+
vec_extract(amaxv[0], 3)));
|
|
1388
|
+
|
|
1389
|
+
const float d = amax / ((1 << 7) - 1);
|
|
1390
|
+
const float id = d ? 1.0f / d : 0.0f;
|
|
1391
|
+
|
|
1392
|
+
y[i].d = WSP_GGML_FP32_TO_FP16(d);
|
|
1393
|
+
|
|
1394
|
+
__vector int32_t acc = vec_splats(0);
|
|
1395
|
+
|
|
1396
|
+
for (int j = 0; j < 8; j++) {
|
|
1397
|
+
const __vector float v = vec_mul(srcv[j], vec_splats(id));
|
|
1398
|
+
const __vector int32_t vi = vec_signed(v);
|
|
1399
|
+
|
|
1400
|
+
y[i].qs[4*j + 0] = vec_extract(vi, 0);
|
|
1401
|
+
y[i].qs[4*j + 1] = vec_extract(vi, 1);
|
|
1402
|
+
y[i].qs[4*j + 2] = vec_extract(vi, 2);
|
|
1403
|
+
y[i].qs[4*j + 3] = vec_extract(vi, 3);
|
|
1404
|
+
|
|
1405
|
+
acc = vec_add(acc, vi);
|
|
1406
|
+
}
|
|
1407
|
+
|
|
1408
|
+
y[i].s = WSP_GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
|
|
1409
|
+
}
|
|
1315
1410
|
#else
|
|
1316
1411
|
WSP_GGML_UNUSED(nb);
|
|
1317
1412
|
// scalar
|
|
@@ -1649,7 +1744,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
|
|
|
1649
1744
|
//===================================== Q8_K ==============================================
|
|
1650
1745
|
|
|
1651
1746
|
void wsp_quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
|
1747
|
+
#ifdef __wasm_simd128__
|
|
1748
|
+
assert(k % QK_K == 0);
|
|
1749
|
+
const int64_t nb = k / QK_K;
|
|
1750
|
+
block_q8_K * restrict yc = y; // Cast to proper type
|
|
1751
|
+
|
|
1752
|
+
for (int i = 0; i < nb; i++) {
|
|
1753
|
+
const float * x_block = x + i * QK_K;
|
|
1754
|
+
|
|
1755
|
+
v128_t min_vec = wasm_v128_load(x_block);
|
|
1756
|
+
v128_t max_vec = min_vec;
|
|
1757
|
+
|
|
1758
|
+
for (int j = 4; j < QK_K; j += 4) {
|
|
1759
|
+
v128_t x_vec = wasm_v128_load(x_block + j);
|
|
1760
|
+
max_vec = wasm_f32x4_pmax(max_vec, x_vec);
|
|
1761
|
+
min_vec = wasm_f32x4_pmin(min_vec, x_vec);
|
|
1762
|
+
}
|
|
1763
|
+
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
|
|
1764
|
+
max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
|
|
1765
|
+
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
|
|
1766
|
+
min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
|
|
1767
|
+
float max = wasm_f32x4_extract_lane(max_vec, 0);
|
|
1768
|
+
float min = wasm_f32x4_extract_lane(min_vec, 0);
|
|
1769
|
+
float amax = -min > max ? min : max;
|
|
1770
|
+
|
|
1771
|
+
if (amax == 0.0f) {
|
|
1772
|
+
yc[i].d = 0.0f;
|
|
1773
|
+
const v128_t zero = wasm_i8x16_splat(0);
|
|
1774
|
+
for (int j = 0; j < QK_K; j += 16) {
|
|
1775
|
+
wasm_v128_store(yc[i].qs + j, zero);
|
|
1776
|
+
}
|
|
1777
|
+
continue;
|
|
1778
|
+
}
|
|
1779
|
+
|
|
1780
|
+
const float iscale = -127.0f / amax;
|
|
1781
|
+
const v128_t scale_vec = wasm_f32x4_splat(iscale);
|
|
1782
|
+
|
|
1783
|
+
// Process 16 elements per iteration
|
|
1784
|
+
for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
|
|
1785
|
+
// Load and quantize 16 floats
|
|
1786
|
+
v128_t x0 = wasm_v128_load(x_block + j);
|
|
1787
|
+
v128_t x1 = wasm_v128_load(x_block + j + 4);
|
|
1788
|
+
v128_t x2 = wasm_v128_load(x_block + j + 8);
|
|
1789
|
+
v128_t x3 = wasm_v128_load(x_block + j + 12);
|
|
1790
|
+
|
|
1791
|
+
v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
|
|
1792
|
+
v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
|
|
1793
|
+
v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
|
|
1794
|
+
v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
|
|
1795
|
+
|
|
1796
|
+
// Convert to i32 with saturation
|
|
1797
|
+
v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
|
|
1798
|
+
v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
|
|
1799
|
+
v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
|
|
1800
|
+
v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
|
|
1801
|
+
|
|
1802
|
+
// Pack into 16 i8 values
|
|
1803
|
+
v128_t i8 = wasm_i8x16_narrow_i16x8(
|
|
1804
|
+
wasm_i16x8_narrow_i32x4(i0, i1),
|
|
1805
|
+
wasm_i16x8_narrow_i32x4(i2, i3)
|
|
1806
|
+
);
|
|
1807
|
+
wasm_v128_store(yc[i].qs + j, i8);
|
|
1808
|
+
|
|
1809
|
+
// Calculate bsums using SIMD
|
|
1810
|
+
v128_t sum16 = wasm_i16x8_add(
|
|
1811
|
+
wasm_i16x8_extend_low_i8x16(i8),
|
|
1812
|
+
wasm_i16x8_extend_high_i8x16(i8)
|
|
1813
|
+
);
|
|
1814
|
+
v128_t sum32 = wasm_i32x4_add(
|
|
1815
|
+
wasm_i32x4_extend_low_i16x8(sum16),
|
|
1816
|
+
wasm_i32x4_extend_high_i16x8(sum16)
|
|
1817
|
+
);
|
|
1818
|
+
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
|
|
1819
|
+
sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
|
|
1820
|
+
yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
|
|
1821
|
+
}
|
|
1822
|
+
|
|
1823
|
+
yc[i].d = 1.0f / iscale;
|
|
1824
|
+
}
|
|
1825
|
+
#else
|
|
1652
1826
|
wsp_quantize_row_q8_K_ref(x, y, k);
|
|
1827
|
+
#endif
|
|
1653
1828
|
}
|
|
1654
1829
|
|
|
1655
1830
|
//===================================== Dot products =================================
|
|
@@ -1791,11 +1966,12 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
1791
1966
|
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
|
1792
1967
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
|
1793
1968
|
|
|
1794
|
-
float32_t _scale[4] = {
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1969
|
+
float32_t _scale[4] = {
|
|
1970
|
+
WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
|
|
1971
|
+
WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
|
|
1972
|
+
WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
|
|
1973
|
+
WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)
|
|
1974
|
+
};
|
|
1799
1975
|
float32x4_t scale = vld1q_f32(_scale);
|
|
1800
1976
|
|
|
1801
1977
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
|
@@ -1811,13 +1987,15 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
1811
1987
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
|
1812
1988
|
|
|
1813
1989
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
|
1814
|
-
|
|
1990
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
|
1815
1991
|
}
|
|
1816
|
-
|
|
1992
|
+
|
|
1993
|
+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
|
1817
1994
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
|
1818
1995
|
|
|
1819
|
-
vst1_f32(s, vget_low_f32(sumv2));
|
|
1996
|
+
vst1_f32(s, vget_low_f32 (sumv2));
|
|
1820
1997
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
|
1998
|
+
|
|
1821
1999
|
return;
|
|
1822
2000
|
}
|
|
1823
2001
|
#endif
|
|
@@ -2004,6 +2182,94 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
2004
2182
|
}
|
|
2005
2183
|
|
|
2006
2184
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
2185
|
+
#elif defined __wasm_simd128__
|
|
2186
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
|
2187
|
+
|
|
2188
|
+
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
|
2189
|
+
const v128_t s8b = wasm_i8x16_splat(0x8);
|
|
2190
|
+
|
|
2191
|
+
for (; ib + 1 < nb; ib += 2) {
|
|
2192
|
+
const block_q4_0 * restrict x0 = &x[ib];
|
|
2193
|
+
const block_q4_0 * restrict x1 = &x[ib + 1];
|
|
2194
|
+
const block_q8_0 * restrict y0 = &y[ib];
|
|
2195
|
+
const block_q8_0 * restrict y1 = &y[ib + 1];
|
|
2196
|
+
|
|
2197
|
+
// Load and process x0
|
|
2198
|
+
v128_t v0_0 = wasm_v128_load(x0->qs);
|
|
2199
|
+
v128_t v0_0l = wasm_v128_and(v0_0, m4b);
|
|
2200
|
+
v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
|
|
2201
|
+
v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
|
|
2202
|
+
v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
|
|
2203
|
+
|
|
2204
|
+
// Load y0 vectors
|
|
2205
|
+
v128_t y0_l = wasm_v128_load(y0->qs);
|
|
2206
|
+
v128_t y0_h = wasm_v128_load(y0->qs + 16);
|
|
2207
|
+
|
|
2208
|
+
// Extend to i16x8 and compute dot products
|
|
2209
|
+
v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
|
|
2210
|
+
v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
|
|
2211
|
+
v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
|
|
2212
|
+
v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
|
|
2213
|
+
|
|
2214
|
+
v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
|
|
2215
|
+
v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
|
|
2216
|
+
v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
|
|
2217
|
+
v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
|
|
2218
|
+
|
|
2219
|
+
v128_t dp0 = wasm_i32x4_add(
|
|
2220
|
+
wasm_i32x4_add(
|
|
2221
|
+
wasm_i32x4_dot_i16x8(dx0l, dy0ll),
|
|
2222
|
+
wasm_i32x4_dot_i16x8(dx0h, dy0lh)
|
|
2223
|
+
),
|
|
2224
|
+
wasm_i32x4_add(
|
|
2225
|
+
wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
|
|
2226
|
+
wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
|
|
2227
|
+
)
|
|
2228
|
+
);
|
|
2229
|
+
|
|
2230
|
+
// Load and process x1
|
|
2231
|
+
v128_t v0_1 = wasm_v128_load(x1->qs);
|
|
2232
|
+
v128_t v0_1l = wasm_v128_and(v0_1, m4b);
|
|
2233
|
+
v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
|
|
2234
|
+
v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
|
|
2235
|
+
v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
|
|
2236
|
+
|
|
2237
|
+
// Load y1 vectors
|
|
2238
|
+
v128_t y1_l = wasm_v128_load(y1->qs);
|
|
2239
|
+
v128_t y1_h = wasm_v128_load(y1->qs + 16);
|
|
2240
|
+
|
|
2241
|
+
// Extend to i16x8 and compute dot products
|
|
2242
|
+
v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
|
|
2243
|
+
v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
|
|
2244
|
+
v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
|
|
2245
|
+
v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
|
|
2246
|
+
|
|
2247
|
+
v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
|
|
2248
|
+
v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
|
|
2249
|
+
v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
|
|
2250
|
+
v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
|
|
2251
|
+
|
|
2252
|
+
v128_t dp1 = wasm_i32x4_add(
|
|
2253
|
+
wasm_i32x4_add(
|
|
2254
|
+
wasm_i32x4_dot_i16x8(dx1l, dy1ll),
|
|
2255
|
+
wasm_i32x4_dot_i16x8(dx1h, dy1lh)
|
|
2256
|
+
),
|
|
2257
|
+
wasm_i32x4_add(
|
|
2258
|
+
wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
|
|
2259
|
+
wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
|
|
2260
|
+
)
|
|
2261
|
+
);
|
|
2262
|
+
|
|
2263
|
+
// Accumulate results with scaling
|
|
2264
|
+
float scale0 = WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d);
|
|
2265
|
+
float scale1 = WSP_GGML_FP16_TO_FP32(x1->d) * WSP_GGML_FP16_TO_FP32(y1->d);
|
|
2266
|
+
|
|
2267
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
|
|
2268
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
|
|
2269
|
+
}
|
|
2270
|
+
|
|
2271
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
|
2272
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
|
2007
2273
|
#elif defined(__AVX2__)
|
|
2008
2274
|
// Initialize accumulator with zeros
|
|
2009
2275
|
__m256 acc = _mm256_setzero_ps();
|
|
@@ -2225,21 +2491,22 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
2225
2491
|
}
|
|
2226
2492
|
|
|
2227
2493
|
sumf = hsum_float_8(acc);
|
|
2494
|
+
|
|
2228
2495
|
#elif defined(__loongarch_sx)
|
|
2229
2496
|
// set constants
|
|
2230
2497
|
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
|
|
2231
2498
|
const __m128i off = __lsx_vreplgr2vr_b(8);
|
|
2232
2499
|
|
|
2233
2500
|
// Initialize accumulator with zeros
|
|
2234
|
-
__m128 acc_0 = __lsx_vldi(0);
|
|
2235
|
-
__m128 acc_1 = __lsx_vldi(0);
|
|
2236
|
-
__m128 acc_2 = __lsx_vldi(0);
|
|
2237
|
-
__m128 acc_3 = __lsx_vldi(0);
|
|
2501
|
+
__m128 acc_0 = (__m128)__lsx_vldi(0);
|
|
2502
|
+
__m128 acc_1 = (__m128)__lsx_vldi(0);
|
|
2503
|
+
__m128 acc_2 = (__m128)__lsx_vldi(0);
|
|
2504
|
+
__m128 acc_3 = (__m128)__lsx_vldi(0);
|
|
2238
2505
|
|
|
2239
2506
|
for (; ib + 1 < nb; ib += 2) {
|
|
2240
2507
|
|
|
2241
2508
|
// Compute combined scale for the block 0 and 1
|
|
2242
|
-
const __m128 d_0_1 = __lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d) );
|
|
2509
|
+
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d) );
|
|
2243
2510
|
|
|
2244
2511
|
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
|
|
2245
2512
|
|
|
@@ -2257,7 +2524,7 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
2257
2524
|
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
|
|
2258
2525
|
|
|
2259
2526
|
// Compute combined scale for the block 2 and 3
|
|
2260
|
-
const __m128 d_2_3 = __lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib + 1].d) * WSP_GGML_FP16_TO_FP32(y[ib + 1].d) );
|
|
2527
|
+
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib + 1].d) * WSP_GGML_FP16_TO_FP32(y[ib + 1].d) );
|
|
2261
2528
|
|
|
2262
2529
|
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
|
|
2263
2530
|
|
|
@@ -2291,6 +2558,37 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
2291
2558
|
}
|
|
2292
2559
|
|
|
2293
2560
|
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
|
2561
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
2562
|
+
__vector float acc = vec_splats(0.0f);
|
|
2563
|
+
|
|
2564
|
+
const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
|
|
2565
|
+
const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
|
|
2566
|
+
|
|
2567
|
+
for (; ib < nb; ++ib) {
|
|
2568
|
+
const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
|
|
2569
|
+
const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
|
|
2570
|
+
const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
|
|
2571
|
+
|
|
2572
|
+
const __vector int8_t v_xls = vec_sub(v_xl, v_s);
|
|
2573
|
+
const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
|
|
2574
|
+
|
|
2575
|
+
const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
|
|
2576
|
+
const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
|
2577
|
+
|
|
2578
|
+
const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
|
|
2579
|
+
const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
|
|
2580
|
+
const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
|
|
2581
|
+
const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
|
|
2582
|
+
|
|
2583
|
+
__vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
|
|
2584
|
+
|
|
2585
|
+
const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
|
|
2586
|
+
const __vector float v_d = vec_splats(WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d));
|
|
2587
|
+
|
|
2588
|
+
acc = vec_madd(v_xy, v_d, acc);
|
|
2589
|
+
}
|
|
2590
|
+
|
|
2591
|
+
sumf = acc[0] + acc[1] + acc[2] + acc[3];
|
|
2294
2592
|
#endif
|
|
2295
2593
|
for (; ib < nb; ++ib) {
|
|
2296
2594
|
int sumi0 = 0;
|
|
@@ -2345,10 +2643,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
|
2345
2643
|
const block_q8_1 * restrict b_y0 = &vy0[i];
|
|
2346
2644
|
const block_q8_1 * restrict b_y1 = &vy1[i];
|
|
2347
2645
|
|
|
2348
|
-
float32_t summs_t[4] = {
|
|
2349
|
-
|
|
2350
|
-
|
|
2351
|
-
|
|
2646
|
+
float32_t summs_t[4] = {
|
|
2647
|
+
WSP_GGML_FP16_TO_FP32(b_x0->m) * WSP_GGML_FP16_TO_FP32(b_y0->s),
|
|
2648
|
+
WSP_GGML_FP16_TO_FP32(b_x1->m) * WSP_GGML_FP16_TO_FP32(b_y0->s),
|
|
2649
|
+
WSP_GGML_FP16_TO_FP32(b_x0->m) * WSP_GGML_FP16_TO_FP32(b_y1->s),
|
|
2650
|
+
WSP_GGML_FP16_TO_FP32(b_x1->m) * WSP_GGML_FP16_TO_FP32(b_y1->s)
|
|
2651
|
+
};
|
|
2352
2652
|
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
|
|
2353
2653
|
|
|
2354
2654
|
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
|
@@ -2369,10 +2669,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
|
2369
2669
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
|
2370
2670
|
|
|
2371
2671
|
// mmla into int32x4_t
|
|
2372
|
-
float32_t _scale[4] = {
|
|
2373
|
-
|
|
2374
|
-
|
|
2375
|
-
|
|
2672
|
+
float32_t _scale[4] = {
|
|
2673
|
+
WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
|
|
2674
|
+
WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
|
|
2675
|
+
WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
|
|
2676
|
+
WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)
|
|
2677
|
+
};
|
|
2376
2678
|
float32x4_t scale = vld1q_f32(_scale);
|
|
2377
2679
|
|
|
2378
2680
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
|
@@ -2387,15 +2689,17 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
|
2387
2689
|
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
|
2388
2690
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
|
2389
2691
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
|
2390
|
-
|
|
2692
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
|
2391
2693
|
}
|
|
2392
2694
|
|
|
2393
|
-
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
|
2695
|
+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
|
2394
2696
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
|
2697
|
+
|
|
2395
2698
|
sumv2 = vaddq_f32(sumv2, summs0);
|
|
2396
2699
|
|
|
2397
2700
|
vst1_f32(s, vget_low_f32 (sumv2));
|
|
2398
2701
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
|
2702
|
+
|
|
2399
2703
|
return;
|
|
2400
2704
|
}
|
|
2401
2705
|
#endif
|
|
@@ -2578,13 +2882,42 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
|
2578
2882
|
}
|
|
2579
2883
|
|
|
2580
2884
|
sumf = hsum_float_8(acc) + summs;
|
|
2581
|
-
#
|
|
2582
|
-
|
|
2583
|
-
|
|
2584
|
-
int sumi1 = 0;
|
|
2885
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
2886
|
+
float summs = 0;
|
|
2887
|
+
float32x4_t acc = vec_splats(0.0f);
|
|
2585
2888
|
|
|
2586
|
-
|
|
2587
|
-
|
|
2889
|
+
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
|
2890
|
+
|
|
2891
|
+
#pragma GCC unroll 4
|
|
2892
|
+
for (; ib < nb; ++ib) {
|
|
2893
|
+
__builtin_prefetch(x[ib].qs, 0, 1);
|
|
2894
|
+
__builtin_prefetch(y[ib].qs, 0, 1);
|
|
2895
|
+
|
|
2896
|
+
summs += WSP_GGML_FP16_TO_FP32(x[ib].m) * WSP_GGML_FP16_TO_FP32(y[ib].s);
|
|
2897
|
+
|
|
2898
|
+
const uint8x16_t v_x = vec_xl(0, x[ib].qs);
|
|
2899
|
+
const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
|
|
2900
|
+
const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
|
|
2901
|
+
|
|
2902
|
+
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
|
|
2903
|
+
const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
|
|
2904
|
+
|
|
2905
|
+
const int32x4_t v_xy_ = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
|
2906
|
+
const float32x4_t v_xy = vec_float(v_xy_);
|
|
2907
|
+
|
|
2908
|
+
const float32x4_t v_d = vec_splats(WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d));
|
|
2909
|
+
|
|
2910
|
+
acc = vec_madd(v_xy, v_d, acc);
|
|
2911
|
+
}
|
|
2912
|
+
|
|
2913
|
+
sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
|
|
2914
|
+
#endif
|
|
2915
|
+
for (; ib < nb; ++ib) {
|
|
2916
|
+
int sumi0 = 0;
|
|
2917
|
+
int sumi1 = 0;
|
|
2918
|
+
|
|
2919
|
+
for (int j = 0; j < qk/2; ++j) {
|
|
2920
|
+
const int v0 = (x[ib].qs[j] & 0x0F);
|
|
2588
2921
|
const int v1 = (x[ib].qs[j] >> 4);
|
|
2589
2922
|
|
|
2590
2923
|
sumi0 += (v0 * y[ib].qs[j]);
|
|
@@ -2683,10 +3016,10 @@ void wsp_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
2683
3016
|
}
|
|
2684
3017
|
|
|
2685
3018
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
2686
|
-
#elif defined
|
|
3019
|
+
#elif defined __wasm_simd128__
|
|
2687
3020
|
v128_t sumv = wasm_f32x4_splat(0.0f);
|
|
2688
3021
|
|
|
2689
|
-
uint32_t
|
|
3022
|
+
uint32_t qh_;
|
|
2690
3023
|
uint64_t tmp[4];
|
|
2691
3024
|
|
|
2692
3025
|
// TODO: check if unrolling this is better
|
|
@@ -2697,12 +3030,12 @@ void wsp_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
2697
3030
|
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
|
2698
3031
|
|
|
2699
3032
|
// extract the 5th bit
|
|
2700
|
-
memcpy(&
|
|
3033
|
+
memcpy(&qh_, x0->qh, sizeof(qh_));
|
|
2701
3034
|
|
|
2702
|
-
tmp[0] = table_b2b_1[(
|
|
2703
|
-
tmp[1] = table_b2b_1[(
|
|
2704
|
-
tmp[2] = table_b2b_1[(
|
|
2705
|
-
tmp[3] = table_b2b_1[(
|
|
3035
|
+
tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
|
|
3036
|
+
tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
|
|
3037
|
+
tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
|
|
3038
|
+
tmp[3] = table_b2b_1[(qh_ >> 24) ];
|
|
2706
3039
|
|
|
2707
3040
|
const v128_t qhl = wasm_v128_load(tmp + 0);
|
|
2708
3041
|
const v128_t qhh = wasm_v128_load(tmp + 2);
|
|
@@ -3044,12 +3377,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
|
3044
3377
|
}
|
|
3045
3378
|
|
|
3046
3379
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
|
3047
|
-
#elif defined
|
|
3380
|
+
#elif defined __wasm_simd128__
|
|
3048
3381
|
v128_t sumv = wasm_f32x4_splat(0.0f);
|
|
3049
3382
|
|
|
3050
3383
|
float summs = 0.0f;
|
|
3051
3384
|
|
|
3052
|
-
uint32_t
|
|
3385
|
+
uint32_t qh_;
|
|
3053
3386
|
uint64_t tmp[4];
|
|
3054
3387
|
|
|
3055
3388
|
// TODO: check if unrolling this is better
|
|
@@ -3062,12 +3395,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
|
|
|
3062
3395
|
const v128_t m4b = wasm_i8x16_splat(0x0F);
|
|
3063
3396
|
|
|
3064
3397
|
// extract the 5th bit
|
|
3065
|
-
memcpy(&
|
|
3398
|
+
memcpy(&qh_, x0->qh, sizeof(qh_));
|
|
3066
3399
|
|
|
3067
|
-
tmp[0] = table_b2b_0[(
|
|
3068
|
-
tmp[1] = table_b2b_0[(
|
|
3069
|
-
tmp[2] = table_b2b_0[(
|
|
3070
|
-
tmp[3] = table_b2b_0[(
|
|
3400
|
+
tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
|
|
3401
|
+
tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
|
|
3402
|
+
tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
|
|
3403
|
+
tmp[3] = table_b2b_0[(qh_ >> 24) ];
|
|
3071
3404
|
|
|
3072
3405
|
const v128_t qhl = wasm_v128_load(tmp + 0);
|
|
3073
3406
|
const v128_t qhh = wasm_v128_load(tmp + 2);
|
|
@@ -3372,10 +3705,12 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
3372
3705
|
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
|
3373
3706
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
|
3374
3707
|
|
|
3375
|
-
float32_t _scale[4] = {
|
|
3376
|
-
|
|
3377
|
-
|
|
3378
|
-
|
|
3708
|
+
float32_t _scale[4] = {
|
|
3709
|
+
WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
|
|
3710
|
+
WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
|
|
3711
|
+
WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
|
|
3712
|
+
WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)
|
|
3713
|
+
};
|
|
3379
3714
|
float32x4_t scale = vld1q_f32(_scale);
|
|
3380
3715
|
|
|
3381
3716
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
|
@@ -3391,13 +3726,15 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
3391
3726
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
|
3392
3727
|
|
|
3393
3728
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
|
3394
|
-
|
|
3729
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
|
3395
3730
|
}
|
|
3396
|
-
|
|
3731
|
+
|
|
3732
|
+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
|
3397
3733
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
|
3398
3734
|
|
|
3399
|
-
vst1_f32(s,
|
|
3735
|
+
vst1_f32(s, vget_low_f32 (sumv2));
|
|
3400
3736
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
|
3737
|
+
|
|
3401
3738
|
return;
|
|
3402
3739
|
}
|
|
3403
3740
|
#endif
|
|
@@ -3556,6 +3893,45 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
3556
3893
|
}
|
|
3557
3894
|
|
|
3558
3895
|
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
|
3896
|
+
#elif defined __wasm_simd128__
|
|
3897
|
+
v128_t sumv = wasm_f32x4_splat(0.0f);
|
|
3898
|
+
|
|
3899
|
+
for (; ib < nb; ++ib) {
|
|
3900
|
+
const block_q8_0 * restrict x0 = &x[ib];
|
|
3901
|
+
const block_q8_0 * restrict y0 = &y[ib];
|
|
3902
|
+
|
|
3903
|
+
const v128_t x0_0 = wasm_v128_load(x0->qs);
|
|
3904
|
+
const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
|
|
3905
|
+
const v128_t y0_0 = wasm_v128_load(y0->qs);
|
|
3906
|
+
const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
|
|
3907
|
+
|
|
3908
|
+
// Extend 8-bit to 16-bit
|
|
3909
|
+
const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
|
|
3910
|
+
const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
|
|
3911
|
+
const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
|
|
3912
|
+
const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
|
|
3913
|
+
|
|
3914
|
+
const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
|
|
3915
|
+
const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
|
|
3916
|
+
const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
|
|
3917
|
+
const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
|
|
3918
|
+
|
|
3919
|
+
// Compute dot products
|
|
3920
|
+
const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
|
|
3921
|
+
const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
|
|
3922
|
+
const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
|
|
3923
|
+
const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
|
|
3924
|
+
|
|
3925
|
+
// Sum all dot products
|
|
3926
|
+
const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
|
|
3927
|
+
|
|
3928
|
+
// Convert to float and accumulate
|
|
3929
|
+
const float scale = WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d);
|
|
3930
|
+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
|
|
3931
|
+
}
|
|
3932
|
+
|
|
3933
|
+
sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
|
|
3934
|
+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
|
|
3559
3935
|
#elif defined(__AVX2__)
|
|
3560
3936
|
// Initialize accumulator with zeros
|
|
3561
3937
|
__m256 acc = _mm256_setzero_ps();
|
|
@@ -3669,6 +4045,27 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
|
|
|
3669
4045
|
}
|
|
3670
4046
|
|
|
3671
4047
|
sumf = hsum_float_8(acc);
|
|
4048
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
4049
|
+
__vector float acc = vec_splats(0.0f);
|
|
4050
|
+
|
|
4051
|
+
#pragma GCC unroll 8
|
|
4052
|
+
for (; ib < nb; ++ib) {
|
|
4053
|
+
__builtin_prefetch(x[ib].qs, 0, 1);
|
|
4054
|
+
__builtin_prefetch(y[ib].qs, 0, 1);
|
|
4055
|
+
|
|
4056
|
+
const int8x16_t v_xl = vec_xl(0 , x[ib].qs);
|
|
4057
|
+
const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
|
|
4058
|
+
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
|
|
4059
|
+
const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
|
|
4060
|
+
|
|
4061
|
+
const int32x4_t v_xy_ = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
|
4062
|
+
const float32x4_t v_xy = vec_float(v_xy_);
|
|
4063
|
+
const float32x4_t v_d = vec_splats(WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d));
|
|
4064
|
+
|
|
4065
|
+
acc = vec_madd(v_xy, v_d, acc);
|
|
4066
|
+
}
|
|
4067
|
+
|
|
4068
|
+
sumf = acc[0] + acc[1] + acc[2] + acc[3];
|
|
3672
4069
|
#endif
|
|
3673
4070
|
for (; ib < nb; ++ib) {
|
|
3674
4071
|
int sumi = 0;
|
|
@@ -4430,6 +4827,106 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
4430
4827
|
|
|
4431
4828
|
*s = hsum_float_8(acc);
|
|
4432
4829
|
|
|
4830
|
+
#elif defined __wasm_simd128__
|
|
4831
|
+
float sumf = 0;
|
|
4832
|
+
|
|
4833
|
+
for (int i = 0; i < nb; ++i) {
|
|
4834
|
+
const uint8_t * q2 = x[i].qs;
|
|
4835
|
+
const int8_t * q8 = y[i].qs;
|
|
4836
|
+
const uint8_t * sc = x[i].scales;
|
|
4837
|
+
|
|
4838
|
+
// Vectorized summs calculation
|
|
4839
|
+
v128_t summs_vec = wasm_i32x4_splat(0);
|
|
4840
|
+
{
|
|
4841
|
+
v128_t sc_vec = wasm_v128_load(sc);
|
|
4842
|
+
v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
|
|
4843
|
+
|
|
4844
|
+
v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
|
|
4845
|
+
v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
|
|
4846
|
+
|
|
4847
|
+
v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
|
|
4848
|
+
v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
|
|
4849
|
+
|
|
4850
|
+
summs_vec = wasm_i32x4_add(
|
|
4851
|
+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
|
|
4852
|
+
wasm_i32x4_dot_i16x8(sc_high, bsums2)),
|
|
4853
|
+
summs_vec
|
|
4854
|
+
);
|
|
4855
|
+
|
|
4856
|
+
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
|
|
4857
|
+
summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
|
|
4858
|
+
}
|
|
4859
|
+
int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
|
|
4860
|
+
|
|
4861
|
+
// Vectorized isum calculation
|
|
4862
|
+
int32_t isum = 0;
|
|
4863
|
+
const uint8_t * sc_ptr = sc;
|
|
4864
|
+
const int k_iters = QK_K/128;
|
|
4865
|
+
|
|
4866
|
+
for (int k = 0; k < k_iters; ++k) {
|
|
4867
|
+
v128_t isum_vec = wasm_i32x4_splat(0);
|
|
4868
|
+
int shift = 0;
|
|
4869
|
+
|
|
4870
|
+
for (int j = 0; j < 4; ++j) {
|
|
4871
|
+
const int d0 = (sc_ptr[0] & 0xF);
|
|
4872
|
+
const int d1 = (sc_ptr[1] & 0xF);
|
|
4873
|
+
sc_ptr += 2;
|
|
4874
|
+
|
|
4875
|
+
// Process first 16 elements
|
|
4876
|
+
v128_t q2_0 = wasm_v128_load(q2);
|
|
4877
|
+
v128_t q8_0 = wasm_v128_load(q8);
|
|
4878
|
+
v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
|
|
4879
|
+
v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
|
|
4880
|
+
|
|
4881
|
+
// Process next 16 elements
|
|
4882
|
+
v128_t q2_1 = wasm_v128_load(q2 + 16);
|
|
4883
|
+
v128_t q8_1 = wasm_v128_load(q8 + 16);
|
|
4884
|
+
v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
|
|
4885
|
+
v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
|
|
4886
|
+
|
|
4887
|
+
// Calculate dot products
|
|
4888
|
+
v128_t p0 = wasm_i32x4_dot_i16x8(
|
|
4889
|
+
wasm_i16x8_extend_low_i8x16(q8_0),
|
|
4890
|
+
wasm_i16x8_extend_low_i8x16(q2_bits_0)
|
|
4891
|
+
);
|
|
4892
|
+
v128_t p1 = wasm_i32x4_dot_i16x8(
|
|
4893
|
+
wasm_i16x8_extend_high_i8x16(q8_0),
|
|
4894
|
+
wasm_i16x8_extend_high_i8x16(q2_bits_0)
|
|
4895
|
+
);
|
|
4896
|
+
v128_t p2 = wasm_i32x4_dot_i16x8(
|
|
4897
|
+
wasm_i16x8_extend_low_i8x16(q8_1),
|
|
4898
|
+
wasm_i16x8_extend_low_i8x16(q2_bits_1)
|
|
4899
|
+
);
|
|
4900
|
+
v128_t p3 = wasm_i32x4_dot_i16x8(
|
|
4901
|
+
wasm_i16x8_extend_high_i8x16(q8_1),
|
|
4902
|
+
wasm_i16x8_extend_high_i8x16(q2_bits_1)
|
|
4903
|
+
);
|
|
4904
|
+
|
|
4905
|
+
// Accumulate scaled results
|
|
4906
|
+
v128_t scaled = wasm_i32x4_add(
|
|
4907
|
+
wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
|
|
4908
|
+
wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
|
|
4909
|
+
);
|
|
4910
|
+
|
|
4911
|
+
isum_vec = wasm_i32x4_add(isum_vec, scaled);
|
|
4912
|
+
q8 += 32;
|
|
4913
|
+
shift += 2;
|
|
4914
|
+
}
|
|
4915
|
+
q2 += 32;
|
|
4916
|
+
|
|
4917
|
+
// Horizontal sum of isum_vec
|
|
4918
|
+
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
|
|
4919
|
+
isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
|
|
4920
|
+
isum += wasm_i32x4_extract_lane(isum_vec, 0);
|
|
4921
|
+
}
|
|
4922
|
+
|
|
4923
|
+
const float dall = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
4924
|
+
const float dmin = WSP_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
|
4925
|
+
sumf += dall * isum - dmin * summs;
|
|
4926
|
+
}
|
|
4927
|
+
|
|
4928
|
+
*s = sumf;
|
|
4929
|
+
|
|
4433
4930
|
#elif defined __riscv_v_intrinsic
|
|
4434
4931
|
|
|
4435
4932
|
float sumf = 0;
|
|
@@ -4649,9 +5146,6 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
4649
5146
|
|
|
4650
5147
|
#elif defined __loongarch_asx
|
|
4651
5148
|
|
|
4652
|
-
const __m256i m3 = __lasx_xvreplgr2vr_b(3);
|
|
4653
|
-
const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
|
|
4654
|
-
|
|
4655
5149
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
|
4656
5150
|
|
|
4657
5151
|
for (int i = 0; i < nb; ++i) {
|
|
@@ -4662,18 +5156,15 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
4662
5156
|
const uint8_t * restrict q2 = x[i].qs;
|
|
4663
5157
|
const int8_t * restrict q8 = y[i].qs;
|
|
4664
5158
|
|
|
4665
|
-
const __m128i
|
|
4666
|
-
const __m128i
|
|
4667
|
-
const
|
|
4668
|
-
const __m256i mins = lasx_ext8_16(mins8);
|
|
5159
|
+
const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
|
|
5160
|
+
const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
|
|
5161
|
+
const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
|
|
4669
5162
|
const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
|
|
4670
5163
|
|
|
4671
5164
|
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
|
|
4672
5165
|
|
|
4673
|
-
const
|
|
4674
|
-
const
|
|
4675
|
-
const __m128i h_scales = lasx_extracti128(all_scales, 1);
|
|
4676
|
-
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
|
|
5166
|
+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
|
5167
|
+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
|
|
4677
5168
|
|
|
4678
5169
|
__m256i sumi = __lasx_xvldi(0);
|
|
4679
5170
|
|
|
@@ -4686,20 +5177,20 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
4686
5177
|
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
4687
5178
|
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
4688
5179
|
|
|
4689
|
-
const __m256i q2_0 =
|
|
4690
|
-
const __m256i q2_1 =
|
|
4691
|
-
const __m256i q2_2 =
|
|
4692
|
-
const __m256i q2_3 =
|
|
5180
|
+
const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
|
|
5181
|
+
const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
|
|
5182
|
+
const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
|
|
5183
|
+
const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
|
|
4693
5184
|
|
|
4694
|
-
__m256i p0 =
|
|
4695
|
-
__m256i p1 =
|
|
4696
|
-
__m256i p2 =
|
|
4697
|
-
__m256i p3 =
|
|
5185
|
+
__m256i p0 = lasx_madd_h_b(q2_0, q8_0);
|
|
5186
|
+
__m256i p1 = lasx_madd_h_b(q2_1, q8_1);
|
|
5187
|
+
__m256i p2 = lasx_madd_h_b(q2_2, q8_2);
|
|
5188
|
+
__m256i p3 = lasx_madd_h_b(q2_3, q8_3);
|
|
4698
5189
|
|
|
4699
|
-
p0 = lasx_madd_h(
|
|
4700
|
-
p1 = lasx_madd_h(
|
|
4701
|
-
p2 = lasx_madd_h(
|
|
4702
|
-
p3 = lasx_madd_h(
|
|
5190
|
+
p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
|
|
5191
|
+
p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
|
|
5192
|
+
p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
|
|
5193
|
+
p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
|
|
4703
5194
|
|
|
4704
5195
|
p0 = __lasx_xvadd_w(p0, p1);
|
|
4705
5196
|
p2 = __lasx_xvadd_w(p2, p3);
|
|
@@ -4772,7 +5263,182 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
4772
5263
|
|
|
4773
5264
|
const int nb = n / QK_K;
|
|
4774
5265
|
|
|
4775
|
-
#
|
|
5266
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
5267
|
+
|
|
5268
|
+
uint32_t utmp[4];
|
|
5269
|
+
|
|
5270
|
+
const int8_t m32 = 32;
|
|
5271
|
+
const int vector_length = svcntb()*8;
|
|
5272
|
+
const svuint8_t m3b_sv = svdup_n_u8(0x3);
|
|
5273
|
+
const svint32_t vzero_sv = svdup_n_s32(0);
|
|
5274
|
+
|
|
5275
|
+
const svuint8_t m0_sv = svdup_n_u8(1);
|
|
5276
|
+
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
|
|
5277
|
+
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
|
|
5278
|
+
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
|
|
5279
|
+
svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
|
|
5280
|
+
|
|
5281
|
+
float sum = 0;
|
|
5282
|
+
|
|
5283
|
+
for (int i = 0; i < nb; ++i) {
|
|
5284
|
+
|
|
5285
|
+
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
5286
|
+
|
|
5287
|
+
const uint8_t * restrict q3_sv = x[i].qs;
|
|
5288
|
+
const uint8_t * restrict qh_sv = x[i].hmask;
|
|
5289
|
+
const int8_t * restrict q8_sv = y[i].qs;
|
|
5290
|
+
|
|
5291
|
+
// Set up scales
|
|
5292
|
+
uint32_t * aux = &x[i].scales;
|
|
5293
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
|
5294
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
|
5295
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
|
5296
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
|
5297
|
+
|
|
5298
|
+
int8_t * scale = (int8_t *)utmp;
|
|
5299
|
+
|
|
5300
|
+
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
|
5301
|
+
|
|
5302
|
+
switch (vector_length) {
|
|
5303
|
+
case 128:
|
|
5304
|
+
{
|
|
5305
|
+
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
|
|
5306
|
+
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
|
|
5307
|
+
svuint8_t q3h_sv;
|
|
5308
|
+
|
|
5309
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
|
5310
|
+
svint8_t q3bytes_sv;
|
|
5311
|
+
|
|
5312
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
5313
|
+
|
|
5314
|
+
const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
|
5315
|
+
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
|
5316
|
+
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5317
|
+
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5318
|
+
|
|
5319
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
|
|
5320
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5321
|
+
|
|
5322
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
|
5323
|
+
|
|
5324
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
|
|
5325
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5326
|
+
|
|
5327
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
|
5328
|
+
|
|
5329
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5330
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5331
|
+
|
|
5332
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
|
|
5333
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5334
|
+
|
|
5335
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
|
5336
|
+
|
|
5337
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
|
|
5338
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5339
|
+
|
|
5340
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
|
5341
|
+
|
|
5342
|
+
|
|
5343
|
+
scale += 4;
|
|
5344
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5345
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5346
|
+
|
|
5347
|
+
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
|
|
5348
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5349
|
+
|
|
5350
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
|
5351
|
+
|
|
5352
|
+
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
|
|
5353
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5354
|
+
|
|
5355
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
|
5356
|
+
|
|
5357
|
+
|
|
5358
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5359
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
|
5360
|
+
|
|
5361
|
+
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
|
|
5362
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5363
|
+
|
|
5364
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
|
5365
|
+
|
|
5366
|
+
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
|
|
5367
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5368
|
+
|
|
5369
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
|
5370
|
+
|
|
5371
|
+
if (j == 0) {
|
|
5372
|
+
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
|
|
5373
|
+
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
|
|
5374
|
+
}
|
|
5375
|
+
|
|
5376
|
+
scale += 4;
|
|
5377
|
+
}
|
|
5378
|
+
|
|
5379
|
+
sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
|
|
5380
|
+
} break;
|
|
5381
|
+
case 256:
|
|
5382
|
+
case 512:
|
|
5383
|
+
{
|
|
5384
|
+
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
|
|
5385
|
+
svuint8_t q3h_sv;
|
|
5386
|
+
|
|
5387
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
|
5388
|
+
svint8_t q3bytes_sv;
|
|
5389
|
+
|
|
5390
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
5391
|
+
|
|
5392
|
+
const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
|
|
5393
|
+
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
|
5394
|
+
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
|
5395
|
+
|
|
5396
|
+
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
|
|
5397
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5398
|
+
|
|
5399
|
+
|
|
5400
|
+
svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
|
5401
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
|
5402
|
+
|
|
5403
|
+
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
|
|
5404
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5405
|
+
|
|
5406
|
+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
|
5407
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
|
5408
|
+
|
|
5409
|
+
scale += 4;
|
|
5410
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
|
5411
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
|
5412
|
+
|
|
5413
|
+
q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
|
|
5414
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5415
|
+
|
|
5416
|
+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
|
5417
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
|
5418
|
+
|
|
5419
|
+
q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
|
|
5420
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
|
5421
|
+
|
|
5422
|
+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
|
5423
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
|
5424
|
+
|
|
5425
|
+
if (j == 0) {
|
|
5426
|
+
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
|
|
5427
|
+
}
|
|
5428
|
+
|
|
5429
|
+
scale += 4;
|
|
5430
|
+
}
|
|
5431
|
+
|
|
5432
|
+
sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
|
|
5433
|
+
} break;
|
|
5434
|
+
default:
|
|
5435
|
+
assert(false && "Unsupported vector length");
|
|
5436
|
+
break;
|
|
5437
|
+
}
|
|
5438
|
+
}
|
|
5439
|
+
*s = sum;
|
|
5440
|
+
|
|
5441
|
+
#elif __ARM_NEON
|
|
4776
5442
|
|
|
4777
5443
|
uint32_t aux[3];
|
|
4778
5444
|
uint32_t utmp[4];
|
|
@@ -5112,6 +5778,94 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5112
5778
|
|
|
5113
5779
|
*s = hsum_float_8(acc);
|
|
5114
5780
|
|
|
5781
|
+
#elif defined __wasm_simd128__
|
|
5782
|
+
int8_t aux8[QK_K];
|
|
5783
|
+
float sums[8] = {0};
|
|
5784
|
+
uint32_t auxs[4];
|
|
5785
|
+
|
|
5786
|
+
float sumf = 0;
|
|
5787
|
+
for (int i = 0; i < nb; ++i) {
|
|
5788
|
+
const uint8_t * restrict q3 = x[i].qs;
|
|
5789
|
+
const uint8_t * restrict hm = x[i].hmask;
|
|
5790
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
5791
|
+
|
|
5792
|
+
// Process blocks with SIMD
|
|
5793
|
+
int8_t * a = aux8;
|
|
5794
|
+
uint8_t m = 1;
|
|
5795
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
5796
|
+
for (int shift = 0; shift <= 6; shift += 2) {
|
|
5797
|
+
v128_t v_m = wasm_i8x16_splat(m);
|
|
5798
|
+
for (int l = 0; l < 32; l += 16) {
|
|
5799
|
+
v128_t v_q3 = wasm_v128_load(q3 + l);
|
|
5800
|
+
v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
|
|
5801
|
+
v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
|
|
5802
|
+
|
|
5803
|
+
v128_t v_hm = wasm_v128_load(hm + l);
|
|
5804
|
+
v128_t v_mask = wasm_v128_and(v_hm, v_m);
|
|
5805
|
+
v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
|
|
5806
|
+
|
|
5807
|
+
v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
|
|
5808
|
+
wasm_v128_store(a + l, v_low2);
|
|
5809
|
+
}
|
|
5810
|
+
a += 32;
|
|
5811
|
+
m <<= 1;
|
|
5812
|
+
}
|
|
5813
|
+
q3 += 32;
|
|
5814
|
+
}
|
|
5815
|
+
|
|
5816
|
+
// Extract scales
|
|
5817
|
+
memcpy(auxs, x[i].scales, 12);
|
|
5818
|
+
uint32_t tmp = auxs[2];
|
|
5819
|
+
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
|
5820
|
+
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
|
5821
|
+
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
|
5822
|
+
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
|
5823
|
+
const int8_t * scales = (const int8_t *)auxs;
|
|
5824
|
+
|
|
5825
|
+
// SIMD dot product with register accumulators
|
|
5826
|
+
v128_t v_acc0 = wasm_i32x4_splat(0);
|
|
5827
|
+
v128_t v_acc1 = wasm_i32x4_splat(0);
|
|
5828
|
+
a = aux8;
|
|
5829
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
5830
|
+
const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
|
|
5831
|
+
|
|
5832
|
+
// Process 16 elements per iteration
|
|
5833
|
+
for (int k = 0; k < 2; ++k) {
|
|
5834
|
+
const v128_t v_q8 = wasm_i16x8_load8x8(q8);
|
|
5835
|
+
const v128_t v_a = wasm_i16x8_load8x8(a);
|
|
5836
|
+
|
|
5837
|
+
v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
|
|
5838
|
+
v_prod = wasm_i16x8_mul(v_prod, v_scale);
|
|
5839
|
+
|
|
5840
|
+
v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
|
|
5841
|
+
v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
|
|
5842
|
+
|
|
5843
|
+
q8 += 8;
|
|
5844
|
+
a += 8;
|
|
5845
|
+
}
|
|
5846
|
+
}
|
|
5847
|
+
|
|
5848
|
+
// Accumulate results
|
|
5849
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
5850
|
+
const v128_t v_d = wasm_f32x4_splat(d);
|
|
5851
|
+
v128_t v_sum = wasm_f32x4_add(
|
|
5852
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
|
|
5853
|
+
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
|
|
5854
|
+
);
|
|
5855
|
+
|
|
5856
|
+
// Accumulate into sums vector
|
|
5857
|
+
wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
|
|
5858
|
+
}
|
|
5859
|
+
|
|
5860
|
+
// Horizontal sum
|
|
5861
|
+
v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
|
|
5862
|
+
sumf = wasm_f32x4_extract_lane(v_sum, 0) +
|
|
5863
|
+
wasm_f32x4_extract_lane(v_sum, 1) +
|
|
5864
|
+
wasm_f32x4_extract_lane(v_sum, 2) +
|
|
5865
|
+
wasm_f32x4_extract_lane(v_sum, 3);
|
|
5866
|
+
|
|
5867
|
+
*s = sumf;
|
|
5868
|
+
|
|
5115
5869
|
#elif defined __riscv_v_intrinsic
|
|
5116
5870
|
|
|
5117
5871
|
uint32_t aux[3];
|
|
@@ -5367,8 +6121,6 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5367
6121
|
|
|
5368
6122
|
#elif defined __loongarch_asx
|
|
5369
6123
|
|
|
5370
|
-
const __m256i m3 = __lasx_xvreplgr2vr_b(3);
|
|
5371
|
-
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
|
5372
6124
|
const __m128i m32 = __lsx_vreplgr2vr_b(32);
|
|
5373
6125
|
|
|
5374
6126
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
|
@@ -5388,10 +6140,9 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5388
6140
|
(aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
|
|
5389
6141
|
(aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
|
|
5390
6142
|
scales128 = __lsx_vsub_b(scales128, m32);
|
|
5391
|
-
|
|
5392
|
-
const
|
|
5393
|
-
const
|
|
5394
|
-
const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
|
|
6143
|
+
|
|
6144
|
+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
|
6145
|
+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
|
|
5395
6146
|
|
|
5396
6147
|
// high bit
|
|
5397
6148
|
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
|
|
@@ -5399,35 +6150,23 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5399
6150
|
// integer accumulator
|
|
5400
6151
|
__m256i sumi = __lasx_xvldi(0);
|
|
5401
6152
|
|
|
5402
|
-
int bit = 0;
|
|
5403
|
-
int is = 0;
|
|
5404
|
-
__m256i xvbit;
|
|
5405
|
-
|
|
5406
|
-
|
|
5407
6153
|
for (int j = 0; j < QK_K/128; ++j) {
|
|
5408
6154
|
// load low 2 bits
|
|
5409
6155
|
const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
|
|
5410
6156
|
|
|
5411
|
-
xvbit = __lasx_xvreplgr2vr_h(bit);
|
|
5412
6157
|
// prepare low and high bits
|
|
5413
|
-
const __m256i q3l_0 =
|
|
5414
|
-
const __m256i
|
|
5415
|
-
|
|
5416
|
-
|
|
5417
|
-
|
|
5418
|
-
const __m256i
|
|
5419
|
-
const __m256i
|
|
5420
|
-
|
|
5421
|
-
|
|
5422
|
-
|
|
5423
|
-
const __m256i
|
|
5424
|
-
const __m256i
|
|
5425
|
-
++bit;
|
|
5426
|
-
|
|
5427
|
-
xvbit = __lasx_xvreplgr2vr_h(bit);
|
|
5428
|
-
const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
|
|
5429
|
-
const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
|
|
5430
|
-
++bit;
|
|
6158
|
+
const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
|
|
6159
|
+
const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
|
|
6160
|
+
const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
|
|
6161
|
+
const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
|
|
6162
|
+
const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
|
|
6163
|
+
const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
|
|
6164
|
+
const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
|
|
6165
|
+
const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
|
|
6166
|
+
const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
|
|
6167
|
+
const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
|
|
6168
|
+
const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
|
|
6169
|
+
const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
|
|
5431
6170
|
|
|
5432
6171
|
// load Q8 quants
|
|
5433
6172
|
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
@@ -5435,29 +6174,16 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5435
6174
|
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
5436
6175
|
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
5437
6176
|
|
|
5438
|
-
|
|
5439
|
-
|
|
5440
|
-
|
|
5441
|
-
__m256i
|
|
5442
|
-
__m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
|
|
5443
|
-
__m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
|
|
5444
|
-
__m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
|
|
5445
|
-
|
|
5446
|
-
__m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
|
|
5447
|
-
__m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
|
|
5448
|
-
__m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
|
|
5449
|
-
__m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
|
|
5450
|
-
|
|
5451
|
-
p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
|
|
5452
|
-
p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
|
|
5453
|
-
p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
|
|
5454
|
-
p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
|
|
6177
|
+
__m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
|
|
6178
|
+
__m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
|
|
6179
|
+
__m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
|
|
6180
|
+
__m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
|
|
5455
6181
|
|
|
5456
6182
|
// multiply with scales
|
|
5457
|
-
p16_0 = lasx_madd_h(
|
|
5458
|
-
p16_1 = lasx_madd_h(
|
|
5459
|
-
p16_2 = lasx_madd_h(
|
|
5460
|
-
p16_3 = lasx_madd_h(
|
|
6183
|
+
p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
|
|
6184
|
+
p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
|
|
6185
|
+
p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
|
|
6186
|
+
p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
|
|
5461
6187
|
|
|
5462
6188
|
// accumulate
|
|
5463
6189
|
p16_0 = __lasx_xvadd_w(p16_0, p16_1);
|
|
@@ -5465,7 +6191,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5465
6191
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
|
|
5466
6192
|
}
|
|
5467
6193
|
// multiply with block scale and accumulate
|
|
5468
|
-
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc)
|
|
6194
|
+
acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
|
|
5469
6195
|
}
|
|
5470
6196
|
|
|
5471
6197
|
*s = hsum_float_8(acc);
|
|
@@ -5556,7 +6282,88 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5556
6282
|
|
|
5557
6283
|
uint32_t utmp[4];
|
|
5558
6284
|
|
|
5559
|
-
#ifdef
|
|
6285
|
+
#ifdef __ARM_FEATURE_SVE
|
|
6286
|
+
float sumf = 0;
|
|
6287
|
+
for (int i = 0; i < nb; ++i) {
|
|
6288
|
+
|
|
6289
|
+
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
6290
|
+
const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
|
|
6291
|
+
|
|
6292
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
|
6293
|
+
|
|
6294
|
+
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
|
|
6295
|
+
|
|
6296
|
+
uint32x2_t mins8 = { 0 };
|
|
6297
|
+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
|
|
6298
|
+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
|
|
6299
|
+
|
|
6300
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
6301
|
+
utmp[0] &= kmask1;
|
|
6302
|
+
|
|
6303
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
|
6304
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
|
6305
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
|
6306
|
+
sumf -= dmin * vaddvq_s32(prod);
|
|
6307
|
+
|
|
6308
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
|
6309
|
+
|
|
6310
|
+
const uint8_t * restrict q4 = x[i].qs;
|
|
6311
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
6312
|
+
|
|
6313
|
+
const int vector_length = wsp_ggml_cpu_get_sve_cnt()*8;
|
|
6314
|
+
const svuint8_t m4b = svdup_n_u8(0xf);
|
|
6315
|
+
const svint32_t mzero = svdup_n_s32(0);
|
|
6316
|
+
svint32_t sumi1 = svdup_n_s32(0);
|
|
6317
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
|
6318
|
+
svint32_t sumi1_2 = svdup_n_s32(0);
|
|
6319
|
+
svint32_t sumi2 = svdup_n_s32(0);
|
|
6320
|
+
svint32_t sumi2_1 = svdup_n_s32(0);
|
|
6321
|
+
svint32_t sumi2_2 = svdup_n_s32(0);
|
|
6322
|
+
switch (vector_length) {
|
|
6323
|
+
case 128:
|
|
6324
|
+
{
|
|
6325
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
6326
|
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
|
|
6327
|
+
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
6328
|
+
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
|
6329
|
+
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
|
|
6330
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
6331
|
+
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
|
6332
|
+
|
|
6333
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
|
|
6334
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
6335
|
+
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
|
6336
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
|
|
6337
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
6338
|
+
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
|
6339
|
+
q4 += 32;
|
|
6340
|
+
}
|
|
6341
|
+
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
|
|
6342
|
+
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
|
|
6343
|
+
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
|
|
6344
|
+
} break;
|
|
6345
|
+
case 256:
|
|
6346
|
+
case 512:
|
|
6347
|
+
{
|
|
6348
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
6349
|
+
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
|
|
6350
|
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
|
|
6351
|
+
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
|
6352
|
+
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
|
6353
|
+
|
|
6354
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
|
|
6355
|
+
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
|
6356
|
+
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
|
6357
|
+
}
|
|
6358
|
+
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
|
|
6359
|
+
} break;
|
|
6360
|
+
default:
|
|
6361
|
+
assert(false && "Unsupported vector length");
|
|
6362
|
+
break;
|
|
6363
|
+
}
|
|
6364
|
+
}
|
|
6365
|
+
*s = sumf;
|
|
6366
|
+
#elif defined __ARM_NEON
|
|
5560
6367
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
5561
6368
|
const int32x4_t mzero = vdupq_n_s32(0);
|
|
5562
6369
|
|
|
@@ -5595,26 +6402,127 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5595
6402
|
int32_t sumi2 = 0;
|
|
5596
6403
|
|
|
5597
6404
|
for (int j = 0; j < QK_K/64; ++j) {
|
|
5598
|
-
const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32;
|
|
5599
|
-
|
|
5600
|
-
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
5601
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
|
5602
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
|
5603
|
-
|
|
5604
|
-
const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
5605
|
-
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
|
5606
|
-
|
|
5607
|
-
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
5608
|
-
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
|
5609
|
-
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
|
6405
|
+
const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32;
|
|
6406
|
+
|
|
6407
|
+
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
6408
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
|
6409
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
|
6410
|
+
|
|
6411
|
+
const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
6412
|
+
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
|
6413
|
+
|
|
6414
|
+
q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
|
|
6415
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
|
6416
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
|
6417
|
+
|
|
6418
|
+
const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
|
6419
|
+
|
|
6420
|
+
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
|
6421
|
+
}
|
|
6422
|
+
|
|
6423
|
+
sumf += d * (sumi1 + sumi2);
|
|
6424
|
+
|
|
6425
|
+
}
|
|
6426
|
+
|
|
6427
|
+
*s = sumf;
|
|
6428
|
+
|
|
6429
|
+
#elif defined __wasm_simd128__
|
|
6430
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
|
6431
|
+
float sumf = 0;
|
|
6432
|
+
|
|
6433
|
+
for (int i = 0; i < nb; ++i) {
|
|
6434
|
+
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
6435
|
+
const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
|
|
6436
|
+
|
|
6437
|
+
const uint8_t * restrict q4 = x[i].qs;
|
|
6438
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
6439
|
+
|
|
6440
|
+
// Process scales and mins
|
|
6441
|
+
memcpy(utmp, x[i].scales, 12);
|
|
6442
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
|
6443
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
|
6444
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
6445
|
+
utmp[2] = uaux;
|
|
6446
|
+
utmp[0] &= kmask1;
|
|
6447
|
+
|
|
6448
|
+
// Sum mins * q8sums
|
|
6449
|
+
int32_t sumi = 0;
|
|
6450
|
+
const int16_t * restrict q8sums = y[i].bsums;
|
|
6451
|
+
const uint8_t * m = (const uint8_t *)&utmp[2];
|
|
6452
|
+
for (int j = 0; j < 16; j += 2) {
|
|
6453
|
+
sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
|
|
6454
|
+
}
|
|
6455
|
+
sumf -= dmin * sumi;
|
|
6456
|
+
|
|
6457
|
+
int32_t sumi1 = 0;
|
|
6458
|
+
int32_t sumi2 = 0;
|
|
6459
|
+
|
|
6460
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
6461
|
+
// Load 64 4-bit weights (32 bytes)
|
|
6462
|
+
const v128_t q4x0 = wasm_v128_load(q4);
|
|
6463
|
+
const v128_t q4x1 = wasm_v128_load(q4 + 16);
|
|
6464
|
+
q4 += 32;
|
|
5610
6465
|
|
|
5611
|
-
|
|
6466
|
+
// Split into low/high nibbles
|
|
6467
|
+
const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
|
|
6468
|
+
const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
|
|
6469
|
+
const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
|
|
6470
|
+
const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
|
|
6471
|
+
|
|
6472
|
+
// Load 64 8-bit values (64 bytes)
|
|
6473
|
+
const v128_t q8x0 = wasm_v128_load(q8);
|
|
6474
|
+
const v128_t q8x1 = wasm_v128_load(q8 + 16);
|
|
6475
|
+
const v128_t q8x2 = wasm_v128_load(q8 + 32);
|
|
6476
|
+
const v128_t q8x3 = wasm_v128_load(q8 + 48);
|
|
6477
|
+
q8 += 64;
|
|
5612
6478
|
|
|
5613
|
-
|
|
6479
|
+
// Low nibble products
|
|
6480
|
+
v128_t vacc1 = wasm_i32x4_dot_i16x8(
|
|
6481
|
+
wasm_i16x8_extend_low_i8x16(q4l0),
|
|
6482
|
+
wasm_i16x8_extend_low_i8x16(q8x0)
|
|
6483
|
+
);
|
|
6484
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
|
6485
|
+
wasm_i16x8_extend_high_i8x16(q4l0),
|
|
6486
|
+
wasm_i16x8_extend_high_i8x16(q8x0)
|
|
6487
|
+
));
|
|
6488
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
|
6489
|
+
wasm_i16x8_extend_low_i8x16(q4l1),
|
|
6490
|
+
wasm_i16x8_extend_low_i8x16(q8x1)
|
|
6491
|
+
));
|
|
6492
|
+
vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
|
|
6493
|
+
wasm_i16x8_extend_high_i8x16(q4l1),
|
|
6494
|
+
wasm_i16x8_extend_high_i8x16(q8x1)
|
|
6495
|
+
));
|
|
6496
|
+
|
|
6497
|
+
// High nibble products
|
|
6498
|
+
v128_t vacc2 = wasm_i32x4_dot_i16x8(
|
|
6499
|
+
wasm_i16x8_extend_low_i8x16(q4h0),
|
|
6500
|
+
wasm_i16x8_extend_low_i8x16(q8x2)
|
|
6501
|
+
);
|
|
6502
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
|
6503
|
+
wasm_i16x8_extend_high_i8x16(q4h0),
|
|
6504
|
+
wasm_i16x8_extend_high_i8x16(q8x2)
|
|
6505
|
+
));
|
|
6506
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
|
6507
|
+
wasm_i16x8_extend_low_i8x16(q4h1),
|
|
6508
|
+
wasm_i16x8_extend_low_i8x16(q8x3)
|
|
6509
|
+
));
|
|
6510
|
+
vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
|
|
6511
|
+
wasm_i16x8_extend_high_i8x16(q4h1),
|
|
6512
|
+
wasm_i16x8_extend_high_i8x16(q8x3)
|
|
6513
|
+
));
|
|
6514
|
+
|
|
6515
|
+
// Accumulate scaled results
|
|
6516
|
+
int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
|
|
6517
|
+
wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
|
|
6518
|
+
sumi1 += vacc1_sum * scales[2*j];
|
|
6519
|
+
|
|
6520
|
+
int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
|
|
6521
|
+
wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
|
|
6522
|
+
sumi2 += vacc2_sum * scales[2*j+1];
|
|
5614
6523
|
}
|
|
5615
6524
|
|
|
5616
6525
|
sumf += d * (sumi1 + sumi2);
|
|
5617
|
-
|
|
5618
6526
|
}
|
|
5619
6527
|
|
|
5620
6528
|
*s = sumf;
|
|
@@ -5976,11 +6884,6 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
5976
6884
|
*s = vec_extract(vsumf0, 0);
|
|
5977
6885
|
|
|
5978
6886
|
#elif defined __loongarch_asx
|
|
5979
|
-
WSP_GGML_UNUSED(kmask1);
|
|
5980
|
-
WSP_GGML_UNUSED(kmask2);
|
|
5981
|
-
WSP_GGML_UNUSED(kmask3);
|
|
5982
|
-
|
|
5983
|
-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
|
5984
6887
|
|
|
5985
6888
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
|
5986
6889
|
__m128 acc_m = (__m128)__lsx_vldi(0);
|
|
@@ -6000,33 +6903,34 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
6000
6903
|
const uint8_t * restrict q4 = x[i].qs;
|
|
6001
6904
|
const int8_t * restrict q8 = y[i].qs;
|
|
6002
6905
|
|
|
6003
|
-
const
|
|
6906
|
+
const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
|
|
6907
|
+
const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
|
|
6908
|
+
const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
|
|
6004
6909
|
|
|
6005
6910
|
const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
|
|
6006
6911
|
const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
|
|
6007
|
-
const __m128i prod = lsx_madd_h(
|
|
6912
|
+
const __m128i prod = lsx_madd_h(mins128, q8s);
|
|
6008
6913
|
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
|
|
6009
6914
|
|
|
6010
|
-
const
|
|
6011
|
-
const __m256i scales = lasx_insertf128(sc128, sc128);
|
|
6915
|
+
const __m256i scales = lasx_insertf128(scales128, scales128);
|
|
6012
6916
|
|
|
6013
6917
|
__m256i sumi = __lasx_xvldi(0);
|
|
6014
6918
|
|
|
6015
6919
|
for (int j = 0; j < QK_K/64; ++j) {
|
|
6016
6920
|
|
|
6017
|
-
const __m256i scale_l =
|
|
6018
|
-
const __m256i scale_h =
|
|
6921
|
+
const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
|
|
6922
|
+
const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
|
|
6019
6923
|
|
|
6020
6924
|
const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
|
|
6021
|
-
const __m256i q4l =
|
|
6022
|
-
const __m256i q4h =
|
|
6925
|
+
const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
|
|
6926
|
+
const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
|
|
6023
6927
|
|
|
6024
6928
|
const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
6025
|
-
__m256i p16l =
|
|
6929
|
+
__m256i p16l = lasx_madd_h_b(q4l, q8l);
|
|
6026
6930
|
p16l = lasx_madd_h(scale_l, p16l);
|
|
6027
6931
|
|
|
6028
6932
|
const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
6029
|
-
__m256i p16h =
|
|
6933
|
+
__m256i p16h = lasx_madd_h_b(q4h, q8h);
|
|
6030
6934
|
p16h = lasx_madd_h(scale_h, p16h);
|
|
6031
6935
|
const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
|
|
6032
6936
|
|
|
@@ -6043,9 +6947,78 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
6043
6947
|
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
|
|
6044
6948
|
|
|
6045
6949
|
|
|
6046
|
-
|
|
6047
|
-
|
|
6048
|
-
|
|
6950
|
+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
|
6951
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
6952
|
+
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
|
6953
|
+
const int32x4_t v_z = vec_splat_s32(0);
|
|
6954
|
+
|
|
6955
|
+
uint8x16_t v_x[2];
|
|
6956
|
+
int8x16_t v_xl[2];
|
|
6957
|
+
int8x16_t v_y[2];
|
|
6958
|
+
|
|
6959
|
+
float sumf = 0;
|
|
6960
|
+
|
|
6961
|
+
for (int i = 0; i < nb; ++i) {
|
|
6962
|
+
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
6963
|
+
const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
|
|
6964
|
+
|
|
6965
|
+
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
|
6966
|
+
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
|
6967
|
+
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
|
|
6968
|
+
|
|
6969
|
+
memcpy(utmp, x[i].scales, 12);
|
|
6970
|
+
|
|
6971
|
+
uint32x4_t v_mins8 = { 0 };
|
|
6972
|
+
v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
|
|
6973
|
+
v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
|
|
6974
|
+
|
|
6975
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
6976
|
+
utmp[0] &= kmask1;
|
|
6977
|
+
|
|
6978
|
+
const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
|
|
6979
|
+
|
|
6980
|
+
const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
|
|
6981
|
+
const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
|
|
6982
|
+
const int32x4_t v_mins = v_minso + v_minse;
|
|
6983
|
+
sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
|
|
6984
|
+
|
|
6985
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
|
6986
|
+
const uint8_t * restrict x0 = x[i].qs;
|
|
6987
|
+
const int8_t * restrict y0 = y[i].qs;
|
|
6988
|
+
|
|
6989
|
+
int32_t sumi1 = 0;
|
|
6990
|
+
int32_t sumi2 = 0;
|
|
6991
|
+
|
|
6992
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
6993
|
+
v_x[0] = vec_xl(0 , x0);
|
|
6994
|
+
v_x[1] = vec_xl(16, x0);
|
|
6995
|
+
x0 += 32;
|
|
6996
|
+
|
|
6997
|
+
v_y[0] = vec_xl(0 , y0);
|
|
6998
|
+
v_y[1] = vec_xl(16, y0);
|
|
6999
|
+
y0 += 32;
|
|
7000
|
+
|
|
7001
|
+
v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
|
|
7002
|
+
v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
|
|
7003
|
+
|
|
7004
|
+
const int32x4_t p1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
|
|
7005
|
+
sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
|
|
7006
|
+
|
|
7007
|
+
v_y[0] = vec_xl(0 , y0);
|
|
7008
|
+
v_y[1] = vec_xl(16, y0);
|
|
7009
|
+
y0 += 32;
|
|
7010
|
+
|
|
7011
|
+
v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
|
|
7012
|
+
v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
|
|
7013
|
+
|
|
7014
|
+
const int32x4_t p2 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
|
|
7015
|
+
sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
|
|
7016
|
+
}
|
|
7017
|
+
|
|
7018
|
+
sumf += d * (sumi1 + sumi2);
|
|
7019
|
+
}
|
|
7020
|
+
|
|
7021
|
+
*s = sumf;
|
|
6049
7022
|
#else
|
|
6050
7023
|
|
|
6051
7024
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
|
@@ -6371,6 +7344,118 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
6371
7344
|
|
|
6372
7345
|
*s = hsum_float_8(acc) + summs;
|
|
6373
7346
|
|
|
7347
|
+
#elif defined __wasm_simd128__
|
|
7348
|
+
//const uint8_t * scales = (const uint8_t*)&utmp[0];
|
|
7349
|
+
float sumf = 0;
|
|
7350
|
+
|
|
7351
|
+
for (int i = 0; i < nb; ++i) {
|
|
7352
|
+
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
7353
|
+
const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
|
|
7354
|
+
|
|
7355
|
+
const uint8_t * restrict q5 = x[i].qs;
|
|
7356
|
+
const uint8_t * restrict qh = x[i].qh;
|
|
7357
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
7358
|
+
|
|
7359
|
+
// Process scales and mins
|
|
7360
|
+
memcpy(utmp, x[i].scales, 12);
|
|
7361
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
|
7362
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
|
7363
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
7364
|
+
utmp[2] = uaux;
|
|
7365
|
+
utmp[0] &= kmask1;
|
|
7366
|
+
|
|
7367
|
+
// Sum mins * q8sums
|
|
7368
|
+
int32_t sumi_mins = 0;
|
|
7369
|
+
const int16_t * restrict q8sums = y[i].bsums;
|
|
7370
|
+
const uint8_t * m = (const uint8_t *)&utmp[2];
|
|
7371
|
+
for (int j = 0; j < 16; j += 2) {
|
|
7372
|
+
sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
|
|
7373
|
+
}
|
|
7374
|
+
sumf -= dmin * sumi_mins; // Correct subtraction
|
|
7375
|
+
|
|
7376
|
+
v128_t qh0 = wasm_v128_load(qh);
|
|
7377
|
+
v128_t qh1 = wasm_v128_load(qh + 16);
|
|
7378
|
+
const uint8_t * sc = (const uint8_t *)utmp;
|
|
7379
|
+
|
|
7380
|
+
int32_t sumi = 0;
|
|
7381
|
+
|
|
7382
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
7383
|
+
const int shift = j * 2;
|
|
7384
|
+
v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
|
|
7385
|
+
v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
|
|
7386
|
+
|
|
7387
|
+
v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
|
|
7388
|
+
v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
|
|
7389
|
+
v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
|
|
7390
|
+
v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
|
|
7391
|
+
|
|
7392
|
+
v128_t q5_0 = wasm_v128_load(q5);
|
|
7393
|
+
v128_t q5_1 = wasm_v128_load(q5 + 16);
|
|
7394
|
+
q5 += 32;
|
|
7395
|
+
|
|
7396
|
+
v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
|
|
7397
|
+
v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
|
|
7398
|
+
v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
|
|
7399
|
+
v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
|
|
7400
|
+
|
|
7401
|
+
v128_t q8_0 = wasm_v128_load(q8);
|
|
7402
|
+
v128_t q8_1 = wasm_v128_load(q8 + 16);
|
|
7403
|
+
v128_t q8_2 = wasm_v128_load(q8 + 32);
|
|
7404
|
+
v128_t q8_3 = wasm_v128_load(q8 + 48);
|
|
7405
|
+
q8 += 64;
|
|
7406
|
+
|
|
7407
|
+
// Process low quants
|
|
7408
|
+
v128_t pl0 = wasm_i32x4_dot_i16x8(
|
|
7409
|
+
wasm_i16x8_extend_low_i8x16(q5l_0),
|
|
7410
|
+
wasm_i16x8_extend_low_i8x16(q8_0)
|
|
7411
|
+
);
|
|
7412
|
+
pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
|
|
7413
|
+
wasm_i16x8_extend_high_i8x16(q5l_0),
|
|
7414
|
+
wasm_i16x8_extend_high_i8x16(q8_0)
|
|
7415
|
+
));
|
|
7416
|
+
v128_t pl1 = wasm_i32x4_dot_i16x8(
|
|
7417
|
+
wasm_i16x8_extend_low_i8x16(q5l_1),
|
|
7418
|
+
wasm_i16x8_extend_low_i8x16(q8_1)
|
|
7419
|
+
);
|
|
7420
|
+
pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
|
|
7421
|
+
wasm_i16x8_extend_high_i8x16(q5l_1),
|
|
7422
|
+
wasm_i16x8_extend_high_i8x16(q8_1)
|
|
7423
|
+
));
|
|
7424
|
+
v128_t sum_low = wasm_i32x4_add(pl0, pl1);
|
|
7425
|
+
|
|
7426
|
+
// Process high quants
|
|
7427
|
+
v128_t ph0 = wasm_i32x4_dot_i16x8(
|
|
7428
|
+
wasm_i16x8_extend_low_i8x16(q5h_0),
|
|
7429
|
+
wasm_i16x8_extend_low_i8x16(q8_2)
|
|
7430
|
+
);
|
|
7431
|
+
ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
|
|
7432
|
+
wasm_i16x8_extend_high_i8x16(q5h_0),
|
|
7433
|
+
wasm_i16x8_extend_high_i8x16(q8_2)
|
|
7434
|
+
));
|
|
7435
|
+
v128_t ph1 = wasm_i32x4_dot_i16x8(
|
|
7436
|
+
wasm_i16x8_extend_low_i8x16(q5h_1),
|
|
7437
|
+
wasm_i16x8_extend_low_i8x16(q8_3)
|
|
7438
|
+
);
|
|
7439
|
+
ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
|
|
7440
|
+
wasm_i16x8_extend_high_i8x16(q5h_1),
|
|
7441
|
+
wasm_i16x8_extend_high_i8x16(q8_3)
|
|
7442
|
+
));
|
|
7443
|
+
v128_t sum_high = wasm_i32x4_add(ph0, ph1);
|
|
7444
|
+
|
|
7445
|
+
// Accumulate with scale factors
|
|
7446
|
+
int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
|
|
7447
|
+
wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
|
|
7448
|
+
int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
|
|
7449
|
+
wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
|
|
7450
|
+
|
|
7451
|
+
sumi += sl * sc[2*j] + sh * sc[2*j+1];
|
|
7452
|
+
}
|
|
7453
|
+
|
|
7454
|
+
sumf += d * sumi;
|
|
7455
|
+
}
|
|
7456
|
+
|
|
7457
|
+
*s = sumf;
|
|
7458
|
+
|
|
6374
7459
|
#elif defined __riscv_v_intrinsic
|
|
6375
7460
|
|
|
6376
7461
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
|
@@ -6593,19 +7678,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
6593
7678
|
*s = vec_extract(vsumf0, 0);
|
|
6594
7679
|
|
|
6595
7680
|
#elif defined __loongarch_asx
|
|
6596
|
-
WSP_GGML_UNUSED(kmask1);
|
|
6597
|
-
WSP_GGML_UNUSED(kmask2);
|
|
6598
|
-
WSP_GGML_UNUSED(kmask3);
|
|
6599
|
-
|
|
6600
|
-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
|
6601
|
-
const __m128i mzero = __lsx_vldi(0);
|
|
6602
|
-
const __m256i mone = __lasx_xvreplgr2vr_b(1);
|
|
6603
7681
|
|
|
6604
7682
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
|
7683
|
+
__m128 acc_m = (__m128)__lsx_vldi(0);
|
|
6605
7684
|
|
|
6606
|
-
|
|
6607
|
-
|
|
6608
|
-
for (int i = 0; i < nb; ++i) {
|
|
7685
|
+
for (int i = 0; i < nb; ++i) {
|
|
6609
7686
|
|
|
6610
7687
|
const uint8_t * restrict q5 = x[i].qs;
|
|
6611
7688
|
const int8_t * restrict q8 = y[i].qs;
|
|
@@ -6620,49 +7697,40 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
6620
7697
|
utmp[2] = uaux;
|
|
6621
7698
|
utmp[0] &= kmask1;
|
|
6622
7699
|
|
|
6623
|
-
const
|
|
7700
|
+
const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
|
|
7701
|
+
const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
|
|
7702
|
+
const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
|
|
6624
7703
|
|
|
6625
7704
|
const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
|
|
6626
7705
|
const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
|
|
6627
|
-
const __m128i prod = lsx_madd_h(
|
|
6628
|
-
|
|
6629
|
-
summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
|
|
7706
|
+
const __m128i prod = lsx_madd_h(mins128, q8s);
|
|
7707
|
+
acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
|
|
6630
7708
|
|
|
6631
|
-
const
|
|
6632
|
-
const __m256i scales = lasx_insertf128(sc128, sc128);
|
|
7709
|
+
const __m256i scales = lasx_insertf128(scales128, scales128);
|
|
6633
7710
|
|
|
6634
7711
|
const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
|
|
6635
|
-
__m256i hmask = mone;
|
|
6636
7712
|
|
|
6637
7713
|
__m256i sumi = __lasx_xvldi(0);
|
|
6638
7714
|
|
|
6639
|
-
int bit = 0;
|
|
6640
|
-
__m256i xvbit;
|
|
6641
|
-
|
|
6642
7715
|
for (int j = 0; j < QK_K/64; ++j) {
|
|
6643
7716
|
|
|
6644
|
-
const __m256i scale_0 =
|
|
6645
|
-
const __m256i scale_1 =
|
|
7717
|
+
const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
|
|
7718
|
+
const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
|
|
6646
7719
|
|
|
6647
7720
|
const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
|
|
6648
7721
|
|
|
6649
|
-
|
|
6650
|
-
const __m256i
|
|
6651
|
-
const __m256i q5h_0 =
|
|
6652
|
-
const __m256i
|
|
6653
|
-
|
|
6654
|
-
|
|
6655
|
-
xvbit = __lasx_xvreplgr2vr_h(bit++);
|
|
6656
|
-
const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
|
|
6657
|
-
const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
|
|
6658
|
-
const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
|
|
6659
|
-
hmask = __lasx_xvslli_h(hmask, 1);
|
|
7722
|
+
const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
|
|
7723
|
+
const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
|
|
7724
|
+
const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
|
|
7725
|
+
const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
|
|
7726
|
+
const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
|
|
7727
|
+
const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
|
|
6660
7728
|
|
|
6661
7729
|
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
6662
7730
|
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
6663
7731
|
|
|
6664
|
-
__m256i p16_0 =
|
|
6665
|
-
__m256i p16_1 =
|
|
7732
|
+
__m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
|
|
7733
|
+
__m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
|
|
6666
7734
|
|
|
6667
7735
|
p16_0 = lasx_madd_h(scale_0, p16_0);
|
|
6668
7736
|
p16_1 = lasx_madd_h(scale_1, p16_1);
|
|
@@ -6676,8 +7744,98 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
6676
7744
|
|
|
6677
7745
|
}
|
|
6678
7746
|
|
|
6679
|
-
|
|
7747
|
+
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
|
|
7748
|
+
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
|
|
7749
|
+
|
|
7750
|
+
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
|
|
7751
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
7752
|
+
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
|
7753
|
+
const uint8x16_t v_1m = vec_splat_u8(0x01);
|
|
7754
|
+
const uint8x16_t v_2m = vec_splat_u8(0x02);
|
|
7755
|
+
|
|
7756
|
+
const int32x4_t v_z = vec_splat_s32(0);
|
|
7757
|
+
|
|
7758
|
+
const uchar8x16_t v_minsm = {
|
|
7759
|
+
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
|
|
7760
|
+
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
|
|
7761
|
+
};
|
|
7762
|
+
|
|
7763
|
+
int8x16_t q5b[4];
|
|
7764
|
+
uint8x16_t q5h[4];
|
|
7765
|
+
|
|
7766
|
+
uint8x16_t v_xl[2];
|
|
7767
|
+
uint8x16_t v_xh[2];
|
|
7768
|
+
int8x16_t v_y[4];
|
|
7769
|
+
|
|
7770
|
+
float sumf = 0;
|
|
7771
|
+
|
|
7772
|
+
for (int i = 0; i < nb; ++i) {
|
|
7773
|
+
const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
7774
|
+
const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
|
|
7775
|
+
|
|
7776
|
+
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
|
7777
|
+
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
|
7778
|
+
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
|
|
7779
|
+
|
|
7780
|
+
memcpy(utmp, x[i].scales, 12);
|
|
7781
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
|
7782
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
|
7783
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
7784
|
+
utmp[2] = uaux;
|
|
7785
|
+
utmp[0] &= kmask1;
|
|
7786
|
+
|
|
7787
|
+
const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
|
|
7788
|
+
const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
|
|
7789
|
+
const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
|
|
7790
|
+
|
|
7791
|
+
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
|
|
7792
|
+
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
|
|
7793
|
+
const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
|
|
7794
|
+
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
|
|
7795
|
+
|
|
7796
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
|
7797
|
+
const uint8_t * restrict x0l = x[i].qs;
|
|
7798
|
+
const uint8_t * restrict x0h = x[i].qh;
|
|
7799
|
+
const int8_t * restrict y0 = y[i].qs;
|
|
7800
|
+
|
|
7801
|
+
v_xh[0] = vec_xl(0 , x0h);
|
|
7802
|
+
v_xh[1] = vec_xl(16, x0h);
|
|
7803
|
+
|
|
7804
|
+
int32_t sumi = 0;
|
|
7805
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
7806
|
+
v_xl[0] = vec_xl(0 , x0l);
|
|
7807
|
+
v_xl[1] = vec_xl(16, x0l);
|
|
7808
|
+
x0l += 32;
|
|
7809
|
+
|
|
7810
|
+
v_y[0] = vec_xl(0 , y0);
|
|
7811
|
+
v_y[1] = vec_xl(16, y0);
|
|
7812
|
+
v_y[2] = vec_xl(32, y0);
|
|
7813
|
+
v_y[3] = vec_xl(48, y0);
|
|
7814
|
+
y0 += 64;
|
|
7815
|
+
|
|
7816
|
+
q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
|
|
7817
|
+
q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
|
|
7818
|
+
q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
|
|
7819
|
+
q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
|
|
7820
|
+
v_xh[0] = vec_sr(v_xh[0], 2);
|
|
7821
|
+
v_xh[1] = vec_sr(v_xh[1], 2);
|
|
7822
|
+
|
|
7823
|
+
q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
|
|
7824
|
+
q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
|
|
7825
|
+
q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
|
|
7826
|
+
q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
|
|
7827
|
+
|
|
7828
|
+
int32x4_t sumi0 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
|
|
7829
|
+
int32x4_t sumi1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
|
|
7830
|
+
|
|
7831
|
+
sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
|
|
7832
|
+
sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
|
|
7833
|
+
}
|
|
7834
|
+
|
|
7835
|
+
sumf += d * sumi - dmin * mins;
|
|
7836
|
+
}
|
|
6680
7837
|
|
|
7838
|
+
*s = sumf;
|
|
6681
7839
|
#else
|
|
6682
7840
|
|
|
6683
7841
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
|
@@ -7034,6 +8192,85 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
7034
8192
|
|
|
7035
8193
|
*s = hsum_float_8(acc);
|
|
7036
8194
|
|
|
8195
|
+
#elif defined __wasm_simd128__
|
|
8196
|
+
int8_t aux8[QK_K] __attribute__((aligned(16)));
|
|
8197
|
+
int32_t aux32[8] __attribute__((aligned(16))) = {0};
|
|
8198
|
+
float sums[8] __attribute__((aligned(16))) = {0};
|
|
8199
|
+
|
|
8200
|
+
for (int i = 0; i < nb; ++i) {
|
|
8201
|
+
// Unpack 6-bit quantized data into aux8 (unchanged)
|
|
8202
|
+
const uint8_t * restrict q4 = x[i].ql;
|
|
8203
|
+
const uint8_t * restrict qh = x[i].qh;
|
|
8204
|
+
int8_t * a = aux8;
|
|
8205
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
8206
|
+
for (int l = 0; l < 32; ++l) {
|
|
8207
|
+
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
|
8208
|
+
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
|
8209
|
+
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
|
8210
|
+
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
|
8211
|
+
}
|
|
8212
|
+
a += 128;
|
|
8213
|
+
q4 += 64;
|
|
8214
|
+
qh += 32;
|
|
8215
|
+
}
|
|
8216
|
+
|
|
8217
|
+
const int8_t * restrict a_ptr = aux8;
|
|
8218
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
8219
|
+
v128_t acc0 = wasm_i32x4_splat(0);
|
|
8220
|
+
v128_t acc1 = wasm_i32x4_splat(0);
|
|
8221
|
+
|
|
8222
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
|
8223
|
+
const int scale = x[i].scales[j];
|
|
8224
|
+
const v128_t vscale = wasm_i32x4_splat(scale);
|
|
8225
|
+
|
|
8226
|
+
// Load 16 elements from a and q8
|
|
8227
|
+
const v128_t a_vec = wasm_v128_load(a_ptr);
|
|
8228
|
+
const v128_t q8_vec = wasm_v128_load(q8);
|
|
8229
|
+
|
|
8230
|
+
// Process low 8 elements
|
|
8231
|
+
v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
|
|
8232
|
+
v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
|
|
8233
|
+
v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
|
|
8234
|
+
v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
|
|
8235
|
+
v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
|
|
8236
|
+
|
|
8237
|
+
// Process high 8 elements
|
|
8238
|
+
v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
|
|
8239
|
+
v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
|
|
8240
|
+
v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
|
|
8241
|
+
v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
|
|
8242
|
+
v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
|
|
8243
|
+
|
|
8244
|
+
// Scale and accumulate
|
|
8245
|
+
prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
|
|
8246
|
+
prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
|
|
8247
|
+
prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
|
|
8248
|
+
prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
|
|
8249
|
+
|
|
8250
|
+
acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
|
|
8251
|
+
acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
|
|
8252
|
+
|
|
8253
|
+
a_ptr += 16;
|
|
8254
|
+
q8 += 16;
|
|
8255
|
+
}
|
|
8256
|
+
|
|
8257
|
+
// Store accumulated results
|
|
8258
|
+
wasm_v128_store(&aux32[0], acc0);
|
|
8259
|
+
wasm_v128_store(&aux32[4], acc1);
|
|
8260
|
+
|
|
8261
|
+
const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8262
|
+
for (int l = 0; l < 8; ++l) {
|
|
8263
|
+
sums[l] += d * aux32[l];
|
|
8264
|
+
}
|
|
8265
|
+
}
|
|
8266
|
+
|
|
8267
|
+
// Sum final results
|
|
8268
|
+
float sumf = 0;
|
|
8269
|
+
for (int l = 0; l < 8; ++l) {
|
|
8270
|
+
sumf += sums[l];
|
|
8271
|
+
}
|
|
8272
|
+
*s = sumf;
|
|
8273
|
+
|
|
7037
8274
|
#elif defined __riscv_v_intrinsic
|
|
7038
8275
|
|
|
7039
8276
|
float sumf = 0;
|
|
@@ -7258,8 +8495,6 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
7258
8495
|
|
|
7259
8496
|
#elif defined __loongarch_asx
|
|
7260
8497
|
|
|
7261
|
-
const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
|
|
7262
|
-
const __m256i m2 = __lasx_xvreplgr2vr_b(3);
|
|
7263
8498
|
const __m256i m32s = __lasx_xvreplgr2vr_b(32);
|
|
7264
8499
|
|
|
7265
8500
|
__m256 acc = (__m256)__lasx_xvldi(0);
|
|
@@ -7272,58 +8507,42 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
7272
8507
|
const uint8_t * restrict qh = x[i].qh;
|
|
7273
8508
|
const int8_t * restrict q8 = y[i].qs;
|
|
7274
8509
|
|
|
7275
|
-
const __m128i
|
|
8510
|
+
const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
|
|
8511
|
+
const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
|
|
8512
|
+
const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
|
|
7276
8513
|
|
|
7277
8514
|
__m256i sumi = __lasx_xvldi(0);
|
|
7278
8515
|
|
|
7279
|
-
int is = 0;
|
|
7280
|
-
|
|
7281
8516
|
for (int j = 0; j < QK_K/128; ++j) {
|
|
7282
8517
|
|
|
7283
|
-
const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
|
|
7284
|
-
const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
|
|
7285
|
-
const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
|
|
7286
|
-
const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
|
|
7287
|
-
is += 4;
|
|
7288
|
-
|
|
7289
8518
|
const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
|
|
7290
8519
|
const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
|
|
7291
8520
|
const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
|
|
7292
8521
|
|
|
7293
|
-
const __m256i q4h_0 =
|
|
7294
|
-
const __m256i q4h_1 =
|
|
7295
|
-
const __m256i q4h_2 =
|
|
7296
|
-
const __m256i q4h_3 =
|
|
8522
|
+
const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
|
|
8523
|
+
const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
|
|
8524
|
+
const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
|
|
8525
|
+
const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
|
|
7297
8526
|
|
|
7298
|
-
const __m256i q4_0 = __lasx_xvor_v(
|
|
7299
|
-
const __m256i q4_1 = __lasx_xvor_v(
|
|
7300
|
-
const __m256i q4_2 = __lasx_xvor_v(
|
|
7301
|
-
const __m256i q4_3 = __lasx_xvor_v(
|
|
8527
|
+
const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
|
|
8528
|
+
const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
|
|
8529
|
+
const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
|
|
8530
|
+
const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
|
|
7302
8531
|
|
|
7303
8532
|
const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
7304
8533
|
const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
7305
8534
|
const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
7306
8535
|
const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
|
|
7307
8536
|
|
|
7308
|
-
__m256i
|
|
7309
|
-
__m256i
|
|
7310
|
-
__m256i
|
|
7311
|
-
__m256i
|
|
7312
|
-
|
|
7313
|
-
__m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
|
|
7314
|
-
__m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
|
|
7315
|
-
__m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
|
|
7316
|
-
__m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
|
|
7317
|
-
|
|
7318
|
-
p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
|
|
7319
|
-
p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
|
|
7320
|
-
p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
|
|
7321
|
-
p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
|
|
8537
|
+
__m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
|
|
8538
|
+
__m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
|
|
8539
|
+
__m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
|
|
8540
|
+
__m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
|
|
7322
8541
|
|
|
7323
|
-
p16_0 = lasx_madd_h(
|
|
7324
|
-
p16_1 = lasx_madd_h(
|
|
7325
|
-
p16_2 = lasx_madd_h(
|
|
7326
|
-
p16_3 = lasx_madd_h(
|
|
8542
|
+
p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
|
|
8543
|
+
p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
|
|
8544
|
+
p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
|
|
8545
|
+
p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
|
|
7327
8546
|
|
|
7328
8547
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
|
|
7329
8548
|
sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
|
|
@@ -7333,7 +8552,130 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
|
7333
8552
|
}
|
|
7334
8553
|
|
|
7335
8554
|
*s = hsum_float_8(acc);
|
|
8555
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
8556
|
+
float sum = 0;
|
|
8557
|
+
|
|
8558
|
+
// Lower 4-bit and upper 2-bit masks
|
|
8559
|
+
const uint8x16_t v_lm = vec_splat_u8(0x0F);
|
|
8560
|
+
const uint8x16_t v_um = vec_splat_u8(0x03);
|
|
8561
|
+
|
|
8562
|
+
const int32x4_t v_z = vec_splat_s32(0);
|
|
8563
|
+
|
|
8564
|
+
int8x16_t q6b[4];
|
|
8565
|
+
uint8x16_t q6h[4];
|
|
8566
|
+
|
|
8567
|
+
uint8x16_t v_xl[4];
|
|
8568
|
+
uint8x16_t v_xh[2];
|
|
8569
|
+
int8x16_t v_y[4];
|
|
8570
|
+
|
|
8571
|
+
for (int i = 0; i < nb; ++i) {
|
|
8572
|
+
const float d_all = WSP_GGML_FP16_TO_FP32(x[i].d);
|
|
8573
|
+
|
|
8574
|
+
const uint8_t * restrict x0l = x[i].ql;
|
|
8575
|
+
const uint8_t * restrict x0h = x[i].qh;
|
|
8576
|
+
const int8_t * restrict y0 = y[i].qs;
|
|
8577
|
+
|
|
8578
|
+
const int8_t * restrict scale = x[i].scales;
|
|
8579
|
+
|
|
8580
|
+
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
|
|
8581
|
+
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
|
|
8582
|
+
|
|
8583
|
+
const int8x16_t v_scale = vec_xl(0, scale);
|
|
8584
|
+
const int16x8_t v_scalel = vec_unpackh(v_scale);
|
|
8585
|
+
const int16x8_t v_scaleh = vec_unpackl(v_scale);
|
|
8586
|
+
|
|
8587
|
+
const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
|
|
8588
|
+
const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
|
|
8589
|
+
const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
|
|
8590
|
+
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
|
|
8591
|
+
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
|
|
8592
|
+
|
|
8593
|
+
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
|
|
8594
|
+
|
|
8595
|
+
int32_t isum = 0;
|
|
8596
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
8597
|
+
// Load model upper 2 bits
|
|
8598
|
+
v_xh[0] = vec_xl(0 , x0h);
|
|
8599
|
+
v_xh[1] = vec_xl(16, x0h);
|
|
8600
|
+
x0h += 32;
|
|
8601
|
+
|
|
8602
|
+
// Load model lower 4 bits
|
|
8603
|
+
v_xl[0] = vec_xl(0 , x0l);
|
|
8604
|
+
v_xl[1] = vec_xl(16, x0l);
|
|
8605
|
+
v_xl[2] = vec_xl(32, x0l);
|
|
8606
|
+
v_xl[3] = vec_xl(48, x0l);
|
|
8607
|
+
x0l += 64;
|
|
8608
|
+
|
|
8609
|
+
// Load activation quants
|
|
8610
|
+
v_y[0] = vec_xl(0 , y0);
|
|
8611
|
+
v_y[1] = vec_xl(16, y0);
|
|
8612
|
+
v_y[2] = vec_xl(32, y0);
|
|
8613
|
+
v_y[3] = vec_xl(48, y0);
|
|
8614
|
+
y0 += 64;
|
|
8615
|
+
|
|
8616
|
+
q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
|
|
8617
|
+
q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
|
|
8618
|
+
uint8x16_t shifted = vec_sr(v_xh[0], 2);
|
|
8619
|
+
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
|
|
8620
|
+
shifted = vec_sr(v_xh[1], 2);
|
|
8621
|
+
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
|
|
8622
|
+
|
|
8623
|
+
q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
|
|
8624
|
+
q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
|
|
8625
|
+
q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
|
|
8626
|
+
q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
|
|
8627
|
+
|
|
8628
|
+
int32x4_t summs0 = wsp_ggml_vec_dot(v_z, q6b[0], v_y[0]);
|
|
8629
|
+
int32x4_t summs1 = wsp_ggml_vec_dot(v_z, q6b[1], v_y[1]);
|
|
8630
|
+
int32x4_t summs2 = wsp_ggml_vec_dot(v_z, q6b[2], v_y[2]);
|
|
8631
|
+
int32x4_t summs3 = wsp_ggml_vec_dot(v_z, q6b[3], v_y[3]);
|
|
8632
|
+
|
|
8633
|
+
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
|
|
8634
|
+
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
|
|
8635
|
+
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
|
|
8636
|
+
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
|
|
8637
|
+
|
|
8638
|
+
scale += 4;
|
|
8639
|
+
|
|
8640
|
+
|
|
8641
|
+
// Load activation quants
|
|
8642
|
+
v_y[0] = vec_xl(0 , y0);
|
|
8643
|
+
v_y[1] = vec_xl(16, y0);
|
|
8644
|
+
v_y[2] = vec_xl(32, y0);
|
|
8645
|
+
v_y[3] = vec_xl(48, y0);
|
|
8646
|
+
y0 += 64;
|
|
8647
|
+
|
|
8648
|
+
shifted = vec_sr(v_xh[0], 4);
|
|
8649
|
+
q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
|
|
8650
|
+
shifted = vec_sr(v_xh[1], 4);
|
|
8651
|
+
q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
|
|
8652
|
+
shifted = vec_sr(v_xh[0], 6);
|
|
8653
|
+
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
|
|
8654
|
+
shifted = vec_sr(v_xh[1], 6);
|
|
8655
|
+
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
|
|
8656
|
+
|
|
8657
|
+
q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
|
|
8658
|
+
q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
|
|
8659
|
+
q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
|
|
8660
|
+
q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
|
|
8661
|
+
|
|
8662
|
+
summs0 = wsp_ggml_vec_dot(v_z, q6b[0], v_y[0]);
|
|
8663
|
+
summs1 = wsp_ggml_vec_dot(v_z, q6b[1], v_y[1]);
|
|
8664
|
+
summs2 = wsp_ggml_vec_dot(v_z, q6b[2], v_y[2]);
|
|
8665
|
+
summs3 = wsp_ggml_vec_dot(v_z, q6b[3], v_y[3]);
|
|
8666
|
+
|
|
8667
|
+
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
|
|
8668
|
+
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
|
|
8669
|
+
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
|
|
8670
|
+
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
|
|
8671
|
+
|
|
8672
|
+
scale += 4;
|
|
8673
|
+
}
|
|
7336
8674
|
|
|
8675
|
+
sum += d_all * y[i].d * (isum - 32 * mins);
|
|
8676
|
+
}
|
|
8677
|
+
|
|
8678
|
+
*s = sum;
|
|
7337
8679
|
#else
|
|
7338
8680
|
|
|
7339
8681
|
int8_t aux8[QK_K];
|
|
@@ -7694,7 +9036,57 @@ void wsp_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const v
|
|
|
7694
9036
|
}
|
|
7695
9037
|
|
|
7696
9038
|
*s = 0.125f * hsum_float_8(accumf);
|
|
7697
|
-
|
|
9039
|
+
//#elif defined(__VXE__) || defined(__VXE2__)
|
|
9040
|
+
// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
|
9041
|
+
//
|
|
9042
|
+
// uint32_t aux32[4];
|
|
9043
|
+
// const uint8_t * aux8 = (const uint8_t *)aux32;
|
|
9044
|
+
//
|
|
9045
|
+
// float sumf = 0;
|
|
9046
|
+
//
|
|
9047
|
+
// for (int i = 0; i < nb; ++i) {
|
|
9048
|
+
// const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
9049
|
+
// const uint16_t * restrict q2 = x[i].qs;
|
|
9050
|
+
// const int8_t * restrict q8 = y[i].qs;
|
|
9051
|
+
//
|
|
9052
|
+
// float sumf1 = 0, sumf2 = 0;
|
|
9053
|
+
//
|
|
9054
|
+
// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
|
|
9055
|
+
// int8x16_t q8b0 = vec_xl( 0, q8);
|
|
9056
|
+
// int8x16_t qb81 = vec_xl(16, q8);
|
|
9057
|
+
// int8x16_t q8b2 = vec_xl(32, q8);
|
|
9058
|
+
// int8x16_t q8b3 = vec_xl(48, q8);
|
|
9059
|
+
// q8 += 64;
|
|
9060
|
+
//
|
|
9061
|
+
// memcpy(aux32, q2, 4 * sizeof(uint32_t));
|
|
9062
|
+
// q2 += 8;
|
|
9063
|
+
//
|
|
9064
|
+
// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
|
|
9065
|
+
// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
|
|
9066
|
+
// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
|
|
9067
|
+
// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
|
|
9068
|
+
//
|
|
9069
|
+
// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) };
|
|
9070
|
+
// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
|
|
9071
|
+
// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) };
|
|
9072
|
+
// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
|
|
9073
|
+
//
|
|
9074
|
+
// q2u0 = vec_mul(q2u0, q2s0);
|
|
9075
|
+
// q2u1 = vec_mul(q2u1, q2s1);
|
|
9076
|
+
// q2u2 = vec_mul(q2u2, q2s2);
|
|
9077
|
+
// q2u3 = vec_mul(q2u3, q2s3);
|
|
9078
|
+
//
|
|
9079
|
+
// const int32x4_t p1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
|
|
9080
|
+
// const int32x4_t p2 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
|
|
9081
|
+
//
|
|
9082
|
+
// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
|
|
9083
|
+
// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
|
|
9084
|
+
// }
|
|
9085
|
+
//
|
|
9086
|
+
// sumf += d * (sumf1 + sumf2);
|
|
9087
|
+
// }
|
|
9088
|
+
//
|
|
9089
|
+
// *s = 0.25f * sumf;
|
|
7698
9090
|
#else
|
|
7699
9091
|
|
|
7700
9092
|
uint32_t aux32[2];
|
|
@@ -9648,13 +11040,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
|
|
9648
11040
|
}
|
|
9649
11041
|
#elif defined(__loongarch_asx)
|
|
9650
11042
|
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
|
9651
|
-
const __m256i
|
|
9652
|
-
const __m256i
|
|
9653
|
-
|
|
9654
|
-
tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
|
|
9655
|
-
tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
|
|
9656
|
-
tmp3 = __lasx_xvadd_h(tmp1, tmp2);
|
|
9657
|
-
return __lasx_xvsat_h(tmp3, 15);
|
|
11043
|
+
const __m256i a = __lasx_xvmulwev_h_b(x, y);
|
|
11044
|
+
const __m256i b = __lasx_xvmulwod_h_b(x, y);
|
|
11045
|
+
return __lasx_xvadd_h(a, b);
|
|
9658
11046
|
}
|
|
9659
11047
|
#endif
|
|
9660
11048
|
|
|
@@ -10459,6 +11847,27 @@ void wsp_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const vo
|
|
|
10459
11847
|
|
|
10460
11848
|
sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
|
|
10461
11849
|
|
|
11850
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
11851
|
+
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
|
|
11852
|
+
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
|
11853
|
+
|
|
11854
|
+
for (; ib < nb; ++ib) {
|
|
11855
|
+
const block_iq4_nl * restrict x0 = &x[ib];
|
|
11856
|
+
const block_q8_0 * restrict y0 = &y[ib];
|
|
11857
|
+
|
|
11858
|
+
const uint8x16_t v_x = vec_xl(0, x0->qs);
|
|
11859
|
+
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
|
|
11860
|
+
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
|
|
11861
|
+
|
|
11862
|
+
v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
|
|
11863
|
+
v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
|
|
11864
|
+
|
|
11865
|
+
const int8x16_t v_yl = vec_xl(0 , y0->qs);
|
|
11866
|
+
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
|
|
11867
|
+
const int32x4_t v_xy = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
|
|
11868
|
+
|
|
11869
|
+
sumf += WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
|
|
11870
|
+
}
|
|
10462
11871
|
#endif
|
|
10463
11872
|
for (; ib < nb; ++ib) {
|
|
10464
11873
|
const float d = WSP_GGML_FP16_TO_FP32(y[ib].d)*WSP_GGML_FP16_TO_FP32(x[ib].d);
|
|
@@ -10704,67 +12113,31 @@ void wsp_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const vo
|
|
|
10704
12113
|
#elif defined(__loongarch_asx)
|
|
10705
12114
|
|
|
10706
12115
|
const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
|
|
10707
|
-
const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
|
|
10708
12116
|
|
|
10709
12117
|
__m256 accum = (__m256)__lasx_xvldi(0);
|
|
10710
|
-
__m256i tmp1;
|
|
10711
|
-
__m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
|
|
10712
12118
|
|
|
10713
|
-
mask_8f = __lsx_vreplgr2vr_b(0x8f);
|
|
10714
12119
|
for (int ibl = 0; ibl < nb; ++ibl) {
|
|
10715
12120
|
const uint8_t * qs = x[ibl].qs;
|
|
10716
12121
|
const int8_t * q8 = y[ibl].qs;
|
|
10717
12122
|
uint16_t sh = x[ibl].scales_h;
|
|
10718
12123
|
__m256i sumi1 = __lasx_xvldi(0);
|
|
10719
12124
|
__m256i sumi2 = __lasx_xvldi(0);
|
|
10720
|
-
__m128i zero = __lsx_vldi(0);
|
|
10721
12125
|
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
|
10722
|
-
const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);
|
|
10723
|
-
const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);
|
|
12126
|
+
const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
|
|
12127
|
+
const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
|
|
10724
12128
|
const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
|
|
10725
12129
|
const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
|
|
10726
|
-
|
|
10727
|
-
|
|
10728
|
-
|
|
10729
|
-
|
|
10730
|
-
tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
|
|
10731
|
-
|
|
10732
|
-
tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
|
|
10733
|
-
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
|
10734
|
-
mask = __lsx_vsle_b(zero, tmp2);
|
|
10735
|
-
tmp4 = __lsx_vand_v(tmp0, mask);
|
|
10736
|
-
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
|
|
10737
|
-
|
|
10738
|
-
const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
|
|
10739
|
-
|
|
10740
|
-
tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
|
|
10741
|
-
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
|
10742
|
-
mask = __lsx_vsle_b(zero, tmp2);
|
|
10743
|
-
tmp3 = __lsx_vand_v(tmp0, mask);
|
|
10744
|
-
tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
|
|
10745
|
-
|
|
10746
|
-
tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
|
|
10747
|
-
tmp0 = __lsx_vori_b(tmp2, 0x10);
|
|
10748
|
-
mask = __lsx_vsle_b(zero, tmp2);
|
|
10749
|
-
tmp4 = __lsx_vand_v(tmp0, mask);
|
|
10750
|
-
tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
|
|
10751
|
-
|
|
10752
|
-
const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
|
|
10753
|
-
|
|
12130
|
+
const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
|
|
12131
|
+
__lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
|
|
12132
|
+
const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
|
|
12133
|
+
__lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
|
|
10754
12134
|
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
|
|
10755
12135
|
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
|
|
10756
12136
|
const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
|
|
10757
12137
|
const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
|
|
10758
12138
|
sh >>= 4;
|
|
10759
|
-
__m256i
|
|
10760
|
-
|
|
10761
|
-
tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
|
|
10762
|
-
tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
|
|
10763
|
-
const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
|
|
10764
|
-
tmp1 = __lasx_xvreplgr2vr_h(ls2);
|
|
10765
|
-
tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
|
|
10766
|
-
tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
|
|
10767
|
-
const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
|
|
12139
|
+
const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
|
|
12140
|
+
const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
|
|
10768
12141
|
sumi1 = __lasx_xvadd_w(p_1, sumi1);
|
|
10769
12142
|
sumi2 = __lasx_xvadd_w(p_2, sumi2);
|
|
10770
12143
|
}
|
|
@@ -10773,6 +12146,56 @@ void wsp_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const vo
|
|
|
10773
12146
|
}
|
|
10774
12147
|
|
|
10775
12148
|
*s = hsum_float_8(accum);
|
|
12149
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
12150
|
+
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
|
|
12151
|
+
const uint8x16_t v_m = vec_splat_u8(0x0F);
|
|
12152
|
+
|
|
12153
|
+
float sumf = 0;
|
|
12154
|
+
|
|
12155
|
+
for (int ibl = 0; ibl < nb; ++ibl) {
|
|
12156
|
+
const uint8_t * restrict q4 = x[ibl].qs;
|
|
12157
|
+
const int8_t * restrict q8 = y[ibl].qs;
|
|
12158
|
+
|
|
12159
|
+
uint16_t h = x[ibl].scales_h;
|
|
12160
|
+
|
|
12161
|
+
int sumi1 = 0, sumi2 = 0;
|
|
12162
|
+
for (int ib = 0; ib < QK_K/64; ++ib) {
|
|
12163
|
+
const uint8x16_t v_x0 = vec_xl(0 , q4);
|
|
12164
|
+
const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
|
|
12165
|
+
q4 += 32;
|
|
12166
|
+
|
|
12167
|
+
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
|
|
12168
|
+
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
|
|
12169
|
+
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
|
|
12170
|
+
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
|
|
12171
|
+
|
|
12172
|
+
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
|
|
12173
|
+
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
|
|
12174
|
+
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
|
|
12175
|
+
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
|
|
12176
|
+
|
|
12177
|
+
const int8x16_t v_y0 = vec_xl( 0, q8);
|
|
12178
|
+
const int8x16_t v_y1 = vec_xl(16, q8);
|
|
12179
|
+
const int8x16_t v_y2 = vec_xl(32, q8);
|
|
12180
|
+
const int8x16_t v_y3 = vec_xl(48, q8);
|
|
12181
|
+
q8 += 64;
|
|
12182
|
+
|
|
12183
|
+
int32x4_t vsumi0 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
|
|
12184
|
+
int32x4_t vsumi1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
|
|
12185
|
+
|
|
12186
|
+
int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
|
|
12187
|
+
int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
|
|
12188
|
+
|
|
12189
|
+
h >>= 4;
|
|
12190
|
+
|
|
12191
|
+
sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
|
|
12192
|
+
sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
|
|
12193
|
+
}
|
|
12194
|
+
|
|
12195
|
+
sumf += WSP_GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
|
|
12196
|
+
}
|
|
12197
|
+
|
|
12198
|
+
*s = sumf;
|
|
10776
12199
|
|
|
10777
12200
|
#else
|
|
10778
12201
|
float sumf = 0;
|