@fugood/llama.node 1.0.0-beta.4 → 1.0.0-beta.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/CMakeLists.txt +7 -4
  2. package/lib/binding.ts +1 -1
  3. package/package.json +14 -14
  4. package/scripts/llama.cpp.patch +27 -26
  5. package/src/LlamaCompletionWorker.cpp +21 -4
  6. package/src/LlamaCompletionWorker.h +2 -0
  7. package/src/LlamaContext.cpp +3 -12
  8. package/src/common.hpp +6 -5
  9. package/src/llama.cpp/CMakeLists.txt +15 -4
  10. package/src/llama.cpp/common/CMakeLists.txt +15 -24
  11. package/src/llama.cpp/common/arg.cpp +172 -110
  12. package/src/llama.cpp/common/chat-parser.cpp +385 -0
  13. package/src/llama.cpp/common/chat-parser.h +120 -0
  14. package/src/llama.cpp/common/chat.cpp +726 -596
  15. package/src/llama.cpp/common/chat.h +74 -8
  16. package/src/llama.cpp/common/common.cpp +56 -38
  17. package/src/llama.cpp/common/common.h +9 -3
  18. package/src/llama.cpp/common/json-partial.cpp +256 -0
  19. package/src/llama.cpp/common/json-partial.h +38 -0
  20. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  21. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
  22. package/src/llama.cpp/common/sampling.cpp +7 -8
  23. package/src/llama.cpp/common/speculative.cpp +6 -4
  24. package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
  25. package/src/llama.cpp/ggml/include/ggml.h +22 -3
  26. package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
  27. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
  28. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  29. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  30. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  31. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  39. package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  40. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  41. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  42. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  43. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  47. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  48. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  49. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  50. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
  52. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  53. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  54. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  55. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
  56. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  57. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
  58. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  60. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
  61. package/src/llama.cpp/include/llama.h +145 -40
  62. package/src/llama.cpp/src/CMakeLists.txt +5 -1
  63. package/src/llama.cpp/src/llama-arch.cpp +99 -3
  64. package/src/llama.cpp/src/llama-arch.h +10 -1
  65. package/src/llama.cpp/src/llama-batch.cpp +728 -272
  66. package/src/llama.cpp/src/llama-batch.h +112 -54
  67. package/src/llama.cpp/src/llama-chat.cpp +19 -2
  68. package/src/llama.cpp/src/llama-chat.h +1 -0
  69. package/src/llama.cpp/src/llama-context.cpp +525 -339
  70. package/src/llama.cpp/src/llama-context.h +38 -17
  71. package/src/llama.cpp/src/llama-cparams.cpp +4 -0
  72. package/src/llama.cpp/src/llama-cparams.h +2 -0
  73. package/src/llama.cpp/src/llama-grammar.cpp +12 -2
  74. package/src/llama.cpp/src/llama-graph.cpp +413 -353
  75. package/src/llama.cpp/src/llama-graph.h +112 -56
  76. package/src/llama.cpp/src/llama-hparams.cpp +10 -2
  77. package/src/llama.cpp/src/llama-hparams.h +13 -2
  78. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
  79. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
  80. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
  81. package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
  82. package/src/llama.cpp/src/llama-kv-cells.h +415 -0
  83. package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  84. package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
  85. package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
  86. package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
  87. package/src/llama.cpp/src/llama-memory.cpp +41 -0
  88. package/src/llama.cpp/src/llama-memory.h +86 -5
  89. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  90. package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
  91. package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
  92. package/src/llama.cpp/src/llama-model.cpp +1137 -528
  93. package/src/llama.cpp/src/llama-model.h +4 -0
  94. package/src/llama.cpp/src/llama-quant.cpp +2 -1
  95. package/src/llama.cpp/src/llama-sampling.cpp +2 -2
  96. package/src/llama.cpp/src/llama-vocab.cpp +69 -32
  97. package/src/llama.cpp/src/llama-vocab.h +1 -0
  98. package/src/llama.cpp/src/llama.cpp +11 -7
  99. package/src/llama.cpp/src/unicode.cpp +5 -0
  100. package/src/tts_utils.h +1 -1
  101. package/src/llama.cpp/common/json.hpp +0 -24766
  102. package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
  103. package/src/llama.cpp/common/minja/minja.hpp +0 -2974
  104. package/src/llama.cpp/common/stb_image.h +0 -7988
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  106. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
  107. package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
  108. package/src/llama.cpp/src/llama-kv-cache.h +0 -515
  109. /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  110. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  111. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -0,0 +1,2068 @@
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+
7
+ #include "../../quants.h"
8
+ #include "../../ggml-cpu-impl.h"
9
+
10
+ #include <math.h>
11
+ #include <string.h>
12
+ #include <assert.h>
13
+ #include <float.h>
14
+ #include <stdlib.h> // for qsort
15
+ #include <stdio.h> // for GGML_ASSERT
16
+
17
+ #define GROUP_MAX_EPS 1e-15f
18
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
19
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
20
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
+
23
+ #define UNUSED GGML_UNUSED
24
+
25
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
26
+ assert(QK8_0 == 32);
27
+ assert(k % QK8_0 == 0);
28
+ const int nb = k / QK8_0;
29
+
30
+ block_q8_0 * GGML_RESTRICT y = vy;
31
+
32
+ #if defined(__riscv_v)
33
+
34
+ size_t vl = QK8_0;
35
+
36
+ for (int i = 0; i < nb; i++) {
37
+ // load elements
38
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
39
+
40
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
41
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
42
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
43
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
44
+
45
+ const float d = amax / ((1 << 7) - 1);
46
+ const float id = d ? 1.0f/d : 0.0f;
47
+
48
+ y[i].d = GGML_FP32_TO_FP16(d);
49
+
50
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
51
+
52
+ // convert to integer
53
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
54
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
55
+
56
+ // store result
57
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
58
+ }
59
+ #else
60
+ GGML_UNUSED(nb);
61
+ // scalar
62
+ quantize_row_q8_0_ref(x, y, k);
63
+ #endif
64
+ }
65
+
66
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
67
+ assert(k % QK8_1 == 0);
68
+ const int nb = k / QK8_1;
69
+
70
+ block_q8_1 * GGML_RESTRICT y = vy;
71
+
72
+ #if defined(__riscv_v)
73
+
74
+ size_t vl = QK8_1;
75
+
76
+ for (int i = 0; i < nb; i++) {
77
+ // load elements
78
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
79
+
80
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
81
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
82
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
83
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
84
+
85
+ const float d = amax / ((1 << 7) - 1);
86
+ const float id = d ? 1.0f/d : 0.0f;
87
+
88
+ y[i].d = GGML_FP32_TO_FP16(d);
89
+
90
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
91
+
92
+ // convert to integer
93
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
94
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
95
+
96
+ // store result
97
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
98
+
99
+ // compute sum for y[i].s
100
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
101
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
102
+
103
+ // set y[i].s
104
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
105
+ y[i].s = GGML_FP32_TO_FP16(sum*d);
106
+ }
107
+
108
+ #else
109
+ GGML_UNUSED(nb);
110
+ // scalar
111
+ quantize_row_q8_1_ref(x, y, k);
112
+ #endif
113
+ }
114
+
115
+ //===================================== Dot products =================================
116
+
117
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
118
+ const int qk = QK8_0;
119
+ const int nb = n / qk;
120
+
121
+ assert(n % qk == 0);
122
+ assert(nrc == 1);
123
+ UNUSED(nrc);
124
+ UNUSED(bx);
125
+ UNUSED(by);
126
+ UNUSED(bs);
127
+
128
+ const block_q4_0 * GGML_RESTRICT x = vx;
129
+ const block_q8_0 * GGML_RESTRICT y = vy;
130
+
131
+ int ib = 0;
132
+ float sumf = 0;
133
+
134
+ #if defined(__riscv_v)
135
+ size_t vl = qk / 2;
136
+
137
+ for (; ib < nb; ++ib) {
138
+ // load elements
139
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
140
+
141
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
142
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
143
+
144
+ // mask and store lower part of x, and then upper part
145
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
146
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
147
+
148
+ vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
149
+ vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
150
+
151
+ // subtract offset
152
+ vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
153
+ vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
154
+
155
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
156
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
157
+
158
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
159
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
160
+
161
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
162
+
163
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
164
+ }
165
+
166
+ #endif
167
+ for (; ib < nb; ++ib) {
168
+ int sumi0 = 0;
169
+ int sumi1 = 0;
170
+
171
+ for (int j = 0; j < qk/2; ++j) {
172
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
173
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
174
+
175
+ sumi0 += (v0 * y[ib].qs[j]);
176
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
177
+ }
178
+
179
+ int sumi = sumi0 + sumi1;
180
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
181
+ }
182
+
183
+ *s = sumf;
184
+ }
185
+
186
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
187
+ const int qk = QK8_1;
188
+ const int nb = n / qk;
189
+
190
+ assert(n % qk == 0);
191
+ assert(nrc == 1);
192
+ UNUSED(nrc);
193
+ UNUSED(bx);
194
+ UNUSED(by);
195
+ UNUSED(bs);
196
+
197
+ const block_q4_1 * GGML_RESTRICT x = vx;
198
+ const block_q8_1 * GGML_RESTRICT y = vy;
199
+
200
+ int ib = 0;
201
+ float sumf = 0;
202
+
203
+ #if defined(__riscv_v)
204
+ size_t vl = qk / 2;
205
+
206
+ for (; ib < nb; ++ib) {
207
+ // load elements
208
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
209
+
210
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
211
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
212
+
213
+ // mask and store lower part of x, and then upper part
214
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
215
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
216
+
217
+ vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
218
+ vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
219
+
220
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
221
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
222
+
223
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
224
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
225
+
226
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
227
+
228
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
229
+ }
230
+
231
+ #endif
232
+ for (; ib < nb; ++ib) {
233
+ int sumi0 = 0;
234
+ int sumi1 = 0;
235
+
236
+ for (int j = 0; j < qk/2; ++j) {
237
+ const int v0 = (x[ib].qs[j] & 0x0F);
238
+ const int v1 = (x[ib].qs[j] >> 4);
239
+
240
+ sumi0 += (v0 * y[ib].qs[j]);
241
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
242
+ }
243
+
244
+ int sumi = sumi0 + sumi1;
245
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
246
+ }
247
+
248
+ *s = sumf;
249
+ }
250
+
251
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
252
+ const int qk = QK8_0;
253
+ const int nb = n / qk;
254
+
255
+ int ib = 0;
256
+ float sumf = 0;
257
+
258
+ assert(n % qk == 0);
259
+ assert(qk == QK5_0);
260
+ assert(nrc == 1);
261
+ UNUSED(nrc);
262
+ UNUSED(bx);
263
+ UNUSED(by);
264
+ UNUSED(bs);
265
+
266
+ const block_q5_0 * GGML_RESTRICT x = vx;
267
+ const block_q8_0 * GGML_RESTRICT y = vy;
268
+
269
+ #if defined(__riscv_v)
270
+ size_t vl;
271
+ size_t vlenb = __riscv_vlenb();
272
+
273
+ for (; ib < nb; ++ib) {
274
+ vl = qk / 2;
275
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
276
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
277
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
278
+ vint8m2_t v0c;
279
+ if (vlenb == 16) {
280
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
281
+ } else {
282
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
283
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
284
+ }
285
+
286
+ vl = qk;
287
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
288
+ qh = __riscv_vmnand_mm_b4(qh, qh, vl);
289
+ vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
290
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
291
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
292
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
293
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
294
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
295
+
296
+ sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi;
297
+ }
298
+
299
+ #endif
300
+ for (; ib < nb; ++ib) {
301
+ uint32_t qh;
302
+ memcpy(&qh, x[ib].qh, sizeof(qh));
303
+
304
+ int sumi0 = 0;
305
+ int sumi1 = 0;
306
+
307
+ for (int j = 0; j < qk/2; ++j) {
308
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
309
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
310
+
311
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
312
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
313
+
314
+ sumi0 += (x0 * y[ib].qs[j]);
315
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
316
+ }
317
+
318
+ int sumi = sumi0 + sumi1;
319
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
320
+ }
321
+
322
+ *s = sumf;
323
+ }
324
+
325
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
326
+ const int qk = QK8_1;
327
+ const int nb = n / qk;
328
+
329
+ int ib = 0;
330
+ float sumf = 0;
331
+
332
+ assert(n % qk == 0);
333
+ assert(qk == QK5_1);
334
+ assert(nrc == 1);
335
+ UNUSED(nrc);
336
+ UNUSED(bx);
337
+ UNUSED(by);
338
+ UNUSED(bs);
339
+
340
+ const block_q5_1 * GGML_RESTRICT x = vx;
341
+ const block_q8_1 * GGML_RESTRICT y = vy;
342
+
343
+ #if defined(__riscv_v)
344
+ size_t vl;
345
+ size_t vlenb = __riscv_vlenb();
346
+
347
+ for (; ib < nb; ++ib) {
348
+ vl = qk / 2;
349
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
350
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
351
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
352
+ vint8m2_t v0c;
353
+ if (vlenb == 16) {
354
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
355
+ } else {
356
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
357
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
358
+ }
359
+
360
+ vl = qk;
361
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
362
+ vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
363
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
364
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
365
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
366
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
367
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
368
+
369
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
370
+ }
371
+
372
+ #endif
373
+ for (; ib < nb; ++ib) {
374
+ uint32_t qh;
375
+ memcpy(&qh, x[ib].qh, sizeof(qh));
376
+
377
+ int sumi0 = 0;
378
+ int sumi1 = 0;
379
+
380
+ for (int j = 0; j < qk/2; ++j) {
381
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
382
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
383
+
384
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
385
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
386
+
387
+ sumi0 += (x0 * y[ib].qs[j]);
388
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
389
+ }
390
+
391
+ int sumi = sumi0 + sumi1;
392
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
393
+ }
394
+
395
+ *s = sumf;
396
+ }
397
+
398
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
399
+ const int qk = QK8_0;
400
+ const int nb = n / qk;
401
+
402
+ assert(n % qk == 0);
403
+ assert(nrc == 1);
404
+ UNUSED(nrc);
405
+ UNUSED(bx);
406
+ UNUSED(by);
407
+ UNUSED(bs);
408
+
409
+ const block_q8_0 * GGML_RESTRICT x = vx;
410
+ const block_q8_0 * GGML_RESTRICT y = vy;
411
+
412
+ int ib = 0;
413
+ float sumf = 0;
414
+
415
+ #if defined(__riscv_v)
416
+ size_t vl = qk;
417
+
418
+ for (; ib < nb; ++ib) {
419
+ // load elements
420
+ vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
421
+ vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
422
+
423
+ vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
424
+
425
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
426
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
427
+
428
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
429
+
430
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
431
+ }
432
+
433
+ #endif
434
+ for (; ib < nb; ++ib) {
435
+ int sumi = 0;
436
+
437
+ for (int j = 0; j < qk; j++) {
438
+ sumi += x[ib].qs[j]*y[ib].qs[j];
439
+ }
440
+
441
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
442
+ }
443
+
444
+ *s = sumf;
445
+ }
446
+
447
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
448
+ assert(nrc == 1);
449
+ UNUSED(nrc);
450
+ UNUSED(bx);
451
+ UNUSED(by);
452
+ UNUSED(bs);
453
+
454
+ const block_q2_K * GGML_RESTRICT x = vx;
455
+ const block_q8_K * GGML_RESTRICT y = vy;
456
+
457
+ const int nb = n / QK_K;
458
+
459
+ #if defined __riscv_xtheadvector
460
+
461
+ float sumf = 0;
462
+ uint8_t atmp[16];
463
+
464
+ for (int i = 0; i < nb; ++i) {
465
+ const uint8_t * q2 = x[i].qs;
466
+ const int8_t * q8 = y[i].qs;
467
+ const uint8_t * sc = x[i].scales;
468
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
469
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
470
+ uint8_t *patmp = atmp;
471
+ int vsums;
472
+ int tmp;
473
+ __asm__ __volatile__(
474
+ "th.vsetvli zero, %[vl16], e8, m1\n\t"
475
+ "th.vmv.v.x v8, zero\n\t"
476
+ "th.vlb.v v1, (%[sc])\n\t"
477
+ "th.vand.vi v0, v1, 0xF\n\t"
478
+ "th.vsrl.vi v1, v1, 4\n\t"
479
+ "th.vsb.v v0, (%[scale])\n\t"
480
+ "th.vwaddu.vx v16, v1, zero\n\t"
481
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
482
+ "th.vlh.v v2, (%[bsums])\n\t"
483
+ "th.vwmul.vv v4, v16, v2\n\t"
484
+ "th.vsetvli zero, %[vl16], e32, m4\n\t"
485
+ "th.vredsum.vs v8, v4, v8\n\t"
486
+ "th.vmv.x.s %[vsums], v8"
487
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
488
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
489
+ , [vl16] "r" (16)
490
+ : "memory"
491
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
492
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
493
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
494
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
495
+ );
496
+ sumf += dmin * vsums;
497
+ int isum = 0;
498
+
499
+ for (int j = 0; j < QK_K/128; ++j) {
500
+ __asm__ __volatile__(
501
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
502
+ "th.vlb.v v0, (%[q2])\n\t"
503
+ "th.vsrl.vi v2, v0, 2\n\t"
504
+ "th.vsrl.vi v4, v0, 4\n\t"
505
+ "th.vsrl.vi v6, v0, 6\n\t"
506
+ "th.vand.vi v0, v0, 0x3\n\t"
507
+ "th.vand.vi v2, v2, 0x3\n\t"
508
+ "th.vand.vi v4, v4, 0x3\n\t"
509
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
510
+ "th.vlb.v v8, (%[q8])\n\t"
511
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
512
+ "th.vwmul.vv v16, v0, v8\n\t"
513
+ "th.vwmul.vv v24, v4, v12\n\t"
514
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
515
+ "th.vmv.v.x v0, zero\n\t"
516
+ "th.vwredsum.vs v10, v16, v0\n\t"
517
+ "th.vwredsum.vs v9, v18, v0\n\t"
518
+ "th.vwredsum.vs v8, v20, v0\n\t"
519
+ "th.vwredsum.vs v7, v22, v0\n\t"
520
+ "th.vwredsum.vs v11, v24, v0\n\t"
521
+ "th.vwredsum.vs v12, v26, v0\n\t"
522
+ "th.vwredsum.vs v13, v28, v0\n\t"
523
+ "th.vwredsum.vs v14, v30, v0\n\t"
524
+ "li %[tmp], 4\n\t"
525
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
526
+ "th.vslideup.vi v10, v9, 1\n\t"
527
+ "th.vslideup.vi v8, v7, 1\n\t"
528
+ "th.vslideup.vi v11, v12, 1\n\t"
529
+ "th.vslideup.vi v13, v14, 1\n\t"
530
+ "th.vslideup.vi v10, v8, 2\n\t"
531
+ "th.vslideup.vi v11, v13, 2\n\t"
532
+ "li %[tmp], 8\n\t"
533
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
534
+ "th.vlbu.v v12, (%[scale])\n\t"
535
+ "th.vmul.vv v10, v10, v12\n\t"
536
+ "th.vredsum.vs v0, v10, v0\n\t"
537
+ "th.vmv.x.s %[tmp], v0\n\t"
538
+ "add %[isum], %[isum], %[tmp]"
539
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
540
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
541
+ , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
542
+ : "memory"
543
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
544
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
545
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
546
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
547
+ );
548
+ q2 += 32; q8 += 128; patmp += 8;
549
+ }
550
+
551
+ sumf += dall * isum;
552
+ }
553
+
554
+ *s = sumf;
555
+
556
+ #elif defined __riscv_v
557
+
558
+ float sumf = 0;
559
+ uint8_t atmp[16];
560
+
561
+ const int vector_length = __riscv_vlenb() * 8;
562
+ uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
563
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
564
+
565
+ switch (vector_length) {
566
+ case 256:
567
+ for (int i = 0; i < nb; ++i) {
568
+ const uint8_t * q2 = x[i].qs;
569
+ const int8_t * q8 = y[i].qs;
570
+ const uint8_t * sc = x[i].scales;
571
+
572
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
573
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
574
+
575
+ size_t vl = 16;
576
+
577
+ vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
578
+ vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
579
+
580
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
581
+
582
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
583
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
584
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
585
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
586
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
587
+
588
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
589
+
590
+ vl = 32;
591
+
592
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
593
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
594
+
595
+ uint8_t is = 0;
596
+ int isum = 0;
597
+
598
+ for (int j = 0; j < QK_K / 128; ++j) {
599
+ // load Q2
600
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
601
+
602
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
603
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
604
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
605
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
606
+
607
+ // duplicate scale elements for product
608
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
609
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
610
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
611
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
612
+
613
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
614
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
615
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
616
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
617
+
618
+ // load Q8
619
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
620
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
621
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
622
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
623
+
624
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
625
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
626
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
627
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
628
+
629
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
630
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
631
+
632
+ isum += __riscv_vmv_x_s_i32m1_i32(isum1);
633
+
634
+ q2 += 32;
635
+ q8 += 128;
636
+ is = 8;
637
+ }
638
+
639
+ sumf += dall * isum;
640
+ }
641
+ break;
642
+ case 128:
643
+ for (int i = 0; i < nb; ++i) {
644
+ const uint8_t * q2 = x[i].qs;
645
+ const int8_t * q8 = y[i].qs;
646
+ const uint8_t * sc = x[i].scales;
647
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
648
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
649
+ uint8_t *patmp = atmp;
650
+ int vsums;
651
+ int tmp;
652
+ __asm__ __volatile__(
653
+ "vsetivli zero, 16, e8, m1\n\t"
654
+ "vmv.v.x v8, zero\n\t"
655
+ "vle8.v v1, (%[sc])\n\t"
656
+ "vand.vi v0, v1, 0xF\n\t"
657
+ "vsrl.vi v1, v1, 4\n\t"
658
+ "vse8.v v0, (%[scale])\n\t"
659
+ "vsetivli zero, 16, e16, m2\n\t"
660
+ "vle16.v v2, (%[bsums])\n\t"
661
+ "vzext.vf2 v0, v1\n\t"
662
+ "vwmul.vv v4, v0, v2\n\t"
663
+ "vsetivli zero, 16, e32, m4\n\t"
664
+ "vredsum.vs v8, v4, v8\n\t"
665
+ "vmv.x.s %[vsums], v8"
666
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
667
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
668
+ : "memory"
669
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
670
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
671
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
672
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
673
+ );
674
+ sumf += dmin * vsums;
675
+ int isum = 0;
676
+
677
+ for (int j = 0; j < QK_K/128; ++j) {
678
+ __asm__ __volatile__(
679
+ "vsetvli zero, %[vl32], e8, m2\n\t"
680
+ "vle8.v v0, (%[q2])\n\t"
681
+ "vsrl.vi v2, v0, 2\n\t"
682
+ "vsrl.vi v4, v0, 4\n\t"
683
+ "vsrl.vi v6, v0, 6\n\t"
684
+ "vand.vi v0, v0, 0x3\n\t"
685
+ "vand.vi v2, v2, 0x3\n\t"
686
+ "vand.vi v4, v4, 0x3\n\t"
687
+ "vsetvli zero, %[vl128], e8, m8\n\t"
688
+ "vle8.v v8, (%[q8])\n\t"
689
+ "vsetvli zero, %[vl64], e8, m4\n\t"
690
+ "vwmul.vv v16, v0, v8\n\t"
691
+ "vwmul.vv v24, v4, v12\n\t"
692
+ "vsetivli zero, 16, e16, m2\n\t"
693
+ "vmv.v.x v0, zero\n\t"
694
+ "vwredsum.vs v10, v16, v0\n\t"
695
+ "vwredsum.vs v9, v18, v0\n\t"
696
+ "vwredsum.vs v8, v20, v0\n\t"
697
+ "vwredsum.vs v7, v22, v0\n\t"
698
+ "vwredsum.vs v11, v24, v0\n\t"
699
+ "vwredsum.vs v12, v26, v0\n\t"
700
+ "vwredsum.vs v13, v28, v0\n\t"
701
+ "vwredsum.vs v14, v30, v0\n\t"
702
+ "vsetivli zero, 4, e32, m1\n\t"
703
+ "vslideup.vi v10, v9, 1\n\t"
704
+ "vslideup.vi v8, v7, 1\n\t"
705
+ "vslideup.vi v11, v12, 1\n\t"
706
+ "vslideup.vi v13, v14, 1\n\t"
707
+ "vslideup.vi v10, v8, 2\n\t"
708
+ "vslideup.vi v11, v13, 2\n\t"
709
+ "vsetivli zero, 8, e32, m2\n\t"
710
+ "vle8.v v15, (%[scale])\n\t"
711
+ "vzext.vf4 v12, v15\n\t"
712
+ "vmul.vv v10, v10, v12\n\t"
713
+ "vredsum.vs v0, v10, v0\n\t"
714
+ "vmv.x.s %[tmp], v0\n\t"
715
+ "add %[isum], %[isum], %[tmp]"
716
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
717
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
718
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
719
+ : "memory"
720
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
721
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
722
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
723
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
724
+ );
725
+ q2 += 32; q8 += 128; patmp += 8;
726
+ }
727
+
728
+ sumf += dall * isum;
729
+ }
730
+ break;
731
+ default:
732
+ assert(false && "Unsupported vector length");
733
+ break;
734
+ }
735
+
736
+ *s = sumf;
737
+
738
+ #else
739
+
740
+ float sumf = 0;
741
+
742
+ for (int i = 0; i < nb; ++i) {
743
+
744
+ const uint8_t * q2 = x[i].qs;
745
+ const int8_t * q8 = y[i].qs;
746
+ const uint8_t * sc = x[i].scales;
747
+
748
+ int summs = 0;
749
+ for (int j = 0; j < 16; ++j) {
750
+ summs += y[i].bsums[j] * (sc[j] >> 4);
751
+ }
752
+
753
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
754
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
755
+
756
+ int isum = 0;
757
+ int is = 0;
758
+ int d;
759
+ for (int k = 0; k < QK_K/128; ++k) {
760
+ int shift = 0;
761
+ for (int j = 0; j < 4; ++j) {
762
+ d = sc[is++] & 0xF;
763
+ int isuml = 0;
764
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
765
+ isum += d * isuml;
766
+ d = sc[is++] & 0xF;
767
+ isuml = 0;
768
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
769
+ isum += d * isuml;
770
+ shift += 2;
771
+ q8 += 32;
772
+ }
773
+ q2 += 32;
774
+ }
775
+ sumf += dall * isum - dmin * summs;
776
+ }
777
+ *s = sumf;
778
+ #endif
779
+ }
780
+
781
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
782
+ assert(n % QK_K == 0);
783
+ assert(nrc == 1);
784
+ UNUSED(nrc);
785
+ UNUSED(bx);
786
+ UNUSED(by);
787
+ UNUSED(bs);
788
+
789
+ const uint32_t kmask1 = 0x03030303;
790
+ const uint32_t kmask2 = 0x0f0f0f0f;
791
+
792
+ const block_q3_K * GGML_RESTRICT x = vx;
793
+ const block_q8_K * GGML_RESTRICT y = vy;
794
+
795
+ const int nb = n / QK_K;
796
+
797
+ #if defined __riscv_xtheadvector
798
+
799
+ uint32_t utmp[4];
800
+ float sumf = 0;
801
+
802
+ for (int i = 0; i < nb; ++i) {
803
+ const uint8_t * restrict q3 = x[i].qs;
804
+ const uint8_t * restrict qh = x[i].hmask;
805
+ const int8_t * restrict q8 = y[i].qs;
806
+
807
+ int8_t * scale = (int8_t *)utmp;
808
+ int tmp;
809
+ __asm__ __volatile__(
810
+ "li %[tmp], 12\n\t"
811
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
812
+ "th.vlb.v v0, (%[s6b])\n\t"
813
+ "th.vmv.v.v v2, v0\n\t"
814
+ "li %[tmp], 2\n\t"
815
+ "th.vsetvli zero, %[tmp], e64, m1\n\t"
816
+ "th.vmv.v.x v9, %[sh]\n\t"\
817
+ "th.vslidedown.vi v1, v0, 1\n\t"
818
+ "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
819
+ "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
820
+ "li %[tmp], 4\n\t"
821
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
822
+ "th.vid.v v9\n\t"
823
+ "th.vmv.x.s %[tmp], v1\n\t"
824
+ "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
825
+ "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
826
+ "th.vsrl.vv v4, v1, v9\n\t"
827
+ "th.vsrl.vv v2, v0, v8\n\t"
828
+ "th.vand.vx v5, v4, %[kmask1]\n\t"
829
+ "th.vand.vx v3, v2, %[kmask2]\n\t"
830
+ "th.vsll.vi v6, v5, 4\n\t"
831
+ "th.vor.vv v7, v6, v3\n\t"
832
+ "li %[tmp], 16\n\t"
833
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
834
+ "th.vsub.vx v0, v7, %[c]\n\t"
835
+ "th.vsb.v v0, (%[scale])"
836
+ : [tmp] "=&r" (tmp)
837
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
838
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
839
+ : "memory"
840
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
841
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
842
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
843
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
844
+ );
845
+
846
+ uint8_t m = 1;
847
+ int isum = 0;
848
+ for (int j = 0; j < QK_K; j += 128) {
849
+ __asm__ __volatile__(
850
+ // fixme: use v0p7 mask layout directly
851
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
852
+ "th.vlb.v v8, (%[q3])\n\t"
853
+ "th.vsrl.vi v10, v8, 2\n\t"
854
+ "th.vsrl.vi v12, v8, 4\n\t"
855
+ "th.vsrl.vi v14, v8, 6\n\t"
856
+ "th.vand.vi v8, v8, 3\n\t"
857
+ "th.vand.vi v10, v10, 3\n\t"
858
+ "th.vand.vi v12, v12, 3\n\t"
859
+ "th.vlb.v v2, (%[qh])\n\t"
860
+ "th.vand.vx v4, v2, %[m]\n\t"
861
+ "slli %[m], %[m], 1\n\t"
862
+ "th.vmseq.vx v0, v4, zero\n\t"
863
+ "th.vadd.vi v8, v8, -4, v0.t\n\t"
864
+ "th.vand.vx v4, v2, %[m]\n\t"
865
+ "slli %[m], %[m], 1\n\t"
866
+ "th.vmseq.vx v0, v4, zero\n\t"
867
+ "th.vadd.vi v10, v10, -4, v0.t\n\t"
868
+ "th.vand.vx v4, v2, %[m]\n\t"
869
+ "slli %[m], %[m], 1\n\t"
870
+ "th.vmseq.vx v0, v4, zero\n\t"
871
+ "th.vadd.vi v12, v12, -4, v0.t\n\t"
872
+ "th.vand.vx v4, v2, %[m]\n\t"
873
+ "slli %[m], %[m], 1\n\t"
874
+ "th.vmseq.vx v0, v4, zero\n\t"
875
+ "th.vadd.vi v14, v14, -4, v0.t\n\t"
876
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
877
+ "th.vlb.v v0, (%[q8])\n\t"
878
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
879
+ "th.vwmul.vv v16, v0, v8\n\t"
880
+ "th.vwmul.vv v24, v4, v12\n\t"
881
+ "li %[tmp], 16\n\t"
882
+ "th.vsetvli zero, %[tmp], e16, m2\n\t"
883
+ "th.vmv.v.x v0, zero\n\t"
884
+ "th.vwredsum.vs v10, v16, v0\n\t"
885
+ "th.vwredsum.vs v9, v18, v0\n\t"
886
+ "th.vwredsum.vs v8, v20, v0\n\t"
887
+ "th.vwredsum.vs v7, v22, v0\n\t"
888
+ "th.vwredsum.vs v11, v24, v0\n\t"
889
+ "th.vwredsum.vs v12, v26, v0\n\t"
890
+ "th.vwredsum.vs v13, v28, v0\n\t"
891
+ "th.vwredsum.vs v14, v30, v0\n\t"
892
+ "li %[tmp], 4\n\t"
893
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
894
+ "th.vslideup.vi v10, v9, 1\n\t"
895
+ "th.vslideup.vi v8, v7, 1\n\t"
896
+ "th.vslideup.vi v11, v12, 1\n\t"
897
+ "th.vslideup.vi v13, v14, 1\n\t"
898
+ "th.vslideup.vi v10, v8, 2\n\t"
899
+ "th.vslideup.vi v11, v13, 2\n\t"
900
+ "li %[tmp], 8\n\t"
901
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
902
+ "th.vlb.v v12, (%[scale])\n\t"
903
+ "th.vmul.vv v10, v10, v12\n\t"
904
+ "th.vredsum.vs v0, v10, v0\n\t"
905
+ "th.vmv.x.s %[tmp], v0\n\t"
906
+ "add %[isum], %[isum], %[tmp]"
907
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
908
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
909
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
910
+ : "memory"
911
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
912
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
913
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
914
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
915
+ );
916
+ q3 += 32; q8 += 128; scale += 8;
917
+ }
918
+
919
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
920
+ sumf += d * isum;
921
+ }
922
+
923
+ *s = sumf;
924
+
925
+ #elif defined __riscv_v
926
+
927
+ uint32_t utmp[4];
928
+ float sumf = 0;
929
+ uint32_t aux[3];
930
+ const int vector_length = __riscv_vlenb() * 8;
931
+
932
+ switch (vector_length) {
933
+ case 256:
934
+ for (int i = 0; i < nb; ++i) {
935
+
936
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
937
+ const uint8_t * GGML_RESTRICT qh = x[i].hmask;
938
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
939
+
940
+ memcpy(aux, x[i].scales, 12);
941
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
942
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
943
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
944
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
945
+
946
+ int8_t * scale = (int8_t *)utmp;
947
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
948
+
949
+
950
+ size_t vl = 32;
951
+ uint8_t m = 1;
952
+
953
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
954
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
955
+
956
+ int sum_t = 0;
957
+
958
+ for (int j = 0; j < QK_K; j += 128) {
959
+
960
+ vl = 32;
961
+
962
+ // load Q3
963
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
964
+
965
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
966
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
967
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
968
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
969
+
970
+ // compute mask for subtraction
971
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
972
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
973
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
974
+ m <<= 1;
975
+
976
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
977
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
978
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
979
+ m <<= 1;
980
+
981
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
982
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
983
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
984
+ m <<= 1;
985
+
986
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
987
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
988
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
989
+ m <<= 1;
990
+
991
+ // load Q8 and take product with Q3
992
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
993
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
994
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
995
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
996
+
997
+ vl = 16;
998
+
999
+ // retrieve lane to multiply with scale
1000
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
1001
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
1002
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
1003
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
1004
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
1005
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
1006
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
1007
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
1008
+
1009
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
1010
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
1011
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
1012
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
1013
+
1014
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1015
+
1016
+ q3 += 32; q8 += 128; scale += 8;
1017
+
1018
+ }
1019
+
1020
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1021
+
1022
+ sumf += d*sum_t;
1023
+
1024
+ }
1025
+ break;
1026
+ case 128:
1027
+ for (int i = 0; i < nb; ++i) {
1028
+ const uint8_t * restrict q3 = x[i].qs;
1029
+ const uint8_t * restrict qh = x[i].hmask;
1030
+ const int8_t * restrict q8 = y[i].qs;
1031
+
1032
+ int8_t * scale = (int8_t *)utmp;
1033
+ int tmp;
1034
+ __asm__ __volatile__(
1035
+ "vsetivli zero, 12, e8, m1\n\t"
1036
+ "vle8.v v0, (%[s6b])\n\t"
1037
+ "vmv1r.v v2, v0\n\t"
1038
+ "vsetivli zero, 2, e64, m1\n\t"
1039
+ "vmv.v.x v9, %[sh]\n\t"\
1040
+ "vslidedown.vi v1, v0, 1\n\t"
1041
+ "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
1042
+ "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
1043
+ "vsetivli zero, 4, e32, m1\n\t"
1044
+ "vid.v v9\n\t"
1045
+ "vmv.x.s %[tmp], v1\n\t"
1046
+ "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
1047
+ "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
1048
+ "vsrl.vv v4, v1, v9\n\t"
1049
+ "vsrl.vv v2, v0, v8\n\t"
1050
+ "vand.vx v5, v4, %[kmask1]\n\t"
1051
+ "vand.vx v3, v2, %[kmask2]\n\t"
1052
+ "vsll.vi v6, v5, 4\n\t"
1053
+ "vor.vv v7, v6, v3\n\t"
1054
+ "vsetivli zero, 16, e8, m1\n\t"
1055
+ "vsub.vx v0, v7, %[c]\n\t"
1056
+ "vse8.v v0, (%[scale])"
1057
+ : [tmp] "=&r" (tmp)
1058
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
1059
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
1060
+ : "memory"
1061
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1062
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1063
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1064
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1065
+ );
1066
+
1067
+ uint8_t m = 1;
1068
+ int isum = 0;
1069
+ for (int j = 0; j < QK_K; j += 128) {
1070
+ __asm__ __volatile__(
1071
+ "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
1072
+ "vle8.v v8, (%[q3])\n\t"
1073
+ "vsrl.vi v10, v8, 2\n\t"
1074
+ "vsrl.vi v12, v8, 4\n\t"
1075
+ "vsrl.vi v14, v8, 6\n\t"
1076
+ "vand.vi v8, v8, 3\n\t"
1077
+ "vand.vi v10, v10, 3\n\t"
1078
+ "vand.vi v12, v12, 3\n\t"
1079
+ "vle8.v v2, (%[qh])\n\t"
1080
+ "vand.vx v4, v2, %[m]\n\t"
1081
+ "slli %[m], %[m], 1\n\t"
1082
+ "vmseq.vx v0, v4, zero\n\t"
1083
+ "vadd.vi v8, v8, -4, v0.t\n\t"
1084
+ "vand.vx v4, v2, %[m]\n\t"
1085
+ "slli %[m], %[m], 1\n\t"
1086
+ "vmseq.vx v0, v4, zero\n\t"
1087
+ "vadd.vi v10, v10, -4, v0.t\n\t"
1088
+ "vand.vx v4, v2, %[m]\n\t"
1089
+ "slli %[m], %[m], 1\n\t"
1090
+ "vmseq.vx v0, v4, zero\n\t"
1091
+ "vadd.vi v12, v12, -4, v0.t\n\t"
1092
+ "vand.vx v4, v2, %[m]\n\t"
1093
+ "slli %[m], %[m], 1\n\t"
1094
+ "vmseq.vx v0, v4, zero\n\t"
1095
+ "vadd.vi v14, v14, -4, v0.t\n\t"
1096
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1097
+ "vle8.v v0, (%[q8])\n\t"
1098
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1099
+ "vwmul.vv v16, v0, v8\n\t"
1100
+ "vwmul.vv v24, v4, v12\n\t"
1101
+ "vsetivli zero, 16, e16, m2\n\t"
1102
+ "vmv.v.x v0, zero\n\t"
1103
+ "vwredsum.vs v10, v16, v0\n\t"
1104
+ "vwredsum.vs v9, v18, v0\n\t"
1105
+ "vwredsum.vs v8, v20, v0\n\t"
1106
+ "vwredsum.vs v7, v22, v0\n\t"
1107
+ "vwredsum.vs v11, v24, v0\n\t"
1108
+ "vwredsum.vs v12, v26, v0\n\t"
1109
+ "vwredsum.vs v13, v28, v0\n\t"
1110
+ "vwredsum.vs v14, v30, v0\n\t"
1111
+ "vsetivli zero, 4, e32, m1\n\t"
1112
+ "vslideup.vi v10, v9, 1\n\t"
1113
+ "vslideup.vi v8, v7, 1\n\t"
1114
+ "vslideup.vi v11, v12, 1\n\t"
1115
+ "vslideup.vi v13, v14, 1\n\t"
1116
+ "vslideup.vi v10, v8, 2\n\t"
1117
+ "vslideup.vi v11, v13, 2\n\t"
1118
+ "vsetivli zero, 8, e32, m2\n\t"
1119
+ "vle8.v v15, (%[scale])\n\t"
1120
+ "vsext.vf4 v12, v15\n\t"
1121
+ "vmul.vv v10, v10, v12\n\t"
1122
+ "vredsum.vs v0, v10, v0\n\t"
1123
+ "vmv.x.s %[tmp], v0\n\t"
1124
+ "add %[isum], %[isum], %[tmp]"
1125
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
1126
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
1127
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
1128
+ : "memory"
1129
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1130
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1131
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1132
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1133
+ );
1134
+ q3 += 32; q8 += 128; scale += 8;
1135
+ }
1136
+
1137
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1138
+ sumf += d * isum;
1139
+ }
1140
+ break;
1141
+ default:
1142
+ assert(false && "Unsupported vector length");
1143
+ break;
1144
+ }
1145
+
1146
+ *s = sumf;
1147
+
1148
+ #else
1149
+ // scalar version
1150
+ // This function is written like this so the compiler can manage to vectorize most of it
1151
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1152
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1153
+ // The ideal situation would be if we could just write the code once, and the compiler would
1154
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
1155
+ // write vectorized versions for AVX, ARM_NEON, etc.
1156
+
1157
+ int8_t aux8[QK_K];
1158
+ int16_t aux16[8];
1159
+ float sums [8];
1160
+ int32_t aux32[8];
1161
+ memset(sums, 0, 8*sizeof(float));
1162
+
1163
+ uint32_t auxs[4];
1164
+ const int8_t * scales = (const int8_t*)auxs;
1165
+
1166
+ float sumf = 0;
1167
+ for (int i = 0; i < nb; ++i) {
1168
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1169
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
1170
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1171
+ memset(aux32, 0, 8*sizeof(int32_t));
1172
+ int8_t * GGML_RESTRICT a = aux8;
1173
+ uint8_t m = 1;
1174
+ for (int j = 0; j < QK_K; j += 128) {
1175
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1176
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1177
+ a += 32; m <<= 1;
1178
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
1179
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1180
+ a += 32; m <<= 1;
1181
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
1182
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1183
+ a += 32; m <<= 1;
1184
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
1185
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1186
+ a += 32; m <<= 1;
1187
+ q3 += 32;
1188
+ }
1189
+ a = aux8;
1190
+
1191
+ memcpy(auxs, x[i].scales, 12);
1192
+ uint32_t tmp = auxs[2];
1193
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1194
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1195
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1196
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1197
+ for (int j = 0; j < QK_K/16; ++j) {
1198
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1199
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1200
+ q8 += 8; a += 8;
1201
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1202
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1203
+ q8 += 8; a += 8;
1204
+ }
1205
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1206
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1207
+ }
1208
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1209
+ *s = sumf;
1210
+
1211
+ #endif
1212
+
1213
+ }
1214
+
1215
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1216
+ assert(n % QK_K == 0);
1217
+ assert(nrc == 1);
1218
+ UNUSED(nrc);
1219
+ UNUSED(bx);
1220
+ UNUSED(by);
1221
+ UNUSED(bs);
1222
+
1223
+ const block_q4_K * GGML_RESTRICT x = vx;
1224
+ const block_q8_K * GGML_RESTRICT y = vy;
1225
+
1226
+ const int nb = n / QK_K;
1227
+
1228
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1229
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1230
+ static const uint32_t kmask3 = 0x03030303;
1231
+
1232
+ uint32_t utmp[4];
1233
+
1234
+ #if defined __riscv_xtheadvector
1235
+
1236
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1237
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1238
+
1239
+ float sumf = 0;
1240
+
1241
+ for (int i = 0; i < nb; ++i) {
1242
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1243
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1244
+
1245
+ int tmp, tmp2, sumi;
1246
+ __asm__ __volatile__(
1247
+ "li %[t1], 12\n\t"
1248
+ "th.vsetvli zero, %[t1], e8, m1\n\t"
1249
+ "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1250
+ "li %[t1], 4\n\t"
1251
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
1252
+ "th.vslidedown.vi v2, v1, 2\n\t"
1253
+ "th.vmv.v.v v3, v2\n\t"
1254
+ "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1255
+ "li %[t1], 2\n\t"
1256
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
1257
+ "th.vmv.v.i v4, 4\n\t"
1258
+ "th.vand.vx v8, v1, %[kmask1]\n\t"
1259
+ "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1260
+ "th.vsrl.vi v6, v1, 6\n\t"
1261
+ "th.vsrl.vv v7, v2, v5\n\t"
1262
+ "th.vand.vx v0, v6, %[kmask3]\n\t"
1263
+ "th.vand.vx v2, v7, %[kmask2]\n\t"
1264
+ "th.vsll.vi v6, v0, 4\n\t"
1265
+ "li %[t2], 8\n\t"
1266
+ "addi %[t1], %[utmp], 4\n\t"
1267
+ "th.vor.vv v1, v6, v2\n\t"
1268
+ "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
1269
+ "th.vssw.v v1, (%[t1]), %[t2]\n\t"
1270
+ "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
1271
+ "th.vlw.v v2, (%[bsums])\n\t"
1272
+ "th.vsetvli zero, %[t2], e16, m1\n\t"
1273
+ "th.vnsrl.vi v0, v2, 0\n\t"
1274
+ "th.vnsrl.vi v1, v2, 16\n\t"
1275
+ "th.vadd.vv v2, v0, v1\n\t"
1276
+ "th.vlbu.v v4, (%[mins])\n\t"
1277
+ "th.vwmul.vv v6, v4, v2\n\t"
1278
+ "th.vmv.v.x v0, zero\n\t"
1279
+ "th.vsetvli zero, %[t2], e32, m2\n\t"
1280
+ "th.vredsum.vs v0, v6, v0\n\t"
1281
+ "th.vmv.x.s %[sumi], v0"
1282
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1283
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1284
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1285
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1286
+ : "memory"
1287
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1288
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1289
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1290
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1291
+ );
1292
+ sumf -= dmin * sumi;
1293
+
1294
+ const uint8_t * restrict q4 = x[i].qs;
1295
+ const int8_t * restrict q8 = y[i].qs;
1296
+
1297
+ sumi = 0;
1298
+ const uint8_t * scale = scales;
1299
+
1300
+ for (int j = 0; j < QK_K/128; ++j) {
1301
+ int vl128 = 128, vl64 = 64, vl32 = 32;
1302
+ __asm__ __volatile__(
1303
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
1304
+ "th.vlb.v v8, (%[q8])\n\t"
1305
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
1306
+ "th.vlb.v v0, (%[q4])\n\t"
1307
+ "th.vsrl.vi v4, v0, 4\n\t"
1308
+ "th.vand.vi v0, v0, 0xF\n\t"
1309
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
1310
+ "th.vwmul.vv v28, v6, v14\n\t"
1311
+ "th.vwmul.vv v20, v4, v10\n\t"
1312
+ "th.vwmul.vv v24, v2, v12\n\t"
1313
+ "th.vwmul.vv v16, v0, v8\n\t"
1314
+ "li %[tmp], 4\n\t"
1315
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
1316
+ "th.vlbu.v v1, (%[scale])\n\t"
1317
+ "th.vmv.v.x v0, zero\n\t"
1318
+ "th.vsetvli zero, %[vl32], e16, m4\n\t"
1319
+ "th.vwredsum.vs v6, v24, v0\n\t"
1320
+ "th.vwredsum.vs v7, v28, v0\n\t"
1321
+ "th.vwredsum.vs v4, v16, v0\n\t"
1322
+ "th.vwredsum.vs v5, v20, v0\n\t"
1323
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
1324
+ "th.vslideup.vi v6, v7, 1\n\t"
1325
+ "th.vslideup.vi v4, v5, 1\n\t"
1326
+ "th.vslideup.vi v4, v6, 2\n\t"
1327
+ "th.vmul.vv v8, v4, v1\n\t"
1328
+ "th.vredsum.vs v0, v8, v0\n\t"
1329
+ "th.vmv.x.s %[tmp], v0\n\t"
1330
+ "add %[sumi], %[sumi], %[tmp]"
1331
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1332
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1333
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1334
+ : "memory"
1335
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1336
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1337
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1338
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1339
+ );
1340
+
1341
+ q4 += 64; q8 += 128; scale += 4;
1342
+ }
1343
+
1344
+ sumf += d * sumi;
1345
+
1346
+ }
1347
+
1348
+ *s = sumf;
1349
+
1350
+ #elif defined __riscv_v
1351
+
1352
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1353
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1354
+
1355
+ float sumf = 0;
1356
+ const int vector_length = __riscv_vlenb() * 8;
1357
+
1358
+ switch (vector_length) {
1359
+ case 256:
1360
+ for (int i = 0; i < nb; ++i) {
1361
+
1362
+ size_t vl = 8;
1363
+
1364
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1365
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1366
+
1367
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
1368
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
1369
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
1370
+
1371
+ memcpy(utmp, x[i].scales, 12);
1372
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1373
+ const uint32_t uaux = utmp[1] & kmask1;
1374
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1375
+ utmp[2] = uaux;
1376
+ utmp[0] &= kmask1;
1377
+
1378
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
1379
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
1380
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
1381
+
1382
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1383
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1384
+
1385
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1386
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1387
+
1388
+ vl = 32;
1389
+
1390
+ int32_t sum_1 = 0;
1391
+ int32_t sum_2 = 0;
1392
+
1393
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
1394
+
1395
+ for (int j = 0; j < QK_K/64; ++j) {
1396
+ // load Q4
1397
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
1398
+
1399
+ // load Q8 and multiply it with lower Q4 nibble
1400
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
1401
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
1402
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
1403
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
1404
+
1405
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
1406
+
1407
+ // load Q8 and multiply it with upper Q4 nibble
1408
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
1409
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
1410
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
1411
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
1412
+
1413
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
1414
+
1415
+ q4 += 32; q8 += 64;
1416
+
1417
+ }
1418
+
1419
+ sumf += d*(sum_1 + sum_2);
1420
+
1421
+ }
1422
+ break;
1423
+ case 128:
1424
+ for (int i = 0; i < nb; ++i) {
1425
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
1426
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1427
+
1428
+ int tmp, tmp2, sumi;
1429
+ __asm__ __volatile__(
1430
+ "vsetivli zero, 12, e8, m1\n\t"
1431
+ "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1432
+ "vsetivli zero, 4, e32, m1\n\t"
1433
+ "vslidedown.vi v2, v1, 2\n\t"
1434
+ "vmv1r.v v3, v2\n\t"
1435
+ "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1436
+ "vsetivli zero, 2, e32, m1\n\t"
1437
+ "vmv.v.i v4, 4\n\t"
1438
+ "vand.vx v8, v1, %[kmask1]\n\t"
1439
+ "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1440
+ "vsrl.vi v6, v1, 6\n\t"
1441
+ "vsrl.vv v7, v2, v5\n\t"
1442
+ "vand.vx v0, v6, %[kmask3]\n\t"
1443
+ "vand.vx v2, v7, %[kmask2]\n\t"
1444
+ "vsll.vi v6, v0, 4\n\t"
1445
+ "li %[t2], 8\n\t"
1446
+ "addi %[t1], %[utmp], 4\n\t"
1447
+ "vor.vv v1, v6, v2\n\t"
1448
+ "vsse32.v v8, (%[utmp]), %[t2]\n\t"
1449
+ "vsse32.v v1, (%[t1]), %[t2]\n\t"
1450
+ "vsetivli zero, 8, e16, m1\n\t"
1451
+ "vle32.v v2, (%[bsums])\n\t"
1452
+ "vnsrl.wi v0, v2, 0\n\t"
1453
+ "vnsrl.wi v1, v2, 16\n\t"
1454
+ "vadd.vv v2, v0, v1\n\t"
1455
+ "vle8.v v3, (%[mins])\n\t"
1456
+ "vzext.vf2 v4, v3\n\t"
1457
+ "vwmul.vv v6, v4, v2\n\t"
1458
+ "vmv.v.x v0, zero\n\t"
1459
+ "vsetivli zero, 8, e32, m2\n\t"
1460
+ "vredsum.vs v0, v6, v0\n\t"
1461
+ "vmv.x.s %[sumi], v0"
1462
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1463
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1464
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1465
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1466
+ : "memory"
1467
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1468
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1469
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1470
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1471
+ );
1472
+ sumf -= dmin * sumi;
1473
+
1474
+ const uint8_t * restrict q4 = x[i].qs;
1475
+ const int8_t * restrict q8 = y[i].qs;
1476
+
1477
+ sumi = 0;
1478
+ const uint8_t * scale = scales;
1479
+
1480
+ for (int j = 0; j < QK_K/128; ++j) {
1481
+ int vl128 = 128, vl64 = 64, vl32 = 32;
1482
+ __asm__ __volatile__(
1483
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1484
+ "vle8.v v8, (%[q8])\n\t"
1485
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1486
+ "vle8.v v0, (%[q4])\n\t"
1487
+ "vsrl.vi v4, v0, 4\n\t"
1488
+ "vand.vi v0, v0, 0xF\n\t"
1489
+ "vsetvli zero, %[vl32], e8, m2\n\t"
1490
+ "vwmul.vv v28, v6, v14\n\t"
1491
+ "vwmul.vv v20, v4, v10\n\t"
1492
+ "vwmul.vv v24, v2, v12\n\t"
1493
+ "vwmul.vv v16, v0, v8\n\t"
1494
+ "vsetivli zero, 4, e32, m1\n\t"
1495
+ "vle8.v v2, (%[scale])\n\t"
1496
+ "vmv.v.x v0, zero\n\t"
1497
+ "vzext.vf4 v1, v2\n\t"
1498
+ "vsetvli zero, %[vl32], e16, m4\n\t"
1499
+ "vwredsum.vs v6, v24, v0\n\t"
1500
+ "vwredsum.vs v7, v28, v0\n\t"
1501
+ "vwredsum.vs v4, v16, v0\n\t"
1502
+ "vwredsum.vs v5, v20, v0\n\t"
1503
+ "vsetivli zero, 4, e32, m1\n\t"
1504
+ "vslideup.vi v6, v7, 1\n\t"
1505
+ "vslideup.vi v4, v5, 1\n\t"
1506
+ "vslideup.vi v4, v6, 2\n\t"
1507
+ "vmul.vv v8, v4, v1\n\t"
1508
+ "vredsum.vs v0, v8, v0\n\t"
1509
+ "vmv.x.s %[tmp], v0\n\t"
1510
+ "add %[sumi], %[sumi], %[tmp]"
1511
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1512
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1513
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1514
+ : "memory"
1515
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1516
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1517
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1518
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1519
+ );
1520
+
1521
+ q4 += 64; q8 += 128; scale += 4;
1522
+ }
1523
+
1524
+ sumf += d * sumi;
1525
+ }
1526
+ break;
1527
+ default:
1528
+ assert(false && "Unsupported vector length");
1529
+ break;
1530
+ }
1531
+
1532
+ *s = sumf;
1533
+
1534
+ #else
1535
+
1536
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1537
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1538
+
1539
+ int8_t aux8[QK_K];
1540
+ int16_t aux16[8];
1541
+ float sums [8];
1542
+ int32_t aux32[8];
1543
+ memset(sums, 0, 8*sizeof(float));
1544
+
1545
+ float sumf = 0;
1546
+ for (int i = 0; i < nb; ++i) {
1547
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1548
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1549
+ memset(aux32, 0, 8*sizeof(int32_t));
1550
+ int8_t * GGML_RESTRICT a = aux8;
1551
+ for (int j = 0; j < QK_K/64; ++j) {
1552
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1553
+ a += 32;
1554
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1555
+ a += 32; q4 += 32;
1556
+ }
1557
+ memcpy(utmp, x[i].scales, 12);
1558
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1559
+ const uint32_t uaux = utmp[1] & kmask1;
1560
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1561
+ utmp[2] = uaux;
1562
+ utmp[0] &= kmask1;
1563
+
1564
+ int sumi = 0;
1565
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1566
+ a = aux8;
1567
+ int is = 0;
1568
+ for (int j = 0; j < QK_K/32; ++j) {
1569
+ int32_t scale = scales[is++];
1570
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1571
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1572
+ q8 += 8; a += 8;
1573
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1574
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1575
+ q8 += 8; a += 8;
1576
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1577
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1578
+ q8 += 8; a += 8;
1579
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1580
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1581
+ q8 += 8; a += 8;
1582
+ }
1583
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1584
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1585
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1586
+ sumf -= dmin * sumi;
1587
+ }
1588
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1589
+ *s = sumf;
1590
+ #endif
1591
+ }
1592
+
1593
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1594
+ assert(n % QK_K == 0);
1595
+ assert(nrc == 1);
1596
+ UNUSED(nrc);
1597
+ UNUSED(bx);
1598
+ UNUSED(by);
1599
+ UNUSED(bs);
1600
+
1601
+ const block_q5_K * GGML_RESTRICT x = vx;
1602
+ const block_q8_K * GGML_RESTRICT y = vy;
1603
+
1604
+ const int nb = n / QK_K;
1605
+
1606
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1607
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1608
+ static const uint32_t kmask3 = 0x03030303;
1609
+
1610
+ uint32_t utmp[4];
1611
+
1612
+ #if defined __riscv_v
1613
+
1614
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1615
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1616
+
1617
+ float sumf = 0;
1618
+ float sums = 0.0;
1619
+
1620
+ size_t vl;
1621
+
1622
+ for (int i = 0; i < nb; ++i) {
1623
+
1624
+ vl = 8;
1625
+
1626
+ const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1627
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1628
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1629
+
1630
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1631
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1632
+
1633
+ vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
1634
+ vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
1635
+ vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
1636
+
1637
+ memcpy(utmp, x[i].scales, 12);
1638
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1639
+ const uint32_t uaux = utmp[1] & kmask1;
1640
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1641
+ utmp[2] = uaux;
1642
+ utmp[0] &= kmask1;
1643
+
1644
+ vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
1645
+ vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
1646
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
1647
+
1648
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1649
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1650
+
1651
+ vl = 32;
1652
+ int32_t aux32 = 0;
1653
+ int is = 0;
1654
+
1655
+ uint8_t m = 1;
1656
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1657
+ vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
1658
+
1659
+ for (int j = 0; j < QK_K/64; ++j) {
1660
+ // load Q5 and Q8
1661
+ vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
1662
+ vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
1663
+ vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
1664
+
1665
+ // compute mask for addition
1666
+ vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
1667
+ vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
1668
+ vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
1669
+ vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
1670
+ m <<= 1;
1671
+
1672
+ vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
1673
+ vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
1674
+ vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
1675
+ vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
1676
+ m <<= 1;
1677
+
1678
+ vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
1679
+ vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
1680
+
1681
+ vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
1682
+ vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
1683
+
1684
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
1685
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
1686
+
1687
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
1688
+ q5 += 32; q8 += 64;
1689
+
1690
+ }
1691
+
1692
+ sums += aux32 * d;
1693
+
1694
+ }
1695
+
1696
+ *s = sumf+sums;
1697
+
1698
+ #else
1699
+
1700
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1701
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1702
+
1703
+ int8_t aux8[QK_K];
1704
+ int16_t aux16[8];
1705
+ float sums [8];
1706
+ int32_t aux32[8];
1707
+ memset(sums, 0, 8*sizeof(float));
1708
+
1709
+ float sumf = 0;
1710
+ for (int i = 0; i < nb; ++i) {
1711
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1712
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1713
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1714
+ memset(aux32, 0, 8*sizeof(int32_t));
1715
+ int8_t * GGML_RESTRICT a = aux8;
1716
+ uint8_t m = 1;
1717
+ for (int j = 0; j < QK_K/64; ++j) {
1718
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1719
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1720
+ a += 32; m <<= 1;
1721
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1722
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1723
+ a += 32; m <<= 1;
1724
+ q4 += 32;
1725
+ }
1726
+ memcpy(utmp, x[i].scales, 12);
1727
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1728
+ const uint32_t uaux = utmp[1] & kmask1;
1729
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1730
+ utmp[2] = uaux;
1731
+ utmp[0] &= kmask1;
1732
+
1733
+ int sumi = 0;
1734
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1735
+ a = aux8;
1736
+ int is = 0;
1737
+ for (int j = 0; j < QK_K/32; ++j) {
1738
+ int32_t scale = scales[is++];
1739
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1740
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1741
+ q8 += 8; a += 8;
1742
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1743
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1744
+ q8 += 8; a += 8;
1745
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1746
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1747
+ q8 += 8; a += 8;
1748
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1749
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1750
+ q8 += 8; a += 8;
1751
+ }
1752
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1753
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1754
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
1755
+ sumf -= dmin * sumi;
1756
+ }
1757
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1758
+ *s = sumf;
1759
+ #endif
1760
+ }
1761
+
1762
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1763
+ assert(n % QK_K == 0);
1764
+ assert(nrc == 1);
1765
+ UNUSED(nrc);
1766
+ UNUSED(bx);
1767
+ UNUSED(by);
1768
+ UNUSED(bs);
1769
+
1770
+ const block_q6_K * GGML_RESTRICT x = vx;
1771
+ const block_q8_K * GGML_RESTRICT y = vy;
1772
+
1773
+ const int nb = n / QK_K;
1774
+
1775
+ #if defined __riscv_xtheadvector
1776
+
1777
+ float sumf = 0;
1778
+
1779
+ for (int i = 0; i < nb; ++i) {
1780
+
1781
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1782
+
1783
+ const uint8_t * restrict q6 = x[i].ql;
1784
+ const uint8_t * restrict qh = x[i].qh;
1785
+ const int8_t * restrict q8 = y[i].qs;
1786
+
1787
+ const int8_t * restrict scale = x[i].scales;
1788
+
1789
+ int sum_t = 0;
1790
+ int t0;
1791
+
1792
+ for (int j = 0; j < QK_K/128; ++j) {
1793
+ __asm__ __volatile__(
1794
+ "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
1795
+ "th.vlb.v v4, (%[qh])\n\t"
1796
+ "th.vsll.vi v0, v4, 4\n\t"
1797
+ "th.vsll.vi v2, v4, 2\n\t"
1798
+ "th.vsrl.vi v6, v4, 2\n\t"
1799
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1800
+ "th.vlb.v v8, (%[q6])\n\t"
1801
+ "th.vsrl.vi v12, v8, 4\n\t"
1802
+ "th.vand.vi v8, v8, 0xF\n\t"
1803
+ "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
1804
+ "th.vand.vx v0, v0, %[mask]\n\t"
1805
+ "th.vor.vv v8, v8, v0\n\t"
1806
+ "th.vlb.v v0, (%[q8])\n\t"
1807
+ "th.vsub.vx v8, v8, %[vl32]\n\t"
1808
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1809
+ "th.vwmul.vv v16, v0, v8\n\t"
1810
+ "th.vwmul.vv v24, v4, v12\n\t"
1811
+ "li %[t0], 16\n\t"
1812
+ "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
1813
+ "th.vmv.v.x v0, zero\n\t"
1814
+ "th.vwredsum.vs v10, v16, v0\n\t"
1815
+ "th.vwredsum.vs v9, v18, v0\n\t"
1816
+ "th.vwredsum.vs v8, v20, v0\n\t"
1817
+ "th.vwredsum.vs v7, v22, v0\n\t"
1818
+ "th.vwredsum.vs v11, v24, v0\n\t"
1819
+ "th.vwredsum.vs v12, v26, v0\n\t"
1820
+ "th.vwredsum.vs v13, v28, v0\n\t"
1821
+ "th.vwredsum.vs v14, v30, v0\n\t"
1822
+ "li %[t0], 4\n\t"
1823
+ "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
1824
+ "th.vslideup.vi v10, v9, 1\n\t"
1825
+ "th.vslideup.vi v8, v7, 1\n\t"
1826
+ "th.vslideup.vi v11, v12, 1\n\t"
1827
+ "th.vslideup.vi v13, v14, 1\n\t"
1828
+ "th.vslideup.vi v10, v8, 2\n\t"
1829
+ "th.vslideup.vi v11, v13, 2\n\t"
1830
+ "li %[t0], 8\n\t"
1831
+ "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
1832
+ "th.vlb.v v4, (%[scale])\n\t"
1833
+ "th.vmul.vv v2, v4, v10\n\t"
1834
+ "th.vredsum.vs v0, v2, v0\n\t"
1835
+ "th.vmv.x.s %[t0], v0\n\t"
1836
+ "add %[sumi], %[sumi], %[t0]"
1837
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
1838
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
1839
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1840
+ , [mask] "r" (0x30)
1841
+ : "memory"
1842
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1843
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1844
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1845
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1846
+ );
1847
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
1848
+ }
1849
+
1850
+ sumf += d * sum_t;
1851
+
1852
+ }
1853
+
1854
+ *s = sumf;
1855
+
1856
+ #elif defined __riscv_v
1857
+
1858
+ float sumf = 0;
1859
+ const int vector_length = __riscv_vlenb() * 8;
1860
+
1861
+ switch (vector_length) {
1862
+ case 256:
1863
+ for (int i = 0; i < nb; ++i) {
1864
+
1865
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1866
+
1867
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
1868
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1869
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1870
+
1871
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
1872
+
1873
+ size_t vl;
1874
+
1875
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1876
+
1877
+ int sum_t = 0;
1878
+ int is = 0;
1879
+
1880
+ for (int j = 0; j < QK_K/128; ++j) {
1881
+
1882
+ vl = 32;
1883
+
1884
+ // load qh
1885
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
1886
+
1887
+ // load Q6
1888
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
1889
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
1890
+
1891
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
1892
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
1893
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
1894
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
1895
+
1896
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
1897
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
1898
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
1899
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
1900
+
1901
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
1902
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
1903
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
1904
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
1905
+
1906
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
1907
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
1908
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
1909
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
1910
+
1911
+ // load Q8 and take product
1912
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
1913
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
1914
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
1915
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
1916
+
1917
+ vl = 16;
1918
+
1919
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
1920
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
1921
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
1922
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
1923
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
1924
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
1925
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
1926
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
1927
+
1928
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
1929
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
1930
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
1931
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
1932
+
1933
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1934
+
1935
+ q6 += 64; qh += 32; q8 += 128; is=8;
1936
+
1937
+ }
1938
+
1939
+ sumf += d * sum_t;
1940
+
1941
+ }
1942
+ break;
1943
+ case 128:
1944
+ for (int i = 0; i < nb; ++i) {
1945
+
1946
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
1947
+
1948
+ const uint8_t * restrict q6 = x[i].ql;
1949
+ const uint8_t * restrict qh = x[i].qh;
1950
+ const int8_t * restrict q8 = y[i].qs;
1951
+
1952
+ const int8_t * restrict scale = x[i].scales;
1953
+
1954
+ int sum_t = 0;
1955
+ int t0;
1956
+
1957
+ for (int j = 0; j < QK_K/128; ++j) {
1958
+ __asm__ __volatile__(
1959
+ "vsetvli zero, %[vl32], e8, m2\n\t"
1960
+ "vle8.v v4, (%[qh])\n\t"
1961
+ "vsll.vi v0, v4, 4\n\t"
1962
+ "vsll.vi v2, v4, 2\n\t"
1963
+ "vsrl.vi v6, v4, 2\n\t"
1964
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1965
+ "vle8.v v8, (%[q6])\n\t"
1966
+ "vsrl.vi v12, v8, 4\n\t"
1967
+ "vand.vi v8, v8, 0xF\n\t"
1968
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1969
+ "vand.vx v0, v0, %[mask]\n\t"
1970
+ "vor.vv v8, v8, v0\n\t"
1971
+ "vle8.v v0, (%[q8])\n\t"
1972
+ "vsub.vx v8, v8, %[vl32]\n\t"
1973
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1974
+ "vwmul.vv v16, v0, v8\n\t"
1975
+ "vwmul.vv v24, v4, v12\n\t"
1976
+ "vsetivli zero, 16, e16, m2\n\t"
1977
+ "vmv.v.x v0, zero\n\t"
1978
+ "vwredsum.vs v10, v16, v0\n\t"
1979
+ "vwredsum.vs v9, v18, v0\n\t"
1980
+ "vwredsum.vs v8, v20, v0\n\t"
1981
+ "vwredsum.vs v7, v22, v0\n\t"
1982
+ "vwredsum.vs v11, v24, v0\n\t"
1983
+ "vwredsum.vs v12, v26, v0\n\t"
1984
+ "vwredsum.vs v13, v28, v0\n\t"
1985
+ "vwredsum.vs v14, v30, v0\n\t"
1986
+ "vsetivli zero, 4, e32, m1\n\t"
1987
+ "vslideup.vi v10, v9, 1\n\t"
1988
+ "vslideup.vi v8, v7, 1\n\t"
1989
+ "vslideup.vi v11, v12, 1\n\t"
1990
+ "vslideup.vi v13, v14, 1\n\t"
1991
+ "vslideup.vi v10, v8, 2\n\t"
1992
+ "vslideup.vi v11, v13, 2\n\t"
1993
+ "vsetivli zero, 8, e32, m2\n\t"
1994
+ "vle8.v v2, (%[scale])\n\t"
1995
+ "vsext.vf4 v4, v2\n\t"
1996
+ "vmul.vv v2, v4, v10\n\t"
1997
+ "vredsum.vs v0, v2, v0\n\t"
1998
+ "vmv.x.s %[t0], v0\n\t"
1999
+ "add %[sumi], %[sumi], %[t0]"
2000
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
2001
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
2002
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
2003
+ , [mask] "r" (0x30)
2004
+ : "memory"
2005
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
2006
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
2007
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
2008
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
2009
+ );
2010
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
2011
+ }
2012
+
2013
+ sumf += d * sum_t;
2014
+
2015
+ }
2016
+ break;
2017
+ default:
2018
+ assert(false && "Unsupported vector length");
2019
+ break;
2020
+ }
2021
+
2022
+ *s = sumf;
2023
+
2024
+ #else
2025
+
2026
+ int8_t aux8[QK_K];
2027
+ int16_t aux16[8];
2028
+ float sums [8];
2029
+ int32_t aux32[8];
2030
+ memset(sums, 0, 8*sizeof(float));
2031
+
2032
+ float sumf = 0;
2033
+ for (int i = 0; i < nb; ++i) {
2034
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2035
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
2036
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
2037
+ memset(aux32, 0, 8*sizeof(int32_t));
2038
+ int8_t * GGML_RESTRICT a = aux8;
2039
+ for (int j = 0; j < QK_K; j += 128) {
2040
+ for (int l = 0; l < 32; ++l) {
2041
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2042
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2043
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2044
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2045
+ }
2046
+ a += 128;
2047
+ q4 += 64;
2048
+ qh += 32;
2049
+ }
2050
+ a = aux8;
2051
+ int is = 0;
2052
+ for (int j = 0; j < QK_K/16; ++j) {
2053
+ int scale = x[i].scales[is++];
2054
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2055
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2056
+ q8 += 8; a += 8;
2057
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2058
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2059
+ q8 += 8; a += 8;
2060
+ }
2061
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2062
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2063
+ }
2064
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2065
+ *s = sumf;
2066
+ #endif
2067
+ }
2068
+