whisper.rn 0.4.0-rc.10 → 0.4.0-rc.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/android/src/main/CMakeLists.txt +9 -3
  2. package/cpp/amx/amx.cpp +220 -0
  3. package/cpp/amx/amx.h +8 -0
  4. package/cpp/amx/common.h +91 -0
  5. package/cpp/amx/mmq.cpp +2511 -0
  6. package/cpp/amx/mmq.h +10 -0
  7. package/cpp/ggml-alloc.c +6 -14
  8. package/cpp/ggml-backend-impl.h +50 -11
  9. package/cpp/ggml-backend-reg.cpp +409 -31
  10. package/cpp/ggml-backend.cpp +9 -3
  11. package/cpp/ggml-backend.h +18 -0
  12. package/cpp/ggml-common.h +41 -43
  13. package/cpp/ggml-cpp.h +1 -0
  14. package/cpp/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +941 -254
  15. package/cpp/ggml-cpu-aarch64.h +2 -24
  16. package/cpp/ggml-cpu-impl.h +171 -11
  17. package/cpp/ggml-cpu-quants.c +1812 -389
  18. package/cpp/ggml-cpu-traits.cpp +36 -0
  19. package/cpp/ggml-cpu-traits.h +38 -0
  20. package/cpp/ggml-cpu.c +1432 -610
  21. package/cpp/ggml-cpu.cpp +131 -141
  22. package/cpp/ggml-cpu.h +10 -50
  23. package/cpp/ggml-impl.h +27 -11
  24. package/cpp/ggml-metal-impl.h +39 -0
  25. package/cpp/ggml-metal.h +1 -1
  26. package/cpp/ggml-metal.m +1031 -359
  27. package/cpp/ggml-opt.cpp +854 -0
  28. package/cpp/ggml-opt.h +216 -0
  29. package/cpp/ggml-quants.c +0 -9
  30. package/cpp/ggml-threading.h +4 -2
  31. package/cpp/ggml-whisper.metallib +0 -0
  32. package/cpp/ggml.c +501 -1537
  33. package/cpp/ggml.h +144 -171
  34. package/cpp/gguf.cpp +1329 -0
  35. package/cpp/gguf.h +202 -0
  36. package/cpp/whisper.cpp +254 -114
  37. package/cpp/whisper.h +6 -3
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/version.json +1 -1
  40. package/package.json +2 -1
  41. package/src/version.json +1 -1
  42. package/whisper-rn.podspec +2 -2
  43. package/cpp/README.md +0 -4
  44. package/cpp/ggml-aarch64.c +0 -129
  45. package/cpp/ggml-aarch64.h +0 -19
  46. package/cpp/ggml-backend.cpp.rej +0 -12
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
103
103
  }
104
104
 
105
105
  static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
106
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
106
+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
107
107
  const __m256i zero = _mm256_setzero_si256();
108
108
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
109
109
  return _mm256_cvtepi32_ps(summed_pairs);
110
+ #elif defined(__AVXVNNI__)
111
+ const __m256i zero = _mm256_setzero_si256();
112
+ const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
113
+ return _mm256_cvtepi32_ps(summed_pairs);
110
114
  #else
111
115
  // Perform multiplication and create 16-bit values
112
116
  const __m256i dot = _mm256_maddubs_epi16(ax, sy);
@@ -293,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
293
297
  static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
294
298
  #endif
295
299
 
300
+ #if defined(__loongarch_sx)
301
+
302
+ static __m128i lsx_packs_w(__m128i a, __m128i b) {
303
+ __m128i tmp, tmp1;
304
+ tmp = __lsx_vsat_w(a, 15);
305
+ tmp1 = __lsx_vsat_w(b, 15);
306
+ return __lsx_vpickev_h(tmp1, tmp);
307
+ }
308
+
309
+ static __m128i lsx_packs_h(__m128i a, __m128i b) {
310
+ __m128i tmp, tmp1;
311
+ tmp = __lsx_vsat_h(a, 7);
312
+ tmp1 = __lsx_vsat_h(b, 7);
313
+ return __lsx_vpickev_b(tmp1, tmp);
314
+ }
315
+
316
+ static __m128i lsx_packus_h(__m128i a, __m128i b) {
317
+ __m128i tmp, tmp1;
318
+ tmp = __lsx_vsat_hu(a, 7);
319
+ tmp1 = __lsx_vsat_hu(b, 7);
320
+ return __lsx_vpickev_b(tmp1, tmp);
321
+ }
322
+
323
+ static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
324
+ __m128i tmp1, tmp2;
325
+ tmp1 = __lsx_vmulwev_h_b(a, b);
326
+ tmp2 = __lsx_vmulwod_h_b(a, b);
327
+ return __lsx_vsadd_h(tmp1, tmp2);
328
+ }
329
+
330
+ static __m128i lsx_madd_h(__m128i a, __m128i b) {
331
+ __m128i tmp1, tmp2;
332
+ tmp1 = __lsx_vmulwev_w_h(a, b);
333
+ tmp2 = __lsx_vmulwod_w_h(a, b);
334
+ return __lsx_vadd_w(tmp1, tmp2);
335
+ }
336
+
337
+ static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
338
+ v4i32 __ret = {d, c, b, a};
339
+ return (__m128i)__ret;
340
+ }
341
+
342
+ static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
343
+ __m128i mask_f, zero, tmp0, tmp2, mask;
344
+ int f = 0x8f;
345
+ mask_f = __lsx_vreplgr2vr_b(f);
346
+ zero = __lsx_vldi(0);
347
+ tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
348
+ tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
349
+ mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
350
+ tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
351
+ return __lsx_vshuf_b(a, zero, tmp2);
352
+ }
353
+
354
+ static __m128i lsx_hadd_h(__m128i a, __m128i b) {
355
+ __m128i tmp1 = __lsx_vpickev_h(b, a);
356
+ __m128i tmp2 = __lsx_vpickod_h(b, a);
357
+ return __lsx_vadd_h(tmp1, tmp2);
358
+ }
359
+
360
+ static __m128i lsx_hadd_w(__m128i a, __m128i b) {
361
+ __m128i tmp1 = __lsx_vpickev_w(b, a);
362
+ __m128i tmp2 = __lsx_vpickod_w(b, a);
363
+ return __lsx_vadd_w(tmp1, tmp2);
364
+ }
365
+
366
+ static __m128 lsx_hadd_s(__m128 a, __m128 b) {
367
+ __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
368
+ __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
369
+
370
+ return __lsx_vfadd_s(tmp1, tmp2);
371
+ }
372
+
373
+ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
374
+ __m128 res_0 =lsx_hadd_s(a, b);
375
+ __m128 res_1 =lsx_hadd_s(c, d);
376
+ __m128 res =lsx_hadd_s(res_0, res_1);
377
+ res =lsx_hadd_s(res, res);
378
+ res =lsx_hadd_s(res, res);
379
+
380
+ return ((v4f32)res)[0];
381
+ }
382
+ #endif
383
+
296
384
  #if defined(__loongarch_asx)
297
385
 
298
386
  #ifdef __clang__
@@ -391,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
391
479
  return (__m256i)__ret;
392
480
  }
393
481
 
394
- static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
395
- v4i32 __ret = {d, c, b, a};
396
- return (__m128i)__ret;
397
- }
398
-
399
482
  static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
400
483
  v4i64 __ret = {d, c, b, a};
401
484
  return (__m256i)__ret;
@@ -405,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
405
488
  return lasx_set_q(x, y);
406
489
  }
407
490
 
408
- static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
409
- __m128i mask_f, zero, tmp0, tmp2, mask;
410
- int f = 0x8f;
411
- mask_f = __lsx_vreplgr2vr_b(f);
412
- zero = __lsx_vldi(0);
413
- tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
414
- tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
415
- mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
416
- tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
417
- return __lsx_vshuf_b(a, zero, tmp2);
418
- }
419
-
420
491
  static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
421
492
  __m256i mask_f, zero, tmp0, tmp2, mask;
422
493
  int f = 0x8f;
@@ -430,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
430
501
  }
431
502
 
432
503
  static __m256i lasx_extu8_16(__m128i a) {
433
- __m128i zero = __lsx_vldi(0);
434
- __m128i vlo = __lsx_vilvl_b(zero, a);
435
- __m128i vhi = __lsx_vilvh_b(zero, a);
436
- return lasx_set_q(vhi, vlo);
504
+ return __lasx_vext2xv_hu_bu(____m256i(a));
437
505
  }
438
506
 
439
507
  static __m256i lasx_ext8_16(__m128i a) {
440
- __m128i sign = __lsx_vslti_b(a, 0);
441
- __m128i vlo = __lsx_vilvl_b(sign, a);
442
- __m128i vhi = __lsx_vilvh_b(sign, a);
443
- return lasx_set_q(vhi, vlo);
508
+ return __lasx_vext2xv_h_b(____m256i(a));
444
509
  }
445
510
 
446
511
  static __m256i lasx_ext16_32(__m128i a) {
447
- __m256i tmp1;
448
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
449
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
450
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
451
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
452
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
453
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
454
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
455
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
456
- return tmp1;
512
+ return __lasx_vext2xv_w_h(____m256i(a));
457
513
  }
458
514
 
459
515
  static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -478,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
478
534
  return ret;
479
535
  }
480
536
 
481
- static __m128i lsx_hadd_h(__m128i a, __m128i b) {
482
- __m128i tmp1 = __lsx_vpickev_h(b, a);
483
- __m128i tmp2 = __lsx_vpickod_h(b, a);
484
- return __lsx_vadd_h(tmp1, tmp2);
485
- }
486
-
487
- static __m128i lsx_hadd_w(__m128i a, __m128i b) {
488
- __m128i tmp1 = __lsx_vpickev_w(b, a);
489
- __m128i tmp2 = __lsx_vpickod_w(b, a);
490
- return __lsx_vadd_w(tmp1, tmp2);
491
- }
492
-
493
- static __m128 lsx_hadd_s(__m128 a, __m128 b) {
494
- __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
495
- __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
496
-
497
- return __lsx_vfadd_s(tmp1, tmp2);
498
- }
499
-
500
537
  static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
501
538
  __m256i tmp1, tmp2;
502
539
  tmp1 = __lasx_xvmulwev_h_b(a, b);
@@ -525,40 +562,39 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
525
562
  return __lasx_xvpickev_b(tmp1, tmp);
526
563
  }
527
564
 
528
- static __m128i lsx_packs_w(__m128i a, __m128i b) {
529
- __m128i tmp, tmp1;
530
- tmp = __lsx_vsat_w(a, 15);
531
- tmp1 = __lsx_vsat_w(b, 15);
532
- return __lsx_vpickev_h(tmp1, tmp);
533
- }
534
-
535
- static __m128i lsx_packs_h(__m128i a, __m128i b) {
536
- __m128i tmp, tmp1;
537
- tmp = __lsx_vsat_h(a, 7);
538
- tmp1 = __lsx_vsat_h(b, 7);
539
- return __lsx_vpickev_b(tmp1, tmp);
540
- }
541
-
542
- static __m128i lsx_packus_h(__m128i a, __m128i b) {
543
- __m128i tmp, tmp1;
544
- tmp = __lsx_vsat_hu(a, 7);
545
- tmp1 = __lsx_vsat_hu(b, 7);
546
- return __lsx_vpickev_b(tmp1, tmp);
565
+ static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
566
+ __m256i tmp1, tmp2;
567
+ tmp1 = __lasx_xvmulwev_h_b(a, b);
568
+ tmp2 = __lasx_xvmulwod_h_b(a, b);
569
+ return __lasx_xvadd_h(tmp1, tmp2);
547
570
  }
548
571
 
549
-
550
- static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
551
- __m128i tmp1, tmp2;
552
- tmp1 = __lsx_vmulwev_h_b(a, b);
553
- tmp2 = __lsx_vmulwod_h_b(a, b);
554
- return __lsx_vsadd_h(tmp1, tmp2);
572
+ static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
573
+ switch (b) {
574
+ case 0: return __lasx_xvrepl128vei_h(a, 0);
575
+ case 1: return __lasx_xvrepl128vei_h(a, 1);
576
+ case 2: return __lasx_xvrepl128vei_h(a, 2);
577
+ case 3: return __lasx_xvrepl128vei_h(a, 3);
578
+ case 4: return __lasx_xvrepl128vei_h(a, 4);
579
+ case 5: return __lasx_xvrepl128vei_h(a, 5);
580
+ case 6: return __lasx_xvrepl128vei_h(a, 6);
581
+ case 7: return __lasx_xvrepl128vei_h(a, 7);
582
+ default: __builtin_unreachable();
583
+ }
555
584
  }
556
585
 
