@fugood/llama.node 0.3.12 → 0.3.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -0
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +13 -4
  21. package/src/llama.cpp/.github/workflows/build.yml +35 -3
  22. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  23. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  24. package/src/llama.cpp/common/CMakeLists.txt +20 -3
  25. package/src/llama.cpp/common/arg.cpp +180 -3
  26. package/src/llama.cpp/common/chat-template.hpp +21 -7
  27. package/src/llama.cpp/common/chat.cpp +220 -101
  28. package/src/llama.cpp/common/chat.hpp +3 -0
  29. package/src/llama.cpp/common/common.h +15 -7
  30. package/src/llama.cpp/common/llguidance.cpp +3 -3
  31. package/src/llama.cpp/common/log.cpp +1 -0
  32. package/src/llama.cpp/common/log.h +2 -1
  33. package/src/llama.cpp/common/minja.hpp +24 -9
  34. package/src/llama.cpp/common/sampling.cpp +52 -46
  35. package/src/llama.cpp/common/speculative.h +1 -1
  36. package/src/llama.cpp/docs/build.md +2 -2
  37. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -1
  38. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  39. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  40. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  41. package/src/llama.cpp/examples/run/run.cpp +5 -12
  42. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  43. package/src/llama.cpp/examples/server/httplib.h +381 -292
  44. package/src/llama.cpp/examples/server/server.cpp +58 -47
  45. package/src/llama.cpp/examples/server/utils.hpp +7 -5
  46. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  47. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  48. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  49. package/src/llama.cpp/ggml/include/ggml.h +1 -1
  50. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  51. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +6 -12
  52. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +852 -268
  53. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +200 -107
  54. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  56. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +2 -2
  57. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +26 -4
  58. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +6 -7
  59. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +812 -569
  60. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +25 -1
  61. package/src/llama.cpp/ggml/src/ggml.c +1 -1
  62. package/src/llama.cpp/include/llama.h +14 -10
  63. package/src/llama.cpp/src/llama-grammar.cpp +1 -1
  64. package/src/llama.cpp/src/llama-grammar.h +1 -1
  65. package/src/llama.cpp/src/llama-impl.h +6 -6
  66. package/src/llama.cpp/src/llama-kv-cache.h +1 -1
  67. package/src/llama.cpp/src/llama-mmap.h +1 -0
  68. package/src/llama.cpp/src/llama-model.cpp +1 -1
  69. package/src/llama.cpp/src/llama-sampling.cpp +131 -57
  70. package/src/llama.cpp/src/llama.cpp +7 -5
  71. package/src/llama.cpp/src/unicode.cpp +9 -2
  72. package/src/llama.cpp/tests/test-backend-ops.cpp +5 -5
  73. package/src/llama.cpp/tests/test-chat.cpp +237 -69
  74. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  75. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
@@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
501
501
  }
502
502
 
503
503
  static __m256i lasx_extu8_16(__m128i a) {
504
- __m128i zero = __lsx_vldi(0);
505
- __m128i vlo = __lsx_vilvl_b(zero, a);
506
- __m128i vhi = __lsx_vilvh_b(zero, a);
507
- return lasx_set_q(vhi, vlo);
504
+ return __lasx_vext2xv_hu_bu(____m256i(a));
508
505
  }
509
506
 
510
507
  static __m256i lasx_ext8_16(__m128i a) {
511
- __m128i sign = __lsx_vslti_b(a, 0);
512
- __m128i vlo = __lsx_vilvl_b(sign, a);
513
- __m128i vhi = __lsx_vilvh_b(sign, a);
514
- return lasx_set_q(vhi, vlo);
508
+ return __lasx_vext2xv_h_b(____m256i(a));
515
509
  }
516
510
 
517
511
  static __m256i lasx_ext16_32(__m128i a) {
518
- __m256i tmp1;
519
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
520
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
521
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
522
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
523
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
524
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
525
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
526
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
527
- return tmp1;
512
+ return __lasx_vext2xv_w_h(____m256i(a));
528
513
  }
529
514
 
530
515
  static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -577,6 +562,41 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
577
562
  return __lasx_xvpickev_b(tmp1, tmp);
578
563
  }
579
564
 
565
+ static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
566
+ __m256i tmp1, tmp2;
567
+ tmp1 = __lasx_xvmulwev_h_b(a, b);
568
+ tmp2 = __lasx_xvmulwod_h_b(a, b);
569
+ return __lasx_xvadd_h(tmp1, tmp2);
570
+ }
571
+
572
+ static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
573
+ switch (b) {
574
+ case 0: return __lasx_xvrepl128vei_h(a, 0);
575
+ case 1: return __lasx_xvrepl128vei_h(a, 1);
576
+ case 2: return __lasx_xvrepl128vei_h(a, 2);
577
+ case 3: return __lasx_xvrepl128vei_h(a, 3);
578
+ case 4: return __lasx_xvrepl128vei_h(a, 4);
579
+ case 5: return __lasx_xvrepl128vei_h(a, 5);
580
+ case 6: return __lasx_xvrepl128vei_h(a, 6);
581
+ case 7: return __lasx_xvrepl128vei_h(a, 7);
582
+ default: __builtin_unreachable();
583
+ }
584
+ }
585
+
586
+ static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
587
+ switch (b) {
588
+ case 0: return __lasx_xvandi_b(a, 1 << 0);
589
+ case 1: return __lasx_xvandi_b(a, 1 << 1);
590
+ case 2: return __lasx_xvandi_b(a, 1 << 2);
591
+ case 3: return __lasx_xvandi_b(a, 1 << 3);
592
+ case 4: return __lasx_xvandi_b(a, 1 << 4);
593
+ case 5: return __lasx_xvandi_b(a, 1 << 5);
594
+ case 6: return __lasx_xvandi_b(a, 1 << 6);
595
+ case 7: return __lasx_xvandi_b(a, 1 << 7);
596
+ default: __builtin_unreachable();
597
+ }
598
+ }
599
+
580
600
  // multiply int8_t, add results pairwise twice
581
601
  static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
582
602
  // Get absolute values of x vectors
@@ -592,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
592
612
  // horizontally add 8 floats
593
613
  static inline float hsum_float_8(const __m256 x) {
594
614
  __m128 res = lasx_extractf128(x, 1);
595
- ft_union tmp;
596
615
  res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
597
616
  res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
598
617
  res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
599
- tmp.i = __lsx_vpickve2gr_w(res, 0);
600
- return tmp.f;
618
+ return ((v4f32)res)[0];
601
619
  }
602
620
 
603
621
  // horizontally add 8 int32_t
@@ -673,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
673
691
 
674
692
  // multiply int8_t, add results pairwise twice and return as float vector
675
693
  static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
676
-
677
- // Get absolute values of x vectors
678
- const __m256i ax = __lasx_xvsigncov_b(x, x);
679
- // Sign the values of the y vectors
680
- const __m256i sy = __lasx_xvsigncov_b(x, y);
681
-
682
- return mul_sum_us8_pairs_float(ax, sy);
694
+ const __m256i dot = lasx_madd_h_b(x, y);
695
+ return sum_i16_pairs_float(dot);
683
696
  }
