@fugood/llama.node 1.0.0-beta.5 → 1.0.0-beta.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. package/lib/binding.ts +3 -1
  2. package/lib/index.js +2 -0
  3. package/lib/index.ts +3 -1
  4. package/package.json +14 -14
  5. package/scripts/llama.cpp.patch +27 -26
  6. package/src/EmbeddingWorker.cpp +1 -1
  7. package/src/LlamaCompletionWorker.cpp +28 -7
  8. package/src/LlamaCompletionWorker.h +4 -0
  9. package/src/LlamaContext.cpp +14 -17
  10. package/src/common.hpp +7 -6
  11. package/src/llama.cpp/CMakeLists.txt +15 -4
  12. package/src/llama.cpp/common/CMakeLists.txt +15 -24
  13. package/src/llama.cpp/common/arg.cpp +172 -110
  14. package/src/llama.cpp/common/chat-parser.cpp +385 -0
  15. package/src/llama.cpp/common/chat-parser.h +120 -0
  16. package/src/llama.cpp/common/chat.cpp +726 -596
  17. package/src/llama.cpp/common/chat.h +74 -8
  18. package/src/llama.cpp/common/common.cpp +56 -38
  19. package/src/llama.cpp/common/common.h +9 -3
  20. package/src/llama.cpp/common/json-partial.cpp +256 -0
  21. package/src/llama.cpp/common/json-partial.h +38 -0
  22. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  23. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
  24. package/src/llama.cpp/common/sampling.cpp +7 -8
  25. package/src/llama.cpp/common/speculative.cpp +6 -4
  26. package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
  27. package/src/llama.cpp/ggml/include/ggml.h +22 -3
  28. package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
  29. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
  30. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  31. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  32. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  39. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  40. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  42. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  43. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  44. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  45. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
  47. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
  48. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  49. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  50. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  51. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  52. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  53. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
  54. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  55. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  56. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  57. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
  58. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  59. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
  60. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  61. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  62. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
  63. package/src/llama.cpp/include/llama.h +145 -40
  64. package/src/llama.cpp/src/CMakeLists.txt +5 -1
  65. package/src/llama.cpp/src/llama-arch.cpp +99 -3
  66. package/src/llama.cpp/src/llama-arch.h +10 -1
  67. package/src/llama.cpp/src/llama-batch.cpp +728 -272
  68. package/src/llama.cpp/src/llama-batch.h +112 -54
  69. package/src/llama.cpp/src/llama-chat.cpp +19 -2
  70. package/src/llama.cpp/src/llama-chat.h +1 -0
  71. package/src/llama.cpp/src/llama-context.cpp +525 -339
  72. package/src/llama.cpp/src/llama-context.h +38 -17
  73. package/src/llama.cpp/src/llama-cparams.cpp +4 -0
  74. package/src/llama.cpp/src/llama-cparams.h +2 -0
  75. package/src/llama.cpp/src/llama-grammar.cpp +12 -2
  76. package/src/llama.cpp/src/llama-graph.cpp +413 -353
  77. package/src/llama.cpp/src/llama-graph.h +112 -56
  78. package/src/llama.cpp/src/llama-hparams.cpp +10 -2
  79. package/src/llama.cpp/src/llama-hparams.h +13 -2
  80. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
  81. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
  82. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
  83. package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
  84. package/src/llama.cpp/src/llama-kv-cells.h +415 -0
  85. package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  86. package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
  87. package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
  88. package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
  89. package/src/llama.cpp/src/llama-memory.cpp +41 -0
  90. package/src/llama.cpp/src/llama-memory.h +86 -5
  91. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  92. package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
  93. package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
  94. package/src/llama.cpp/src/llama-model.cpp +1137 -528
  95. package/src/llama.cpp/src/llama-model.h +4 -0
  96. package/src/llama.cpp/src/llama-quant.cpp +2 -1
  97. package/src/llama.cpp/src/llama-sampling.cpp +2 -2
  98. package/src/llama.cpp/src/llama-vocab.cpp +69 -32
  99. package/src/llama.cpp/src/llama-vocab.h +1 -0
  100. package/src/llama.cpp/src/llama.cpp +11 -7
  101. package/src/llama.cpp/src/unicode.cpp +5 -0
  102. package/src/tts_utils.h +1 -1
  103. package/src/llama.cpp/common/json.hpp +0 -24766
  104. package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
  105. package/src/llama.cpp/common/minja/minja.hpp +0 -2974
  106. package/src/llama.cpp/common/stb_image.h +0 -7988
  107. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  108. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
  109. package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
  110. package/src/llama.cpp/src/llama-kv-cache.h +0 -515
  111. /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  112. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  113. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -0,0 +1,1480 @@
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+
7
+ #include "../../quants.h"
8
+ #include "../../ggml-cpu-impl.h"
9
+
10
+ #include <math.h>
11
+ #include <string.h>
12
+ #include <assert.h>
13
+ #include <float.h>
14
+ #include <stdlib.h> // for qsort
15
+ #include <stdio.h> // for GGML_ASSERT
16
+
17
+ #define GROUP_MAX_EPS 1e-15f
18
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
20
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
+
23
+ #define UNUSED GGML_UNUSED
24
+
25
+ #if defined(__wasm_simd128__)
26
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
27
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
28
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
29
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
30
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
31
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
32
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
33
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
34
+
35
+ // precomputed tables for expanding 8bits to 8 bytes:
36
+ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
37
+ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
38
+ #endif
39
+
40
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
41
+ assert(QK8_0 == 32);
42
+ assert(k % QK8_0 == 0);
43
+ const int nb = k / QK8_0;
44
+
45
+ block_q8_0 * GGML_RESTRICT y = vy;
46
+
47
+ #if defined __wasm_simd128__
48
+ for (int i = 0; i < nb; i++) {
49
+ v128_t srcv [8];
50
+ v128_t asrcv[8];
51
+ v128_t amaxv[8];
52
+
53
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
54
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
55
+
56
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
57
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
58
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
59
+
60
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
61
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
62
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
63
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
64
+
65
+ const float d = amax / ((1 << 7) - 1);
66
+ const float id = d ? 1.0f/d : 0.0f;
67
+
68
+ y[i].d = GGML_FP32_TO_FP16(d);
69
+
70
+ for (int j = 0; j < 8; j++) {
71
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
72
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
73
+
74
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
75
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
76
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
77
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
78
+ }
79
+ }
80
+ #else
81
+ GGML_UNUSED(nb);
82
+ // scalar
83
+ quantize_row_q8_0_ref(x, y, k);
84
+ #endif
85
+ }
86
+
87
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
88
+ assert(k % QK8_1 == 0);
89
+ const int nb = k / QK8_1;
90
+
91
+ block_q8_1 * GGML_RESTRICT y = vy;
92
+ #if defined __wasm_simd128__
93
+ for (int i = 0; i < nb; i++) {
94
+ v128_t srcv [8];
95
+ v128_t asrcv[8];
96
+ v128_t amaxv[8];
97
+
98
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
99
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
100
+
101
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
102
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
103
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
104
+
105
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
106
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
107
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
108
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
109
+
110
+ const float d = amax / ((1 << 7) - 1);
111
+ const float id = d ? 1.0f/d : 0.0f;
112
+
113
+ y[i].d = GGML_FP32_TO_FP16(d);
114
+
115
+ v128_t accv = wasm_i32x4_splat(0);
116
+
117
+ for (int j = 0; j < 8; j++) {
118
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
119
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
120
+
121
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
122
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
123
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
124
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
125
+
126
+ accv = wasm_i32x4_add(accv, vi);
127
+ }
128
+
129
+ y[i].s = GGML_FP32_TO_FP16(
130
+ d * (wasm_i32x4_extract_lane(accv, 0) +
131
+ wasm_i32x4_extract_lane(accv, 1) +
132
+ wasm_i32x4_extract_lane(accv, 2) +
133
+ wasm_i32x4_extract_lane(accv, 3)));
134
+ }
135
+ #else
136
+ GGML_UNUSED(nb);
137
+ // scalar
138
+ quantize_row_q8_1_ref(x, y, k);
139
+ #endif
140
+ }
141
+
142
+ //===================================== Q8_K ==============================================
143
+
144
+ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
145
+ #ifdef __wasm_simd128__
146
+ assert(k % QK_K == 0);
147
+ const int64_t nb = k / QK_K;
148
+ block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type
149
+
150
+ for (int i = 0; i < nb; i++) {
151
+ const float * x_block = x + i * QK_K;
152
+
153
+ v128_t min_vec = wasm_v128_load(x_block);
154
+ v128_t max_vec = min_vec;
155
+
156
+ for (int j = 4; j < QK_K; j += 4) {
157
+ v128_t x_vec = wasm_v128_load(x_block + j);
158
+ max_vec = wasm_f32x4_pmax(max_vec, x_vec);
159
+ min_vec = wasm_f32x4_pmin(min_vec, x_vec);
160
+ }
161
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
162
+ max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
163
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
164
+ min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
165
+ float max = wasm_f32x4_extract_lane(max_vec, 0);
166
+ float min = wasm_f32x4_extract_lane(min_vec, 0);
167
+ float amax = -min > max ? min : max;
168
+
169
+ if (amax == 0.0f) {
170
+ yc[i].d = 0.0f;
171
+ const v128_t zero = wasm_i8x16_splat(0);
172
+ for (int j = 0; j < QK_K; j += 16) {
173
+ wasm_v128_store(yc[i].qs + j, zero);
174
+ }
175
+ continue;
176
+ }
177
+
178
+ const float iscale = -127.0f / amax;
179
+ const v128_t scale_vec = wasm_f32x4_splat(iscale);
180
+
181
+ // Process 16 elements per iteration
182
+ for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
183
+ // Load and quantize 16 floats
184
+ v128_t x0 = wasm_v128_load(x_block + j);
185
+ v128_t x1 = wasm_v128_load(x_block + j + 4);
186
+ v128_t x2 = wasm_v128_load(x_block + j + 8);
187
+ v128_t x3 = wasm_v128_load(x_block + j + 12);
188
+
189
+ v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
190
+ v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
191
+ v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
192
+ v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
193
+
194
+ // Convert to i32 with saturation
195
+ v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
196
+ v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
197
+ v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
198
+ v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
199
+
200
+ // Pack into 16 i8 values
201
+ v128_t i8 = wasm_i8x16_narrow_i16x8(
202
+ wasm_i16x8_narrow_i32x4(i0, i1),
203
+ wasm_i16x8_narrow_i32x4(i2, i3)
204
+ );
205
+ wasm_v128_store(yc[i].qs + j, i8);
206
+
207
+ // Calculate bsums using SIMD
208
+ v128_t sum16 = wasm_i16x8_add(
209
+ wasm_i16x8_extend_low_i8x16(i8),
210
+ wasm_i16x8_extend_high_i8x16(i8)
211
+ );
212
+ v128_t sum32 = wasm_i32x4_add(
213
+ wasm_i32x4_extend_low_i16x8(sum16),
214
+ wasm_i32x4_extend_high_i16x8(sum16)
215
+ );
216
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
217
+ sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
218
+ yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
219
+ }
220
+
221
+ yc[i].d = 1.0f / iscale;
222
+ }
223
+ #else
224
+ quantize_row_q8_K_ref(x, y, k);
225
+ #endif
226
+ }
227
+
228
+
229
+ //===================================== Dot products =================================
230
+
231
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
232
+ const int qk = QK8_0;
233
+ const int nb = n / qk;
234
+
235
+ assert(n % qk == 0);
236
+ assert(nrc == 1);
237
+ UNUSED(nrc);
238
+ UNUSED(bx);
239
+ UNUSED(by);
240
+ UNUSED(bs);
241
+
242
+ const block_q4_0 * GGML_RESTRICT x = vx;
243
+ const block_q8_0 * GGML_RESTRICT y = vy;
244
+
245
+ int ib = 0;
246
+ float sumf = 0;
247
+
248
+ #if defined __wasm_simd128__
249
+ v128_t sumv = wasm_f32x4_splat(0.0f);
250
+
251
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
252
+ const v128_t s8b = wasm_i8x16_splat(0x8);
253
+
254
+ for (; ib + 1 < nb; ib += 2) {
255
+ const block_q4_0 * GGML_RESTRICT x0 = &x[ib];
256
+ const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1];
257
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
258
+ const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
259
+
260
+ // Load and process x0
261
+ v128_t v0_0 = wasm_v128_load(x0->qs);
262
+ v128_t v0_0l = wasm_v128_and(v0_0, m4b);
263
+ v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
264
+ v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
265
+ v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
266
+
267
+ // Load y0 vectors
268
+ v128_t y0_l = wasm_v128_load(y0->qs);
269
+ v128_t y0_h = wasm_v128_load(y0->qs + 16);
270
+
271
+ // Extend to i16x8 and compute dot products
272
+ v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
273
+ v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
274
+ v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
275
+ v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
276
+
277
+ v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
278
+ v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
279
+ v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
280
+ v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
281
+
282
+ v128_t dp0 = wasm_i32x4_add(
283
+ wasm_i32x4_add(
284
+ wasm_i32x4_dot_i16x8(dx0l, dy0ll),
285
+ wasm_i32x4_dot_i16x8(dx0h, dy0lh)
286
+ ),
287
+ wasm_i32x4_add(
288
+ wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
289
+ wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
290
+ )
291
+ );
292
+
293
+ // Load and process x1
294
+ v128_t v0_1 = wasm_v128_load(x1->qs);
295
+ v128_t v0_1l = wasm_v128_and(v0_1, m4b);
296
+ v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
297
+ v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
298
+ v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
299
+
300
+ // Load y1 vectors
301
+ v128_t y1_l = wasm_v128_load(y1->qs);
302
+ v128_t y1_h = wasm_v128_load(y1->qs + 16);
303
+
304
+ // Extend to i16x8 and compute dot products
305
+ v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
306
+ v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
307
+ v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
308
+ v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
309
+
310
+ v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
311
+ v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
312
+ v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
313
+ v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
314
+
315
+ v128_t dp1 = wasm_i32x4_add(
316
+ wasm_i32x4_add(
317
+ wasm_i32x4_dot_i16x8(dx1l, dy1ll),
318
+ wasm_i32x4_dot_i16x8(dx1h, dy1lh)
319
+ ),
320
+ wasm_i32x4_add(
321
+ wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
322
+ wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
323
+ )
324
+ );
325
+
326
+ // Accumulate results with scaling
327
+ float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
328
+ float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
329
+
330
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
331
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
332
+ }
333
+
334
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
335
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
336
+
337
+ #endif
338
+ for (; ib < nb; ++ib) {
339
+ int sumi0 = 0;
340
+ int sumi1 = 0;
341
+
342
+ for (int j = 0; j < qk/2; ++j) {
343
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
344
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
345
+
346
+ sumi0 += (v0 * y[ib].qs[j]);
347
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
348
+ }
349
+
350
+ int sumi = sumi0 + sumi1;
351
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
352
+ }
353
+
354
+ *s = sumf;
355
+ }
356
+
357
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
358
+ const int qk = QK8_0;
359
+ const int nb = n / qk;
360
+
361
+ int ib = 0;
362
+ float sumf = 0;
363
+
364
+ assert(n % qk == 0);
365
+ assert(qk == QK5_0);
366
+ assert(nrc == 1);
367
+ UNUSED(nrc);
368
+ UNUSED(bx);
369
+ UNUSED(by);
370
+ UNUSED(bs);
371
+
372
+ const block_q5_0 * GGML_RESTRICT x = vx;
373
+ const block_q8_0 * GGML_RESTRICT y = vy;
374
+
375
+ #if defined __wasm_simd128__
376
+ v128_t sumv = wasm_f32x4_splat(0.0f);
377
+
378
+ uint32_t qh_;
379
+ uint64_t tmp[4];
380
+
381
+ // TODO: check if unrolling this is better
382
+ for (; ib < nb; ++ib) {
383
+ const block_q5_0 * GGML_RESTRICT x0 = &x[ib];
384
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
385
+
386
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
387
+
388
+ // extract the 5th bit
389
+ memcpy(&qh_, x0->qh, sizeof(qh_));
390
+
391
+ tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF];
392
+ tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF];
393
+ tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
394
+ tmp[3] = table_b2b_1[(qh_ >> 24) ];
395
+
396
+ const v128_t qhl = wasm_v128_load(tmp + 0);
397
+ const v128_t qhh = wasm_v128_load(tmp + 2);
398
+
399
+ const v128_t v0 = wasm_v128_load(x0->qs);
400
+
401
+ // 4-bit -> 8-bit
402
+ const v128_t v0l = wasm_v128_and (v0, m4b);
403
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
404
+
405
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
406
+ const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
407
+ const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
408
+
409
+ // load y
410
+ const v128_t v1l = wasm_v128_load(y0->qs);
411
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
412
+
413
+ // int8x16 -> int16x8
414
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
415
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
416
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
417
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
418
+
419
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
420
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
421
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
422
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
423
+
424
+ // dot product
425
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
426
+ wasm_i32x4_add(
427
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
428
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
429
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
430
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
431
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
432
+ }
433
+
434
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
435
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
436
+
437
+ #endif
438
+ for (; ib < nb; ++ib) {
439
+ uint32_t qh;
440
+ memcpy(&qh, x[ib].qh, sizeof(qh));
441
+
442
+ int sumi0 = 0;
443
+ int sumi1 = 0;
444
+
445
+ for (int j = 0; j < qk/2; ++j) {
446
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
447
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
448
+
449
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
450
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
451
+
452
+ sumi0 += (x0 * y[ib].qs[j]);
453
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
454
+ }
455
+
456
+ int sumi = sumi0 + sumi1;
457
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
458
+ }
459
+
460
+ *s = sumf;
461
+ }
462
+
463
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
464
+ const int qk = QK8_1;
465
+ const int nb = n / qk;
466
+
467
+ int ib = 0;
468
+ float sumf = 0;
469
+
470
+ assert(n % qk == 0);
471
+ assert(qk == QK5_1);
472
+ assert(nrc == 1);
473
+ UNUSED(nrc);
474
+ UNUSED(bx);
475
+ UNUSED(by);
476
+ UNUSED(bs);
477
+
478
+ const block_q5_1 * GGML_RESTRICT x = vx;
479
+ const block_q8_1 * GGML_RESTRICT y = vy;
480
+
481
+ #if defined __wasm_simd128__
482
+ v128_t sumv = wasm_f32x4_splat(0.0f);
483
+
484
+ float summs = 0.0f;
485
+
486
+ uint32_t qh_;
487
+ uint64_t tmp[4];
488
+
489
+ // TODO: check if unrolling this is better
490
+ for (; ib < nb; ++ib) {
491
+ const block_q5_1 * GGML_RESTRICT x0 = &x[ib];
492
+ const block_q8_1 * GGML_RESTRICT y0 = &y[ib];
493
+
494
+ summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
495
+
496
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
497
+
498
+ // extract the 5th bit
499
+ memcpy(&qh_, x0->qh, sizeof(qh_));
500
+
501
+ tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF];
502
+ tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF];
503
+ tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
504
+ tmp[3] = table_b2b_0[(qh_ >> 24) ];
505
+
506
+ const v128_t qhl = wasm_v128_load(tmp + 0);
507
+ const v128_t qhh = wasm_v128_load(tmp + 2);
508
+
509
+ const v128_t v0 = wasm_v128_load(x0->qs);
510
+
511
+ // 4-bit -> 8-bit
512
+ const v128_t v0l = wasm_v128_and (v0, m4b);
513
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
514
+
515
+ // add high bit
516
+ const v128_t v0lf = wasm_v128_or(v0l, qhl);
517
+ const v128_t v0hf = wasm_v128_or(v0h, qhh);
518
+
519
+ // load y
520
+ const v128_t v1l = wasm_v128_load(y0->qs);
521
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
522
+
523
+ // int8x16 -> int16x8
524
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
525
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
526
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
527
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
528
+
529
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
530
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
531
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
532
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
533
+
534
+ // dot product
535
+ sumv = wasm_f32x4_add(sumv,
536
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
537
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
538
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
539
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
540
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
541
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
542
+ }
543
+
544
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
545
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
546
+
547
+ #endif
548
+ for (; ib < nb; ++ib) {
549
+ uint32_t qh;
550
+ memcpy(&qh, x[ib].qh, sizeof(qh));
551
+
552
+ int sumi0 = 0;
553
+ int sumi1 = 0;
554
+
555
+ for (int j = 0; j < qk/2; ++j) {
556
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
557
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
558
+
559
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
560
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
561
+
562
+ sumi0 += (x0 * y[ib].qs[j]);
563
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
564
+ }
565
+
566
+ int sumi = sumi0 + sumi1;
567
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
568
+ }
569
+
570
+ *s = sumf;
571
+ }
572
+
573
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
574
+ const int qk = QK8_0;
575
+ const int nb = n / qk;
576
+
577
+ assert(n % qk == 0);
578
+ assert(nrc == 1);
579
+ UNUSED(nrc);
580
+ UNUSED(bx);
581
+ UNUSED(by);
582
+ UNUSED(bs);
583
+
584
+ const block_q8_0 * GGML_RESTRICT x = vx;
585
+ const block_q8_0 * GGML_RESTRICT y = vy;
586
+
587
+ int ib = 0;
588
+ float sumf = 0;
589
+
590
+ #if defined __wasm_simd128__
591
+ v128_t sumv = wasm_f32x4_splat(0.0f);
592
+
593
+ for (; ib < nb; ++ib) {
594
+ const block_q8_0 * GGML_RESTRICT x0 = &x[ib];
595
+ const block_q8_0 * GGML_RESTRICT y0 = &y[ib];
596
+
597
+ const v128_t x0_0 = wasm_v128_load(x0->qs);
598
+ const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
599
+ const v128_t y0_0 = wasm_v128_load(y0->qs);
600
+ const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
601
+
602
+ // Extend 8-bit to 16-bit
603
+ const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
604
+ const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
605
+ const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
606
+ const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
607
+
608
+ const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
609
+ const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
610
+ const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
611
+ const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
612
+
613
+ // Compute dot products
614
+ const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
615
+ const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
616
+ const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
617
+ const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
618
+
619
+ // Sum all dot products
620
+ const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
621
+
622
+ // Convert to float and accumulate
623
+ const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
624
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
625
+ }
626
+
627
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
628
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
629
+
630
+ #endif
631
+ for (; ib < nb; ++ib) {
632
+ int sumi = 0;
633
+
634
+ for (int j = 0; j < qk; j++) {
635
+ sumi += x[ib].qs[j]*y[ib].qs[j];
636
+ }
637
+
638
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
639
+ }
640
+
641
+ *s = sumf;
642
+ }
643
+
644
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
645
+ assert(nrc == 1);
646
+ UNUSED(nrc);
647
+ UNUSED(bx);
648
+ UNUSED(by);
649
+ UNUSED(bs);
650
+
651
+ const block_q2_K * GGML_RESTRICT x = vx;
652
+ const block_q8_K * GGML_RESTRICT y = vy;
653
+
654
+ const int nb = n / QK_K;
655
+
656
+ #if defined __wasm_simd128__
657
+ float sumf = 0;
658
+
659
+ for (int i = 0; i < nb; ++i) {
660
+ const uint8_t * q2 = x[i].qs;
661
+ const int8_t * q8 = y[i].qs;
662
+ const uint8_t * sc = x[i].scales;
663
+
664
+ // Vectorized summs calculation
665
+ v128_t summs_vec = wasm_i32x4_splat(0);
666
+ {
667
+ v128_t sc_vec = wasm_v128_load(sc);
668
+ v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
669
+
670
+ v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
671
+ v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
672
+
673
+ v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
674
+ v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
675
+
676
+ summs_vec = wasm_i32x4_add(
677
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
678
+ wasm_i32x4_dot_i16x8(sc_high, bsums2)),
679
+ summs_vec
680
+ );
681
+
682
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
683
+ summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
684
+ }
685
+ int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
686
+
687
+ // Vectorized isum calculation
688
+ int32_t isum = 0;
689
+ const uint8_t * sc_ptr = sc;
690
+ const int k_iters = QK_K/128;
691
+
692
+ for (int k = 0; k < k_iters; ++k) {
693
+ v128_t isum_vec = wasm_i32x4_splat(0);
694
+ int shift = 0;
695
+
696
+ for (int j = 0; j < 4; ++j) {
697
+ const int d0 = (sc_ptr[0] & 0xF);
698
+ const int d1 = (sc_ptr[1] & 0xF);
699
+ sc_ptr += 2;
700
+
701
+ // Process first 16 elements
702
+ v128_t q2_0 = wasm_v128_load(q2);
703
+ v128_t q8_0 = wasm_v128_load(q8);
704
+ v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
705
+ v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
706
+
707
+ // Process next 16 elements
708
+ v128_t q2_1 = wasm_v128_load(q2 + 16);
709
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
710
+ v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
711
+ v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
712
+
713
+ // Calculate dot products
714
+ v128_t p0 = wasm_i32x4_dot_i16x8(
715
+ wasm_i16x8_extend_low_i8x16(q8_0),
716
+ wasm_i16x8_extend_low_i8x16(q2_bits_0)
717
+ );
718
+ v128_t p1 = wasm_i32x4_dot_i16x8(
719
+ wasm_i16x8_extend_high_i8x16(q8_0),
720
+ wasm_i16x8_extend_high_i8x16(q2_bits_0)
721
+ );
722
+ v128_t p2 = wasm_i32x4_dot_i16x8(
723
+ wasm_i16x8_extend_low_i8x16(q8_1),
724
+ wasm_i16x8_extend_low_i8x16(q2_bits_1)
725
+ );
726
+ v128_t p3 = wasm_i32x4_dot_i16x8(
727
+ wasm_i16x8_extend_high_i8x16(q8_1),
728
+ wasm_i16x8_extend_high_i8x16(q2_bits_1)
729
+ );
730
+
731
+ // Accumulate scaled results
732
+ v128_t scaled = wasm_i32x4_add(
733
+ wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
734
+ wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
735
+ );
736
+
737
+ isum_vec = wasm_i32x4_add(isum_vec, scaled);
738
+ q8 += 32;
739
+ shift += 2;
740
+ }
741
+ q2 += 32;
742
+
743
+ // Horizontal sum of isum_vec
744
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
745
+ isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
746
+ isum += wasm_i32x4_extract_lane(isum_vec, 0);
747
+ }
748
+
749
+ const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
750
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
751
+ sumf += dall * isum - dmin * summs;
752
+ }
753
+
754
+ *s = sumf;
755
+
756
+ #else
757
+
758
+ float sumf = 0;
759
+
760
+ for (int i = 0; i < nb; ++i) {
761
+
762
+ const uint8_t * q2 = x[i].qs;
763
+ const int8_t * q8 = y[i].qs;
764
+ const uint8_t * sc = x[i].scales;
765
+
766
+ int summs = 0;
767
+ for (int j = 0; j < 16; ++j) {
768
+ summs += y[i].bsums[j] * (sc[j] >> 4);
769
+ }
770
+
771
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
772
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
773
+
774
+ int isum = 0;
775
+ int is = 0;
776
+ int d;
777
+ for (int k = 0; k < QK_K/128; ++k) {
778
+ int shift = 0;
779
+ for (int j = 0; j < 4; ++j) {
780
+ d = sc[is++] & 0xF;
781
+ int isuml = 0;
782
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
783
+ isum += d * isuml;
784
+ d = sc[is++] & 0xF;
785
+ isuml = 0;
786
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
787
+ isum += d * isuml;
788
+ shift += 2;
789
+ q8 += 32;
790
+ }
791
+ q2 += 32;
792
+ }
793
+ sumf += dall * isum - dmin * summs;
794
+ }
795
+ *s = sumf;
796
+ #endif
797
+ }
798
+
799
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
800
+ assert(n % QK_K == 0);
801
+ assert(nrc == 1);
802
+ UNUSED(nrc);
803
+ UNUSED(bx);
804
+ UNUSED(by);
805
+ UNUSED(bs);
806
+
807
+ const uint32_t kmask1 = 0x03030303;
808
+ const uint32_t kmask2 = 0x0f0f0f0f;
809
+
810
+ const block_q3_K * GGML_RESTRICT x = vx;
811
+ const block_q8_K * GGML_RESTRICT y = vy;
812
+
813
+ const int nb = n / QK_K;
814
+
815
+ #if defined __wasm_simd128__
816
+ int8_t aux8[QK_K];
817
+ float sums[8] = {0};
818
+ uint32_t auxs[4];
819
+
820
+ float sumf = 0;
821
+ for (int i = 0; i < nb; ++i) {
822
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
823
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
824
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
825
+
826
+ // Process blocks with SIMD
827
+ int8_t * a = aux8;
828
+ uint8_t m = 1;
829
+ for (int j = 0; j < QK_K; j += 128) {
830
+ for (int shift = 0; shift <= 6; shift += 2) {
831
+ v128_t v_m = wasm_i8x16_splat(m);
832
+ for (int l = 0; l < 32; l += 16) {
833
+ v128_t v_q3 = wasm_v128_load(q3 + l);
834
+ v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
835
+ v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
836
+
837
+ v128_t v_hm = wasm_v128_load(hm + l);
838
+ v128_t v_mask = wasm_v128_and(v_hm, v_m);
839
+ v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
840
+
841
+ v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
842
+ wasm_v128_store(a + l, v_low2);
843
+ }
844
+ a += 32;
845
+ m <<= 1;
846
+ }
847
+ q3 += 32;
848
+ }
849
+
850
+ // Extract scales
851
+ memcpy(auxs, x[i].scales, 12);
852
+ uint32_t tmp = auxs[2];
853
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
854
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
855
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
856
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
857
+ const int8_t * scales = (const int8_t *)auxs;
858
+
859
+ // SIMD dot product with register accumulators
860
+ v128_t v_acc0 = wasm_i32x4_splat(0);
861
+ v128_t v_acc1 = wasm_i32x4_splat(0);
862
+ a = aux8;
863
+ for (int j = 0; j < QK_K/16; ++j) {
864
+ const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
865
+
866
+ // Process 16 elements per iteration
867
+ for (int k = 0; k < 2; ++k) {
868
+ const v128_t v_q8 = wasm_i16x8_load8x8(q8);
869
+ const v128_t v_a = wasm_i16x8_load8x8(a);
870
+
871
+ v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
872
+ v_prod = wasm_i16x8_mul(v_prod, v_scale);
873
+
874
+ v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
875
+ v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
876
+
877
+ q8 += 8;
878
+ a += 8;
879
+ }
880
+ }
881
+
882
+ // Accumulate results
883
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
884
+ const v128_t v_d = wasm_f32x4_splat(d);
885
+ v128_t v_sum = wasm_f32x4_add(
886
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
887
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
888
+ );
889
+
890
+ // Accumulate into sums vector
891
+ wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
892
+ }
893
+
894
+ // Horizontal sum
895
+ v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
896
+ sumf = wasm_f32x4_extract_lane(v_sum, 0) +
897
+ wasm_f32x4_extract_lane(v_sum, 1) +
898
+ wasm_f32x4_extract_lane(v_sum, 2) +
899
+ wasm_f32x4_extract_lane(v_sum, 3);
900
+
901
+ *s = sumf;
902
+
903
+ #else
904
+ // scalar version
905
+ // This function is written like this so the compiler can manage to vectorize most of it
906
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
907
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
908
+ // The ideal situation would be if we could just write the code once, and the compiler would
909
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
910
+ // write vectorized versions for AVX, ARM_NEON, etc.
911
+
912
+ int8_t aux8[QK_K];
913
+ int16_t aux16[8];
914
+ float sums [8];
915
+ int32_t aux32[8];
916
+ memset(sums, 0, 8*sizeof(float));
917
+
918
+ uint32_t auxs[4];
919
+ const int8_t * scales = (const int8_t*)auxs;
920
+
921
+ float sumf = 0;
922
+ for (int i = 0; i < nb; ++i) {
923
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
924
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
925
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
926
+ memset(aux32, 0, 8*sizeof(int32_t));
927
+ int8_t * GGML_RESTRICT a = aux8;
928
+ uint8_t m = 1;
929
+ for (int j = 0; j < QK_K; j += 128) {
930
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
931
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
932
+ a += 32; m <<= 1;
933
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
934
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
935
+ a += 32; m <<= 1;
936
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
937
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
938
+ a += 32; m <<= 1;
939
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
940
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
941
+ a += 32; m <<= 1;
942
+ q3 += 32;
943
+ }
944
+ a = aux8;
945
+
946
+ memcpy(auxs, x[i].scales, 12);
947
+ uint32_t tmp = auxs[2];
948
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
949
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
950
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
951
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
952
+ for (int j = 0; j < QK_K/16; ++j) {
953
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
954
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
955
+ q8 += 8; a += 8;
956
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
957
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
958
+ q8 += 8; a += 8;
959
+ }
960
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
961
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
962
+ }
963
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
964
+ *s = sumf;
965
+
966
+ #endif
967
+
968
+ }
969
+
970
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
971
+ assert(n % QK_K == 0);
972
+ assert(nrc == 1);
973
+ UNUSED(nrc);
974
+ UNUSED(bx);
975
+ UNUSED(by);
976
+ UNUSED(bs);
977
+
978
+ const block_q4_K * GGML_RESTRICT x = vx;
979
+ const block_q8_K * GGML_RESTRICT y = vy;
980
+
981
+ const int nb = n / QK_K;
982
+
983
+ static const uint32_t kmask1 = 0x3f3f3f3f;
984
+ static const uint32_t kmask2 = 0x0f0f0f0f;
985
+ static const uint32_t kmask3 = 0x03030303;
986
+
987
+ uint32_t utmp[4];
988
+
989
+ #if defined __wasm_simd128__
990
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
991
+ float sumf = 0;
992
+
993
+ for (int i = 0; i < nb; ++i) {
994
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
995
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
996
+
997
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
998
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
999
+
1000
+ // Process scales and mins
1001
+ memcpy(utmp, x[i].scales, 12);
1002
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1003
+ const uint32_t uaux = utmp[1] & kmask1;
1004
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1005
+ utmp[2] = uaux;
1006
+ utmp[0] &= kmask1;
1007
+
1008
+ // Sum mins * q8sums
1009
+ int32_t sumi = 0;
1010
+ const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
1011
+ const uint8_t * m = (const uint8_t *)&utmp[2];
1012
+ for (int j = 0; j < 16; j += 2) {
1013
+ sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
1014
+ }
1015
+ sumf -= dmin * sumi;
1016
+
1017
+ int32_t sumi1 = 0;
1018
+ int32_t sumi2 = 0;
1019
+
1020
+ for (int j = 0; j < QK_K/64; ++j) {
1021
+ // Load 64 4-bit weights (32 bytes)
1022
+ const v128_t q4x0 = wasm_v128_load(q4);
1023
+ const v128_t q4x1 = wasm_v128_load(q4 + 16);
1024
+ q4 += 32;
1025
+
1026
+ // Split into low/high nibbles
1027
+ const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
1028
+ const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
1029
+ const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
1030
+ const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
1031
+
1032
+ // Load 64 8-bit values (64 bytes)
1033
+ const v128_t q8x0 = wasm_v128_load(q8);
1034
+ const v128_t q8x1 = wasm_v128_load(q8 + 16);
1035
+ const v128_t q8x2 = wasm_v128_load(q8 + 32);
1036
+ const v128_t q8x3 = wasm_v128_load(q8 + 48);
1037
+ q8 += 64;
1038
+
1039
+ // Low nibble products
1040
+ v128_t vacc1 = wasm_i32x4_dot_i16x8(
1041
+ wasm_i16x8_extend_low_i8x16(q4l0),
1042
+ wasm_i16x8_extend_low_i8x16(q8x0)
1043
+ );
1044
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
1045
+ wasm_i16x8_extend_high_i8x16(q4l0),
1046
+ wasm_i16x8_extend_high_i8x16(q8x0)
1047
+ ));
1048
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
1049
+ wasm_i16x8_extend_low_i8x16(q4l1),
1050
+ wasm_i16x8_extend_low_i8x16(q8x1)
1051
+ ));
1052
+ vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
1053
+ wasm_i16x8_extend_high_i8x16(q4l1),
1054
+ wasm_i16x8_extend_high_i8x16(q8x1)
1055
+ ));
1056
+
1057
+ // High nibble products
1058
+ v128_t vacc2 = wasm_i32x4_dot_i16x8(
1059
+ wasm_i16x8_extend_low_i8x16(q4h0),
1060
+ wasm_i16x8_extend_low_i8x16(q8x2)
1061
+ );
1062
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
1063
+ wasm_i16x8_extend_high_i8x16(q4h0),
1064
+ wasm_i16x8_extend_high_i8x16(q8x2)
1065
+ ));
1066
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
1067
+ wasm_i16x8_extend_low_i8x16(q4h1),
1068
+ wasm_i16x8_extend_low_i8x16(q8x3)
1069
+ ));
1070
+ vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
1071
+ wasm_i16x8_extend_high_i8x16(q4h1),
1072
+ wasm_i16x8_extend_high_i8x16(q8x3)
1073
+ ));
1074
+
1075
+ // Accumulate scaled results
1076
+ int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
1077
+ wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
1078
+ sumi1 += vacc1_sum * scales[2*j];
1079
+
1080
+ int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
1081
+ wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
1082
+ sumi2 += vacc2_sum * scales[2*j+1];
1083
+ }
1084
+
1085
+ sumf += d * (sumi1 + sumi2);
1086
+ }
1087
+
1088
+ *s = sumf;
1089
+
1090
+ #else
1091
+
1092
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1093
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1094
+
1095
+ int8_t aux8[QK_K];
1096
+ int16_t aux16[8];
1097
+ float sums [8];
1098
+ int32_t aux32[8];
1099
+ memset(sums, 0, 8*sizeof(float));
1100
+
1101
+ float sumf = 0;
1102
+ for (int i = 0; i < nb; ++i) {
1103
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1104
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1105
+ memset(aux32, 0, 8*sizeof(int32_t));
1106
+ int8_t * GGML_RESTRICT a = aux8;
1107
+ for (int j = 0; j < QK_K/64; ++j) {
1108
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1109
+ a += 32;
1110
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1111
+ a += 32; q4 += 32;
1112
+ }
1113
+ memcpy(utmp, x[i].scales, 12);
1114
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1115
+ const uint32_t uaux = utmp[1] & kmask1;
1116
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1117
+ utmp[2] = uaux;
1118
+ utmp[0] &= kmask1;
1119
+
1120
+ int sumi = 0;
1121
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1122
+ a = aux8;
1123
+ int is = 0;
1124
+ for (int j = 0; j < QK_K/32; ++j) {
1125
+ int32_t scale = scales[is++];
1126
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1127
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1128
+ q8 += 8; a += 8;
1129
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1130
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1131
+ q8 += 8; a += 8;
1132
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1133
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1134
+ q8 += 8; a += 8;
1135
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1136
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1137
+ q8 += 8; a += 8;
1138
+ }
1139
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1140
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1141
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1142
+ sumf -= dmin * sumi;
1143
+ }
1144
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1145
+ *s = sumf;
1146
+ #endif
1147
+ }
1148
+
1149
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1150
+ assert(n % QK_K == 0);
1151
+ assert(nrc == 1);
1152
+ UNUSED(nrc);
1153
+ UNUSED(bx);
1154
+ UNUSED(by);
1155
+ UNUSED(bs);
1156
+
1157
+ const block_q5_K * GGML_RESTRICT x = vx;
1158
+ const block_q8_K * GGML_RESTRICT y = vy;
1159
+
1160
+ const int nb = n / QK_K;
1161
+
1162
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1163
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1164
+ static const uint32_t kmask3 = 0x03030303;
1165
+
1166
+ uint32_t utmp[4];
1167
+
1168
+ #if defined __wasm_simd128__
1169
+ //const uint8_t * scales = (const uint8_t*)&utmp[0];
1170
+ float sumf = 0;
1171
+
1172
+ for (int i = 0; i < nb; ++i) {
1173
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1174
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
1175
+
1176
+ const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1177
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1178
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1179
+
1180
+ // Process scales and mins
1181
+ memcpy(utmp, x[i].scales, 12);
1182
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1183
+ const uint32_t uaux = utmp[1] & kmask1;
1184
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1185
+ utmp[2] = uaux;
1186
+ utmp[0] &= kmask1;
1187
+
1188
+ // Sum mins * q8sums
1189
+ int32_t sumi_mins = 0;
1190
+ const int16_t * GGML_RESTRICT q8sums = y[i].bsums;
1191
+ const uint8_t * m = (const uint8_t *)&utmp[2];
1192
+ for (int j = 0; j < 16; j += 2) {
1193
+ sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
1194
+ }
1195
+ sumf -= dmin * sumi_mins; // Correct subtraction
1196
+
1197
+ v128_t qh0 = wasm_v128_load(qh);
1198
+ v128_t qh1 = wasm_v128_load(qh + 16);
1199
+ const uint8_t * sc = (const uint8_t *)utmp;
1200
+
1201
+ int32_t sumi = 0;
1202
+
1203
+ for (int j = 0; j < QK_K/64; ++j) {
1204
+ const int shift = j * 2;
1205
+ v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
1206
+ v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
1207
+
1208
+ v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
1209
+ v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
1210
+ v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
1211
+ v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
1212
+
1213
+ v128_t q5_0 = wasm_v128_load(q5);
1214
+ v128_t q5_1 = wasm_v128_load(q5 + 16);
1215
+ q5 += 32;
1216
+
1217
+ v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
1218
+ v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
1219
+ v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
1220
+ v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
1221
+
1222
+ v128_t q8_0 = wasm_v128_load(q8);
1223
+ v128_t q8_1 = wasm_v128_load(q8 + 16);
1224
+ v128_t q8_2 = wasm_v128_load(q8 + 32);
1225
+ v128_t q8_3 = wasm_v128_load(q8 + 48);
1226
+ q8 += 64;
1227
+
1228
+ // Process low quants
1229
+ v128_t pl0 = wasm_i32x4_dot_i16x8(
1230
+ wasm_i16x8_extend_low_i8x16(q5l_0),
1231
+ wasm_i16x8_extend_low_i8x16(q8_0)
1232
+ );
1233
+ pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
1234
+ wasm_i16x8_extend_high_i8x16(q5l_0),
1235
+ wasm_i16x8_extend_high_i8x16(q8_0)
1236
+ ));
1237
+ v128_t pl1 = wasm_i32x4_dot_i16x8(
1238
+ wasm_i16x8_extend_low_i8x16(q5l_1),
1239
+ wasm_i16x8_extend_low_i8x16(q8_1)
1240
+ );
1241
+ pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
1242
+ wasm_i16x8_extend_high_i8x16(q5l_1),
1243
+ wasm_i16x8_extend_high_i8x16(q8_1)
1244
+ ));
1245
+ v128_t sum_low = wasm_i32x4_add(pl0, pl1);
1246
+
1247
+ // Process high quants
1248
+ v128_t ph0 = wasm_i32x4_dot_i16x8(
1249
+ wasm_i16x8_extend_low_i8x16(q5h_0),
1250
+ wasm_i16x8_extend_low_i8x16(q8_2)
1251
+ );
1252
+ ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
1253
+ wasm_i16x8_extend_high_i8x16(q5h_0),
1254
+ wasm_i16x8_extend_high_i8x16(q8_2)
1255
+ ));
1256
+ v128_t ph1 = wasm_i32x4_dot_i16x8(
1257
+ wasm_i16x8_extend_low_i8x16(q5h_1),
1258
+ wasm_i16x8_extend_low_i8x16(q8_3)
1259
+ );
1260
+ ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
1261
+ wasm_i16x8_extend_high_i8x16(q5h_1),
1262
+ wasm_i16x8_extend_high_i8x16(q8_3)
1263
+ ));
1264
+ v128_t sum_high = wasm_i32x4_add(ph0, ph1);
1265
+
1266
+ // Accumulate with scale factors
1267
+ int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
1268
+ wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
1269
+ int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
1270
+ wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
1271
+
1272
+ sumi += sl * sc[2*j] + sh * sc[2*j+1];
1273
+ }
1274
+
1275
+ sumf += d * sumi;
1276
+ }
1277
+
1278
+ *s = sumf;
1279
+
1280
+ #else
1281
+
1282
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1283
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1284
+
1285
+ int8_t aux8[QK_K];
1286
+ int16_t aux16[8];
1287
+ float sums [8];
1288
+ int32_t aux32[8];
1289
+ memset(sums, 0, 8*sizeof(float));
1290
+
1291
+ float sumf = 0;
1292
+ for (int i = 0; i < nb; ++i) {
1293
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1294
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1295
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1296
+ memset(aux32, 0, 8*sizeof(int32_t));
1297
+ int8_t * GGML_RESTRICT a = aux8;
1298
+ uint8_t m = 1;
1299
+ for (int j = 0; j < QK_K/64; ++j) {
1300
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1301
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1302
+ a += 32; m <<= 1;
1303
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1304
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1305
+ a += 32; m <<= 1;
1306
+ q4 += 32;
1307
+ }
1308
+ memcpy(utmp, x[i].scales, 12);
1309
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1310
+ const uint32_t uaux = utmp[1] & kmask1;
1311
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1312
+ utmp[2] = uaux;
1313
+ utmp[0] &= kmask1;
1314
+
1315
+ int sumi = 0;
1316
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1317
+ a = aux8;
1318
+ int is = 0;
1319
+ for (int j = 0; j < QK_K/32; ++j) {
1320
+ int32_t scale = scales[is++];
1321
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1322
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1323
+ q8 += 8; a += 8;
1324
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1325
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1326
+ q8 += 8; a += 8;
1327
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1328
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1329
+ q8 += 8; a += 8;
1330
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1331
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1332
+ q8 += 8; a += 8;
1333
+ }
1334
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1335
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1336
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1337
+ sumf -= dmin * sumi;
1338
+ }
1339
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1340
+ *s = sumf;
1341
+ #endif
1342
+ }
1343
+
1344
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1345
+ assert(n % QK_K == 0);
1346
+ assert(nrc == 1);
1347
+ UNUSED(nrc);
1348
+ UNUSED(bx);
1349
+ UNUSED(by);
1350
+ UNUSED(bs);
1351
+
1352
+ const block_q6_K * GGML_RESTRICT x = vx;
1353
+ const block_q8_K * GGML_RESTRICT y = vy;
1354
+
1355
+ const int nb = n / QK_K;
1356
+
1357
+ #if defined __wasm_simd128__
1358
+ int8_t aux8[QK_K] __attribute__((aligned(16)));
1359
+ int32_t aux32[8] __attribute__((aligned(16))) = {0};
1360
+ float sums[8] __attribute__((aligned(16))) = {0};
1361
+
1362
+ for (int i = 0; i < nb; ++i) {
1363
+ // Unpack 6-bit quantized data into aux8 (unchanged)
1364
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1365
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1366
+ int8_t * a = aux8;
1367
+ for (int j = 0; j < QK_K; j += 128) {
1368
+ for (int l = 0; l < 32; ++l) {
1369
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1370
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1371
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1372
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1373
+ }
1374
+ a += 128;
1375
+ q4 += 64;
1376
+ qh += 32;
1377
+ }
1378
+
1379
+ const int8_t * GGML_RESTRICT a_ptr = aux8;
1380
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1381
+ v128_t acc0 = wasm_i32x4_splat(0);
1382
+ v128_t acc1 = wasm_i32x4_splat(0);
1383
+
1384
+ for (int j = 0; j < QK_K/16; ++j) {
1385
+ const int scale = x[i].scales[j];
1386
+ const v128_t vscale = wasm_i32x4_splat(scale);
1387
+
1388
+ // Load 16 elements from a and q8
1389
+ const v128_t a_vec = wasm_v128_load(a_ptr);
1390
+ const v128_t q8_vec = wasm_v128_load(q8);
1391
+
1392
+ // Process low 8 elements
1393
+ v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
1394
+ v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
1395
+ v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
1396
+ v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
1397
+ v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
1398
+
1399
+ // Process high 8 elements
1400
+ v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
1401
+ v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
1402
+ v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
1403
+ v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
1404
+ v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
1405
+
1406
+ // Scale and accumulate
1407
+ prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
1408
+ prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
1409
+ prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
1410
+ prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
1411
+
1412
+ acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
1413
+ acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
1414
+
1415
+ a_ptr += 16;
1416
+ q8 += 16;
1417
+ }
1418
+
1419
+ // Store accumulated results
1420
+ wasm_v128_store(&aux32[0], acc0);
1421
+ wasm_v128_store(&aux32[4], acc1);
1422
+
1423
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1424
+ for (int l = 0; l < 8; ++l) {
1425
+ sums[l] += d * aux32[l];
1426
+ }
1427
+ }
1428
+
1429
+ // Sum final results
1430
+ float sumf = 0;
1431
+ for (int l = 0; l < 8; ++l) {
1432
+ sumf += sums[l];
1433
+ }
1434
+ *s = sumf;
1435
+
1436
+ #else
1437
+
1438
+ int8_t aux8[QK_K];
1439
+ int16_t aux16[8];
1440
+ float sums [8];
1441
+ int32_t aux32[8];
1442
+ memset(sums, 0, 8*sizeof(float));
1443
+
1444
+ float sumf = 0;
1445
+ for (int i = 0; i < nb; ++i) {
1446
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1447
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1448
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1449
+ memset(aux32, 0, 8*sizeof(int32_t));
1450
+ int8_t * GGML_RESTRICT a = aux8;
1451
+ for (int j = 0; j < QK_K; j += 128) {
1452
+ for (int l = 0; l < 32; ++l) {
1453
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1454
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1455
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1456
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1457
+ }
1458
+ a += 128;
1459
+ q4 += 64;
1460
+ qh += 32;
1461
+ }
1462
+ a = aux8;
1463
+ int is = 0;
1464
+ for (int j = 0; j < QK_K/16; ++j) {
1465
+ int scale = x[i].scales[is++];
1466
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1467
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1468
+ q8 += 8; a += 8;
1469
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1470
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1471
+ q8 += 8; a += 8;
1472
+ }
1473
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1474
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1475
+ }
1476
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1477
+ *s = sumf;
1478
+ #endif
1479
+ }
1480
+