557
- static __m128i lsx_madd_h(__m128i a, __m128i b) {
558
- __m128i tmp1, tmp2;
559
- tmp1 = __lsx_vmulwev_w_h(a, b);
560
- tmp2 = __lsx_vmulwod_w_h(a, b);
561
- return __lsx_vadd_w(tmp1, tmp2);
586
+ static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
587
+ switch (b) {
588
+ case 0: return __lasx_xvandi_b(a, 1 << 0);
589
+ case 1: return __lasx_xvandi_b(a, 1 << 1);
590
+ case 2: return __lasx_xvandi_b(a, 1 << 2);
591
+ case 3: return __lasx_xvandi_b(a, 1 << 3);
592
+ case 4: return __lasx_xvandi_b(a, 1 << 4);
593
+ case 5: return __lasx_xvandi_b(a, 1 << 5);
594
+ case 6: return __lasx_xvandi_b(a, 1 << 6);
595
+ case 7: return __lasx_xvandi_b(a, 1 << 7);
596
+ default: __builtin_unreachable();
597
+ }
562
598
  }
563
599
 
564
600
  // multiply int8_t, add results pairwise twice
@@ -576,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
576
612
  // horizontally add 8 floats
577
613
  static inline float hsum_float_8(const __m256 x) {
578
614
  __m128 res = lasx_extractf128(x, 1);
579
- ft_union tmp;
580
615
  res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
581
616
  res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
582
617
  res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
583
- tmp.i = __lsx_vpickve2gr_w(res, 0);
584
- return tmp.f;
618
+ return ((v4f32)res)[0];
585
619
  }
586
620
 
587
621
  // horizontally add 8 int32_t
@@ -657,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
657
691
 
658
692
  // multiply int8_t, add results pairwise twice and return as float vector
659
693
  static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
660
-
661
- // Get absolute values of x vectors
662
- const __m256i ax = __lasx_xvsigncov_b(x, x);
663
- // Sign the values of the y vectors
664
- const __m256i sy = __lasx_xvsigncov_b(x, y);
665
-
666
- return mul_sum_us8_pairs_float(ax, sy);
694
+ const __m256i dot = lasx_madd_h_b(x, y);
695
+ return sum_i16_pairs_float(dot);
667
696
  }
668
697
 
669
698
  static inline __m128i packNibbles( __m256i bytes ) {
@@ -743,7 +772,7 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
743
772
  y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
744
773
  }
745
774
  }
746
- #elif defined(__wasm_simd128__)
775
+ #elif defined __wasm_simd128__
747
776
  for (int i = 0; i < nb; i++) {
748
777
  v128_t srcv [8];
749
778
  v128_t asrcv[8];
@@ -923,7 +952,6 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
923
952
 
924
953
  #elif defined(__loongarch_asx)
925
954
  for (int i = 0; i < nb; i++) {
926
- ft_union fi;
927
955
  __m256 v0 = (__m256)__lasx_xvld( x , 0);
928
956
  __m256 v1 = (__m256)__lasx_xvld( x , 32);
929
957
  __m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -941,8 +969,7 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
941
969
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
942
970
  __m128 tmp = max4;
943
971
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
944
- fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
945
- const float max_scalar = fi.f;
972
+ const float max_scalar = ((v4f32)max4)[0];
946
973
 
947
974
  // Quantize these floats
948
975
  const float d = max_scalar / 127.f;
@@ -984,6 +1011,38 @@ void wsp_quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t
984
1011
  __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
985
1012
 
986
1013
  }
1014
+ #elif defined(__VXE__) || defined(__VXE2__)
1015
+ for (int i = 0; i < nb; i++) {
1016
+ __vector float srcv [8];
1017
+ __vector float asrcv[8];
1018
+ __vector float amaxv[8];
1019
+
1020
+ for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
1021
+ for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
1022
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
1023
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
1024
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
1025
+
1026
+ const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
1027
+ vec_extract(amaxv[0], 1)),
1028
+ MAX(vec_extract(amaxv[0], 2),
1029
+ vec_extract(amaxv[0], 3)));
1030
+
1031
+ const float d = amax / ((1 << 7) - 1);
1032
+ const float id = d ? 1.0f / d : 0.0f;
1033
+
1034
+ y[i].d = WSP_GGML_FP32_TO_FP16(d);
1035
+
1036
+ for (int j = 0; j < 8; j++) {
1037
+ const __vector float v = vec_mul(srcv[j], vec_splats(id));
1038
+ const __vector int32_t vi = vec_signed(v);
1039
+
1040
+ y[i].qs[4*j + 0] = vec_extract(vi, 0);
1041
+ y[i].qs[4*j + 1] = vec_extract(vi, 1);
1042
+ y[i].qs[4*j + 2] = vec_extract(vi, 2);
1043
+ y[i].qs[4*j + 3] = vec_extract(vi, 3);
1044
+ }
1045
+ }
987
1046
  #else
988
1047
  WSP_GGML_UNUSED(nb);
989
1048
  // scalar
@@ -1033,7 +1092,7 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
1033
1092
 
1034
1093
  y[i].s = WSP_GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
1035
1094
  }
1036
- #elif defined(__wasm_simd128__)
1095
+ #elif defined __wasm_simd128__
1037
1096
  for (int i = 0; i < nb; i++) {
1038
1097
  v128_t srcv [8];
1039
1098
  v128_t asrcv[8];
@@ -1247,7 +1306,6 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
1247
1306
 
1248
1307
  #elif defined(__loongarch_asx)
1249
1308
  for (int i = 0; i < nb; i++) {
1250
- ft_union ft;
1251
1309
  __m256 v0 = (__m256)__lasx_xvld( x , 0 );
1252
1310
  __m256 v1 = (__m256)__lasx_xvld( x , 32 );
1253
1311
  __m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1265,8 +1323,7 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
1265
1323
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
1266
1324
  __m128 tmp = max4;
1267
1325
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
1268
- ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
1269
- const float max_scalar = ft.f;
1326
+ const float max_scalar = ((v4f32)max4)[0];
1270
1327
 
1271
1328
  // Quantize these floats
1272
1329
  const float d = max_scalar / 127.f;
@@ -1312,6 +1369,44 @@ void wsp_quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t
1312
1369
  __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
1313
1370
  __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
1314
1371
  }
1372
+ #elif defined(__VXE__) || defined(__VXE2__)
1373
+ for (int i = 0; i < nb; i++) {
1374
+ __vector float srcv [8];
1375
+ __vector float asrcv[8];
1376
+ __vector float amaxv[8];
1377
+
1378
+ for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
1379
+ for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
1380
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
1381
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
1382
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
1383
+
1384
+ const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
1385
+ vec_extract(amaxv[0], 1)),
1386
+ MAX(vec_extract(amaxv[0], 2),
1387
+ vec_extract(amaxv[0], 3)));
1388
+
1389
+ const float d = amax / ((1 << 7) - 1);
1390
+ const float id = d ? 1.0f / d : 0.0f;
1391
+
1392
+ y[i].d = WSP_GGML_FP32_TO_FP16(d);
1393
+
1394
+ __vector int32_t acc = vec_splats(0);
1395
+
1396
+ for (int j = 0; j < 8; j++) {
1397
+ const __vector float v = vec_mul(srcv[j], vec_splats(id));
1398
+ const __vector int32_t vi = vec_signed(v);
1399
+
1400
+ y[i].qs[4*j + 0] = vec_extract(vi, 0);
1401
+ y[i].qs[4*j + 1] = vec_extract(vi, 1);
1402
+ y[i].qs[4*j + 2] = vec_extract(vi, 2);
1403
+ y[i].qs[4*j + 3] = vec_extract(vi, 3);
1404
+
1405
+ acc = vec_add(acc, vi);
1406
+ }
1407
+
1408
+ y[i].s = WSP_GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
1409
+ }
1315
1410
  #else
1316
1411
  WSP_GGML_UNUSED(nb);
1317
1412
  // scalar
@@ -1649,7 +1744,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
1649
1744
  //===================================== Q8_K ==============================================
1650
1745
 
1651
1746
  void wsp_quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
1747
+ #ifdef __wasm_simd128__
1748
+ assert(k % QK_K == 0);
1749
+ const int64_t nb = k / QK_K;
1750
+ block_q8_K * restrict yc = y; // Cast to proper type
1751
+
1752
+ for (int i = 0; i < nb; i++) {
1753
+ const float * x_block = x + i * QK_K;
1754
+
1755
+ v128_t min_vec = wasm_v128_load(x_block);
1756
+ v128_t max_vec = min_vec;
1757
+
1758
+ for (int j = 4; j < QK_K; j += 4) {
1759
+ v128_t x_vec = wasm_v128_load(x_block + j);
1760
+ max_vec = wasm_f32x4_pmax(max_vec, x_vec);
1761
+ min_vec = wasm_f32x4_pmin(min_vec, x_vec);
1762
+ }
1763
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
1764
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
1765
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
1766
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
1767
+ float max = wasm_f32x4_extract_lane(max_vec, 0);
1768
+ float min = wasm_f32x4_extract_lane(min_vec, 0);
1769
+ float amax = -min > max ? min : max;
1770
+
1771
+ if (amax == 0.0f) {
1772
+ yc[i].d = 0.0f;
1773
+ const v128_t zero = wasm_i8x16_splat(0);
1774
+ for (int j = 0; j < QK_K; j += 16) {
1775
+ wasm_v128_store(yc[i].qs + j, zero);
1776
+ }
1777
+ continue;
1778
+ }
1779
+
1780
+ const float iscale = -127.0f / amax;
1781
+ const v128_t scale_vec = wasm_f32x4_splat(iscale);
1782
+
1783
+ // Process 16 elements per iteration
1784
+ for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
1785
+ // Load and quantize 16 floats
1786
+ v128_t x0 = wasm_v128_load(x_block + j);
1787
+ v128_t x1 = wasm_v128_load(x_block + j + 4);
1788
+ v128_t x2 = wasm_v128_load(x_block + j + 8);
1789
+ v128_t x3 = wasm_v128_load(x_block + j + 12);
1790
+
1791
+ v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
1792
+ v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
1793
+ v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
1794
+ v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
1795
+
1796
+ // Convert to i32 with saturation
1797
+ v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
1798
+ v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
1799
+ v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
1800
+ v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
1801
+
1802
+ // Pack into 16 i8 values
1803
+ v128_t i8 = wasm_i8x16_narrow_i16x8(
1804
+ wasm_i16x8_narrow_i32x4(i0, i1),
1805
+ wasm_i16x8_narrow_i32x4(i2, i3)
1806
+ );
1807
+ wasm_v128_store(yc[i].qs + j, i8);
1808
+
1809
+ // Calculate bsums using SIMD
1810
+ v128_t sum16 = wasm_i16x8_add(
1811
+ wasm_i16x8_extend_low_i8x16(i8),
1812
+ wasm_i16x8_extend_high_i8x16(i8)
1813
+ );
1814
+ v128_t sum32 = wasm_i32x4_add(
1815
+ wasm_i32x4_extend_low_i16x8(sum16),
1816
+ wasm_i32x4_extend_high_i16x8(sum16)
1817
+ );
1818
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
1819
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
1820
+ yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
1821
+ }
1822
+
1823
+ yc[i].d = 1.0f / iscale;
1824
+ }
1825
+ #else
1652
1826
  wsp_quantize_row_q8_K_ref(x, y, k);
1827
+ #endif
1653
1828
  }
1654
1829
 
1655
1830
  //===================================== Dot products =================================
@@ -1791,11 +1966,12 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
1791
1966
  const int8x16_t y1_l = vld1q_s8(b_y1->qs);
1792
1967
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
1793
1968
 
1794
- float32_t _scale[4] = { WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
1795
- WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
1796
- WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
1797
- WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)};
1798
-
1969
+ float32_t _scale[4] = {
1970
+ WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
1971
+ WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
1972
+ WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
1973
+ WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)
1974
+ };
1799
1975
  float32x4_t scale = vld1q_f32(_scale);
1800
1976
 
1801
1977
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -1811,13 +1987,15 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
1811
1987
  int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
1812
1988
 
1813
1989
  sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
1814
- l1, r1)), l2, r2)), l3, r3))), scale);
1990
+ l1, r1)), l2, r2)), l3, r3))), scale);
1815
1991
  }
1816
- float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
1992
+
1993
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
1817
1994
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
1818
1995
 
1819
- vst1_f32(s, vget_low_f32(sumv2));
1996
+ vst1_f32(s, vget_low_f32 (sumv2));
1820
1997
  vst1_f32(s + bs, vget_high_f32(sumv2));
1998
+
1821
1999
  return;
1822
2000
  }
1823
2001
  #endif
@@ -2004,6 +2182,94 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2004
2182
  }
2005
2183
 