684
697
 
685
698
  static inline __m128i packNibbles( __m256i bytes ) {
@@ -759,7 +772,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
759
772
  y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
760
773
  }
761
774
  }
762
- #elif defined(__wasm_simd128__)
775
+ #elif defined __wasm_simd128__
763
776
  for (int i = 0; i < nb; i++) {
764
777
  v128_t srcv [8];
765
778
  v128_t asrcv[8];
@@ -939,7 +952,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
939
952
 
940
953
  #elif defined(__loongarch_asx)
941
954
  for (int i = 0; i < nb; i++) {
942
- ft_union fi;
943
955
  __m256 v0 = (__m256)__lasx_xvld( x , 0);
944
956
  __m256 v1 = (__m256)__lasx_xvld( x , 32);
945
957
  __m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -957,8 +969,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
957
969
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
958
970
  __m128 tmp = max4;
959
971
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
960
- fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
961
- const float max_scalar = fi.f;
972
+ const float max_scalar = ((v4f32)max4)[0];
962
973
 
963
974
  // Quantize these floats
964
975
  const float d = max_scalar / 127.f;
@@ -1049,7 +1060,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1049
1060
 
1050
1061
  y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
1051
1062
  }
1052
- #elif defined(__wasm_simd128__)
1063
+ #elif defined __wasm_simd128__
1053
1064
  for (int i = 0; i < nb; i++) {
1054
1065
  v128_t srcv [8];
1055
1066
  v128_t asrcv[8];
@@ -1263,7 +1274,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1263
1274
 
1264
1275
  #elif defined(__loongarch_asx)
1265
1276
  for (int i = 0; i < nb; i++) {
1266
- ft_union ft;
1267
1277
  __m256 v0 = (__m256)__lasx_xvld( x , 0 );
1268
1278
  __m256 v1 = (__m256)__lasx_xvld( x , 32 );
1269
1279
  __m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1281,8 +1291,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1281
1291
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
1282
1292
  __m128 tmp = max4;
1283
1293
  max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
1284
- ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
1285
- const float max_scalar = ft.f;
1294
+ const float max_scalar = ((v4f32)max4)[0];
1286
1295
 
1287
1296
  // Quantize these floats
1288
1297
  const float d = max_scalar / 127.f;
@@ -1665,7 +1674,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
1665
1674
  //===================================== Q8_K ==============================================
1666
1675
 
1667
1676
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
1677
+ #ifdef __wasm_simd128__
1678
+ assert(k % QK_K == 0);
1679
+ const int64_t nb = k / QK_K;
1680
+ block_q8_K * restrict yc = y; // Cast to proper type
1681
+
1682
+ for (int i = 0; i < nb; i++) {
1683
+ const float * x_block = x + i * QK_K;
1684
+
1685
+ v128_t min_vec = wasm_v128_load(x_block);
1686
+ v128_t max_vec = min_vec;
1687
+
1688
+ for (int j = 4; j < QK_K; j += 4) {
1689
+ v128_t x_vec = wasm_v128_load(x_block + j);
1690
+ max_vec = wasm_f32x4_pmax(max_vec, x_vec);
1691
+ min_vec = wasm_f32x4_pmin(min_vec, x_vec);
1692
+ }
1693
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
1694
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
1695
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
1696
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
1697
+ float max = wasm_f32x4_extract_lane(max_vec, 0);
1698
+ float min = wasm_f32x4_extract_lane(min_vec, 0);
1699
+ float amax = -min > max ? min : max;
1700
+
1701
+ if (amax == 0.0f) {
1702
+ yc[i].d = 0.0f;
1703
+ const v128_t zero = wasm_i8x16_splat(0);
1704
+ for (int j = 0; j < QK_K; j += 16) {
1705
+ wasm_v128_store(yc[i].qs + j, zero);
1706
+ }
1707
+ continue;
1708
+ }
1709
+
1710
+ const float iscale = -127.0f / amax;
1711
+ const v128_t scale_vec = wasm_f32x4_splat(iscale);
1712
+
1713
+ // Process 16 elements per iteration
1714
+ for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
1715
+ // Load and quantize 16 floats
1716
+ v128_t x0 = wasm_v128_load(x_block + j);
1717
+ v128_t x1 = wasm_v128_load(x_block + j + 4);
1718
+ v128_t x2 = wasm_v128_load(x_block + j + 8);
1719
+ v128_t x3 = wasm_v128_load(x_block + j + 12);
1720
+
1721
+ v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
1722
+ v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
1723
+ v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
1724
+ v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
1725
+
1726
+ // Convert to i32 with saturation
1727
+ v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
1728
+ v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
1729
+ v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
1730
+ v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
1731
+
1732
+ // Pack into 16 i8 values
1733
+ v128_t i8 = wasm_i8x16_narrow_i16x8(
1734
+ wasm_i16x8_narrow_i32x4(i0, i1),
1735
+ wasm_i16x8_narrow_i32x4(i2, i3)
1736
+ );
1737
+ wasm_v128_store(yc[i].qs + j, i8);
1738
+
1739
+ // Calculate bsums using SIMD
1740
+ v128_t sum16 = wasm_i16x8_add(
1741
+ wasm_i16x8_extend_low_i8x16(i8),
1742
+ wasm_i16x8_extend_high_i8x16(i8)
1743
+ );
1744
+ v128_t sum32 = wasm_i32x4_add(
1745
+ wasm_i32x4_extend_low_i16x8(sum16),
1746
+ wasm_i32x4_extend_high_i16x8(sum16)
1747
+ );
1748
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
1749
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
1750
+ yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
1751
+ }
1752
+
1753
+ yc[i].d = 1.0f / iscale;
1754
+ }
1755
+ #else
1668
1756
  quantize_row_q8_K_ref(x, y, k);
1757
+ #endif
1669
1758
  }
1670
1759
 
1671
1760
  //===================================== Dot products =================================
@@ -2023,6 +2112,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
2023
2112
  }
2024
2113
 
2025
2114
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2115
+ #elif defined __wasm_simd128__
2116
+ v128_t sumv = wasm_f32x4_splat(0.0f);
2117
+
2118
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2119
+ const v128_t s8b = wasm_i8x16_splat(0x8);
2120
+
2121
+ for (; ib + 1 < nb; ib += 2) {
2122
+ const block_q4_0 * restrict x0 = &x[ib];
2123
+ const block_q4_0 * restrict x1 = &x[ib + 1];
2124
+ const block_q8_0 * restrict y0 = &y[ib];
2125
+ const block_q8_0 * restrict y1 = &y[ib + 1];
2126
+
2127
+ // Load and process x0
2128
+ v128_t v0_0 = wasm_v128_load(x0->qs);
2129
+ v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2130
+ v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2131
+ v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2132
+ v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2133
+
2134
+ // Load y0 vectors
2135
+ v128_t y0_l = wasm_v128_load(y0->qs);
2136
+ v128_t y0_h = wasm_v128_load(y0->qs + 16);
2137
+
2138
+ // Extend to i16x8 and compute dot products
2139
+ v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
2140
+ v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
2141
+ v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
2142
+ v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
2143
+
2144
+ v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
2145
+ v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
2146
+ v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
2147
+ v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
2148
+
2149
+ v128_t dp0 = wasm_i32x4_add(
2150
+ wasm_i32x4_add(
2151
+ wasm_i32x4_dot_i16x8(dx0l, dy0ll),
2152
+ wasm_i32x4_dot_i16x8(dx0h, dy0lh)
2153
+ ),
2154
+ wasm_i32x4_add(
2155
+ wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
2156
+ wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
2157
+ )
2158
+ );
2159
+
2160
+ // Load and process x1
2161
+ v128_t v0_1 = wasm_v128_load(x1->qs);
2162
+ v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2163
+ v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2164
+ v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2165
+ v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2166
+
2167
+ // Load y1 vectors
2168
+ v128_t y1_l = wasm_v128_load(y1->qs);
2169
+ v128_t y1_h = wasm_v128_load(y1->qs + 16);
2170
+
2171
+ // Extend to i16x8 and compute dot products
2172
+ v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
2173
+ v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
2174
+ v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
2175
+ v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
2176
+
2177
+ v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
2178
+ v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
2179
+ v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
2180
+ v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
2181
+
2182
+ v128_t dp1 = wasm_i32x4_add(
2183
+ wasm_i32x4_add(
2184
+ wasm_i32x4_dot_i16x8(dx1l, dy1ll),
2185
+ wasm_i32x4_dot_i16x8(dx1h, dy1lh)
2186
+ ),
2187
+ wasm_i32x4_add(
2188
+ wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
2189
+ wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
2190
+ )
2191
+ );
2192
+
2193
+ // Accumulate results with scaling
2194
+ float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
2195
+ float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
2196
+
2197
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
2198
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
2199
+ }
2200
+
2201
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
2202
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
2026
2203
  #elif defined(__AVX2__)
2027
2204
  // Initialize accumulator with zeros
2028
2205
  __m256 acc = _mm256_setzero_ps();
@@ -2709,10 +2886,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
2709
2886
  }
2710
2887
 
2711
2888
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2712
- #elif defined(__wasm_simd128__)
2889
+ #elif defined __wasm_simd128__
2713
2890
  v128_t sumv = wasm_f32x4_splat(0.0f);
2714
2891
 
2715
- uint32_t qh;
2892
+ uint32_t qh_;
2716
2893
  uint64_t tmp[4];
2717
2894
 
2718
2895
  // TODO: check if unrolling this is better
@@ -2723,12 +2900,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
2723
2900
  const v128_t m4b = wasm_i8x16_splat(0x0F);
2724
2901
 
2725
2902
  // extract the 5th bit
2726
- memcpy(&qh, x0->qh, sizeof(qh));
2903
+ memcpy(&qh_, x0->qh, sizeof(qh_));
2727
2904
 
2728
- tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
2729
- tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
2730
- tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
2731
- tmp[3] = table_b2b_1[(qh >> 24) ];
2905
+ tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
2906
+ tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
2907
+ tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
2908
+ tmp[3] = table_b2b_1[(qh_ >> 24) ];
2732
2909
 
2733
2910
  const v128_t qhl = wasm_v128_load(tmp + 0);
2734
2911
  const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3070,12 +3247,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3070
3247
  }
3071
3248
 
3072
3249
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3073
- #elif defined(__wasm_simd128__)
3250
+ #elif defined __wasm_simd128__
3074
3251
  v128_t sumv = wasm_f32x4_splat(0.0f);
3075
3252
 
3076
3253
  float summs = 0.0f;
3077
3254
 
3078
- uint32_t qh;
3255
+ uint32_t qh_;
3079
3256
  uint64_t tmp[4];
3080
3257
 
3081
3258
  // TODO: check if unrolling this is better
@@ -3088,12 +3265,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
3088
3265
  const v128_t m4b = wasm_i8x16_splat(0x0F);
3089
3266
 
3090
3267
  // extract the 5th bit
3091
- memcpy(&qh, x0->qh, sizeof(qh));
3268
+ memcpy(&qh_, x0->qh, sizeof(qh_));
3092
3269
 
3093
- tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3094
- tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3095
- tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3096
- tmp[3] = table_b2b_0[(qh >> 24) ];
3270
+ tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
3271
+ tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
3272
+ tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
3273
+ tmp[3] = table_b2b_0[(qh_ >> 24) ];
3097
3274
 
3098
3275
  const v128_t qhl = wasm_v128_load(tmp + 0);
3099
3276
  const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3586,6 +3763,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3586
3763
  }
3587
3764
 
3588
3765
  sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3766
+ #elif defined __wasm_simd128__
3767
+ v128_t sumv = wasm_f32x4_splat(0.0f);
3768
+
3769
+ for (; ib < nb; ++ib) {
3770
+ const block_q8_0 * restrict x0 = &x[ib];
3771
+ const block_q8_0 * restrict y0 = &y[ib];
3772
+
3773
+ const v128_t x0_0 = wasm_v128_load(x0->qs);
3774
+ const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
3775
+ const v128_t y0_0 = wasm_v128_load(y0->qs);
3776
+ const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
3777
+
3778
+ // Extend 8-bit to 16-bit
3779
+ const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
3780
+ const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
3781
+ const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
3782
+ const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
3783
+
3784
+ const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
3785
+ const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
3786
+ const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
3787
+ const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
3788
+
3789
+ // Compute dot products
3790
+ const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
3791
+ const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
3792
+ const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
3793
+ const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
3794
+
3795
+ // Sum all dot products
3796
+ const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
3797
+
3798
+ // Convert to float and accumulate
3799
+ const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
3800
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
3801
+ }
3802
+
3803
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3804
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3589
3805
  #elif defined(__AVX2__)
3590
3806
  // Initialize accumulator with zeros
3591
3807
  __m256 acc = _mm256_setzero_ps();
@@ -4460,6 +4676,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
4460
4676
 
4461
4677
  *s = hsum_float_8(acc);
4462
4678
 