2006
2184
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2185
+ #elif defined __wasm_simd128__
2186
+ v128_t sumv = wasm_f32x4_splat(0.0f);
2187
+
2188
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2189
+ const v128_t s8b = wasm_i8x16_splat(0x8);
2190
+
2191
+ for (; ib + 1 < nb; ib += 2) {
2192
+ const block_q4_0 * restrict x0 = &x[ib];
2193
+ const block_q4_0 * restrict x1 = &x[ib + 1];
2194
+ const block_q8_0 * restrict y0 = &y[ib];
2195
+ const block_q8_0 * restrict y1 = &y[ib + 1];
2196
+
2197
+ // Load and process x0
2198
+ v128_t v0_0 = wasm_v128_load(x0->qs);
2199
+ v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2200
+ v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2201
+ v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2202
+ v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2203
+
2204
+ // Load y0 vectors
2205
+ v128_t y0_l = wasm_v128_load(y0->qs);
2206
+ v128_t y0_h = wasm_v128_load(y0->qs + 16);
2207
+
2208
+ // Extend to i16x8 and compute dot products
2209
+ v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
2210
+ v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
2211
+ v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
2212
+ v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
2213
+
2214
+ v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
2215
+ v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
2216
+ v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
2217
+ v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
2218
+
2219
+ v128_t dp0 = wasm_i32x4_add(
2220
+ wasm_i32x4_add(
2221
+ wasm_i32x4_dot_i16x8(dx0l, dy0ll),
2222
+ wasm_i32x4_dot_i16x8(dx0h, dy0lh)
2223
+ ),
2224
+ wasm_i32x4_add(
2225
+ wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
2226
+ wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
2227
+ )
2228
+ );
2229
+
2230
+ // Load and process x1
2231
+ v128_t v0_1 = wasm_v128_load(x1->qs);
2232
+ v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2233
+ v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2234
+ v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2235
+ v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2236
+
2237
+ // Load y1 vectors
2238
+ v128_t y1_l = wasm_v128_load(y1->qs);
2239
+ v128_t y1_h = wasm_v128_load(y1->qs + 16);
2240
+
2241
+ // Extend to i16x8 and compute dot products
2242
+ v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
2243
+ v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
2244
+ v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
2245
+ v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
2246
+
2247
+ v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
2248
+ v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
2249
+ v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
2250
+ v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
2251
+
2252
+ v128_t dp1 = wasm_i32x4_add(
2253
+ wasm_i32x4_add(
2254
+ wasm_i32x4_dot_i16x8(dx1l, dy1ll),
2255
+ wasm_i32x4_dot_i16x8(dx1h, dy1lh)
2256
+ ),
2257
+ wasm_i32x4_add(
2258
+ wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
2259
+ wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
2260
+ )
2261
+ );
2262
+
2263
+ // Accumulate results with scaling
2264
+ float scale0 = WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d);
2265
+ float scale1 = WSP_GGML_FP16_TO_FP32(x1->d) * WSP_GGML_FP16_TO_FP32(y1->d);
2266
+
2267
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
2268
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
2269
+ }
2270
+
2271
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
2272
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
2007
2273
  #elif defined(__AVX2__)
2008
2274
  // Initialize accumulator with zeros
2009
2275
  __m256 acc = _mm256_setzero_ps();
@@ -2225,21 +2491,22 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2225
2491
  }
2226
2492
 
2227
2493
  sumf = hsum_float_8(acc);
2494
+
2228
2495
  #elif defined(__loongarch_sx)
2229
2496
  // set constants
2230
2497
  const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
2231
2498
  const __m128i off = __lsx_vreplgr2vr_b(8);
2232
2499
 
2233
2500
  // Initialize accumulator with zeros
2234
- __m128 acc_0 = __lsx_vldi(0);
2235
- __m128 acc_1 = __lsx_vldi(0);
2236
- __m128 acc_2 = __lsx_vldi(0);
2237
- __m128 acc_3 = __lsx_vldi(0);
2501
+ __m128 acc_0 = (__m128)__lsx_vldi(0);
2502
+ __m128 acc_1 = (__m128)__lsx_vldi(0);
2503
+ __m128 acc_2 = (__m128)__lsx_vldi(0);
2504
+ __m128 acc_3 = (__m128)__lsx_vldi(0);
2238
2505
 
2239
2506
  for (; ib + 1 < nb; ib += 2) {
2240
2507
 
2241
2508
  // Compute combined scale for the block 0 and 1
2242
- const __m128 d_0_1 = __lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d) );
2509
+ const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d) );
2243
2510
 
2244
2511
  const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
2245
2512
 
@@ -2257,7 +2524,7 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2257
2524
  //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2258
2525
 
2259
2526
  // Compute combined scale for the block 2 and 3
2260
- const __m128 d_2_3 = __lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib + 1].d) * WSP_GGML_FP16_TO_FP32(y[ib + 1].d) );
2527
+ const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( WSP_GGML_FP16_TO_FP32(x[ib + 1].d) * WSP_GGML_FP16_TO_FP32(y[ib + 1].d) );
2261
2528
 
2262
2529
  const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
2263
2530
 
@@ -2291,6 +2558,37 @@ void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void
2291
2558
  }
2292
2559
 
2293
2560
  sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2561
+ #elif defined(__VXE__) || defined(__VXE2__)
2562
+ __vector float acc = vec_splats(0.0f);
2563
+
2564
+ const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
2565
+ const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
2566
+
2567
+ for (; ib < nb; ++ib) {
2568
+ const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
2569
+ const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
2570
+ const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
2571
+
2572
+ const __vector int8_t v_xls = vec_sub(v_xl, v_s);
2573
+ const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
2574
+
2575
+ const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
2576
+ const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
2577
+
2578
+ const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
2579
+ const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
2580
+ const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
2581
+ const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
2582
+
2583
+ __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
2584
+
2585
+ const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
2586
+ const __vector float v_d = vec_splats(WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d));
2587
+
2588
+ acc = vec_madd(v_xy, v_d, acc);
2589
+ }
2590
+
2591
+ sumf = acc[0] + acc[1] + acc[2] + acc[3];
2294
2592
  #endif
2295
2593
  for (; ib < nb; ++ib) {
2296
2594
  int sumi0 = 0;
@@ -2345,10 +2643,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
2345
2643
  const block_q8_1 * restrict b_y0 = &vy0[i];
2346
2644
  const block_q8_1 * restrict b_y1 = &vy1[i];
2347
2645
 
2348
- float32_t summs_t[4] = {WSP_GGML_FP16_TO_FP32(b_x0->m) * WSP_GGML_FP16_TO_FP32(b_y0->s),
2349
- WSP_GGML_FP16_TO_FP32(b_x1->m) * WSP_GGML_FP16_TO_FP32(b_y0->s),
2350
- WSP_GGML_FP16_TO_FP32(b_x0->m) * WSP_GGML_FP16_TO_FP32(b_y1->s),
2351
- WSP_GGML_FP16_TO_FP32(b_x1->m) * WSP_GGML_FP16_TO_FP32(b_y1->s)};
2646
+ float32_t summs_t[4] = {
2647
+ WSP_GGML_FP16_TO_FP32(b_x0->m) * WSP_GGML_FP16_TO_FP32(b_y0->s),
2648
+ WSP_GGML_FP16_TO_FP32(b_x1->m) * WSP_GGML_FP16_TO_FP32(b_y0->s),
2649
+ WSP_GGML_FP16_TO_FP32(b_x0->m) * WSP_GGML_FP16_TO_FP32(b_y1->s),
2650
+ WSP_GGML_FP16_TO_FP32(b_x1->m) * WSP_GGML_FP16_TO_FP32(b_y1->s)
2651
+ };
2352
2652
  summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
2353
2653
 
2354
2654
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
@@ -2369,10 +2669,12 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
2369
2669
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
2370
2670
 
2371
2671
  // mmla into int32x4_t
2372
- float32_t _scale[4] = {WSP_GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
2373
- WSP_GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
2374
- WSP_GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
2375
- WSP_GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
2672
+ float32_t _scale[4] = {
2673
+ WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
2674
+ WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
2675
+ WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
2676
+ WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)
2677
+ };
2376
2678
  float32x4_t scale = vld1q_f32(_scale);
2377
2679
 
2378
2680
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -2387,15 +2689,17 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
2387
2689
  int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
2388
2690
  int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
2389
2691
  sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
2390
- l1, r1)), l2, r2)), l3, r3))), scale);
2692
+ l1, r1)), l2, r2)), l3, r3))), scale);
2391
2693
  }
2392
2694
 
2393
- float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
2695
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
2394
2696
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
2697
+
2395
2698
  sumv2 = vaddq_f32(sumv2, summs0);
2396
2699
 
2397
2700
  vst1_f32(s, vget_low_f32 (sumv2));
2398
2701
  vst1_f32(s + bs, vget_high_f32(sumv2));
2702
+
2399
2703
  return;
2400
2704
  }
2401
2705
  #endif
@@ -2578,13 +2882,42 @@ void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void
2578
2882
  }
2579
2883
 
2580
2884
  sumf = hsum_float_8(acc) + summs;
2581
- #endif
2582
- for (; ib < nb; ++ib) {
2583
- int sumi0 = 0;
2584
- int sumi1 = 0;
2885
+ #elif defined(__VXE__) || defined(__VXE2__)
2886
+ float summs = 0;
2887
+ float32x4_t acc = vec_splats(0.0f);
2585
2888
 
2586
- for (int j = 0; j < qk/2; ++j) {
2587
- const int v0 = (x[ib].qs[j] & 0x0F);
2889
+ const uint8x16_t v_m = vec_splat_u8(0x0F);
2890
+
2891
+ #pragma GCC unroll 4
2892
+ for (; ib < nb; ++ib) {
2893
+ __builtin_prefetch(x[ib].qs, 0, 1);
2894
+ __builtin_prefetch(y[ib].qs, 0, 1);
2895
+
2896
+ summs += WSP_GGML_FP16_TO_FP32(x[ib].m) * WSP_GGML_FP16_TO_FP32(y[ib].s);
2897
+
2898
+ const uint8x16_t v_x = vec_xl(0, x[ib].qs);
2899
+ const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
2900
+ const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
2901
+
2902
+ const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
2903
+ const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
2904
+
2905
+ const int32x4_t v_xy_ = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
2906
+ const float32x4_t v_xy = vec_float(v_xy_);
2907
+
2908
+ const float32x4_t v_d = vec_splats(WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d));
2909
+
2910
+ acc = vec_madd(v_xy, v_d, acc);
2911
+ }
2912
+
2913
+ sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
2914
+ #endif
2915
+ for (; ib < nb; ++ib) {
2916
+ int sumi0 = 0;
2917
+ int sumi1 = 0;
2918
+
2919
+ for (int j = 0; j < qk/2; ++j) {
2920
+ const int v0 = (x[ib].qs[j] & 0x0F);
2588
2921
  const int v1 = (x[ib].qs[j] >> 4);
2589
2922
 
2590
2923
  sumi0 += (v0 * y[ib].qs[j]);
@@ -2683,10 +3016,10 @@ void wsp_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
2683
3016
  }
2684
3017
 
2685
3018
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2686
- #elif defined(__wasm_simd128__)
3019
+ #elif defined __wasm_simd128__
2687
3020
  v128_t sumv = wasm_f32x4_splat(0.0f);
2688
3021
 
2689
- uint32_t qh;
3022
+ uint32_t qh_;
2690
3023
  uint64_t tmp[4];
2691
3024
 
2692
3025
  // TODO: check if unrolling this is better
@@ -2697,12 +3030,12 @@ void wsp_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void
2697
3030
  const v128_t m4b = wasm_i8x16_splat(0x0F);
2698
3031
 
2699
3032
  // extract the 5th bit
2700
- memcpy(&qh, x0->qh, sizeof(qh));
3033
+ memcpy(&qh_, x0->qh, sizeof(qh_));
2701
3034
 
2702
- tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
2703
- tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
2704
- tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
2705
- tmp[3] = table_b2b_1[(qh >> 24) ];
3035
+ tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
3036
+ tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
3037
+ tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
3038
+ tmp[3] = table_b2b_1[(qh_ >> 24) ];
2706
3039
 
2707
3040
  const v128_t qhl = wasm_v128_load(tmp + 0);
2708
3041
  const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3044,12 +3377,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
3044
3377
  }
3045
3378
 
3046
3379
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3047
- #elif defined(__wasm_simd128__)
3380
+ #elif defined __wasm_simd128__
3048
3381
  v128_t sumv = wasm_f32x4_splat(0.0f);
3049
3382
 
3050
3383
  float summs = 0.0f;
3051
3384
 
3052
- uint32_t qh;
3385
+ uint32_t qh_;
3053
3386
  uint64_t tmp[4];
3054
3387
 
3055
3388
  // TODO: check if unrolling this is better
@@ -3062,12 +3395,12 @@ void wsp_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void
3062
3395
  const v128_t m4b = wasm_i8x16_splat(0x0F);
3063
3396
 
3064
3397
  // extract the 5th bit
3065
- memcpy(&qh, x0->qh, sizeof(qh));
3398
+ memcpy(&qh_, x0->qh, sizeof(qh_));
3066
3399
 
3067
- tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3068
- tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3069
- tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3070
- tmp[3] = table_b2b_0[(qh >> 24) ];
3400
+ tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
3401
+ tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
3402
+ tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
3403
+ tmp[3] = table_b2b_0[(qh_ >> 24) ];
3071
3404
 
3072
3405
  const v128_t qhl = wasm_v128_load(tmp + 0);
3073
3406
  const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3372,10 +3705,12 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
3372
3705
  const int8x16_t y1_l = vld1q_s8(b_y1->qs);
3373
3706
  const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
3374
3707
 
3375
- float32_t _scale[4] = {WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
3376
- WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
3377
- WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
3378
- WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)};
3708
+ float32_t _scale[4] = {
3709
+ WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
3710
+ WSP_GGML_FP16_TO_FP32(b_x0->d)*WSP_GGML_FP16_TO_FP32(b_y1->d),
3711
+ WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y0->d),
3712
+ WSP_GGML_FP16_TO_FP32(b_x1->d)*WSP_GGML_FP16_TO_FP32(b_y1->d)
3713
+ };
3379
3714
  float32x4_t scale = vld1q_f32(_scale);