4679
+ #elif defined __wasm_simd128__
4680
+ float sumf = 0;
4681
+
4682
+ for (int i = 0; i < nb; ++i) {
4683
+ const uint8_t * q2 = x[i].qs;
4684
+ const int8_t * q8 = y[i].qs;
4685
+ const uint8_t * sc = x[i].scales;
4686
+
4687
+ // Vectorized summs calculation
4688
+ v128_t summs_vec = wasm_i32x4_splat(0);
4689
+ {
4690
+ v128_t sc_vec = wasm_v128_load(sc);
4691
+ v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
4692
+
4693
+ v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
4694
+ v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
4695
+
4696
+ v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
4697
+ v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
4698
+
4699
+ summs_vec = wasm_i32x4_add(
4700
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
4701
+ wasm_i32x4_dot_i16x8(sc_high, bsums2)),
4702
+ summs_vec
4703
+ );
4704
+
4705
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
4706
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
4707
+ }
4708
+ int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
4709
+
4710
+ // Vectorized isum calculation
4711
+ int32_t isum = 0;
4712
+ const uint8_t * sc_ptr = sc;
4713
+ const int k_iters = QK_K/128;
4714
+
4715
+ for (int k = 0; k < k_iters; ++k) {
4716
+ v128_t isum_vec = wasm_i32x4_splat(0);
4717
+ int shift = 0;
4718
+
4719
+ for (int j = 0; j < 4; ++j) {
4720
+ const int d0 = (sc_ptr[0] & 0xF);
4721
+ const int d1 = (sc_ptr[1] & 0xF);
4722
+ sc_ptr += 2;
4723
+
4724
+ // Process first 16 elements
4725
+ v128_t q2_0 = wasm_v128_load(q2);
4726
+ v128_t q8_0 = wasm_v128_load(q8);
4727
+ v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
4728
+ v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
4729
+
4730
+ // Process next 16 elements
4731
+ v128_t q2_1 = wasm_v128_load(q2 + 16);
4732
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
4733
+ v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
4734
+ v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
4735
+
4736
+ // Calculate dot products
4737
+ v128_t p0 = wasm_i32x4_dot_i16x8(
4738
+ wasm_i16x8_extend_low_i8x16(q8_0),
4739
+ wasm_i16x8_extend_low_i8x16(q2_bits_0)
4740
+ );
4741
+ v128_t p1 = wasm_i32x4_dot_i16x8(
4742
+ wasm_i16x8_extend_high_i8x16(q8_0),
4743
+ wasm_i16x8_extend_high_i8x16(q2_bits_0)
4744
+ );
4745
+ v128_t p2 = wasm_i32x4_dot_i16x8(
4746
+ wasm_i16x8_extend_low_i8x16(q8_1),
4747
+ wasm_i16x8_extend_low_i8x16(q2_bits_1)
4748
+ );
4749
+ v128_t p3 = wasm_i32x4_dot_i16x8(
4750
+ wasm_i16x8_extend_high_i8x16(q8_1),
4751
+ wasm_i16x8_extend_high_i8x16(q2_bits_1)
4752
+ );
4753
+
4754
+ // Accumulate scaled results
4755
+ v128_t scaled = wasm_i32x4_add(
4756
+ wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
4757
+ wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
4758
+ );
4759
+
4760
+ isum_vec = wasm_i32x4_add(isum_vec, scaled);
4761
+ q8 += 32;
4762
+ shift += 2;
4763
+ }
4764
+ q2 += 32;
4765
+
4766
+ // Horizontal sum of isum_vec
4767
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
4768
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
4769
+ isum += wasm_i32x4_extract_lane(isum_vec, 0);
4770
+ }
4771
+
4772
+ const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4773
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
4774
+ sumf += dall * isum - dmin * summs;
4775
+ }
4776
+
4777
+ *s = sumf;
4778
+
4463
4779
  #elif defined __riscv_v_intrinsic
4464
4780
 
4465
4781
  float sumf = 0;
@@ -4679,9 +4995,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
4679
4995
 
4680
4996
  #elif defined __loongarch_asx
4681
4997
 
4682
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
4683
- const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
4684
-
4685
4998
  __m256 acc = (__m256)__lasx_xvldi(0);
4686
4999
 
4687
5000
  for (int i = 0; i < nb; ++i) {
@@ -4692,18 +5005,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
4692
5005
  const uint8_t * restrict q2 = x[i].qs;
4693
5006
  const int8_t * restrict q8 = y[i].qs;
4694
5007
 
4695
- const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
4696
- const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
4697
- const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
4698
- const __m256i mins = lasx_ext8_16(mins8);
5008
+ const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
5009
+ const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
5010
+ const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
4699
5011
  const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
4700
5012
 
4701
5013
  acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
4702
5014
 
4703
- const __m256i all_scales = lasx_ext8_16(scales8);
4704
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
4705
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
4706
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5015
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5016
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
4707
5017
 
4708
5018
  __m256i sumi = __lasx_xvldi(0);
4709
5019
 
@@ -4716,20 +5026,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
4716
5026
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
4717
5027
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
4718
5028
 
4719
- const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
4720
- const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
4721
- const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
4722
- const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
5029
+ const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
5030
+ const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
5031
+ const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
5032
+ const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
4723
5033
 
4724
- __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
4725
- __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
4726
- __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
4727
- __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
5034
+ __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
5035
+ __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
5036
+ __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
5037
+ __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
4728
5038
 
4729
- p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
4730
- p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
4731
- p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
4732
- p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
5039
+ p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
5040
+ p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
5041
+ p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
5042
+ p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
4733
5043
 
4734
5044
  p0 = __lasx_xvadd_w(p0, p1);
4735
5045
  p2 = __lasx_xvadd_w(p2, p3);
@@ -5142,6 +5452,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5142
5452
 
5143
5453
  *s = hsum_float_8(acc);
5144
5454
 
5455
+ #elif defined __wasm_simd128__
5456
+ int8_t aux8[QK_K];
5457
+ float sums[8] = {0};
5458
+ uint32_t auxs[4];
5459
+
5460
+ float sumf = 0;
5461
+ for (int i = 0; i < nb; ++i) {
5462
+ const uint8_t * restrict q3 = x[i].qs;
5463
+ const uint8_t * restrict hm = x[i].hmask;
5464
+ const int8_t * restrict q8 = y[i].qs;
5465
+
5466
+ // Process blocks with SIMD
5467
+ int8_t * a = aux8;
5468
+ uint8_t m = 1;
5469
+ for (int j = 0; j < QK_K; j += 128) {
5470
+ for (int shift = 0; shift <= 6; shift += 2) {
5471
+ v128_t v_m = wasm_i8x16_splat(m);
5472
+ for (int l = 0; l < 32; l += 16) {
5473
+ v128_t v_q3 = wasm_v128_load(q3 + l);
5474
+ v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
5475
+ v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
5476
+
5477
+ v128_t v_hm = wasm_v128_load(hm + l);
5478
+ v128_t v_mask = wasm_v128_and(v_hm, v_m);
5479
+ v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
5480
+
5481
+ v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
5482
+ wasm_v128_store(a + l, v_low2);
5483
+ }
5484
+ a += 32;
5485
+ m <<= 1;
5486
+ }
5487
+ q3 += 32;
5488
+ }
5489
+
5490
+ // Extract scales
5491
+ memcpy(auxs, x[i].scales, 12);
5492
+ uint32_t tmp = auxs[2];
5493
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
5494
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
5495
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
5496
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
5497
+ const int8_t * scales = (const int8_t *)auxs;
5498
+
5499
+ // SIMD dot product with register accumulators
5500
+ v128_t v_acc0 = wasm_i32x4_splat(0);
5501
+ v128_t v_acc1 = wasm_i32x4_splat(0);
5502
+ a = aux8;
5503
+ for (int j = 0; j < QK_K/16; ++j) {
5504
+ const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
5505
+
5506
+ // Process 16 elements per iteration
5507
+ for (int k = 0; k < 2; ++k) {
5508
+ const v128_t v_q8 = wasm_i16x8_load8x8(q8);
5509
+ const v128_t v_a = wasm_i16x8_load8x8(a);
5510
+
5511
+ v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
5512
+ v_prod = wasm_i16x8_mul(v_prod, v_scale);
5513
+
5514
+ v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
5515
+ v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
5516
+
5517
+ q8 += 8;
5518
+ a += 8;
5519
+ }
5520
+ }
5521
+
5522
+ // Accumulate results
5523
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5524
+ const v128_t v_d = wasm_f32x4_splat(d);
5525
+ v128_t v_sum = wasm_f32x4_add(
5526
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
5527
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
5528
+ );
5529
+
5530
+ // Accumulate into sums vector
5531
+ wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
5532
+ }
5533
+
5534
+ // Horizontal sum
5535
+ v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
5536
+ sumf = wasm_f32x4_extract_lane(v_sum, 0) +
5537
+ wasm_f32x4_extract_lane(v_sum, 1) +
5538
+ wasm_f32x4_extract_lane(v_sum, 2) +
5539
+ wasm_f32x4_extract_lane(v_sum, 3);
5540
+
5541
+ *s = sumf;
5542
+
5145
5543
  #elif defined __riscv_v_intrinsic
5146
5544
 
5147
5545
  uint32_t aux[3];
@@ -5397,8 +5795,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5397
5795
 
5398
5796
  #elif defined __loongarch_asx
5399
5797
 
5400
- const __m256i m3 = __lasx_xvreplgr2vr_b(3);
5401
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
5402
5798
  const __m128i m32 = __lsx_vreplgr2vr_b(32);
5403
5799
 
5404
5800
  __m256 acc = (__m256)__lasx_xvldi(0);
@@ -5418,10 +5814,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5418
5814
  (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
5419
5815
  (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
5420
5816
  scales128 = __lsx_vsub_b(scales128, m32);
5421
- const __m256i all_scales = lasx_ext8_16(scales128);
5422
- const __m128i l_scales = lasx_extracti128(all_scales, 0);
5423
- const __m128i h_scales = lasx_extracti128(all_scales, 1);
5424
- const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
5817
+
5818
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
5819
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
5425
5820
 
5426
5821
  // high bit
5427
5822
  const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
@@ -5429,35 +5824,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5429
5824
  // integer accumulator
5430
5825
  __m256i sumi = __lasx_xvldi(0);
5431
5826
 
5432
- int bit = 0;
5433
- int is = 0;
5434
- __m256i xvbit;
5435
-
5436
-
5437
5827
  for (int j = 0; j < QK_K/128; ++j) {
5438
5828
  // load low 2 bits
5439
5829
  const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
5440
5830
 
5441
- xvbit = __lasx_xvreplgr2vr_h(bit);
5442
5831
  // prepare low and high bits
5443
- const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
5444
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5445
- ++bit;
5446
-
5447
- xvbit = __lasx_xvreplgr2vr_h(bit);
5448
- const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
5449
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5450
- ++bit;
5451
-
5452
- xvbit = __lasx_xvreplgr2vr_h(bit);
5453
- const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
5454
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5455
- ++bit;
5456
-
5457
- xvbit = __lasx_xvreplgr2vr_h(bit);
5458
- const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
5459
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
5460
- ++bit;
5832
+ const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
5833
+ const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
5834
+ const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
5835
+ const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
5836
+ const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
5837
+ const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
5838
+ const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
5839
+ const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
5840
+ const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
5841
+ const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
5842
+ const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
5843
+ const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
5461
5844
 
5462
5845
  // load Q8 quants
5463
5846
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
@@ -5465,29 +5848,16 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5465
5848
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
5466
5849
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
5467
5850
 
5468
- // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
5469
- // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
5470
- // and 2 if the high bit was set)
5471
- __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
5472
- __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
5473
- __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
5474
- __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
5475
-
5476
- __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
5477
- __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
5478
- __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
5479
- __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
5480
-
5481
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
5482
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
5483
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
5484
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
5851
+ __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
5852
+ __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
5853
+ __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
5854
+ __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
5485
5855
 
5486
5856
  // multiply with scales
5487
- p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
5488
- p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
5489
- p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
5490
- p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
5857
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
5858
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
5859
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
5860
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
5491
5861
 
5492
5862
  // accumulate
5493
5863
  p16_0 = __lasx_xvadd_w(p16_0, p16_1);
@@ -5495,7 +5865,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5495
5865
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
5496
5866
  }
5497
5867
  // multiply with block scale and accumulate
5498
- acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
5868
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
5499
5869
  }
5500
5870
 
5501
5871
  *s = hsum_float_8(acc);
@@ -5667,7 +6037,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5667
6037
  }
5668
6038
  }
5669
6039
  *s = sumf;
5670
- #elif __ARM_NEON
6040
+ #elif defined __ARM_NEON
5671
6041
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5672
6042
  const int32x4_t mzero = vdupq_n_s32(0);
5673
6043
 
@@ -5730,6 +6100,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5730
6100
 
5731
6101
  *s = sumf;
5732
6102
 
6103
+ #elif defined __wasm_simd128__
6104
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
6105
+ float sumf = 0;
6106
+
6107
+ for (int i = 0; i < nb; ++i) {
6108
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6109
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
6110
+
6111
+ const uint8_t * restrict q4 = x[i].qs;
6112
+ const int8_t * restrict q8 = y[i].qs;
6113
+
6114
+ // Process scales and mins
6115
+ memcpy(utmp, x[i].scales, 12);
6116
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6117
+ const uint32_t uaux = utmp[1] & kmask1;
6118
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6119
+ utmp[2] = uaux;
6120
+ utmp[0] &= kmask1;
6121
+
6122
+ // Sum mins * q8sums
6123
+ int32_t sumi = 0;
6124
+ const int16_t * restrict q8sums = y[i].bsums;
6125
+ const uint8_t * m = (const uint8_t *)&utmp[2];
6126
+ for (int j = 0; j < 16; j += 2) {
6127
+ sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
6128
+ }
6129
+ sumf -= dmin * sumi;
6130
+
6131
+ int32_t sumi1 = 0;
6132
+ int32_t sumi2 = 0;
6133
+
6134
+ for (int j = 0; j < QK_K/64; ++j) {
6135
+ // Load 64 4-bit weights (32 bytes)
6136
+ const v128_t q4x0 = wasm_v128_load(q4);
6137
+ const v128_t q4x1 = wasm_v128_load(q4 + 16);
6138
+ q4 += 32;
6139
+
6140
+ // Split into low/high nibbles
6141
+ const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
6142
+ const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
6143
+ const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
6144
+ const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
6145
+
6146
+ // Load 64 8-bit values (64 bytes)
6147
+ const v128_t q8x0 = wasm_v128_load(q8);
6148
+ const v128_t q8x1 = wasm_v128_load(q8 + 16);
6149
+ const v128_t q8x2 = wasm_v128_load(q8 + 32);
6150
+ const v128_t q8x3 = wasm_v128_load(q8 + 48);
6151
+ q8 += 64;
6152
+
6153
+ // Low nibble products
6154
+ v128_t vacc1 = wasm_i32x4_dot_i16x8(
6155
+ wasm_i16x8_extend_low_i8x16(q4l0),
6156
+ wasm_i16x8_extend_low_i8x16(q8x0)
6157
+ );
6158
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6159
+ wasm_i16x8_extend_high_i8x16(q4l0),
6160
+ wasm_i16x8_extend_high_i8x16(q8x0)
6161
+ ));
6162
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6163
+ wasm_i16x8_extend_low_i8x16(q4l1),
6164
+ wasm_i16x8_extend_low_i8x16(q8x1)
6165
+ ));
6166
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
6167
+ wasm_i16x8_extend_high_i8x16(q4l1),
6168
+ wasm_i16x8_extend_high_i8x16(q8x1)
6169
+ ));
6170
+
6171
+ // High nibble products
6172
+ v128_t vacc2 = wasm_i32x4_dot_i16x8(
6173
+ wasm_i16x8_extend_low_i8x16(q4h0),
6174
+ wasm_i16x8_extend_low_i8x16(q8x2)
6175
+ );
6176
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6177
+ wasm_i16x8_extend_high_i8x16(q4h0),
6178
+ wasm_i16x8_extend_high_i8x16(q8x2)
6179
+ ));
6180
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6181
+ wasm_i16x8_extend_low_i8x16(q4h1),
6182
+ wasm_i16x8_extend_low_i8x16(q8x3)
6183
+ ));
6184
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
6185
+ wasm_i16x8_extend_high_i8x16(q4h1),
6186
+ wasm_i16x8_extend_high_i8x16(q8x3)
6187
+ ));
6188
+
6189
+ // Accumulate scaled results
6190
+ int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
6191
+ wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
6192
+ sumi1 += vacc1_sum * scales[2*j];
6193
+
6194
+ int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
6195
+ wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
6196
+ sumi2 += vacc2_sum * scales[2*j+1];
6197
+ }
6198
+
6199
+ sumf += d * (sumi1 + sumi2);
6200
+ }
6201
+
6202
+ *s = sumf;
6203
+
5733
6204
  #elif defined __AVX2__