3380
3715
 
3381
3716
  int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3391,13 +3726,15 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
3391
3726
  int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
3392
3727
 
3393
3728
  sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
3394
- l1, r1)), l2, r2)), l3, r3))), scale);
3729
+ l1, r1)), l2, r2)), l3, r3))), scale);
3395
3730
  }
3396
- float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3731
+
3732
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
3397
3733
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3398
3734
 
3399
- vst1_f32(s, vget_low_f32(sumv2));
3735
+ vst1_f32(s, vget_low_f32 (sumv2));
3400
3736
  vst1_f32(s + bs, vget_high_f32(sumv2));
3737
+
3401
3738
  return;
3402
3739
  }
3403
3740
  #endif
@@ -3556,6 +3893,45 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
3556
3893
  }
3557
3894
 
3558
3895
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3896
+ #elif defined __wasm_simd128__
3897
+ v128_t sumv = wasm_f32x4_splat(0.0f);
3898
+
3899
+ for (; ib < nb; ++ib) {
3900
+ const block_q8_0 * restrict x0 = &x[ib];
3901
+ const block_q8_0 * restrict y0 = &y[ib];
3902
+
3903
+ const v128_t x0_0 = wasm_v128_load(x0->qs);
3904
+ const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
3905
+ const v128_t y0_0 = wasm_v128_load(y0->qs);
3906
+ const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
3907
+
3908
+ // Extend 8-bit to 16-bit
3909
+ const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
3910
+ const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
3911
+ const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
3912
+ const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
3913
+
3914
+ const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
3915
+ const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
3916
+ const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
3917
+ const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
3918
+
3919
+ // Compute dot products
3920
+ const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
3921
+ const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
3922
+ const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
3923
+ const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
3924
+
3925
+ // Sum all dot products
3926
+ const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
3927
+
3928
+ // Convert to float and accumulate
3929
+ const float scale = WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d);
3930
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
3931
+ }
3932
+
3933
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3934
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3559
3935
  #elif defined(__AVX2__)
3560
3936
  // Initialize accumulator with zeros
3561
3937
  __m256 acc = _mm256_setzero_ps();
@@ -3669,6 +4045,27 @@ void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void
3669
4045
  }
3670
4046
 
3671
4047
  sumf = hsum_float_8(acc);
4048
+ #elif defined(__VXE__) || defined(__VXE2__)
4049
+ __vector float acc = vec_splats(0.0f);
4050
+
4051
+ #pragma GCC unroll 8
4052
+ for (; ib < nb; ++ib) {
4053
+ __builtin_prefetch(x[ib].qs, 0, 1);
4054
+ __builtin_prefetch(y[ib].qs, 0, 1);
4055
+
4056
+ const int8x16_t v_xl = vec_xl(0 , x[ib].qs);
4057
+ const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
4058
+ const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
4059
+ const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
4060
+
4061
+ const int32x4_t v_xy_ = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
4062
+ const float32x4_t v_xy = vec_float(v_xy_);
4063
+ const float32x4_t v_d = vec_splats(WSP_GGML_FP16_TO_FP32(x[ib].d) * WSP_GGML_FP16_TO_FP32(y[ib].d));
4064
+
4065
+ acc = vec_madd(v_xy, v_d, acc);
4066
+ }
4067
+
4068
+ sumf = acc[0] + acc[1] + acc[2] + acc[3];
3672
4069
  #endif
3673
4070
  for (; ib < nb; ++ib) {
3674
4071
  int sumi = 0;
@@ -4430,6 +4827,106 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4430
4827
 
4431
4828
  *s = hsum_float_8(acc);
4432
4829
 
4830
+ #elif defined __wasm_simd128__
4831
+ float sumf = 0;
4832
+
4833
+ for (int i = 0; i < nb; ++i) {
4834
+ const uint8_t * q2 = x[i].qs;
4835
+ const int8_t * q8 = y[i].qs;
4836
+ const uint8_t * sc = x[i].scales;
4837
+
4838
+ // Vectorized summs calculation
4839
+ v128_t summs_vec = wasm_i32x4_splat(0);
4840
+ {
4841
+ v128_t sc_vec = wasm_v128_load(sc);
4842
+ v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
4843
+
4844
+ v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
4845
+ v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
4846
+
4847
+ v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
4848
+ v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
4849
+
4850
+ summs_vec = wasm_i32x4_add(
4851
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
4852
+ wasm_i32x4_dot_i16x8(sc_high, bsums2)),
4853
+ summs_vec
4854
+ );
4855
+
4856
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
4857
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
4858
+ }
4859
+ int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
4860
+
4861
+ // Vectorized isum calculation
4862
+ int32_t isum = 0;
4863
+ const uint8_t * sc_ptr = sc;
4864
+ const int k_iters = QK_K/128;
4865
+
4866
+ for (int k = 0; k < k_iters; ++k) {
4867
+ v128_t isum_vec = wasm_i32x4_splat(0);
4868
+ int shift = 0;
4869
+
4870
+ for (int j = 0; j < 4; ++j) {
4871
+ const int d0 = (sc_ptr[0] & 0xF);
4872
+ const int d1 = (sc_ptr[1] & 0xF);
4873
+ sc_ptr += 2;
4874
+
4875
+ // Process first 16 elements
4876
+ v128_t q2_0 = wasm_v128_load(q2);
4877
+ v128_t q8_0 = wasm_v128_load(q8);
4878
+ v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
4879
+ v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
4880
+
4881
+ // Process next 16 elements
4882
+ v128_t q2_1 = wasm_v128_load(q2 + 16);
4883
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
4884
+ v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
4885
+ v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
4886
+
4887
+ // Calculate dot products
4888
+ v128_t p0 = wasm_i32x4_dot_i16x8(
4889
+ wasm_i16x8_extend_low_i8x16(q8_0),
4890
+ wasm_i16x8_extend_low_i8x16(q2_bits_0)
4891
+ );
4892
+ v128_t p1 = wasm_i32x4_dot_i16x8(
4893
+ wasm_i16x8_extend_high_i8x16(q8_0),
4894
+ wasm_i16x8_extend_high_i8x16(q2_bits_0)
4895
+ );
4896
+ v128_t p2 = wasm_i32x4_dot_i16x8(
4897
+ wasm_i16x8_extend_low_i8x16(q8_1),
4898
+ wasm_i16x8_extend_low_i8x16(q2_bits_1)
4899
+ );
4900
+ v128_t p3 = wasm_i32x4_dot_i16x8(
4901
+ wasm_i16x8_extend_high_i8x16(q8_1),
4902
+ wasm_i16x8_extend_high_i8x16(q2_bits_1)
4903
+ );
4904
+
4905
+ // Accumulate scaled results
4906
+ v128_t scaled = wasm_i32x4_add(
4907
+ wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
4908
+ wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
4909
+ );
4910
+
4911
+ isum_vec = wasm_i32x4_add(isum_vec, scaled);
4912
+ q8 += 32;
4913
+ shift += 2;
4914
+ }
4915
+ q2 += 32;
4916
+
4917
+ // Horizontal sum of isum_vec
4918
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
4919
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
4920
+ isum += wasm_i32x4_extract_lane(isum_vec, 0);
4921
+ }
4922
+
4923
+ const float dall = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4924
+ const float dmin = WSP_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
4925
+ sumf += dall * isum - dmin * summs;
4926
+ }
4927
+
4928
+ *s = sumf;
4929
+
4433
4930
  #elif defined __riscv_v_intrinsic
4434
4931
 
4435
4932
  float sumf = 0;
@@ -4649,9 +5146,6 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4649
5146
 
4650
5147
  #elif defined __loongarch_asx
4651
5148
 
4652
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
4653
- const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
4654
-
4655
5149
  __m256 acc = (__m256)__lasx_xvldi(0);
4656
5150
 
4657
5151
  for (int i = 0; i < nb; ++i) {
@@ -4662,18 +5156,15 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4662
5156
  const uint8_t * restrict q2 = x[i].qs;
4663
5157
  const int8_t * restrict q8 = y[i].qs;
4664
5158
 
4665
- const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
4666
- const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
4667
- const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
4668
- const __m256i mins = lasx_ext8_16(mins8);
5159
+ const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
5160
+ const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
5161
+ const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
4669
5162
  const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
4670
5163
 
4671
5164
  acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
4672
5165
 
4673
- const __m256i all_scales = lasx_ext8_16(scales8);
4674
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
4675
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
4676
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5166
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5167
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
4677
5168
 
4678
5169
  __m256i sumi = __lasx_xvldi(0);
4679
5170
 
@@ -4686,20 +5177,20 @@ void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void
4686
5177
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
4687
5178
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
4688
5179
 
4689
- const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
4690
- const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
4691
- const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
4692
- const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
5180
+ const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
5181
+ const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
5182
+ const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
5183
+ const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
4693
5184
 
4694
- __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
4695
- __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
4696
- __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
4697
- __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
5185
+ __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
5186
+ __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
5187
+ __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
5188
+ __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
4698
5189
 
4699
- p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
4700
- p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
4701
- p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
4702
- p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
5190
+ p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
5191
+ p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
5192
+ p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
5193
+ p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
4703
5194
 
4704
5195
  p0 = __lasx_xvadd_w(p0, p1);
4705
5196
  p2 = __lasx_xvadd_w(p2, p3);
@@ -4772,7 +5263,182 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
4772
5263
 
4773
5264
  const int nb = n / QK_K;
4774
5265
 
4775
- #ifdef __ARM_NEON
5266
+ #if defined(__ARM_FEATURE_SVE)
5267
+
5268
+ uint32_t utmp[4];
5269
+
5270
+ const int8_t m32 = 32;
5271
+ const int vector_length = svcntb()*8;
5272
+ const svuint8_t m3b_sv = svdup_n_u8(0x3);
5273
+ const svint32_t vzero_sv = svdup_n_s32(0);
5274
+
5275
+ const svuint8_t m0_sv = svdup_n_u8(1);
5276
+ const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
5277
+ const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
5278
+ const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
5279
+ svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
5280
+
5281
+ float sum = 0;
5282
+
5283
+ for (int i = 0; i < nb; ++i) {
5284
+
5285
+ const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
5286
+
5287
+ const uint8_t * restrict q3_sv = x[i].qs;
5288
+ const uint8_t * restrict qh_sv = x[i].hmask;
5289
+ const int8_t * restrict q8_sv = y[i].qs;
5290
+
5291
+ // Set up scales
5292
+ uint32_t * aux = &x[i].scales;
5293
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
5294
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
5295
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
5296
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
5297
+
5298
+ int8_t * scale = (int8_t *)utmp;
5299
+
5300
+ for (int j = 0; j < 16; ++j) scale[j] -= m32;
5301
+
5302
+ switch (vector_length) {
5303
+ case 128:
5304
+ {
5305
+ svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
5306
+ svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
5307
+ svuint8_t q3h_sv;
5308
+
5309
+ svint32_t sumi1_1 = svdup_n_s32(0);
5310
+ svint8_t q3bytes_sv;
5311
+
5312
+ for (int j = 0; j < QK_K/128; ++j) {
5313
+
5314
+ const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
5315
+ const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
5316
+ svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5317
+ svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5318
+
5319
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
5320
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5321
+
5322
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
5323
+
5324
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
5325
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5326
+
5327
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
5328
+
5329
+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5330
+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5331
+
5332
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
5333
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5334
+
5335
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
5336
+
5337
+ q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
5338
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5339
+
5340
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
5341
+
5342
+
5343
+ scale += 4;
5344
+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5345
+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5346
+
5347
+ q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
5348
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5349
+
5350
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
5351
+
5352
+ q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
5353
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5354
+
5355
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
5356
+
5357
+
5358
+ q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5359
+ q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
5360
+
5361
+ q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
5362
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5363
+
5364
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
5365
+
5366
+ q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
5367
+ q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5368
+
5369
+ sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
5370
+
5371
+ if (j == 0) {
5372
+ qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
5373
+ qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
5374
+ }
5375
+
5376
+ scale += 4;
5377
+ }
5378
+
5379
+ sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
5380
+ } break;
5381
+ case 256:
5382
+ case 512:
5383
+ {
5384
+ svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
5385
+ svuint8_t q3h_sv;
5386
+
5387
+ svint32_t sumi1_1 = svdup_n_s32(0);
5388
+ svint8_t q3bytes_sv;
5389
+
5390
+ for (int j = 0; j < QK_K/128; ++j) {
5391
+
5392
+ const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
5393
+ svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5394
+ svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5395
+
5396
+ q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
5397
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5398
+
5399
+
5400
+ svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
5401
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
5402
+
5403
+ q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
5404
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5405
+
5406
+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
5407
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
5408
+
5409
+ scale += 4;
5410
+ q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5411
+ q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
5412
+
5413
+ q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
5414
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5415
+
5416
+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
5417
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
5418
+
5419
+ q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
5420
+ q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
5421
+
5422
+ scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
5423
+ sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
5424
+
5425
+ if (j == 0) {
5426
+ qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
5427
+ }
5428
+
5429
+ scale += 4;
5430
+ }
5431
+
5432
+ sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
5433
+ } break;
5434
+ default:
5435
+ assert(false && "Unsupported vector length");
5436
+ break;
5437
+ }
5438
+ }
5439
+ *s = sum;
5440
+
5441
+ #elif __ARM_NEON
4776
5442
 
4777
5443
  uint32_t aux[3];
4778
5444
  uint32_t utmp[4];
@@ -5112,6 +5778,94 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5112
5778
 
5113
5779
  *s = hsum_float_8(acc);
5114
5780
 
5781
+ #elif defined __wasm_simd128__
5782
+ int8_t aux8[QK_K];
5783
+ float sums[8] = {0};
5784
+ uint32_t auxs[4];
5785
+
5786
+ float sumf = 0;
5787
+ for (int i = 0; i < nb; ++i) {
5788
+ const uint8_t * restrict q3 = x[i].qs;
5789
+ const uint8_t * restrict hm = x[i].hmask;
5790
+ const int8_t * restrict q8 = y[i].qs;
5791
+
5792
+ // Process blocks with SIMD
5793
+ int8_t * a = aux8;
5794
+ uint8_t m = 1;
5795
+ for (int j = 0; j < QK_K; j += 128) {
5796
+ for (int shift = 0; shift <= 6; shift += 2) {
5797
+ v128_t v_m = wasm_i8x16_splat(m);
5798
+ for (int l = 0; l < 32; l += 16) {
5799
+ v128_t v_q3 = wasm_v128_load(q3 + l);
5800
+ v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
5801
+ v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
5802
+
5803
+ v128_t v_hm = wasm_v128_load(hm + l);
5804
+ v128_t v_mask = wasm_v128_and(v_hm, v_m);
5805
+ v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
5806
+
5807
+ v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
5808
+ wasm_v128_store(a + l, v_low2);
5809
+ }
5810
+ a += 32;
5811
+ m <<= 1;
5812
+ }
5813
+ q3 += 32;
5814
+ }
5815
+
5816
+ // Extract scales
5817
+ memcpy(auxs, x[i].scales, 12);
5818
+ uint32_t tmp = auxs[2];
5819
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
5820
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
5821
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
5822
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
5823
+ const int8_t * scales = (const int8_t *)auxs;
5824
+
5825
+ // SIMD dot product with register accumulators
5826
+ v128_t v_acc0 = wasm_i32x4_splat(0);
5827
+ v128_t v_acc1 = wasm_i32x4_splat(0);
5828
+ a = aux8;
5829
+ for (int j = 0; j < QK_K/16; ++j) {
5830
+ const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
5831
+
5832
+ // Process 16 elements per iteration
5833
+ for (int k = 0; k < 2; ++k) {
5834
+ const v128_t v_q8 = wasm_i16x8_load8x8(q8);
5835
+ const v128_t v_a = wasm_i16x8_load8x8(a);
5836
+
5837
+ v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
5838
+ v_prod = wasm_i16x8_mul(v_prod, v_scale);
5839
+
5840
+ v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
5841
+ v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
5842
+
5843
+ q8 += 8;
5844
+ a += 8;
5845
+ }
5846
+ }
5847
+
5848
+ // Accumulate results
5849
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5850
+ const v128_t v_d = wasm_f32x4_splat(d);
5851
+ v128_t v_sum = wasm_f32x4_add(
5852
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
5853
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
5854
+ );
5855
+
5856
+ // Accumulate into sums vector
5857
+ wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
5858
+ }
5859
+
5860
+ // Horizontal sum
5861
+ v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
5862
+ sumf = wasm_f32x4_extract_lane(v_sum, 0) +
5863
+ wasm_f32x4_extract_lane(v_sum, 1) +
5864
+ wasm_f32x4_extract_lane(v_sum, 2) +
5865
+ wasm_f32x4_extract_lane(v_sum, 3);
5866
+
5867
+ *s = sumf;
5868
+
5115
5869
  #elif defined __riscv_v_intrinsic
5116
5870
 
5117
5871
  uint32_t aux[3];
@@ -5367,8 +6121,6 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5367
6121
 
5368
6122
  #elif defined __loongarch_asx
5369
6123
 
5370
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
5371
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
5372
6124
  const __m128i m32 = __lsx_vreplgr2vr_b(32);
5373
6125
 
5374
6126
  __m256 acc = (__m256)__lasx_xvldi(0);
@@ -5388,10 +6140,9 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5388
6140
  (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
5389
6141
  (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
5390
6142
  scales128 = __lsx_vsub_b(scales128, m32);
5391
- const __m256i all_scales = lasx_ext8_16(scales128);
5392
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
5393
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
5394
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
6143
+
6144
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
6145
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
5395
6146
 
5396
6147
  // high bit
5397
6148
  const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
@@ -5399,35 +6150,23 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5399
6150
  // integer accumulator
5400
6151
  __m256i sumi = __lasx_xvldi(0);
5401
6152
 
5402
- int bit = 0;
5403
- int is = 0;
5404
- __m256i xvbit;
5405
-
5406
-
5407
6153
  for (int j = 0; j < QK_K/128; ++j) {
5408
6154
  // load low 2 bits
5409
6155
  const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
5410
6156
 
5411
- xvbit = __lasx_xvreplgr2vr_h(bit);
5412
6157
  // prepare low and high bits
5413
- const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
5414
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5415
- ++bit;
5416
-
5417
- xvbit = __lasx_xvreplgr2vr_h(bit);
5418
- const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
5419
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5420
- ++bit;
5421
-
5422
- xvbit = __lasx_xvreplgr2vr_h(bit);
5423
- const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
5424
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5425
- ++bit;
5426
-
5427
- xvbit = __lasx_xvreplgr2vr_h(bit);
5428
- const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
5429
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5430
- ++bit;
6158
+ const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
6159
+ const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
6160
+ const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
6161
+ const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
6162
+ const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
6163
+ const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
6164
+ const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
6165
+ const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
6166
+ const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
6167
+ const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
6168
+ const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
6169
+ const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
5431
6170
 
5432
6171
  // load Q8 quants
5433
6172
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
@@ -5435,29 +6174,16 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5435
6174
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
5436
6175
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
5437
6176
 
5438
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
5439
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
5440
- // and 2 if the high bit was set)
5441
- __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
5442
- __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
5443
- __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
5444
- __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
5445
-
5446
- __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
5447
- __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
5448
- __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
5449
- __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
5450
-
5451
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
5452
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
5453
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
5454
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
6177
+ __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
6178
+ __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
6179
+ __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
6180
+ __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
5455
6181
 
5456
6182
  // multiply with scales
5457
- p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
5458
- p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
5459
- p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
5460
- p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
6183
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
6184
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
6185
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
6186
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
5461
6187
 
5462
6188
  // accumulate
5463
6189
  p16_0 = __lasx_xvadd_w(p16_0, p16_1);
@@ -5465,7 +6191,7 @@ void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void
5465
6191
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
5466
6192
  }
5467
6193
  // multiply with block scale and accumulate
5468
- acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
6194
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
5469
6195
  }
5470
6196
 
5471
6197
  *s = hsum_float_8(acc);
@@ -5556,7 +6282,88 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
5556
6282
 
5557
6283
  uint32_t utmp[4];
5558
6284
 
5559
- #ifdef __ARM_NEON
6285
+ #ifdef __ARM_FEATURE_SVE
6286
+ float sumf = 0;
6287
+ for (int i = 0; i < nb; ++i) {
6288
+
6289
+ const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
6290
+ const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
6291
+
6292
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
6293
+
6294
+ memcpy(utmp, x[i].scales, K_SCALE_SIZE);
6295
+
6296
+ uint32x2_t mins8 = { 0 };
6297
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
6298
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
6299
+
6300
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6301
+ utmp[0] &= kmask1;
6302
+
6303
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
6304
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
6305
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
6306
+ sumf -= dmin * vaddvq_s32(prod);
6307
+
6308
+ const uint8_t * scales = (const uint8_t *)utmp;
6309
+
6310
+ const uint8_t * restrict q4 = x[i].qs;
6311
+ const int8_t * restrict q8 = y[i].qs;
6312
+
6313
+ const int vector_length = wsp_ggml_cpu_get_sve_cnt()*8;
6314
+ const svuint8_t m4b = svdup_n_u8(0xf);
6315
+ const svint32_t mzero = svdup_n_s32(0);
6316
+ svint32_t sumi1 = svdup_n_s32(0);
6317
+ svint32_t sumi1_1 = svdup_n_s32(0);
6318
+ svint32_t sumi1_2 = svdup_n_s32(0);
6319
+ svint32_t sumi2 = svdup_n_s32(0);
6320
+ svint32_t sumi2_1 = svdup_n_s32(0);
6321
+ svint32_t sumi2_2 = svdup_n_s32(0);
6322
+ switch (vector_length) {
6323
+ case 128:
6324
+ {
6325
+ for (int j = 0; j < QK_K/64; ++j) {
6326
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
6327
+ svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6328
+ sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6329
+ q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
6330
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6331
+ sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6332
+
6333
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
6334
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6335
+ sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6336
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
6337
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
6338
+ sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6339
+ q4 += 32;
6340
+ }
6341
+ sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
6342
+ sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
6343
+ sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
6344
+ } break;
6345
+ case 256:
6346
+ case 512:
6347
+ {
6348
+ for (int j = 0; j < QK_K/64; ++j) {
6349
+ const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
6350
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
6351
+ svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
6352
+ sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
6353
+
6354
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
6355
+ q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
6356
+ sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
6357
+ }
6358
+ sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
6359
+ } break;
6360
+ default:
6361
+ assert(false && "Unsupported vector length");
6362
+ break;
6363
+ }
6364
+ }
6365
+ *s = sumf;
6366
+ #elif defined __ARM_NEON
5560
6367
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5561
6368
  const int32x4_t mzero = vdupq_n_s32(0);
5562
6369
 
@@ -5595,26 +6402,127 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
5595
6402
  int32_t sumi2 = 0;
5596
6403
 
5597
6404
  for (int j = 0; j < QK_K/64; ++j) {
5598
- const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32;
5599
-
5600
- q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
5601
- q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
5602
- q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
5603
-
5604
- const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
5605
- sumi1 += vaddvq_s32(p1) * scales[2*j+0];
5606
-
5607
- q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
5608
- q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
5609
- q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
6405
+ const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32;
6406
+
6407
+ q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
6408
+ q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
6409
+ q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
6410
+
6411
+ const int32x4_t p1 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
6412
+ sumi1 += vaddvq_s32(p1) * scales[2*j+0];
6413
+
6414
+ q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;
6415
+ q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
6416
+ q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
6417
+
6418
+ const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
6419
+
6420
+ sumi2 += vaddvq_s32(p2) * scales[2*j+1];
6421
+ }
6422
+
6423
+ sumf += d * (sumi1 + sumi2);
6424
+
6425
+ }
6426
+
6427
+ *s = sumf;
6428
+
6429
+ #elif defined __wasm_simd128__
6430
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
6431
+ float sumf = 0;
6432
+
6433
+ for (int i = 0; i < nb; ++i) {
6434
+ const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
6435
+ const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
6436
+
6437
+ const uint8_t * restrict q4 = x[i].qs;
6438
+ const int8_t * restrict q8 = y[i].qs;
6439
+
6440
+ // Process scales and mins
6441
+ memcpy(utmp, x[i].scales, 12);
6442
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6443
+ const uint32_t uaux = utmp[1] & kmask1;
6444
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6445
+ utmp[2] = uaux;
6446
+ utmp[0] &= kmask1;
6447
+
6448
+ // Sum mins * q8sums
6449
+ int32_t sumi = 0;
6450
+ const int16_t * restrict q8sums = y[i].bsums;
6451
+ const uint8_t * m = (const uint8_t *)&utmp[2];
6452
+ for (int j = 0; j < 16; j += 2) {
6453
+ sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
6454
+ }
6455
+ sumf -= dmin * sumi;
6456
+
6457
+ int32_t sumi1 = 0;
6458
+ int32_t sumi2 = 0;
6459
+
6460
+ for (int j = 0; j < QK_K/64; ++j) {
6461
+ // Load 64 4-bit weights (32 bytes)
6462
+ const v128_t q4x0 = wasm_v128_load(q4);
6463
+ const v128_t q4x1 = wasm_v128_load(q4 + 16);
6464
+ q4 += 32;
5610
6465
 
5611
- const int32x4_t p2 = wsp_ggml_vdotq_s32(wsp_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
6466
+ // Split into low/high nibbles
6467
+ const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
6468
+ const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
6469
+ const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
6470
+ const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
6471
+
6472
+ // Load 64 8-bit values (64 bytes)
6473
+ const v128_t q8x0 = wasm_v128_load(q8);
6474
+ const v128_t q8x1 = wasm_v128_load(q8 + 16);
6475
+ const v128_t q8x2 = wasm_v128_load(q8 + 32);
6476
+ const v128_t q8x3 = wasm_v128_load(q8 + 48);
6477
+ q8 += 64;
5612
6478
 
5613
- sumi2 += vaddvq_s32(p2) * scales[2*j+1];
6479
+ // Low nibble products
6480
+ v128_t vacc1 = wasm_i32x4_dot_i16x8(
6481
+ wasm_i16x8_extend_low_i8x16(q4l0),
6482
+ wasm_i16x8_extend_low_i8x16(q8x0)
6483
+ );
6484
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6485
+ wasm_i16x8_extend_high_i8x16(q4l0),
6486
+ wasm_i16x8_extend_high_i8x16(q8x0)
6487
+ ));
6488
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6489
+ wasm_i16x8_extend_low_i8x16(q4l1),
6490
+ wasm_i16x8_extend_low_i8x16(q8x1)
6491
+ ));
6492
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6493
+ wasm_i16x8_extend_high_i8x16(q4l1),
6494
+ wasm_i16x8_extend_high_i8x16(q8x1)
6495
+ ));
6496
+
6497
+ // High nibble products
6498
+ v128_t vacc2 = wasm_i32x4_dot_i16x8(
6499
+ wasm_i16x8_extend_low_i8x16(q4h0),
6500
+ wasm_i16x8_extend_low_i8x16(q8x2)
6501
+ );
6502
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6503
+ wasm_i16x8_extend_high_i8x16(q4h0),
6504
+ wasm_i16x8_extend_high_i8x16(q8x2)
6505
+ ));
6506
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6507
+ wasm_i16x8_extend_low_i8x16(q4h1),
6508
+ wasm_i16x8_extend_low_i8x16(q8x3)
6509
+ ));
6510
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6511
+ wasm_i16x8_extend_high_i8x16(q4h1),
6512
+ wasm_i16x8_extend_high_i8x16(q8x3)
6513
+ ));
6514
+
6515
+ // Accumulate scaled results
6516
+ int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
6517
+ wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
6518
+ sumi1 += vacc1_sum * scales[2*j];
6519
+
6520
+ int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
6521
+ wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
6522
+ sumi2 += vacc2_sum * scales[2*j+1];
5614
6523
  }