5734
6205
 
5735
6206
  const __m256i m4 = _mm256_set1_epi8(0xF);
@@ -6087,11 +6558,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6087
6558
  *s = vec_extract(vsumf0, 0);
6088
6559
 
6089
6560
  #elif defined __loongarch_asx
6090
- GGML_UNUSED(kmask1);
6091
- GGML_UNUSED(kmask2);
6092
- GGML_UNUSED(kmask3);
6093
-
6094
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
6095
6561
 
6096
6562
  __m256 acc = (__m256)__lasx_xvldi(0);
6097
6563
  __m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6111,33 +6577,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6111
6577
  const uint8_t * restrict q4 = x[i].qs;
6112
6578
  const int8_t * restrict q8 = y[i].qs;
6113
6579
 
6114
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
6580
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
6581
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
6582
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
6115
6583
 
6116
6584
  const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
6117
6585
  const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6118
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6586
+ const __m128i prod = lsx_madd_h(mins128, q8s);
6119
6587
  acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
6120
6588
 
6121
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6122
- const __m256i scales = lasx_insertf128(sc128, sc128);
6589
+ const __m256i scales = lasx_insertf128(scales128, scales128);
6123
6590
 
6124
6591
  __m256i sumi = __lasx_xvldi(0);
6125
6592
 
6126
6593
  for (int j = 0; j < QK_K/64; ++j) {
6127
6594
 
6128
- const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6129
- const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
6595
+ const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
6596
+ const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
6130
6597
 
6131
6598
  const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
6132
- const __m256i q4l = __lasx_xvand_v(q4bits, m4);
6133
- const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
6599
+ const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
6600
+ const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
6134
6601
 
6135
6602
  const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6136
- __m256i p16l = lasx_maddubs_h(q4l, q8l);
6603
+ __m256i p16l = lasx_madd_h_b(q4l, q8l);
6137
6604
  p16l = lasx_madd_h(scale_l, p16l);
6138
6605
 
6139
6606
  const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6140
- __m256i p16h = lasx_maddubs_h(q4h, q8h);
6607
+ __m256i p16h = lasx_madd_h_b(q4h, q8h);
6141
6608
  p16h = lasx_madd_h(scale_h, p16h);
6142
6609
  const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
6143
6610
 
@@ -6154,9 +6621,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6154
6621
  acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
6155
6622
 
6156
6623
 
6157
- ft_union fi;
6158
- fi.i = __lsx_vpickve2gr_w(acc_m, 0);
6159
- *s = hsum_float_8(acc) + fi.f ;
6624
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
6160
6625
  #else
6161
6626
 
6162
6627
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6482,6 +6947,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6482
6947
 
6483
6948
  *s = hsum_float_8(acc) + summs;
6484
6949
 
6950
+ #elif defined __wasm_simd128__
6951
+ //const uint8_t * scales = (const uint8_t*)&utmp[0];
6952
+ float sumf = 0;
6953
+
6954
+ for (int i = 0; i < nb; ++i) {
6955
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6956
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
6957
+
6958
+ const uint8_t * restrict q5 = x[i].qs;
6959
+ const uint8_t * restrict qh = x[i].qh;
6960
+ const int8_t * restrict q8 = y[i].qs;
6961
+
6962
+ // Process scales and mins
6963
+ memcpy(utmp, x[i].scales, 12);
6964
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6965
+ const uint32_t uaux = utmp[1] & kmask1;
6966
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6967
+ utmp[2] = uaux;
6968
+ utmp[0] &= kmask1;
6969
+
6970
+ // Sum mins * q8sums
6971
+ int32_t sumi_mins = 0;
6972
+ const int16_t * restrict q8sums = y[i].bsums;
6973
+ const uint8_t * m = (const uint8_t *)&utmp[2];
6974
+ for (int j = 0; j < 16; j += 2) {
6975
+ sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
6976
+ }
6977
+ sumf -= dmin * sumi_mins; // Correct subtraction
6978
+
6979
+ v128_t qh0 = wasm_v128_load(qh);
6980
+ v128_t qh1 = wasm_v128_load(qh + 16);
6981
+ const uint8_t * sc = (const uint8_t *)utmp;
6982
+
6983
+ int32_t sumi = 0;
6984
+
6985
+ for (int j = 0; j < QK_K/64; ++j) {
6986
+ const int shift = j * 2;
6987
+ v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
6988
+ v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
6989
+
6990
+ v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
6991
+ v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
6992
+ v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
6993
+ v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
6994
+
6995
+ v128_t q5_0 = wasm_v128_load(q5);
6996
+ v128_t q5_1 = wasm_v128_load(q5 + 16);
6997
+ q5 += 32;
6998
+
6999
+ v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
7000
+ v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
7001
+ v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
7002
+ v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
7003
+
7004
+ v128_t q8_0 = wasm_v128_load(q8);
7005
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
7006
+ v128_t q8_2 = wasm_v128_load(q8 + 32);
7007
+ v128_t q8_3 = wasm_v128_load(q8 + 48);
7008
+ q8 += 64;
7009
+
7010
+ // Process low quants
7011
+ v128_t pl0 = wasm_i32x4_dot_i16x8(
7012
+ wasm_i16x8_extend_low_i8x16(q5l_0),
7013
+ wasm_i16x8_extend_low_i8x16(q8_0)
7014
+ );
7015
+ pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
7016
+ wasm_i16x8_extend_high_i8x16(q5l_0),
7017
+ wasm_i16x8_extend_high_i8x16(q8_0)
7018
+ ));
7019
+ v128_t pl1 = wasm_i32x4_dot_i16x8(
7020
+ wasm_i16x8_extend_low_i8x16(q5l_1),
7021
+ wasm_i16x8_extend_low_i8x16(q8_1)
7022
+ );
7023
+ pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
7024
+ wasm_i16x8_extend_high_i8x16(q5l_1),
7025
+ wasm_i16x8_extend_high_i8x16(q8_1)
7026
+ ));
7027
+ v128_t sum_low = wasm_i32x4_add(pl0, pl1);
7028
+
7029
+ // Process high quants
7030
+ v128_t ph0 = wasm_i32x4_dot_i16x8(
7031
+ wasm_i16x8_extend_low_i8x16(q5h_0),
7032
+ wasm_i16x8_extend_low_i8x16(q8_2)
7033
+ );
7034
+ ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
7035
+ wasm_i16x8_extend_high_i8x16(q5h_0),
7036
+ wasm_i16x8_extend_high_i8x16(q8_2)
7037
+ ));
7038
+ v128_t ph1 = wasm_i32x4_dot_i16x8(
7039
+ wasm_i16x8_extend_low_i8x16(q5h_1),
7040
+ wasm_i16x8_extend_low_i8x16(q8_3)
7041
+ );
7042
+ ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
7043
+ wasm_i16x8_extend_high_i8x16(q5h_1),
7044
+ wasm_i16x8_extend_high_i8x16(q8_3)
7045
+ ));
7046
+ v128_t sum_high = wasm_i32x4_add(ph0, ph1);
7047
+
7048
+ // Accumulate with scale factors
7049
+ int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
7050
+ wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
7051
+ int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
7052
+ wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
7053
+
7054
+ sumi += sl * sc[2*j] + sh * sc[2*j+1];
7055
+ }
7056
+
7057
+ sumf += d * sumi;
7058
+ }
7059
+
7060
+ *s = sumf;
7061
+
6485
7062
  #elif defined __riscv_v_intrinsic