5615
6524
 
5616
6525
  sumf += d * (sumi1 + sumi2);
5617
-
5618
6526
  }
5619
6527
 
5620
6528
  *s = sumf;
@@ -5976,11 +6884,6 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
5976
6884
  *s = vec_extract(vsumf0, 0);
5977
6885
 
5978
6886
  #elif defined __loongarch_asx
5979
- WSP_GGML_UNUSED(kmask1);
5980
- WSP_GGML_UNUSED(kmask2);
5981
- WSP_GGML_UNUSED(kmask3);
5982
-
5983
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
5984
6887
 
5985
6888
  __m256 acc = (__m256)__lasx_xvldi(0);
5986
6889
  __m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6000,33 +6903,34 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
6000
6903
  const uint8_t * restrict q4 = x[i].qs;
6001
6904
  const int8_t * restrict q8 = y[i].qs;
6002
6905
 
6003
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
6906
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
6907
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
6908
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
6004
6909
 
6005
6910
  const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
6006
6911
  const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6007
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6912
+ const __m128i prod = lsx_madd_h(mins128, q8s);
6008
6913
  acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
6009
6914
 
6010
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6011
- const __m256i scales = lasx_insertf128(sc128, sc128);
6915
+ const __m256i scales = lasx_insertf128(scales128, scales128);
6012
6916
 
6013
6917
  __m256i sumi = __lasx_xvldi(0);
6014
6918
 
6015
6919
  for (int j = 0; j < QK_K/64; ++j) {
6016
6920
 
6017
- const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6018
- const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
6921
+ const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
6922
+ const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
6019
6923
 
6020
6924
  const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
6021
- const __m256i q4l = __lasx_xvand_v(q4bits, m4);
6022
- const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
6925
+ const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
6926
+ const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
6023
6927
 
6024
6928
  const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6025
- __m256i p16l = lasx_maddubs_h(q4l, q8l);
6929
+ __m256i p16l = lasx_madd_h_b(q4l, q8l);
6026
6930
  p16l = lasx_madd_h(scale_l, p16l);
6027
6931
 
6028
6932
  const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6029
- __m256i p16h = lasx_maddubs_h(q4h, q8h);
6933
+ __m256i p16h = lasx_madd_h_b(q4h, q8h);
6030
6934
  p16h = lasx_madd_h(scale_h, p16h);
6031
6935
  const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
6032
6936
 
@@ -6043,9 +6947,78 @@ void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void
6043
6947
  acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
6044
6948
 
6045
6949
 
6046
- ft_union fi;
6047
- fi.i = __lsx_vpickve2gr_w(acc_m, 0);
6048
- *s = hsum_float_8(acc) + fi.f ;
6950
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
6951
+ #elif defined(__VXE__) || defined(__VXE2__)
6952
+ const uint8x16_t v_lm = vec_splat_u8(0x0F);
6953
+ const int32x4_t v_z = vec_splat_s32(0);
6954
+
6955
+ uint8x16_t v_x[2];
6956
+ int8x16_t v_xl[2];
6957
+ int8x16_t v_y[2];
6958
+
6959
+ float sumf = 0;
6960
+
6961
+ for (int i = 0; i < nb; ++i) {
6962
+ const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
6963
+ const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
6964
+
6965
+ const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
6966
+ const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
6967
+ const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
6968
+
6969
+ memcpy(utmp, x[i].scales, 12);
6970
+
6971
+ uint32x4_t v_mins8 = { 0 };
6972
+ v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
6973
+ v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
6974
+
6975
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6976
+ utmp[0] &= kmask1;
6977
+
6978
+ const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
6979
+
6980
+ const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
6981
+ const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
6982
+ const int32x4_t v_mins = v_minso + v_minse;
6983
+ sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
6984
+
6985
+ const uint8_t * scales = (const uint8_t *)utmp;
6986
+ const uint8_t * restrict x0 = x[i].qs;
6987
+ const int8_t * restrict y0 = y[i].qs;
6988
+
6989
+ int32_t sumi1 = 0;
6990
+ int32_t sumi2 = 0;
6991
+
6992
+ for (int j = 0; j < QK_K/64; ++j) {
6993
+ v_x[0] = vec_xl(0 , x0);
6994
+ v_x[1] = vec_xl(16, x0);
6995
+ x0 += 32;
6996
+
6997
+ v_y[0] = vec_xl(0 , y0);
6998
+ v_y[1] = vec_xl(16, y0);
6999
+ y0 += 32;
7000
+
7001
+ v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
7002
+ v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
7003
+
7004
+ const int32x4_t p1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
7005
+ sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
7006
+
7007
+ v_y[0] = vec_xl(0 , y0);
7008
+ v_y[1] = vec_xl(16, y0);
7009
+ y0 += 32;
7010
+
7011
+ v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
7012
+ v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
7013
+
7014
+ const int32x4_t p2 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
7015
+ sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
7016
+ }
7017
+
7018
+ sumf += d * (sumi1 + sumi2);
7019
+ }
7020
+
7021
+ *s = sumf;
6049
7022
  #else
6050
7023
 
6051
7024
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6371,6 +7344,118 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6371
7344
 
6372
7345
  *s = hsum_float_8(acc) + summs;
6373
7346
 
7347
+ #elif defined __wasm_simd128__
7348
+ //const uint8_t * scales = (const uint8_t*)&utmp[0];
7349
+ float sumf = 0;
7350
+
7351
+ for (int i = 0; i < nb; ++i) {
7352
+ const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
7353
+ const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
7354
+
7355
+ const uint8_t * restrict q5 = x[i].qs;
7356
+ const uint8_t * restrict qh = x[i].qh;
7357
+ const int8_t * restrict q8 = y[i].qs;
7358
+
7359
+ // Process scales and mins
7360
+ memcpy(utmp, x[i].scales, 12);
7361
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7362
+ const uint32_t uaux = utmp[1] & kmask1;
7363
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7364
+ utmp[2] = uaux;
7365
+ utmp[0] &= kmask1;
7366
+
7367
+ // Sum mins * q8sums
7368
+ int32_t sumi_mins = 0;
7369
+ const int16_t * restrict q8sums = y[i].bsums;
7370
+ const uint8_t * m = (const uint8_t *)&utmp[2];
7371
+ for (int j = 0; j < 16; j += 2) {
7372
+ sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
7373
+ }
7374
+ sumf -= dmin * sumi_mins; // Correct subtraction
7375
+
7376
+ v128_t qh0 = wasm_v128_load(qh);
7377
+ v128_t qh1 = wasm_v128_load(qh + 16);
7378
+ const uint8_t * sc = (const uint8_t *)utmp;
7379
+
7380
+ int32_t sumi = 0;
7381
+
7382
+ for (int j = 0; j < QK_K/64; ++j) {
7383
+ const int shift = j * 2;
7384
+ v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
7385
+ v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
7386
+
7387
+ v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
7388
+ v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
7389
+ v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
7390
+ v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
7391
+
7392
+ v128_t q5_0 = wasm_v128_load(q5);
7393
+ v128_t q5_1 = wasm_v128_load(q5 + 16);
7394
+ q5 += 32;
7395
+
7396
+ v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
7397
+ v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
7398
+ v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
7399
+ v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
7400
+
7401
+ v128_t q8_0 = wasm_v128_load(q8);
7402
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
7403
+ v128_t q8_2 = wasm_v128_load(q8 + 32);
7404
+ v128_t q8_3 = wasm_v128_load(q8 + 48);
7405
+ q8 += 64;
7406
+
7407
+ // Process low quants
7408
+ v128_t pl0 = wasm_i32x4_dot_i16x8(
7409
+ wasm_i16x8_extend_low_i8x16(q5l_0),
7410
+ wasm_i16x8_extend_low_i8x16(q8_0)
7411
+ );
7412
+ pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
7413
+ wasm_i16x8_extend_high_i8x16(q5l_0),
7414
+ wasm_i16x8_extend_high_i8x16(q8_0)
7415
+ ));
7416
+ v128_t pl1 = wasm_i32x4_dot_i16x8(
7417
+ wasm_i16x8_extend_low_i8x16(q5l_1),
7418
+ wasm_i16x8_extend_low_i8x16(q8_1)
7419
+ );
7420
+ pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
7421
+ wasm_i16x8_extend_high_i8x16(q5l_1),
7422
+ wasm_i16x8_extend_high_i8x16(q8_1)
7423
+ ));
7424
+ v128_t sum_low = wasm_i32x4_add(pl0, pl1);
7425
+
7426
+ // Process high quants
7427
+ v128_t ph0 = wasm_i32x4_dot_i16x8(
7428
+ wasm_i16x8_extend_low_i8x16(q5h_0),
7429
+ wasm_i16x8_extend_low_i8x16(q8_2)
7430
+ );
7431
+ ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
7432
+ wasm_i16x8_extend_high_i8x16(q5h_0),
7433
+ wasm_i16x8_extend_high_i8x16(q8_2)
7434
+ ));
7435
+ v128_t ph1 = wasm_i32x4_dot_i16x8(
7436
+ wasm_i16x8_extend_low_i8x16(q5h_1),
7437
+ wasm_i16x8_extend_low_i8x16(q8_3)
7438
+ );
7439
+ ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
7440
+ wasm_i16x8_extend_high_i8x16(q5h_1),
7441
+ wasm_i16x8_extend_high_i8x16(q8_3)
7442
+ ));
7443
+ v128_t sum_high = wasm_i32x4_add(ph0, ph1);
7444
+
7445
+ // Accumulate with scale factors
7446
+ int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
7447
+ wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
7448
+ int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
7449
+ wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
7450
+
7451
+ sumi += sl * sc[2*j] + sh * sc[2*j+1];
7452
+ }
7453
+
7454
+ sumf += d * sumi;
7455
+ }
7456
+
7457
+ *s = sumf;
7458
+
6374
7459
  #elif defined __riscv_v_intrinsic
6375
7460
 
6376
7461
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6593,19 +7678,11 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6593
7678
  *s = vec_extract(vsumf0, 0);
6594
7679
 
6595
7680
  #elif defined __loongarch_asx
6596
- WSP_GGML_UNUSED(kmask1);
6597
- WSP_GGML_UNUSED(kmask2);
6598
- WSP_GGML_UNUSED(kmask3);
6599
-
6600
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
6601
- const __m128i mzero = __lsx_vldi(0);
6602
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
6603
7681
 
6604
7682
  __m256 acc = (__m256)__lasx_xvldi(0);
7683
+ __m128 acc_m = (__m128)__lsx_vldi(0);
6605
7684
 
6606
- float summs = 0.f;
6607
-
6608
- for (int i = 0; i < nb; ++i) {
7685
+ for (int i = 0; i < nb; ++i) {
6609
7686
 
6610
7687
  const uint8_t * restrict q5 = x[i].qs;
6611
7688
  const int8_t * restrict q8 = y[i].qs;
@@ -6620,49 +7697,40 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6620
7697
  utmp[2] = uaux;
6621
7698
  utmp[0] &= kmask1;
6622
7699
 
6623
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7700
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
7701
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
7702
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
6624
7703
 
6625
7704
  const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
6626
7705
  const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6627
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6628
- const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
6629
- summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
7706
+ const __m128i prod = lsx_madd_h(mins128, q8s);
7707
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
6630
7708
 
6631
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6632
- const __m256i scales = lasx_insertf128(sc128, sc128);
7709
+ const __m256i scales = lasx_insertf128(scales128, scales128);
6633
7710
 
6634
7711
  const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
6635
- __m256i hmask = mone;
6636
7712
 
6637
7713
  __m256i sumi = __lasx_xvldi(0);
6638
7714
 
6639
- int bit = 0;
6640
- __m256i xvbit;
6641
-
6642
7715
  for (int j = 0; j < QK_K/64; ++j) {
6643
7716
 
6644
- const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6645
- const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
7717
+ const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
7718
+ const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
6646
7719
 
6647
7720
  const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
6648
7721
 
6649
- xvbit = __lasx_xvreplgr2vr_h(bit++);
6650
- const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
6651
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
6652
- const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
6653
- hmask = __lasx_xvslli_h(hmask, 1);
6654
-
6655
- xvbit = __lasx_xvreplgr2vr_h(bit++);
6656
- const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
6657
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
6658
- const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
6659
- hmask = __lasx_xvslli_h(hmask, 1);
7722
+ const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
7723
+ const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
7724
+ const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
7725
+ const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
7726
+ const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
7727
+ const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
6660
7728
 
6661
7729
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6662
7730
  const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6663
7731
 
6664
- __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
6665
- __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
7732
+ __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
7733
+ __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
6666
7734
 
6667
7735
  p16_0 = lasx_madd_h(scale_0, p16_0);
6668
7736
  p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -6676,8 +7744,98 @@ void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void
6676
7744
 
6677
7745
  }
6678
7746
 
6679
- *s = hsum_float_8(acc) + summs;
7747
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
7748
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
7749
+
7750
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
7751
+ #elif defined(__VXE__) || defined(__VXE2__)
7752
+ const uint8x16_t v_lm = vec_splat_u8(0x0F);
7753
+ const uint8x16_t v_1m = vec_splat_u8(0x01);
7754
+ const uint8x16_t v_2m = vec_splat_u8(0x02);
7755
+
7756
+ const int32x4_t v_z = vec_splat_s32(0);
7757
+
7758
+ const uchar8x16_t v_minsm = {
7759
+ 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
7760
+ 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
7761
+ };
7762
+
7763
+ int8x16_t q5b[4];
7764
+ uint8x16_t q5h[4];
7765
+
7766
+ uint8x16_t v_xl[2];
7767
+ uint8x16_t v_xh[2];
7768
+ int8x16_t v_y[4];
7769
+
7770
+ float sumf = 0;
7771
+
7772
+ for (int i = 0; i < nb; ++i) {
7773
+ const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d);
7774
+ const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin);
7775
+
7776
+ const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
7777
+ const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
7778
+ const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
7779
+
7780
+ memcpy(utmp, x[i].scales, 12);
7781
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7782
+ const uint32_t uaux = utmp[1] & kmask1;
7783
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7784
+ utmp[2] = uaux;
7785
+ utmp[0] &= kmask1;
7786
+
7787
+ const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
7788
+ const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
7789
+ const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
7790
+
7791
+ const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
7792
+ const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
7793
+ const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
7794
+ const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
7795
+
7796
+ const uint8_t * scales = (const uint8_t *)utmp;
7797
+ const uint8_t * restrict x0l = x[i].qs;
7798
+ const uint8_t * restrict x0h = x[i].qh;
7799
+ const int8_t * restrict y0 = y[i].qs;
7800
+
7801
+ v_xh[0] = vec_xl(0 , x0h);
7802
+ v_xh[1] = vec_xl(16, x0h);
7803
+
7804
+ int32_t sumi = 0;
7805
+ for (int j = 0; j < QK_K/64; ++j) {
7806
+ v_xl[0] = vec_xl(0 , x0l);
7807
+ v_xl[1] = vec_xl(16, x0l);
7808
+ x0l += 32;
7809
+
7810
+ v_y[0] = vec_xl(0 , y0);
7811
+ v_y[1] = vec_xl(16, y0);
7812
+ v_y[2] = vec_xl(32, y0);
7813
+ v_y[3] = vec_xl(48, y0);
7814
+ y0 += 64;
7815
+
7816
+ q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
7817
+ q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
7818
+ q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
7819
+ q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
7820
+ v_xh[0] = vec_sr(v_xh[0], 2);
7821
+ v_xh[1] = vec_sr(v_xh[1], 2);
7822
+
7823
+ q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
7824
+ q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
7825
+ q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
7826
+ q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
7827
+
7828
+ int32x4_t sumi0 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
7829
+ int32x4_t sumi1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
7830
+
7831
+ sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
7832
+ sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
7833
+ }
7834
+
7835
+ sumf += d * sumi - dmin * mins;
7836
+ }
6680
7837
 
7838
+ *s = sumf;
6681
7839
  #else
6682
7840
 
6683
7841
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7034,6 +8192,85 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7034
8192
 
7035
8193
  *s = hsum_float_8(acc);
7036
8194
 
8195
+ #elif defined __wasm_simd128__
8196
+ int8_t aux8[QK_K] __attribute__((aligned(16)));
8197
+ int32_t aux32[8] __attribute__((aligned(16))) = {0};
8198
+ float sums[8] __attribute__((aligned(16))) = {0};
8199
+
8200
+ for (int i = 0; i < nb; ++i) {
8201
+ // Unpack 6-bit quantized data into aux8 (unchanged)
8202
+ const uint8_t * restrict q4 = x[i].ql;
8203
+ const uint8_t * restrict qh = x[i].qh;
8204
+ int8_t * a = aux8;
8205
+ for (int j = 0; j < QK_K; j += 128) {
8206
+ for (int l = 0; l < 32; ++l) {
8207
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
8208
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
8209
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
8210
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
8211
+ }
8212
+ a += 128;
8213
+ q4 += 64;
8214
+ qh += 32;
8215
+ }
8216
+
8217
+ const int8_t * restrict a_ptr = aux8;
8218
+ const int8_t * restrict q8 = y[i].qs;
8219
+ v128_t acc0 = wasm_i32x4_splat(0);
8220
+ v128_t acc1 = wasm_i32x4_splat(0);
8221
+
8222
+ for (int j = 0; j < QK_K/16; ++j) {
8223
+ const int scale = x[i].scales[j];
8224
+ const v128_t vscale = wasm_i32x4_splat(scale);
8225
+
8226
+ // Load 16 elements from a and q8
8227
+ const v128_t a_vec = wasm_v128_load(a_ptr);
8228
+ const v128_t q8_vec = wasm_v128_load(q8);
8229
+
8230
+ // Process low 8 elements
8231
+ v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
8232
+ v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
8233
+ v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
8234
+ v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
8235
+ v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
8236
+
8237
+ // Process high 8 elements
8238
+ v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
8239
+ v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
8240
+ v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
8241
+ v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
8242
+ v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
8243
+
8244
+ // Scale and accumulate
8245
+ prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
8246
+ prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
8247
+ prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
8248
+ prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
8249
+
8250
+ acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
8251
+ acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
8252
+
8253
+ a_ptr += 16;
8254
+ q8 += 16;
8255
+ }
8256
+
8257
+ // Store accumulated results
8258
+ wasm_v128_store(&aux32[0], acc0);
8259
+ wasm_v128_store(&aux32[4], acc1);
8260
+
8261
+ const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8262
+ for (int l = 0; l < 8; ++l) {
8263
+ sums[l] += d * aux32[l];
8264
+ }
8265
+ }
8266
+
8267
+ // Sum final results
8268
+ float sumf = 0;
8269
+ for (int l = 0; l < 8; ++l) {
8270
+ sumf += sums[l];
8271
+ }
8272
+ *s = sumf;
8273
+
7037
8274
  #elif defined __riscv_v_intrinsic
7038
8275
 
7039
8276
  float sumf = 0;
@@ -7258,8 +8495,6 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7258
8495
 
7259
8496
  #elif defined __loongarch_asx
7260
8497
 
7261
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7262
- const __m256i m2 = __lasx_xvreplgr2vr_b(3);
7263
8498
  const __m256i m32s = __lasx_xvreplgr2vr_b(32);
7264
8499
 
7265
8500
  __m256 acc = (__m256)__lasx_xvldi(0);
@@ -7272,58 +8507,42 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7272
8507
  const uint8_t * restrict qh = x[i].qh;
7273
8508
  const int8_t * restrict q8 = y[i].qs;
7274
8509
 
7275
- const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
8510
+ const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
8511
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
8512
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
7276
8513
 
7277
8514
  __m256i sumi = __lasx_xvldi(0);
7278
8515
 
7279
- int is = 0;
7280
-
7281
8516
  for (int j = 0; j < QK_K/128; ++j) {
7282
8517
 
7283
- const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
7284
- const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
7285
- const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
7286
- const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
7287
- is += 4;
7288
-
7289
8518
  const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
7290
8519
  const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
7291
8520
  const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
7292
8521
 
7293
- const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
7294
- const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
7295
- const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
7296
- const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
8522
+ const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
8523
+ const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
8524
+ const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
8525
+ const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
7297
8526
 
7298
- const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
7299
- const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
7300
- const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
7301
- const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
8527
+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
8528
+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
8529
+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
8530
+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
7302
8531
 
7303
8532
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7304
8533
  const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7305
8534
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7306
8535
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7307
8536
 
7308
- __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
7309
- __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
7310
- __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
7311
- __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
7312
-
7313
- __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
7314
- __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
7315
- __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
7316
- __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
7317
-
7318
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
7319
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
7320
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
7321
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
8537
+ __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
8538
+ __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
8539
+ __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
8540
+ __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
7322
8541
 
7323
- p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
7324
- p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
7325
- p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
7326
- p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
8542
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
8543
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
8544
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
8545
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
7327
8546
 
7328
8547
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
7329
8548
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
@@ -7333,7 +8552,130 @@ void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
7333
8552
  }
7334
8553
 
7335
8554
  *s = hsum_float_8(acc);