6486
7063
 
6487
7064
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6704,19 +7281,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6704
7281
  *s = vec_extract(vsumf0, 0);
6705
7282
 
6706
7283
  #elif defined __loongarch_asx
6707
- GGML_UNUSED(kmask1);
6708
- GGML_UNUSED(kmask2);
6709
- GGML_UNUSED(kmask3);
6710
-
6711
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
6712
- const __m128i mzero = __lsx_vldi(0);
6713
- const __m256i mone = __lasx_xvreplgr2vr_b(1);
6714
7284
 
6715
7285
  __m256 acc = (__m256)__lasx_xvldi(0);
7286
+ __m128 acc_m = (__m128)__lsx_vldi(0);
6716
7287
 
6717
- float summs = 0.f;
6718
-
6719
- for (int i = 0; i < nb; ++i) {
7288
+ for (int i = 0; i < nb; ++i) {
6720
7289
 
6721
7290
  const uint8_t * restrict q5 = x[i].qs;
6722
7291
  const int8_t * restrict q8 = y[i].qs;
@@ -6731,49 +7300,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6731
7300
  utmp[2] = uaux;
6732
7301
  utmp[0] &= kmask1;
6733
7302
 
6734
- const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
7303
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
7304
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
7305
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
6735
7306
 
6736
7307
  const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
6737
7308
  const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
6738
- const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
6739
- const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
6740
- summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check
7309
+ const __m128i prod = lsx_madd_h(mins128, q8s);
7310
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
6741
7311
 
6742
- const __m128i sc128 = lasx_extracti128(mins_and_scales, 0);
6743
- const __m256i scales = lasx_insertf128(sc128, sc128);
7312
+ const __m256i scales = lasx_insertf128(scales128, scales128);
6744
7313
 
6745
7314
  const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
6746
- __m256i hmask = mone;
6747
7315
 
6748
7316
  __m256i sumi = __lasx_xvldi(0);
6749
7317
 
6750
- int bit = 0;
6751
- __m256i xvbit;
6752
-
6753
7318
  for (int j = 0; j < QK_K/64; ++j) {
6754
7319
 
6755
- const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
6756
- const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
7320
+ const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
7321
+ const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
6757
7322
 
6758
7323
  const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
6759
7324
 
6760
- xvbit = __lasx_xvreplgr2vr_h(bit++);
6761
- const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
6762
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
6763
- const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
6764
- hmask = __lasx_xvslli_h(hmask, 1);
6765
-
6766
- xvbit = __lasx_xvreplgr2vr_h(bit++);
6767
- const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
6768
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
6769
- const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
6770
- hmask = __lasx_xvslli_h(hmask, 1);
7325
+ const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
7326
+ const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
7327
+ const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
7328
+ const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
7329
+ const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
7330
+ const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
6771
7331
 
6772
7332
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6773
7333
  const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
6774
7334
 
6775
- __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
6776
- __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
7335
+ __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
7336
+ __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
6777
7337
 
6778
7338
  p16_0 = lasx_madd_h(scale_0, p16_0);
6779
7339
  p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -6787,7 +7347,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6787
7347
 
6788
7348
  }
6789
7349
 
6790
- *s = hsum_float_8(acc) + summs;
7350
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
7351
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
7352
+
7353
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
6791
7354
 
6792
7355
  #else
6793
7356
 
@@ -7145,6 +7708,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7145
7708
 
7146
7709
  *s = hsum_float_8(acc);
7147
7710
 
7711
+ #elif defined __wasm_simd128__
7712
+ int8_t aux8[QK_K] __attribute__((aligned(16)));
7713
+ int32_t aux32[8] __attribute__((aligned(16))) = {0};
7714
+ float sums[8] __attribute__((aligned(16))) = {0};
7715
+
7716
+ for (int i = 0; i < nb; ++i) {
7717
+ // Unpack 6-bit quantized data into aux8 (unchanged)
7718
+ const uint8_t * restrict q4 = x[i].ql;
7719
+ const uint8_t * restrict qh = x[i].qh;
7720
+ int8_t * a = aux8;
7721
+ for (int j = 0; j < QK_K; j += 128) {
7722
+ for (int l = 0; l < 32; ++l) {
7723
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
7724
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
7725
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
7726
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
7727
+ }
7728
+ a += 128;
7729
+ q4 += 64;
7730
+ qh += 32;
7731
+ }
7732
+
7733
+ const int8_t * restrict a_ptr = aux8;
7734
+ const int8_t * restrict q8 = y[i].qs;
7735
+ v128_t acc0 = wasm_i32x4_splat(0);
7736
+ v128_t acc1 = wasm_i32x4_splat(0);
7737
+
7738
+ for (int j = 0; j < QK_K/16; ++j) {
7739
+ const int scale = x[i].scales[j];
7740
+ const v128_t vscale = wasm_i32x4_splat(scale);
7741
+
7742
+ // Load 16 elements from a and q8
7743
+ const v128_t a_vec = wasm_v128_load(a_ptr);
7744
+ const v128_t q8_vec = wasm_v128_load(q8);
7745
+
7746
+ // Process low 8 elements
7747
+ v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
7748
+ v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
7749
+ v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
7750
+ v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
7751
+ v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
7752
+
7753
+ // Process high 8 elements
7754
+ v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
7755
+ v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
7756
+ v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
7757
+ v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
7758
+ v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
7759
+
7760
+ // Scale and accumulate
7761
+ prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
7762
+ prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
7763
+ prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
7764
+ prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
7765
+
7766
+ acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
7767
+ acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
7768
+
7769
+ a_ptr += 16;
7770
+ q8 += 16;
7771
+ }
7772
+
7773
+ // Store accumulated results
7774
+ wasm_v128_store(&aux32[0], acc0);
7775
+ wasm_v128_store(&aux32[4], acc1);
7776
+
7777
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7778
+ for (int l = 0; l < 8; ++l) {
7779
+ sums[l] += d * aux32[l];
7780
+ }
7781
+ }
7782
+
7783
+ // Sum final results
7784
+ float sumf = 0;
7785
+ for (int l = 0; l < 8; ++l) {
7786
+ sumf += sums[l];
7787
+ }
7788
+ *s = sumf;
7789
+
7148
7790
  #elif defined __riscv_v_intrinsic
7149
7791
 
7150
7792
  float sumf = 0;
@@ -7369,8 +8011,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7369
8011
 
7370
8012
  #elif defined __loongarch_asx
7371
8013
 
7372
- const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7373
- const __m256i m2 = __lasx_xvreplgr2vr_b(3);
7374
8014
  const __m256i m32s = __lasx_xvreplgr2vr_b(32);
7375
8015
 
7376
8016
  __m256 acc = (__m256)__lasx_xvldi(0);
@@ -7383,58 +8023,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7383
8023
  const uint8_t * restrict qh = x[i].qh;
7384
8024
  const int8_t * restrict q8 = y[i].qs;
7385
8025
 
7386
- const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
8026
+ const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
8027
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
8028
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
7387
8029
 
7388
8030
  __m256i sumi = __lasx_xvldi(0);
7389
8031
 
7390
- int is = 0;
7391
-
7392
8032
  for (int j = 0; j < QK_K/128; ++j) {
7393
8033
 
7394
- const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
7395
- const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
7396
- const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
7397
- const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
7398
- is += 4;
7399
-
7400
8034
  const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
7401
8035
  const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
7402
8036
  const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
7403
8037
 
7404
- const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
7405
- const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
7406
- const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
7407
- const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
8038
+ const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
8039
+ const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
8040
+ const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
8041
+ const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
7408
8042
 
7409
- const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
7410
- const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
7411
- const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
7412
- const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
8043
+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
8044
+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
8045
+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
8046
+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
7413
8047
 
7414
8048
  const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7415
8049
  const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7416
8050
  const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7417
8051
  const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
7418
8052
 
7419
- __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
7420
- __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
7421
- __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
7422
- __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
7423
-
7424
- __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
7425
- __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
7426
- __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
7427
- __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
7428
-
7429
- p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
7430
- p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
7431
- p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
7432
- p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
8053
+ __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
8054
+ __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
8055
+ __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
8056
+ __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
7433
8057
 
7434
- p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
7435
- p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
7436
- p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
7437
- p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
8058
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
8059
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
8060
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
8061
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
7438
8062
 
7439
8063
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
7440
8064
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
@@ -9759,13 +10383,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9759
10383
  }
9760
10384
  #elif defined(__loongarch_asx)
9761
10385
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
9762
- const __m256i ax = __lasx_xvsigncov_b(x, x);
9763
- const __m256i sy = __lasx_xvsigncov_b(x, y);
9764
- __m256i tmp1, tmp2, tmp3;
9765
- tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
9766
- tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
9767
- tmp3 = __lasx_xvadd_h(tmp1, tmp2);
9768
- return __lasx_xvsat_h(tmp3, 15);
10386
+ const __m256i a = __lasx_xvmulwev_h_b(x, y);
10387
+ const __m256i b = __lasx_xvmulwod_h_b(x, y);
10388
+ return __lasx_xvadd_h(a, b);
9769
10389
  }
9770
10390
  #endif
9771
10391
 
@@ -10815,67 +11435,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
10815
11435
  #elif defined(__loongarch_asx)
10816
11436
 
10817
11437
  const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
10818
- const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
10819
11438
 
10820
11439
  __m256 accum = (__m256)__lasx_xvldi(0);
10821
- __m256i tmp1;
10822
- __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
10823
11440
 
10824
- mask_8f = __lsx_vreplgr2vr_b(0x8f);
10825
11441
  for (int ibl = 0; ibl < nb; ++ibl) {
10826
11442
  const uint8_t * qs = x[ibl].qs;
10827
11443
  const int8_t * q8 = y[ibl].qs;
10828
11444
  uint16_t sh = x[ibl].scales_h;
10829
11445
  __m256i sumi1 = __lasx_xvldi(0);
10830
11446
  __m256i sumi2 = __lasx_xvldi(0);
10831
- __m128i zero = __lsx_vldi(0);
10832
11447
  for (int ib = 0; ib < QK_K/32; ib += 2) {
10833
- const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
10834
- const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11448
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
11449
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
10835
11450
  const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
10836
11451
  const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
10837
- tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
10838
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10839
- mask = __lsx_vsle_b(zero, tmp2);
10840
- tmp3 = __lsx_vand_v(tmp0, mask);
10841
- tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
10842
-
10843
- tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
10844
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10845
- mask = __lsx_vsle_b(zero, tmp2);
10846
- tmp4 = __lsx_vand_v(tmp0, mask);
10847
- tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
10848
-
10849
- const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
10850
-
10851
- tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
10852
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10853
- mask = __lsx_vsle_b(zero, tmp2);
10854
- tmp3 = __lsx_vand_v(tmp0, mask);
10855
- tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
10856
-
10857
- tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
10858
- tmp0 = __lsx_vori_b(tmp2, 0x10);
10859
- mask = __lsx_vsle_b(zero, tmp2);
10860
- tmp4 = __lsx_vand_v(tmp0, mask);
10861
- tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
10862
-
10863
- const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
10864
-
11452
+ const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
11453
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
11454
+ const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
11455
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
10865
11456
  const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
10866
11457
  const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
10867
11458
  const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
10868
11459
  const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
10869
11460
  sh >>= 4;
10870
- __m256i tmp5, tmp6;
10871
- tmp1 = __lasx_xvreplgr2vr_h(ls1);
10872
- tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
10873
- tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
10874
- const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
10875
- tmp1 = __lasx_xvreplgr2vr_h(ls2);
10876
- tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
10877
- tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
10878
- const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
11461
+ const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
11462
+ const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
10879
11463
  sumi1 = __lasx_xvadd_w(p_1, sumi1);
10880
11464
  sumi2 = __lasx_xvadd_w(p_2, sumi2);
10881
11465
  }