8555
+ #elif defined(__VXE__) || defined(__VXE2__)
8556
+ float sum = 0;
8557
+
8558
+ // Lower 4-bit and upper 2-bit masks
8559
+ const uint8x16_t v_lm = vec_splat_u8(0x0F);
8560
+ const uint8x16_t v_um = vec_splat_u8(0x03);
8561
+
8562
+ const int32x4_t v_z = vec_splat_s32(0);
8563
+
8564
+ int8x16_t q6b[4];
8565
+ uint8x16_t q6h[4];
8566
+
8567
+ uint8x16_t v_xl[4];
8568
+ uint8x16_t v_xh[2];
8569
+ int8x16_t v_y[4];
8570
+
8571
+ for (int i = 0; i < nb; ++i) {
8572
+ const float d_all = WSP_GGML_FP16_TO_FP32(x[i].d);
8573
+
8574
+ const uint8_t * restrict x0l = x[i].ql;
8575
+ const uint8_t * restrict x0h = x[i].qh;
8576
+ const int8_t * restrict y0 = y[i].qs;
8577
+
8578
+ const int8_t * restrict scale = x[i].scales;
8579
+
8580
+ const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
8581
+ const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
8582
+
8583
+ const int8x16_t v_scale = vec_xl(0, scale);
8584
+ const int16x8_t v_scalel = vec_unpackh(v_scale);
8585
+ const int16x8_t v_scaleh = vec_unpackl(v_scale);
8586
+
8587
+ const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
8588
+ const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
8589
+ const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
8590
+ const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
8591
+ const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
8592
+
8593
+ const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
8594
+
8595
+ int32_t isum = 0;
8596
+ for (int j = 0; j < QK_K/128; ++j) {
8597
+ // Load model upper 2 bits
8598
+ v_xh[0] = vec_xl(0 , x0h);
8599
+ v_xh[1] = vec_xl(16, x0h);
8600
+ x0h += 32;
8601
+
8602
+ // Load model lower 4 bits
8603
+ v_xl[0] = vec_xl(0 , x0l);
8604
+ v_xl[1] = vec_xl(16, x0l);
8605
+ v_xl[2] = vec_xl(32, x0l);
8606
+ v_xl[3] = vec_xl(48, x0l);
8607
+ x0l += 64;
8608
+
8609
+ // Load activation quants
8610
+ v_y[0] = vec_xl(0 , y0);
8611
+ v_y[1] = vec_xl(16, y0);
8612
+ v_y[2] = vec_xl(32, y0);
8613
+ v_y[3] = vec_xl(48, y0);
8614
+ y0 += 64;
8615
+
8616
+ q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
8617
+ q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
8618
+ uint8x16_t shifted = vec_sr(v_xh[0], 2);
8619
+ q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
8620
+ shifted = vec_sr(v_xh[1], 2);
8621
+ q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
8622
+
8623
+ q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
8624
+ q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
8625
+ q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
8626
+ q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
8627
+
8628
+ int32x4_t summs0 = wsp_ggml_vec_dot(v_z, q6b[0], v_y[0]);
8629
+ int32x4_t summs1 = wsp_ggml_vec_dot(v_z, q6b[1], v_y[1]);
8630
+ int32x4_t summs2 = wsp_ggml_vec_dot(v_z, q6b[2], v_y[2]);
8631
+ int32x4_t summs3 = wsp_ggml_vec_dot(v_z, q6b[3], v_y[3]);
8632
+
8633
+ isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
8634
+ (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
8635
+ (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
8636
+ (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
8637
+
8638
+ scale += 4;
8639
+
8640
+
8641
+ // Load activation quants
8642
+ v_y[0] = vec_xl(0 , y0);
8643
+ v_y[1] = vec_xl(16, y0);
8644
+ v_y[2] = vec_xl(32, y0);
8645
+ v_y[3] = vec_xl(48, y0);
8646
+ y0 += 64;
8647
+
8648
+ shifted = vec_sr(v_xh[0], 4);
8649
+ q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
8650
+ shifted = vec_sr(v_xh[1], 4);
8651
+ q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
8652
+ shifted = vec_sr(v_xh[0], 6);
8653
+ q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
8654
+ shifted = vec_sr(v_xh[1], 6);
8655
+ q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
8656
+
8657
+ q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
8658
+ q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
8659
+ q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
8660
+ q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
8661
+
8662
+ summs0 = wsp_ggml_vec_dot(v_z, q6b[0], v_y[0]);
8663
+ summs1 = wsp_ggml_vec_dot(v_z, q6b[1], v_y[1]);
8664
+ summs2 = wsp_ggml_vec_dot(v_z, q6b[2], v_y[2]);
8665
+ summs3 = wsp_ggml_vec_dot(v_z, q6b[3], v_y[3]);
8666
+
8667
+ isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
8668
+ (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
8669
+ (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
8670
+ (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
8671
+
8672
+ scale += 4;
8673
+ }
7336
8674
 
8675
+ sum += d_all * y[i].d * (isum - 32 * mins);
8676
+ }
8677
+
8678
+ *s = sum;
7337
8679
  #else
7338
8680
 
7339
8681
  int8_t aux8[QK_K];
@@ -7694,7 +9036,57 @@ void wsp_ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const v
7694
9036
  }
7695
9037
 
7696
9038
  *s = 0.125f * hsum_float_8(accumf);
7697
-
9039
+ //#elif defined(__VXE__) || defined(__VXE2__)
9040
+ // const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
9041
+ //
9042
+ // uint32_t aux32[4];
9043
+ // const uint8_t * aux8 = (const uint8_t *)aux32;
9044
+ //
9045
+ // float sumf = 0;
9046
+ //
9047
+ // for (int i = 0; i < nb; ++i) {
9048
+ // const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9049
+ // const uint16_t * restrict q2 = x[i].qs;
9050
+ // const int8_t * restrict q8 = y[i].qs;
9051
+ //
9052
+ // float sumf1 = 0, sumf2 = 0;
9053
+ //
9054
+ // for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
9055
+ // int8x16_t q8b0 = vec_xl( 0, q8);
9056
+ // int8x16_t qb81 = vec_xl(16, q8);
9057
+ // int8x16_t q8b2 = vec_xl(32, q8);
9058
+ // int8x16_t q8b3 = vec_xl(48, q8);
9059
+ // q8 += 64;
9060
+ //
9061
+ // memcpy(aux32, q2, 4 * sizeof(uint32_t));
9062
+ // q2 += 8;
9063
+ //
9064
+ // int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
9065
+ // int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
9066
+ // int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
9067
+ // int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
9068
+ //
9069
+ // int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) };
9070
+ // int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
9071
+ // int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) };
9072
+ // int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
9073
+ //
9074
+ // q2u0 = vec_mul(q2u0, q2s0);
9075
+ // q2u1 = vec_mul(q2u1, q2s1);
9076
+ // q2u2 = vec_mul(q2u2, q2s2);
9077
+ // q2u3 = vec_mul(q2u3, q2s3);
9078
+ //
9079
+ // const int32x4_t p1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
9080
+ // const int32x4_t p2 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
9081
+ //
9082
+ // sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
9083
+ // sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
9084
+ // }
9085
+ //
9086
+ // sumf += d * (sumf1 + sumf2);
9087
+ // }
9088
+ //
9089
+ // *s = 0.25f * sumf;
7698
9090
  #else
7699
9091
 
7700
9092
  uint32_t aux32[2];
@@ -9648,13 +11040,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9648
11040
  }
9649
11041
  #elif defined(__loongarch_asx)
9650
11042
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9651
- const __m256i ax = __lasx_xvsigncov_b(x, x);
9652
- const __m256i sy = __lasx_xvsigncov_b(x, y);
9653
- __m256i tmp1, tmp2, tmp3;
9654
- tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
9655
- tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
9656
- tmp3 = __lasx_xvadd_h(tmp1, tmp2);
9657
- return __lasx_xvsat_h(tmp3, 15);
11043
+ const __m256i a = __lasx_xvmulwev_h_b(x, y);
11044
+ const __m256i b = __lasx_xvmulwod_h_b(x, y);
11045
+ return __lasx_xvadd_h(a, b);
9658
11046
  }
9659
11047
  #endif
9660
11048
 
@@ -10459,6 +11847,27 @@ void wsp_ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const vo
10459
11847
 
10460
11848
  sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
10461
11849
 
11850
+ #elif defined(__VXE__) || defined(__VXE2__)
11851
+ const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
11852
+ const uint8x16_t v_m = vec_splat_u8(0x0F);
11853
+
11854
+ for (; ib < nb; ++ib) {
11855
+ const block_iq4_nl * restrict x0 = &x[ib];
11856
+ const block_q8_0 * restrict y0 = &y[ib];
11857
+
11858
+ const uint8x16_t v_x = vec_xl(0, x0->qs);
11859
+ int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
11860
+ int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
11861
+
11862
+ v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
11863
+ v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
11864
+
11865
+ const int8x16_t v_yl = vec_xl(0 , y0->qs);
11866
+ const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
11867
+ const int32x4_t v_xy = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
11868
+
11869
+ sumf += WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
11870
+ }
10462
11871
  #endif
10463
11872
  for (; ib < nb; ++ib) {
10464
11873
  const float d = WSP_GGML_FP16_TO_FP32(y[ib].d)*WSP_GGML_FP16_TO_FP32(x[ib].d);
@@ -10704,67 +12113,31 @@ void wsp_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const vo
10704
12113
  #elif defined(__loongarch_asx)
10705
12114
 
10706
12115
  const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
10707
- const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
10708
12116
 
10709
12117
  __m256 accum = (__m256)__lasx_xvldi(0);
10710
- __m256i tmp1;
10711
- __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
10712
12118
 
10713
- mask_8f = __lsx_vreplgr2vr_b(0x8f);
10714
12119
  for (int ibl = 0; ibl < nb; ++ibl) {
10715
12120
  const uint8_t * qs = x[ibl].qs;
10716
12121
  const int8_t * q8 = y[ibl].qs;
10717
12122
  uint16_t sh = x[ibl].scales_h;
10718
12123
  __m256i sumi1 = __lasx_xvldi(0);
10719
12124
  __m256i sumi2 = __lasx_xvldi(0);
10720
- __m128i zero = __lsx_vldi(0);
10721
12125
  for (int ib = 0; ib < QK_K/32; ib += 2) {
10722
- const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
10723
- const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
12126
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
12127
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
10724
12128
  const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
10725
12129
  const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
10726
- tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
10727
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10728
- mask = __lsx_vsle_b(zero, tmp2);
10729
- tmp3 = __lsx_vand_v(tmp0, mask);
10730
- tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
10731
-
10732
- tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
10733
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10734
- mask = __lsx_vsle_b(zero, tmp2);
10735
- tmp4 = __lsx_vand_v(tmp0, mask);
10736
- tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
10737
-
10738
- const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
10739
-
10740
- tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
10741
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10742
- mask = __lsx_vsle_b(zero, tmp2);
10743
- tmp3 = __lsx_vand_v(tmp0, mask);
10744
- tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
10745
-
10746
- tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
10747
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10748
- mask = __lsx_vsle_b(zero, tmp2);
10749
- tmp4 = __lsx_vand_v(tmp0, mask);
10750
- tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
10751
-
10752
- const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
10753
-
12130
+ const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
12131
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
12132
+ const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
12133
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
10754
12134
  const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
10755
12135
  const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
10756
12136
  const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
10757
12137
  const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
10758
12138
  sh >>= 4;
10759
- __m256i tmp5, tmp6;
10760
- tmp1 = __lasx_xvreplgr2vr_h(ls1);
10761
- tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
10762
- tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
10763
- const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
10764
- tmp1 = __lasx_xvreplgr2vr_h(ls2);
10765
- tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
10766
- tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
10767
- const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
12139
+ const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
12140
+ const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
10768
12141
  sumi1 = __lasx_xvadd_w(p_1, sumi1);
10769
12142
  sumi2 = __lasx_xvadd_w(p_2, sumi2);
10770
12143
  }
@@ -10773,6 +12146,56 @@ void wsp_ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const vo
10773
12146
  }
10774
12147
 
10775
12148
  *s = hsum_float_8(accum);
12149
+ #elif defined(__VXE__) || defined(__VXE2__)
12150
+ const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
12151
+ const uint8x16_t v_m = vec_splat_u8(0x0F);
12152
+
12153
+ float sumf = 0;
12154
+
12155
+ for (int ibl = 0; ibl < nb; ++ibl) {
12156
+ const uint8_t * restrict q4 = x[ibl].qs;
12157
+ const int8_t * restrict q8 = y[ibl].qs;
12158
+
12159
+ uint16_t h = x[ibl].scales_h;
12160
+
12161
+ int sumi1 = 0, sumi2 = 0;
12162
+ for (int ib = 0; ib < QK_K/64; ++ib) {
12163
+ const uint8x16_t v_x0 = vec_xl(0 , q4);
12164
+ const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
12165
+ q4 += 32;
12166
+
12167
+ int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
12168
+ int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
12169
+ int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
12170
+ int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
12171
+
12172
+ v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
12173
+ v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
12174
+ v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
12175
+ v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
12176
+
12177
+ const int8x16_t v_y0 = vec_xl( 0, q8);
12178
+ const int8x16_t v_y1 = vec_xl(16, q8);
12179
+ const int8x16_t v_y2 = vec_xl(32, q8);
12180
+ const int8x16_t v_y3 = vec_xl(48, q8);
12181
+ q8 += 64;
12182
+
12183
+ int32x4_t vsumi0 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
12184
+ int32x4_t vsumi1 = wsp_ggml_vec_dot(wsp_ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
12185
+
12186
+ int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
12187
+ int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
12188
+
12189
+ h >>= 4;
12190
+
12191
+ sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
12192
+ sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
12193
+ }
12194
+
12195
+ sumf += WSP_GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
12196
+ }
12197
+
12198
+ *s = sumf;
10776
12199
 
10777
12200
  #else
10778
12201
  float sumf = 0;