llama_cpp 0.9.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,10 @@
1
- #include "k_quants.h"
2
- #include "ggml.h"
1
+ #include "ggml-quants.h"
2
+ #include "ggml-impl.h"
3
3
 
4
4
  #include <math.h>
5
5
  #include <string.h>
6
6
  #include <assert.h>
7
+ #include <float.h>
7
8
 
8
9
  #ifdef __ARM_NEON
9
10
 
@@ -65,1251 +66,3480 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
65
66
 
66
67
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
67
68
 
68
- //
69
- // 2-6 bit quantization in super-blocks
70
- //
69
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
70
+ // multiply int8_t, add results pairwise twice
71
+ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
72
+ // Get absolute values of x vectors
73
+ const __m128i ax = _mm_sign_epi8(x, x);
74
+ // Sign the values of the y vectors
75
+ const __m128i sy = _mm_sign_epi8(y, x);
76
+ // Perform multiplication and create 16-bit values
77
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
78
+ const __m128i ones = _mm_set1_epi16(1);
79
+ return _mm_madd_epi16(ones, dot);
80
+ }
71
81
 
72
- //
73
- // ===================== Helper functions
74
- //
75
- static inline int nearest_int(float fval) {
76
- assert(fval <= 4194303.f);
77
- float val = fval + 12582912.f;
78
- int i; memcpy(&i, &val, sizeof(int));
79
- return (i & 0x007fffff) - 0x00400000;
82
+ #if __AVX__ || __AVX2__ || __AVX512F__
83
+ // horizontally add 8 floats
84
+ static inline float hsum_float_8(const __m256 x) {
85
+ __m128 res = _mm256_extractf128_ps(x, 1);
86
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
87
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
88
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
89
+ return _mm_cvtss_f32(res);
80
90
  }
81
91
 
82
- static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
83
- float max = 0;
84
- float amax = 0;
85
- for (int i = 0; i < n; ++i) {
86
- float ax = fabsf(x[i]);
87
- if (ax > amax) { amax = ax; max = x[i]; }
88
- }
89
- if (amax < 1e-30f) { // all zero
90
- for (int i = 0; i < n; ++i) {
91
- L[i] = 0;
92
+ // horizontally add 8 int32_t
93
+ static inline int hsum_i32_8(const __m256i a) {
94
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
95
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
96
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
97
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
98
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
99
+ }
100
+
101
+ // horizontally add 4 int32_t
102
+ static inline int hsum_i32_4(const __m128i a) {
103
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
104
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
105
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
106
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
107
+ }
108
+
109
+ #if defined(__AVX2__) || defined(__AVX512F__)
110
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
111
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
112
+ uint32_t x32;
113
+ memcpy(&x32, x, sizeof(uint32_t));
114
+ const __m256i shuf_mask = _mm256_set_epi64x(
115
+ 0x0303030303030303, 0x0202020202020202,
116
+ 0x0101010101010101, 0x0000000000000000);
117
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
118
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
119
+ bytes = _mm256_or_si256(bytes, bit_mask);
120
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
121
+ }
122
+
123
+ // Unpack 32 4-bit fields into 32 bytes
124
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
125
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
126
+ {
127
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
128
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
129
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
130
+ return _mm256_and_si256(lowMask, bytes);
131
+ }
132
+
133
+ // add int16_t pairwise and return as float vector
134
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
135
+ const __m256i ones = _mm256_set1_epi16(1);
136
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
137
+ return _mm256_cvtepi32_ps(summed_pairs);
138
+ }
139
+
140
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
141
+ #if __AVXVNNI__
142
+ const __m256i zero = _mm256_setzero_si256();
143
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
144
+ return _mm256_cvtepi32_ps(summed_pairs);
145
+ #else
146
+ // Perform multiplication and create 16-bit values
147
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
148
+ return sum_i16_pairs_float(dot);
149
+ #endif
150
+ }
151
+
152
+ // multiply int8_t, add results pairwise twice and return as float vector
153
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
154
+ #if __AVXVNNIINT8__
155
+ const __m256i zero = _mm256_setzero_si256();
156
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
157
+ return _mm256_cvtepi32_ps(summed_pairs);
158
+ #else
159
+ // Get absolute values of x vectors
160
+ const __m256i ax = _mm256_sign_epi8(x, x);
161
+ // Sign the values of the y vectors
162
+ const __m256i sy = _mm256_sign_epi8(y, x);
163
+ return mul_sum_us8_pairs_float(ax, sy);
164
+ #endif
165
+ }
166
+
167
+ static inline __m128i packNibbles( __m256i bytes )
168
+ {
169
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
170
+ #if __AVX512F__
171
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
172
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
173
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
174
+ #else
175
+ const __m256i lowByte = _mm256_set1_epi16( 0xFF );
176
+ __m256i high = _mm256_andnot_si256( lowByte, bytes );
177
+ __m256i low = _mm256_and_si256( lowByte, bytes );
178
+ high = _mm256_srli_epi16( high, 4 );
179
+ bytes = _mm256_or_si256( low, high );
180
+
181
+ // Compress uint16_t lanes into bytes
182
+ __m128i r0 = _mm256_castsi256_si128( bytes );
183
+ __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
184
+ return _mm_packus_epi16( r0, r1 );
185
+ #endif
186
+ }
187
+ #elif defined(__AVX__)
188
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
189
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
190
+ uint32_t x32;
191
+ memcpy(&x32, x, sizeof(uint32_t));
192
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
193
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
194
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
195
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
196
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
197
+ bytesl = _mm_or_si128(bytesl, bit_mask);
198
+ bytesh = _mm_or_si128(bytesh, bit_mask);
199
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
200
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
201
+ return MM256_SET_M128I(bytesh, bytesl);
202
+ }
203
+
204
+ // Unpack 32 4-bit fields into 32 bytes
205
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
206
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
207
+ {
208
+ // Load 16 bytes from memory
209
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
210
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
211
+ const __m128i lowMask = _mm_set1_epi8(0xF);
212
+ tmpl = _mm_and_si128(lowMask, tmpl);
213
+ tmph = _mm_and_si128(lowMask, tmph);
214
+ return MM256_SET_M128I(tmph, tmpl);
215
+ }
216
+
217
+ // add int16_t pairwise and return as float vector
218
+ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
219
+ const __m128i ones = _mm_set1_epi16(1);
220
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
221
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
222
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
223
+ return _mm256_cvtepi32_ps(summed_pairs);
224
+ }
225
+
226
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
227
+ const __m128i axl = _mm256_castsi256_si128(ax);
228
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
229
+ const __m128i syl = _mm256_castsi256_si128(sy);
230
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
231
+ // Perform multiplication and create 16-bit values
232
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
233
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
234
+ return sum_i16_pairs_float(doth, dotl);
235
+ }
236
+
237
+ // multiply int8_t, add results pairwise twice and return as float vector
238
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
239
+ const __m128i xl = _mm256_castsi256_si128(x);
240
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
241
+ const __m128i yl = _mm256_castsi256_si128(y);
242
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
243
+ // Get absolute values of x vectors
244
+ const __m128i axl = _mm_sign_epi8(xl, xl);
245
+ const __m128i axh = _mm_sign_epi8(xh, xh);
246
+ // Sign the values of the y vectors
247
+ const __m128i syl = _mm_sign_epi8(yl, xl);
248
+ const __m128i syh = _mm_sign_epi8(yh, xh);
249
+ // Perform multiplication and create 16-bit values
250
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
251
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
252
+ return sum_i16_pairs_float(doth, dotl);
253
+ }
254
+
255
+ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
256
+ {
257
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
258
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
259
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
260
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
261
+ high = _mm_srli_epi16( high, 4 );
262
+ bytes1 = _mm_or_si128( low, high );
263
+ high = _mm_andnot_si128( lowByte, bytes2 );
264
+ low = _mm_and_si128( lowByte, bytes2 );
265
+ high = _mm_srli_epi16( high, 4 );
266
+ bytes2 = _mm_or_si128( low, high );
267
+
268
+ return _mm_packus_epi16( bytes1, bytes2);
269
+ }
270
+ #endif
271
+ #elif defined(__SSSE3__)
272
+ // horizontally add 4x4 floats
273
+ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
274
+ __m128 res_0 =_mm_hadd_ps(a, b);
275
+ __m128 res_1 =_mm_hadd_ps(c, d);
276
+ __m128 res =_mm_hadd_ps(res_0, res_1);
277
+ res =_mm_hadd_ps(res, res);
278
+ res =_mm_hadd_ps(res, res);
279
+
280
+ return _mm_cvtss_f32(res);
281
+ }
282
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
283
+ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
284
+
285
+ #if defined(__ARM_NEON)
286
+
287
+ #if !defined(__aarch64__)
288
+
289
+ inline static int32_t vaddvq_s32(int32x4_t v) {
290
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
291
+ }
292
+
293
+ inline static float vaddvq_f32(float32x4_t v) {
294
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
295
+ }
296
+
297
+ inline static float vmaxvq_f32(float32x4_t v) {
298
+ return
299
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
300
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
301
+ }
302
+
303
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
304
+ int32x4_t res;
305
+
306
+ res[0] = roundf(vgetq_lane_f32(v, 0));
307
+ res[1] = roundf(vgetq_lane_f32(v, 1));
308
+ res[2] = roundf(vgetq_lane_f32(v, 2));
309
+ res[3] = roundf(vgetq_lane_f32(v, 3));
310
+
311
+ return res;
312
+ }
313
+
314
+ #endif
315
+ #endif
316
+
317
+ #if defined(__ARM_NEON) || defined(__wasm_simd128__)
318
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
319
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
320
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
321
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
322
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
323
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
324
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
325
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
326
+
327
+ // precomputed tables for expanding 8bits to 8 bytes:
328
+ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
329
+ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
330
+ #endif
331
+
332
+ // reference implementation for deterministic creation of model files
333
+ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
334
+ static const int qk = QK4_0;
335
+
336
+ assert(k % qk == 0);
337
+
338
+ const int nb = k / qk;
339
+
340
+ for (int i = 0; i < nb; i++) {
341
+ float amax = 0.0f; // absolute max
342
+ float max = 0.0f;
343
+
344
+ for (int j = 0; j < qk; j++) {
345
+ const float v = x[i*qk + j];
346
+ if (amax < fabsf(v)) {
347
+ amax = fabsf(v);
348
+ max = v;
349
+ }
92
350
  }
93
- return 0.f;
94
- }
95
- float iscale = -nmax / max;
96
- if (rmse_type == 0) {
97
- for (int i = 0; i < n; ++i) {
98
- int l = nearest_int(iscale * x[i]);
99
- L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
351
+
352
+ const float d = max / -8;
353
+ const float id = d ? 1.0f/d : 0.0f;
354
+
355
+ y[i].d = GGML_FP32_TO_FP16(d);
356
+
357
+ for (int j = 0; j < qk/2; ++j) {
358
+ const float x0 = x[i*qk + 0 + j]*id;
359
+ const float x1 = x[i*qk + qk/2 + j]*id;
360
+
361
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
362
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
363
+
364
+ y[i].qs[j] = xi0;
365
+ y[i].qs[j] |= xi1 << 4;
100
366
  }
101
- return 1/iscale;
102
- }
103
- bool return_early = false;
104
- if (rmse_type < 0) {
105
- rmse_type = -rmse_type;
106
- return_early = true;
107
- }
108
- int weight_type = rmse_type%2;
109
- float sumlx = 0;
110
- float suml2 = 0;
111
- for (int i = 0; i < n; ++i) {
112
- int l = nearest_int(iscale * x[i]);
113
- l = MAX(-nmax, MIN(nmax-1, l));
114
- L[i] = l + nmax;
115
- float w = weight_type == 1 ? x[i] * x[i] : 1;
116
- sumlx += w*x[i]*l;
117
- suml2 += w*l*l;
118
367
  }
119
- float scale = sumlx/suml2;
120
- if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
121
- float best = scale * sumlx;
122
- for (int is = -9; is <= 9; ++is) {
123
- if (is == 0) {
124
- continue;
368
+ }
369
+
370
+ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
371
+ quantize_row_q4_0_reference(x, y, k);
372
+ }
373
+
374
+ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
375
+ const int qk = QK4_1;
376
+
377
+ assert(k % qk == 0);
378
+
379
+ const int nb = k / qk;
380
+
381
+ for (int i = 0; i < nb; i++) {
382
+ float min = FLT_MAX;
383
+ float max = -FLT_MAX;
384
+
385
+ for (int j = 0; j < qk; j++) {
386
+ const float v = x[i*qk + j];
387
+
388
+ if (v < min) min = v;
389
+ if (v > max) max = v;
125
390
  }
126
- iscale = -(nmax + 0.1f*is) / max;
127
- sumlx = suml2 = 0;
128
- for (int i = 0; i < n; ++i) {
129
- int l = nearest_int(iscale * x[i]);
130
- l = MAX(-nmax, MIN(nmax-1, l));
131
- float w = weight_type == 1 ? x[i] * x[i] : 1;
132
- sumlx += w*x[i]*l;
133
- suml2 += w*l*l;
391
+
392
+ const float d = (max - min) / ((1 << 4) - 1);
393
+ const float id = d ? 1.0f/d : 0.0f;
394
+
395
+ y[i].d = GGML_FP32_TO_FP16(d);
396
+ y[i].m = GGML_FP32_TO_FP16(min);
397
+
398
+ for (int j = 0; j < qk/2; ++j) {
399
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
400
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
401
+
402
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
403
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
404
+
405
+ y[i].qs[j] = xi0;
406
+ y[i].qs[j] |= xi1 << 4;
134
407
  }
135
- if (suml2 > 0 && sumlx*sumlx > best*suml2) {
136
- for (int i = 0; i < n; ++i) {
137
- int l = nearest_int(iscale * x[i]);
138
- L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
408
+ }
409
+ }
410
+
411
+ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
412
+ quantize_row_q4_1_reference(x, y, k);
413
+ }
414
+
415
+ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
416
+ static const int qk = QK5_0;
417
+
418
+ assert(k % qk == 0);
419
+
420
+ const int nb = k / qk;
421
+
422
+ for (int i = 0; i < nb; i++) {
423
+ float amax = 0.0f; // absolute max
424
+ float max = 0.0f;
425
+
426
+ for (int j = 0; j < qk; j++) {
427
+ const float v = x[i*qk + j];
428
+ if (amax < fabsf(v)) {
429
+ amax = fabsf(v);
430
+ max = v;
139
431
  }
140
- scale = sumlx/suml2; best = scale*sumlx;
141
432
  }
433
+
434
+ const float d = max / -16;
435
+ const float id = d ? 1.0f/d : 0.0f;
436
+
437
+ y[i].d = GGML_FP32_TO_FP16(d);
438
+
439
+ uint32_t qh = 0;
440
+
441
+ for (int j = 0; j < qk/2; ++j) {
442
+ const float x0 = x[i*qk + 0 + j]*id;
443
+ const float x1 = x[i*qk + qk/2 + j]*id;
444
+
445
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
446
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
447
+
448
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
449
+
450
+ // get the 5-th bit and store it in qh at the right position
451
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
452
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
453
+ }
454
+
455
+ memcpy(&y[i].qh, &qh, sizeof(qh));
142
456
  }
143
- return scale;
144
457
  }
145
458
 
146
- static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
147
- float max = 0;
148
- float amax = 0;
149
- for (int i = 0; i < n; ++i) {
150
- float ax = fabsf(x[i]);
151
- if (ax > amax) { amax = ax; max = x[i]; }
459
+ void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
460
+ quantize_row_q5_0_reference(x, y, k);
461
+ }
462
+
463
+ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
464
+ const int qk = QK5_1;
465
+
466
+ assert(k % qk == 0);
467
+
468
+ const int nb = k / qk;
469
+
470
+ for (int i = 0; i < nb; i++) {
471
+ float min = FLT_MAX;
472
+ float max = -FLT_MAX;
473
+
474
+ for (int j = 0; j < qk; j++) {
475
+ const float v = x[i*qk + j];
476
+
477
+ if (v < min) min = v;
478
+ if (v > max) max = v;
479
+ }
480
+
481
+ const float d = (max - min) / ((1 << 5) - 1);
482
+ const float id = d ? 1.0f/d : 0.0f;
483
+
484
+ y[i].d = GGML_FP32_TO_FP16(d);
485
+ y[i].m = GGML_FP32_TO_FP16(min);
486
+
487
+ uint32_t qh = 0;
488
+
489
+ for (int j = 0; j < qk/2; ++j) {
490
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
491
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
492
+
493
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
494
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
495
+
496
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
497
+
498
+ // get the 5-th bit and store it in qh at the right position
499
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
500
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
501
+ }
502
+
503
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
152
504
  }
153
- if (!amax) { // all zero
154
- for (int i = 0; i < n; ++i) { L[i] = 0; }
155
- return 0.f;
505
+ }
506
+
507
+ void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
508
+ quantize_row_q5_1_reference(x, y, k);
509
+ }
510
+
511
+ // reference implementation for deterministic creation of model files
512
+ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
513
+ assert(k % QK8_0 == 0);
514
+ const int nb = k / QK8_0;
515
+
516
+ for (int i = 0; i < nb; i++) {
517
+ float amax = 0.0f; // absolute max
518
+
519
+ for (int j = 0; j < QK8_0; j++) {
520
+ const float v = x[i*QK8_0 + j];
521
+ amax = MAX(amax, fabsf(v));
522
+ }
523
+
524
+ const float d = amax / ((1 << 7) - 1);
525
+ const float id = d ? 1.0f/d : 0.0f;
526
+
527
+ y[i].d = GGML_FP32_TO_FP16(d);
528
+
529
+ for (int j = 0; j < QK8_0; ++j) {
530
+ const float x0 = x[i*QK8_0 + j]*id;
531
+
532
+ y[i].qs[j] = roundf(x0);
533
+ }
156
534
  }
157
- float iscale = -nmax / max;
158
- if (do_rmse) {
159
- float sumlx = 0;
160
- float suml2 = 0;
161
- for (int i = 0; i < n; ++i) {
162
- int l = nearest_int(iscale * x[i]);
163
- l = MAX(-nmax, MIN(nmax-1, l));
164
- L[i] = l;
165
- float w = x[i]*x[i];
166
- sumlx += w*x[i]*l;
167
- suml2 += w*l*l;
535
+ }
536
+
537
+ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
538
+ assert(QK8_0 == 32);
539
+ assert(k % QK8_0 == 0);
540
+ const int nb = k / QK8_0;
541
+
542
+ block_q8_0 * restrict y = vy;
543
+
544
+ #if defined(__ARM_NEON)
545
+ for (int i = 0; i < nb; i++) {
546
+ float32x4_t srcv [8];
547
+ float32x4_t asrcv[8];
548
+ float32x4_t amaxv[8];
549
+
550
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
551
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
552
+
553
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
554
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
555
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
556
+
557
+ const float amax = vmaxvq_f32(amaxv[0]);
558
+
559
+ const float d = amax / ((1 << 7) - 1);
560
+ const float id = d ? 1.0f/d : 0.0f;
561
+
562
+ y[i].d = GGML_FP32_TO_FP16(d);
563
+
564
+ for (int j = 0; j < 8; j++) {
565
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
566
+ const int32x4_t vi = vcvtnq_s32_f32(v);
567
+
568
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
569
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
570
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
571
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
168
572
  }
169
- for (int itry = 0; itry < 5; ++itry) {
170
- int n_changed = 0;
171
- for (int i = 0; i < n; ++i) {
172
- float w = x[i]*x[i];
173
- float slx = sumlx - w*x[i]*L[i];
174
- if (slx > 0) {
175
- float sl2 = suml2 - w*L[i]*L[i];
176
- int new_l = nearest_int(x[i] * sl2 / slx);
177
- new_l = MAX(-nmax, MIN(nmax-1, new_l));
178
- if (new_l != L[i]) {
179
- slx += w*x[i]*new_l;
180
- sl2 += w*new_l*new_l;
181
- if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
182
- L[i] = new_l; sumlx = slx; suml2 = sl2;
183
- ++n_changed;
184
- }
185
- }
186
- }
573
+ }
574
+ #elif defined(__wasm_simd128__)
575
+ for (int i = 0; i < nb; i++) {
576
+ v128_t srcv [8];
577
+ v128_t asrcv[8];
578
+ v128_t amaxv[8];
579
+
580
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
581
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
582
+
583
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
584
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
585
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
586
+
587
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
588
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
589
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
590
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
591
+
592
+ const float d = amax / ((1 << 7) - 1);
593
+ const float id = d ? 1.0f/d : 0.0f;
594
+
595
+ y[i].d = GGML_FP32_TO_FP16(d);
596
+
597
+ for (int j = 0; j < 8; j++) {
598
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
599
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
600
+
601
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
602
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
603
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
604
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
605
+ }
606
+ }
607
+ #elif defined(__AVX2__) || defined(__AVX__)
608
+ for (int i = 0; i < nb; i++) {
609
+ // Load elements into 4 AVX vectors
610
+ __m256 v0 = _mm256_loadu_ps( x );
611
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
612
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
613
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
614
+ x += 32;
615
+
616
+ // Compute max(abs(e)) for the block
617
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
618
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
619
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
620
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
621
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
622
+
623
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
624
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
625
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
626
+ const float maxScalar = _mm_cvtss_f32( max4 );
627
+
628
+ // Quantize these floats
629
+ const float d = maxScalar / 127.f;
630
+ y[i].d = GGML_FP32_TO_FP16(d);
631
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
632
+ const __m256 mul = _mm256_set1_ps( id );
633
+
634
+ // Apply the multiplier
635
+ v0 = _mm256_mul_ps( v0, mul );
636
+ v1 = _mm256_mul_ps( v1, mul );
637
+ v2 = _mm256_mul_ps( v2, mul );
638
+ v3 = _mm256_mul_ps( v3, mul );
639
+
640
+ // Round to nearest integer
641
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
642
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
643
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
644
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
645
+
646
+ // Convert floats to integers
647
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
648
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
649
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
650
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
651
+
652
+ #if defined(__AVX2__)
653
+ // Convert int32 to int16
654
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
655
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
656
+ // Convert int16 to int8
657
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
658
+
659
+ // We got our precious signed bytes, but the order is now wrong
660
+ // These AVX2 pack instructions process 16-byte pieces independently
661
+ // The following instruction is fixing the order
662
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
663
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
664
+
665
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
666
+ #else
667
+ // Since we don't have in AVX some necessary functions,
668
+ // we split the registers in half and call AVX2 analogs from SSE
669
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
670
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
671
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
672
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
673
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
674
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
675
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
676
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
677
+
678
+ // Convert int32 to int16
679
+ ni0 = _mm_packs_epi32( ni0, ni1 );
680
+ ni2 = _mm_packs_epi32( ni2, ni3 );
681
+ ni4 = _mm_packs_epi32( ni4, ni5 );
682
+ ni6 = _mm_packs_epi32( ni6, ni7 );
683
+ // Convert int16 to int8
684
+ ni0 = _mm_packs_epi16( ni0, ni2 );
685
+ ni4 = _mm_packs_epi16( ni4, ni6 );
686
+
687
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
688
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
689
+ #endif
690
+ }
691
+ #elif defined(__riscv_v_intrinsic)
692
+
693
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
694
+
695
+ for (int i = 0; i < nb; i++) {
696
+ // load elements
697
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
698
+
699
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
700
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
701
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
702
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
703
+
704
+ const float d = amax / ((1 << 7) - 1);
705
+ const float id = d ? 1.0f/d : 0.0f;
706
+
707
+ y[i].d = GGML_FP32_TO_FP16(d);
708
+
709
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
710
+
711
+ // convert to integer
712
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
713
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
714
+
715
+ // store result
716
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
717
+ }
718
+ #else
719
+ GGML_UNUSED(nb);
720
+ // scalar
721
+ quantize_row_q8_0_reference(x, y, k);
722
+ #endif
723
+ }
724
+
725
+ // reference implementation for deterministic creation of model files
726
+ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
727
+ assert(QK8_1 == 32);
728
+ assert(k % QK8_1 == 0);
729
+ const int nb = k / QK8_1;
730
+
731
+ for (int i = 0; i < nb; i++) {
732
+ float amax = 0.0f; // absolute max
733
+
734
+ for (int j = 0; j < QK8_1; j++) {
735
+ const float v = x[i*QK8_1 + j];
736
+ amax = MAX(amax, fabsf(v));
737
+ }
738
+
739
+ const float d = amax / ((1 << 7) - 1);
740
+ const float id = d ? 1.0f/d : 0.0f;
741
+
742
+ y[i].d = d;
743
+
744
+ int sum = 0;
745
+
746
+ for (int j = 0; j < QK8_1/2; ++j) {
747
+ const float v0 = x[i*QK8_1 + j]*id;
748
+ const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
749
+
750
+ y[i].qs[ j] = roundf(v0);
751
+ y[i].qs[QK8_1/2 + j] = roundf(v1);
752
+
753
+ sum += y[i].qs[ j];
754
+ sum += y[i].qs[QK8_1/2 + j];
755
+ }
756
+
757
+ y[i].s = sum*d;
758
+ }
759
+ }
760
+
761
+ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
762
+ assert(k % QK8_1 == 0);
763
+ const int nb = k / QK8_1;
764
+
765
+ block_q8_1 * restrict y = vy;
766
+
767
+ #if defined(__ARM_NEON)
768
+ for (int i = 0; i < nb; i++) {
769
+ float32x4_t srcv [8];
770
+ float32x4_t asrcv[8];
771
+ float32x4_t amaxv[8];
772
+
773
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
774
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
775
+
776
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
777
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
778
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
779
+
780
+ const float amax = vmaxvq_f32(amaxv[0]);
781
+
782
+ const float d = amax / ((1 << 7) - 1);
783
+ const float id = d ? 1.0f/d : 0.0f;
784
+
785
+ y[i].d = d;
786
+
787
+ int32x4_t accv = vdupq_n_s32(0);
788
+
789
+ for (int j = 0; j < 8; j++) {
790
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
791
+ const int32x4_t vi = vcvtnq_s32_f32(v);
792
+
793
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
794
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
795
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
796
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
797
+
798
+ accv = vaddq_s32(accv, vi);
799
+ }
800
+
801
+ y[i].s = d * vaddvq_s32(accv);
802
+ }
803
+ #elif defined(__wasm_simd128__)
804
+ for (int i = 0; i < nb; i++) {
805
+ v128_t srcv [8];
806
+ v128_t asrcv[8];
807
+ v128_t amaxv[8];
808
+
809
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
810
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
811
+
812
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
813
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
814
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
815
+
816
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
817
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
818
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
819
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
820
+
821
+ const float d = amax / ((1 << 7) - 1);
822
+ const float id = d ? 1.0f/d : 0.0f;
823
+
824
+ y[i].d = d;
825
+
826
+ v128_t accv = wasm_i32x4_splat(0);
827
+
828
+ for (int j = 0; j < 8; j++) {
829
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
830
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
831
+
832
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
833
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
834
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
835
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
836
+
837
+ accv = wasm_i32x4_add(accv, vi);
838
+ }
839
+
840
+ y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
841
+ wasm_i32x4_extract_lane(accv, 1) +
842
+ wasm_i32x4_extract_lane(accv, 2) +
843
+ wasm_i32x4_extract_lane(accv, 3));
844
+ }
845
+ #elif defined(__AVX2__) || defined(__AVX__)
846
+ for (int i = 0; i < nb; i++) {
847
+ // Load elements into 4 AVX vectors
848
+ __m256 v0 = _mm256_loadu_ps( x );
849
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
850
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
851
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
852
+ x += 32;
853
+
854
+ // Compute max(abs(e)) for the block
855
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
856
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
857
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
858
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
859
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
860
+
861
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
862
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
863
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
864
+ const float maxScalar = _mm_cvtss_f32( max4 );
865
+
866
+ // Quantize these floats
867
+ const float d = maxScalar / 127.f;
868
+ y[i].d = d;
869
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
870
+ const __m256 mul = _mm256_set1_ps( id );
871
+
872
+ // Apply the multiplier
873
+ v0 = _mm256_mul_ps( v0, mul );
874
+ v1 = _mm256_mul_ps( v1, mul );
875
+ v2 = _mm256_mul_ps( v2, mul );
876
+ v3 = _mm256_mul_ps( v3, mul );
877
+
878
+ // Round to nearest integer
879
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
880
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
881
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
882
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
883
+
884
+ // Convert floats to integers
885
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
886
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
887
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
888
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
889
+
890
+ #if defined(__AVX2__)
891
+ // Compute the sum of the quants and set y[i].s
892
+ y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
893
+
894
+ // Convert int32 to int16
895
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
896
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
897
+ // Convert int16 to int8
898
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
899
+
900
+ // We got our precious signed bytes, but the order is now wrong
901
+ // These AVX2 pack instructions process 16-byte pieces independently
902
+ // The following instruction is fixing the order
903
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
904
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
905
+
906
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
907
+ #else
908
+ // Since we don't have in AVX some necessary functions,
909
+ // we split the registers in half and call AVX2 analogs from SSE
910
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
911
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
912
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
913
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
914
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
915
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
916
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
917
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
918
+
919
+ // Compute the sum of the quants and set y[i].s
920
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
921
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
922
+ y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
923
+
924
+ // Convert int32 to int16
925
+ ni0 = _mm_packs_epi32( ni0, ni1 );
926
+ ni2 = _mm_packs_epi32( ni2, ni3 );
927
+ ni4 = _mm_packs_epi32( ni4, ni5 );
928
+ ni6 = _mm_packs_epi32( ni6, ni7 );
929
+ // Convert int16 to int8
930
+ ni0 = _mm_packs_epi16( ni0, ni2 );
931
+ ni4 = _mm_packs_epi16( ni4, ni6 );
932
+
933
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
934
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
935
+ #endif
936
+ }
937
+ #elif defined(__riscv_v_intrinsic)
938
+
939
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
940
+
941
+ for (int i = 0; i < nb; i++) {
942
+ // load elements
943
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
944
+
945
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
946
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
947
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
948
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
949
+
950
+ const float d = amax / ((1 << 7) - 1);
951
+ const float id = d ? 1.0f/d : 0.0f;
952
+
953
+ y[i].d = d;
954
+
955
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
956
+
957
+ // convert to integer
958
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
959
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
960
+
961
+ // store result
962
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
963
+
964
+ // compute sum for y[i].s
965
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
966
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
967
+
968
+ // set y[i].s
969
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
970
+ y[i].s = sum*d;
971
+ }
972
+ #else
973
+ GGML_UNUSED(nb);
974
+ // scalar
975
+ quantize_row_q8_1_reference(x, y, k);
976
+ #endif
977
+ }
978
+
979
+ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
980
+ static const int qk = QK4_0;
981
+
982
+ assert(k % qk == 0);
983
+
984
+ const int nb = k / qk;
985
+
986
+ for (int i = 0; i < nb; i++) {
987
+ const float d = GGML_FP16_TO_FP32(x[i].d);
988
+
989
+ for (int j = 0; j < qk/2; ++j) {
990
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
991
+ const int x1 = (x[i].qs[j] >> 4) - 8;
992
+
993
+ y[i*qk + j + 0 ] = x0*d;
994
+ y[i*qk + j + qk/2] = x1*d;
995
+ }
996
+ }
997
+ }
998
+
999
+ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
1000
+ static const int qk = QK4_1;
1001
+
1002
+ assert(k % qk == 0);
1003
+
1004
+ const int nb = k / qk;
1005
+
1006
+ for (int i = 0; i < nb; i++) {
1007
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1008
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1009
+
1010
+ for (int j = 0; j < qk/2; ++j) {
1011
+ const int x0 = (x[i].qs[j] & 0x0F);
1012
+ const int x1 = (x[i].qs[j] >> 4);
1013
+
1014
+ y[i*qk + j + 0 ] = x0*d + m;
1015
+ y[i*qk + j + qk/2] = x1*d + m;
1016
+ }
1017
+ }
1018
+ }
1019
+
1020
+ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
1021
+ static const int qk = QK5_0;
1022
+
1023
+ assert(k % qk == 0);
1024
+
1025
+ const int nb = k / qk;
1026
+
1027
+ for (int i = 0; i < nb; i++) {
1028
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1029
+
1030
+ uint32_t qh;
1031
+ memcpy(&qh, x[i].qh, sizeof(qh));
1032
+
1033
+ for (int j = 0; j < qk/2; ++j) {
1034
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1035
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1036
+
1037
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
1038
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
1039
+
1040
+ y[i*qk + j + 0 ] = x0*d;
1041
+ y[i*qk + j + qk/2] = x1*d;
1042
+ }
1043
+ }
1044
+ }
1045
+
1046
+ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
1047
+ static const int qk = QK5_1;
1048
+
1049
+ assert(k % qk == 0);
1050
+
1051
+ const int nb = k / qk;
1052
+
1053
+ for (int i = 0; i < nb; i++) {
1054
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1055
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1056
+
1057
+ uint32_t qh;
1058
+ memcpy(&qh, x[i].qh, sizeof(qh));
1059
+
1060
+ for (int j = 0; j < qk/2; ++j) {
1061
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1062
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1063
+
1064
+ const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
1065
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
1066
+
1067
+ y[i*qk + j + 0 ] = x0*d + m;
1068
+ y[i*qk + j + qk/2] = x1*d + m;
1069
+ }
1070
+ }
1071
+ }
1072
+
1073
+ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
1074
+ static const int qk = QK8_0;
1075
+
1076
+ assert(k % qk == 0);
1077
+
1078
+ const int nb = k / qk;
1079
+
1080
+ for (int i = 0; i < nb; i++) {
1081
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1082
+
1083
+ for (int j = 0; j < qk; ++j) {
1084
+ y[i*qk + j] = x[i].qs[j]*d;
1085
+ }
1086
+ }
1087
+ }
1088
+
1089
+ //
1090
+ // 2-6 bit quantization in super-blocks
1091
+ //
1092
+
1093
+ //
1094
+ // ===================== Helper functions
1095
+ //
1096
+ static inline int nearest_int(float fval) {
1097
+ assert(fval <= 4194303.f);
1098
+ float val = fval + 12582912.f;
1099
+ int i; memcpy(&i, &val, sizeof(int));
1100
+ return (i & 0x007fffff) - 0x00400000;
1101
+ }
1102
+
1103
+ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
1104
+ float max = 0;
1105
+ float amax = 0;
1106
+ for (int i = 0; i < n; ++i) {
1107
+ float ax = fabsf(x[i]);
1108
+ if (ax > amax) { amax = ax; max = x[i]; }
1109
+ }
1110
+ if (amax < 1e-30f) { // all zero
1111
+ for (int i = 0; i < n; ++i) {
1112
+ L[i] = 0;
1113
+ }
1114
+ return 0.f;
1115
+ }
1116
+ float iscale = -nmax / max;
1117
+ if (rmse_type == 0) {
1118
+ for (int i = 0; i < n; ++i) {
1119
+ int l = nearest_int(iscale * x[i]);
1120
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1121
+ }
1122
+ return 1/iscale;
1123
+ }
1124
+ bool return_early = false;
1125
+ if (rmse_type < 0) {
1126
+ rmse_type = -rmse_type;
1127
+ return_early = true;
1128
+ }
1129
+ int weight_type = rmse_type%2;
1130
+ float sumlx = 0;
1131
+ float suml2 = 0;
1132
+ for (int i = 0; i < n; ++i) {
1133
+ int l = nearest_int(iscale * x[i]);
1134
+ l = MAX(-nmax, MIN(nmax-1, l));
1135
+ L[i] = l + nmax;
1136
+ float w = weight_type == 1 ? x[i] * x[i] : 1;
1137
+ sumlx += w*x[i]*l;
1138
+ suml2 += w*l*l;
1139
+ }
1140
+ float scale = sumlx/suml2;
1141
+ if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
1142
+ float best = scale * sumlx;
1143
+ for (int is = -9; is <= 9; ++is) {
1144
+ if (is == 0) {
1145
+ continue;
1146
+ }
1147
+ iscale = -(nmax + 0.1f*is) / max;
1148
+ sumlx = suml2 = 0;
1149
+ for (int i = 0; i < n; ++i) {
1150
+ int l = nearest_int(iscale * x[i]);
1151
+ l = MAX(-nmax, MIN(nmax-1, l));
1152
+ float w = weight_type == 1 ? x[i] * x[i] : 1;
1153
+ sumlx += w*x[i]*l;
1154
+ suml2 += w*l*l;
1155
+ }
1156
+ if (suml2 > 0 && sumlx*sumlx > best*suml2) {
1157
+ for (int i = 0; i < n; ++i) {
1158
+ int l = nearest_int(iscale * x[i]);
1159
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1160
+ }
1161
+ scale = sumlx/suml2; best = scale*sumlx;
1162
+ }
1163
+ }
1164
+ return scale;
1165
+ }
1166
+
1167
+ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
1168
+ float max = 0;
1169
+ float amax = 0;
1170
+ for (int i = 0; i < n; ++i) {
1171
+ float ax = fabsf(x[i]);
1172
+ if (ax > amax) { amax = ax; max = x[i]; }
1173
+ }
1174
+ if (!amax) { // all zero
1175
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
1176
+ return 0.f;
1177
+ }
1178
+ float iscale = -nmax / max;
1179
+ if (do_rmse) {
1180
+ float sumlx = 0;
1181
+ float suml2 = 0;
1182
+ for (int i = 0; i < n; ++i) {
1183
+ int l = nearest_int(iscale * x[i]);
1184
+ l = MAX(-nmax, MIN(nmax-1, l));
1185
+ L[i] = l;
1186
+ float w = x[i]*x[i];
1187
+ sumlx += w*x[i]*l;
1188
+ suml2 += w*l*l;
1189
+ }
1190
+ for (int itry = 0; itry < 5; ++itry) {
1191
+ int n_changed = 0;
1192
+ for (int i = 0; i < n; ++i) {
1193
+ float w = x[i]*x[i];
1194
+ float slx = sumlx - w*x[i]*L[i];
1195
+ if (slx > 0) {
1196
+ float sl2 = suml2 - w*L[i]*L[i];
1197
+ int new_l = nearest_int(x[i] * sl2 / slx);
1198
+ new_l = MAX(-nmax, MIN(nmax-1, new_l));
1199
+ if (new_l != L[i]) {
1200
+ slx += w*x[i]*new_l;
1201
+ sl2 += w*new_l*new_l;
1202
+ if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
1203
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
1204
+ ++n_changed;
1205
+ }
1206
+ }
1207
+ }
1208
+ }
1209
+ if (!n_changed) {
1210
+ break;
1211
+ }
1212
+ }
1213
+ for (int i = 0; i < n; ++i) {
1214
+ L[i] += nmax;
1215
+ }
1216
+ return sumlx / suml2;
1217
+ }
1218
+ for (int i = 0; i < n; ++i) {
1219
+ int l = nearest_int(iscale * x[i]);
1220
+ l = MAX(-nmax, MIN(nmax-1, l));
1221
+ L[i] = l + nmax;
1222
+ }
1223
+ return 1/iscale;
1224
+ }
1225
+
1226
+ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
1227
+ int ntry, float alpha) {
1228
+ float min = x[0];
1229
+ float max = x[0];
1230
+ for (int i = 1; i < n; ++i) {
1231
+ if (x[i] < min) min = x[i];
1232
+ if (x[i] > max) max = x[i];
1233
+ }
1234
+ if (max == min) {
1235
+ for (int i = 0; i < n; ++i) L[i] = 0;
1236
+ *the_min = 0;
1237
+ return 0.f;
1238
+ }
1239
+ if (min > 0) min = 0;
1240
+ float iscale = nmax/(max - min);
1241
+ float scale = 1/iscale;
1242
+ for (int itry = 0; itry < ntry; ++itry) {
1243
+ float sumlx = 0; int suml2 = 0;
1244
+ bool did_change = false;
1245
+ for (int i = 0; i < n; ++i) {
1246
+ int l = nearest_int(iscale*(x[i] - min));
1247
+ l = MAX(0, MIN(nmax, l));
1248
+ if (l != L[i]) {
1249
+ L[i] = l;
1250
+ did_change = true;
1251
+ }
1252
+ sumlx += (x[i] - min)*l;
1253
+ suml2 += l*l;
1254
+ }
1255
+ scale = sumlx/suml2;
1256
+ float sum = 0;
1257
+ for (int i = 0; i < n; ++i) {
1258
+ sum += x[i] - scale*L[i];
1259
+ }
1260
+ min = alpha*min + (1 - alpha)*sum/n;
1261
+ if (min > 0) min = 0;
1262
+ iscale = 1/scale;
1263
+ if (!did_change) break;
1264
+ }
1265
+ *the_min = -min;
1266
+ return scale;
1267
+ }
1268
+
1269
+ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
1270
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
1271
+ float rmin, float rdelta, int nstep, bool use_mad) {
1272
+ float min = x[0];
1273
+ float max = x[0];
1274
+ float sum_w = weights[0];
1275
+ float sum_x = sum_w * x[0];
1276
+ for (int i = 1; i < n; ++i) {
1277
+ if (x[i] < min) min = x[i];
1278
+ if (x[i] > max) max = x[i];
1279
+ float w = weights[i];
1280
+ sum_w += w;
1281
+ sum_x += w * x[i];
1282
+ }
1283
+ if (min > 0) min = 0;
1284
+ if (max == min) {
1285
+ for (int i = 0; i < n; ++i) L[i] = 0;
1286
+ *the_min = -min;
1287
+ return 0.f;
1288
+ }
1289
+ float iscale = nmax/(max - min);
1290
+ float scale = 1/iscale;
1291
+ float best_mad = 0;
1292
+ for (int i = 0; i < n; ++i) {
1293
+ int l = nearest_int(iscale*(x[i] - min));
1294
+ L[i] = MAX(0, MIN(nmax, l));
1295
+ float diff = scale * L[i] + min - x[i];
1296
+ diff = use_mad ? fabsf(diff) : diff * diff;
1297
+ float w = weights[i];
1298
+ best_mad += w * diff;
1299
+ }
1300
+ if (nstep < 1) {
1301
+ *the_min = -min;
1302
+ return scale;
1303
+ }
1304
+ for (int is = 0; is <= nstep; ++is) {
1305
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
1306
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
1307
+ for (int i = 0; i < n; ++i) {
1308
+ int l = nearest_int(iscale*(x[i] - min));
1309
+ l = MAX(0, MIN(nmax, l));
1310
+ Laux[i] = l;
1311
+ float w = weights[i];
1312
+ sum_l += w*l;
1313
+ sum_l2 += w*l*l;
1314
+ sum_xl += w*l*x[i];
1315
+ }
1316
+ float D = sum_w * sum_l2 - sum_l * sum_l;
1317
+ if (D > 0) {
1318
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
1319
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
1320
+ if (this_min > 0) {
1321
+ this_min = 0;
1322
+ this_scale = sum_xl / sum_l2;
1323
+ }
1324
+ float mad = 0;
1325
+ for (int i = 0; i < n; ++i) {
1326
+ float diff = this_scale * Laux[i] + this_min - x[i];
1327
+ diff = use_mad ? fabsf(diff) : diff * diff;
1328
+ float w = weights[i];
1329
+ mad += w * diff;
1330
+ }
1331
+ if (mad < best_mad) {
1332
+ for (int i = 0; i < n; ++i) {
1333
+ L[i] = Laux[i];
1334
+ }
1335
+ best_mad = mad;
1336
+ scale = this_scale;
1337
+ min = this_min;
1338
+ }
1339
+ }
1340
+ }
1341
+ *the_min = -min;
1342
+ return scale;
1343
+ }
1344
+
1345
+ #if QK_K == 256
1346
+ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
1347
+ if (j < 4) {
1348
+ *d = q[j] & 63; *m = q[j + 4] & 63;
1349
+ } else {
1350
+ *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
1351
+ *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
1352
+ }
1353
+ }
1354
+ #endif
1355
+
1356
+ //========================- 2-bit (de)-quantization
1357
+
1358
+ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
1359
+ assert(k % QK_K == 0);
1360
+ const int nb = k / QK_K;
1361
+
1362
+ uint8_t L[QK_K];
1363
+ uint8_t Laux[16];
1364
+ float weights[16];
1365
+ float mins[QK_K/16];
1366
+ float scales[QK_K/16];
1367
+
1368
+ const float q4scale = 15.f;
1369
+
1370
+ for (int i = 0; i < nb; i++) {
1371
+ float max_scale = 0; // as we are deducting the min, scales are always positive
1372
+ float max_min = 0;
1373
+ for (int j = 0; j < QK_K/16; ++j) {
1374
+ for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
1375
+ scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
1376
+ float scale = scales[j];
1377
+ if (scale > max_scale) {
1378
+ max_scale = scale;
1379
+ }
1380
+ float min = mins[j];
1381
+ if (min > max_min) {
1382
+ max_min = min;
1383
+ }
1384
+ }
1385
+
1386
+ if (max_scale > 0) {
1387
+ float iscale = q4scale/max_scale;
1388
+ for (int j = 0; j < QK_K/16; ++j) {
1389
+ int l = nearest_int(iscale*scales[j]);
1390
+ y[i].scales[j] = l;
1391
+ }
1392
+ y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
1393
+ } else {
1394
+ for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
1395
+ y[i].d = GGML_FP32_TO_FP16(0.f);
1396
+ }
1397
+ if (max_min > 0) {
1398
+ float iscale = q4scale/max_min;
1399
+ for (int j = 0; j < QK_K/16; ++j) {
1400
+ int l = nearest_int(iscale*mins[j]);
1401
+ y[i].scales[j] |= (l << 4);
1402
+ }
1403
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
1404
+ } else {
1405
+ y[i].dmin = GGML_FP32_TO_FP16(0.f);
1406
+ }
1407
+ for (int j = 0; j < QK_K/16; ++j) {
1408
+ const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
1409
+ if (!d) continue;
1410
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
1411
+ for (int ii = 0; ii < 16; ++ii) {
1412
+ int l = nearest_int((x[16*j + ii] + dm)/d);
1413
+ l = MAX(0, MIN(3, l));
1414
+ L[16*j + ii] = l;
1415
+ }
1416
+ }
1417
+
1418
+ #if QK_K == 256
1419
+ for (int j = 0; j < QK_K; j += 128) {
1420
+ for (int l = 0; l < 32; ++l) {
1421
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1422
+ }
1423
+ }
1424
+ #else
1425
+ for (int l = 0; l < 16; ++l) {
1426
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
1427
+ }
1428
+ #endif
1429
+
1430
+ x += QK_K;
1431
+
1432
+ }
1433
+ }
1434
+
1435
+ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
1436
+ assert(k % QK_K == 0);
1437
+ const int nb = k / QK_K;
1438
+
1439
+ for (int i = 0; i < nb; i++) {
1440
+
1441
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1442
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
1443
+
1444
+ const uint8_t * q = x[i].qs;
1445
+
1446
+ #if QK_K == 256
1447
+ int is = 0;
1448
+ float dl, ml;
1449
+ for (int n = 0; n < QK_K; n += 128) {
1450
+ int shift = 0;
1451
+ for (int j = 0; j < 4; ++j) {
1452
+
1453
+ uint8_t sc = x[i].scales[is++];
1454
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
1455
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
1456
+
1457
+ sc = x[i].scales[is++];
1458
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
1459
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
1460
+
1461
+ shift += 2;
1462
+ }
1463
+ q += 32;
1464
+ }
1465
+ #else
1466
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
1467
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
1468
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
1469
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
1470
+ for (int l = 0; l < 16; ++l) {
1471
+ y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
1472
+ y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
1473
+ y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
1474
+ y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
1475
+ }
1476
+ y += QK_K;
1477
+ #endif
1478
+ }
1479
+ }
1480
+
1481
+ void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
1482
+ quantize_row_q2_K_reference(x, vy, k);
1483
+ }
1484
+
1485
+ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1486
+ (void)hist; // TODO: collect histograms
1487
+
1488
+ for (int j = 0; j < n; j += k) {
1489
+ block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
1490
+ quantize_row_q2_K_reference(src + j, y, k);
1491
+ }
1492
+ return (n/QK_K*sizeof(block_q2_K));
1493
+ }
1494
+
1495
+ //========================= 3-bit (de)-quantization
1496
+
1497
+ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
1498
+ assert(k % QK_K == 0);
1499
+ const int nb = k / QK_K;
1500
+
1501
+ int8_t L[QK_K];
1502
+ float scales[QK_K / 16];
1503
+
1504
+ for (int i = 0; i < nb; i++) {
1505
+
1506
+ float max_scale = 0;
1507
+ float amax = 0;
1508
+ for (int j = 0; j < QK_K/16; ++j) {
1509
+ scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
1510
+ float scale = fabsf(scales[j]);
1511
+ if (scale > amax) {
1512
+ amax = scale; max_scale = scales[j];
1513
+ }
1514
+ }
1515
+
1516
+ #if QK_K == 256
1517
+ memset(y[i].scales, 0, 12);
1518
+ if (max_scale) {
1519
+ float iscale = -32.f/max_scale;
1520
+ for (int j = 0; j < QK_K/16; ++j) {
1521
+ int8_t l = nearest_int(iscale*scales[j]);
1522
+ l = MAX(-32, MIN(31, l)) + 32;
1523
+ if (j < 8) {
1524
+ y[i].scales[j] = l & 0xF;
1525
+ } else {
1526
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
1527
+ }
1528
+ l >>= 4;
1529
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
1530
+ }
1531
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
1532
+ } else {
1533
+ y[i].d = GGML_FP32_TO_FP16(0.f);
1534
+ }
1535
+
1536
+ int8_t sc;
1537
+ for (int j = 0; j < QK_K/16; ++j) {
1538
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1539
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1540
+ float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1541
+ if (!d) {
1542
+ continue;
1543
+ }
1544
+ for (int ii = 0; ii < 16; ++ii) {
1545
+ int l = nearest_int(x[16*j + ii]/d);
1546
+ l = MAX(-4, MIN(3, l));
1547
+ L[16*j + ii] = l + 4;
1548
+ }
1549
+ }
1550
+ #else
1551
+ if (max_scale) {
1552
+ float iscale = -8.f/max_scale;
1553
+ for (int j = 0; j < QK_K/16; j+=2) {
1554
+ int l1 = nearest_int(iscale*scales[j]);
1555
+ l1 = 8 + MAX(-8, MIN(7, l1));
1556
+ int l2 = nearest_int(iscale*scales[j+1]);
1557
+ l2 = 8 + MAX(-8, MIN(7, l2));
1558
+ y[i].scales[j/2] = l1 | (l2 << 4);
1559
+ }
1560
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
1561
+ } else {
1562
+ for (int j = 0; j < QK_K/16; j+=2) {
1563
+ y[i].scales[j/2] = 0;
1564
+ }
1565
+ y[i].d = GGML_FP32_TO_FP16(0.f);
1566
+ }
1567
+ for (int j = 0; j < QK_K/16; ++j) {
1568
+ int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
1569
+ float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8);
1570
+ if (!d) {
1571
+ continue;
1572
+ }
1573
+ for (int ii = 0; ii < 16; ++ii) {
1574
+ int l = nearest_int(x[16*j + ii]/d);
1575
+ l = MAX(-4, MIN(3, l));
1576
+ L[16*j + ii] = l + 4;
1577
+ }
1578
+ }
1579
+ #endif
1580
+
1581
+ memset(y[i].hmask, 0, QK_K/8);
1582
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
1583
+ int m = 0;
1584
+ uint8_t hm = 1;
1585
+ for (int j = 0; j < QK_K; ++j) {
1586
+ if (L[j] > 3) {
1587
+ y[i].hmask[m] |= hm;
1588
+ L[j] -= 4;
1589
+ }
1590
+ if (++m == QK_K/8) {
1591
+ m = 0; hm <<= 1;
1592
+ }
1593
+ }
1594
+ #if QK_K == 256
1595
+ for (int j = 0; j < QK_K; j += 128) {
1596
+ for (int l = 0; l < 32; ++l) {
1597
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1598
+ }
1599
+ }
1600
+ #else
1601
+ for (int l = 0; l < 16; ++l) {
1602
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
1603
+ }
1604
+ #endif
1605
+
1606
+ x += QK_K;
1607
+ }
1608
+ }
1609
+
1610
+ #if QK_K == 256
1611
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
1612
+ assert(k % QK_K == 0);
1613
+ const int nb = k / QK_K;
1614
+
1615
+ const uint32_t kmask1 = 0x03030303;
1616
+ const uint32_t kmask2 = 0x0f0f0f0f;
1617
+
1618
+ uint32_t aux[4];
1619
+ const int8_t * scales = (const int8_t*)aux;
1620
+
1621
+ for (int i = 0; i < nb; i++) {
1622
+
1623
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
1624
+
1625
+ const uint8_t * restrict q = x[i].qs;
1626
+ const uint8_t * restrict hm = x[i].hmask;
1627
+ uint8_t m = 1;
1628
+
1629
+ memcpy(aux, x[i].scales, 12);
1630
+ uint32_t tmp = aux[2];
1631
+ aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1632
+ aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1633
+ aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1634
+ aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1635
+
1636
+ int is = 0;
1637
+ float dl;
1638
+ for (int n = 0; n < QK_K; n += 128) {
1639
+ int shift = 0;
1640
+ for (int j = 0; j < 4; ++j) {
1641
+
1642
+ dl = d_all * (scales[is++] - 32);
1643
+ for (int l = 0; l < 16; ++l) {
1644
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
1645
+ }
1646
+
1647
+ dl = d_all * (scales[is++] - 32);
1648
+ for (int l = 0; l < 16; ++l) {
1649
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
1650
+ }
1651
+
1652
+ shift += 2;
1653
+ m <<= 1;
1654
+ }
1655
+ q += 32;
1656
+ }
1657
+
1658
+ }
1659
+ }
1660
+ #else
1661
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
1662
+ assert(k % QK_K == 0);
1663
+ assert(QK_K == 64);
1664
+ const int nb = k / QK_K;
1665
+
1666
+ for (int i = 0; i < nb; i++) {
1667
+
1668
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
1669
+
1670
+ const uint8_t * restrict q = x[i].qs;
1671
+ const uint8_t * restrict hm = x[i].hmask;
1672
+
1673
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1674
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1675
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1676
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1677
+
1678
+ for (int l=0; l<8; ++l) {
1679
+ uint8_t h = hm[l];
1680
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
1681
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
1682
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
1683
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
1684
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
1685
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
1686
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
1687
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
1688
+ }
1689
+ y += QK_K;
1690
+ }
1691
+ }
1692
+ #endif
1693
+
1694
+ void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
1695
+ quantize_row_q3_K_reference(x, vy, k);
1696
+ }
1697
+
1698
+ size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1699
+ (void)hist; // TODO: collect histograms
1700
+
1701
+ for (int j = 0; j < n; j += k) {
1702
+ block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
1703
+ quantize_row_q3_K_reference(src + j, y, k);
1704
+ }
1705
+ return (n/QK_K*sizeof(block_q3_K));
1706
+ }
1707
+
1708
+ // ====================== 4-bit (de)-quantization
1709
+
1710
+ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
1711
+ assert(k % QK_K == 0);
1712
+ const int nb = k / QK_K;
1713
+
1714
+ uint8_t L[QK_K];
1715
+ uint8_t Laux[32];
1716
+ float weights[32];
1717
+ float mins[QK_K/32];
1718
+ float scales[QK_K/32];
1719
+
1720
+ for (int i = 0; i < nb; i++) {
1721
+
1722
+ float max_scale = 0; // as we are deducting the min, scales are always positive
1723
+ float max_min = 0;
1724
+ for (int j = 0; j < QK_K/32; ++j) {
1725
+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
1726
+ float sum_x2 = 0;
1727
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
1728
+ float av_x = sqrtf(sum_x2/32);
1729
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1730
+ scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
1731
+ float scale = scales[j];
1732
+ if (scale > max_scale) {
1733
+ max_scale = scale;
1734
+ }
1735
+ float min = mins[j];
1736
+ if (min > max_min) {
1737
+ max_min = min;
1738
+ }
1739
+ }
1740
+
1741
+ #if QK_K == 256
1742
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
1743
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
1744
+ for (int j = 0; j < QK_K/32; ++j) {
1745
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
1746
+ uint8_t lm = nearest_int(inv_min*mins[j]);
1747
+ ls = MIN(63, ls);
1748
+ lm = MIN(63, lm);
1749
+ if (j < 4) {
1750
+ y[i].scales[j] = ls;
1751
+ y[i].scales[j+4] = lm;
1752
+ } else {
1753
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1754
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
1755
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
1756
+ }
1757
+ }
1758
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
1759
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
1760
+
1761
+ uint8_t sc, m;
1762
+ for (int j = 0; j < QK_K/32; ++j) {
1763
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
1764
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1765
+ if (!d) continue;
1766
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1767
+ for (int ii = 0; ii < 32; ++ii) {
1768
+ int l = nearest_int((x[32*j + ii] + dm)/d);
1769
+ l = MAX(0, MIN(15, l));
1770
+ L[32*j + ii] = l;
1771
+ }
1772
+ }
1773
+ #else
1774
+ const float s_factor = 15.f;
1775
+ float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
1776
+ float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
1777
+ int d1 = nearest_int(inv_scale*scales[0]);
1778
+ int m1 = nearest_int(inv_min*mins[0]);
1779
+ int d2 = nearest_int(inv_scale*scales[1]);
1780
+ int m2 = nearest_int(inv_min*mins[1]);
1781
+ y[i].scales[0] = d1 | (m1 << 4);
1782
+ y[i].scales[1] = d2 | (m2 << 4);
1783
+ y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor);
1784
+ y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor);
1785
+
1786
+ float sumlx = 0;
1787
+ int suml2 = 0;
1788
+ for (int j = 0; j < QK_K/32; ++j) {
1789
+ const uint8_t sd = y[i].scales[j] & 0xF;
1790
+ const uint8_t sm = y[i].scales[j] >> 4;
1791
+ const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd;
1792
+ if (!d) continue;
1793
+ const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm;
1794
+ for (int ii = 0; ii < 32; ++ii) {
1795
+ int l = nearest_int((x[32*j + ii] + m)/d);
1796
+ l = MAX(0, MIN(15, l));
1797
+ L[32*j + ii] = l;
1798
+ sumlx += (x[32*j + ii] + m)*l*sd;
1799
+ suml2 += l*l*sd*sd;
1800
+ }
1801
+ }
1802
+ if (suml2) {
1803
+ y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2);
1804
+ }
1805
+ #endif
1806
+ uint8_t * q = y[i].qs;
1807
+ for (int j = 0; j < QK_K; j += 64) {
1808
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
1809
+ q += 32;
1810
+ }
1811
+
1812
+ x += QK_K;
1813
+
1814
+ }
1815
+ }
1816
+
1817
+ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
1818
+ assert(k % QK_K == 0);
1819
+ const int nb = k / QK_K;
1820
+
1821
+ for (int i = 0; i < nb; i++) {
1822
+
1823
+ const uint8_t * q = x[i].qs;
1824
+
1825
+ #if QK_K == 256
1826
+
1827
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1828
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
1829
+
1830
+ int is = 0;
1831
+ uint8_t sc, m;
1832
+ for (int j = 0; j < QK_K; j += 64) {
1833
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
1834
+ const float d1 = d * sc; const float m1 = min * m;
1835
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
1836
+ const float d2 = d * sc; const float m2 = min * m;
1837
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
1838
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
1839
+ q += 32; is += 2;
1840
+ }
1841
+ #else
1842
+ const float dall = GGML_FP16_TO_FP32(x[i].d[0]);
1843
+ const float mall = GGML_FP16_TO_FP32(x[i].d[1]);
1844
+ const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
1845
+ const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
1846
+ for (int l = 0; l < 32; ++l) {
1847
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1848
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1849
+ }
1850
+ y += QK_K;
1851
+ #endif
1852
+
1853
+ }
1854
+ }
1855
+
1856
+ void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
1857
+ assert(k % QK_K == 0);
1858
+ block_q4_K * restrict y = vy;
1859
+ quantize_row_q4_K_reference(x, y, k);
1860
+ }
1861
+
1862
+ size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1863
+ assert(k % QK_K == 0);
1864
+ (void)hist; // TODO: collect histograms
1865
+
1866
+ for (int j = 0; j < n; j += k) {
1867
+ block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
1868
+ quantize_row_q4_K_reference(src + j, y, k);
1869
+ }
1870
+ return (n/QK_K*sizeof(block_q4_K));
1871
+ }
1872
+
1873
+ // ====================== 5-bit (de)-quantization
1874
+
1875
+ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
1876
+ assert(k % QK_K == 0);
1877
+ const int nb = k / QK_K;
1878
+
1879
+ #if QK_K == 256
1880
+ uint8_t L[QK_K];
1881
+ float mins[QK_K/32];
1882
+ float scales[QK_K/32];
1883
+ float weights[32];
1884
+ uint8_t Laux[32];
1885
+ #else
1886
+ int8_t L[QK_K];
1887
+ float scales[QK_K/16];
1888
+ #endif
1889
+
1890
+ for (int i = 0; i < nb; i++) {
1891
+
1892
+ #if QK_K == 256
1893
+
1894
+ float max_scale = 0; // as we are deducting the min, scales are always positive
1895
+ float max_min = 0;
1896
+ for (int j = 0; j < QK_K/32; ++j) {
1897
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
1898
+ float sum_x2 = 0;
1899
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
1900
+ float av_x = sqrtf(sum_x2/32);
1901
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1902
+ scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
1903
+ float scale = scales[j];
1904
+ if (scale > max_scale) {
1905
+ max_scale = scale;
1906
+ }
1907
+ float min = mins[j];
1908
+ if (min > max_min) {
1909
+ max_min = min;
1910
+ }
1911
+ }
1912
+
1913
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
1914
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
1915
+ for (int j = 0; j < QK_K/32; ++j) {
1916
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
1917
+ uint8_t lm = nearest_int(inv_min*mins[j]);
1918
+ ls = MIN(63, ls);
1919
+ lm = MIN(63, lm);
1920
+ if (j < 4) {
1921
+ y[i].scales[j] = ls;
1922
+ y[i].scales[j+4] = lm;
1923
+ } else {
1924
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1925
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
1926
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
1927
+ }
1928
+ }
1929
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
1930
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
1931
+
1932
+ uint8_t sc, m;
1933
+ for (int j = 0; j < QK_K/32; ++j) {
1934
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
1935
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1936
+ if (!d) continue;
1937
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1938
+ for (int ii = 0; ii < 32; ++ii) {
1939
+ int l = nearest_int((x[32*j + ii] + dm)/d);
1940
+ l = MAX(0, MIN(31, l));
1941
+ L[32*j + ii] = l;
1942
+ }
1943
+ }
1944
+
1945
+ uint8_t * restrict qh = y[i].qh;
1946
+ uint8_t * restrict ql = y[i].qs;
1947
+ memset(qh, 0, QK_K/8);
1948
+
1949
+ uint8_t m1 = 1, m2 = 2;
1950
+ for (int n = 0; n < QK_K; n += 64) {
1951
+ for (int j = 0; j < 32; ++j) {
1952
+ int l1 = L[n + j];
1953
+ if (l1 > 15) {
1954
+ l1 -= 16; qh[j] |= m1;
1955
+ }
1956
+ int l2 = L[n + j + 32];
1957
+ if (l2 > 15) {
1958
+ l2 -= 16; qh[j] |= m2;
1959
+ }
1960
+ ql[j] = l1 | (l2 << 4);
1961
+ }
1962
+ m1 <<= 2; m2 <<= 2;
1963
+ ql += 32;
1964
+ }
1965
+ #else
1966
+ float max_scale = 0, amax = 0;
1967
+ for (int j = 0; j < QK_K/16; ++j) {
1968
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
1969
+ float abs_scale = fabsf(scales[j]);
1970
+ if (abs_scale > amax) {
1971
+ amax = abs_scale;
1972
+ max_scale = scales[j];
1973
+ }
1974
+ }
1975
+
1976
+ float iscale = -128.f/max_scale;
1977
+ for (int j = 0; j < QK_K/16; ++j) {
1978
+ int l = nearest_int(iscale*scales[j]);
1979
+ y[i].scales[j] = MAX(-128, MIN(127, l));
1980
+ }
1981
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
1982
+
1983
+ for (int j = 0; j < QK_K/16; ++j) {
1984
+ const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
1985
+ if (!d) continue;
1986
+ for (int ii = 0; ii < 16; ++ii) {
1987
+ int l = nearest_int(x[16*j + ii]/d);
1988
+ l = MAX(-16, MIN(15, l));
1989
+ L[16*j + ii] = l + 16;
1990
+ }
1991
+ }
1992
+
1993
+ uint8_t * restrict qh = y[i].qh;
1994
+ uint8_t * restrict ql = y[i].qs;
1995
+ memset(qh, 0, QK_K/8);
1996
+
1997
+ for (int j = 0; j < 32; ++j) {
1998
+ int jm = j%8;
1999
+ int is = j/8;
2000
+ int l1 = L[j];
2001
+ if (l1 > 15) {
2002
+ l1 -= 16; qh[jm] |= (1 << is);
2003
+ }
2004
+ int l2 = L[j + 32];
2005
+ if (l2 > 15) {
2006
+ l2 -= 16; qh[jm] |= (1 << (4 + is));
2007
+ }
2008
+ ql[j] = l1 | (l2 << 4);
2009
+ }
2010
+ #endif
2011
+
2012
+ x += QK_K;
2013
+
2014
+ }
2015
+ }
2016
+
2017
+ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
2018
+ assert(k % QK_K == 0);
2019
+ const int nb = k / QK_K;
2020
+
2021
+ for (int i = 0; i < nb; i++) {
2022
+
2023
+ const uint8_t * ql = x[i].qs;
2024
+ const uint8_t * qh = x[i].qh;
2025
+
2026
+ #if QK_K == 256
2027
+
2028
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2029
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
2030
+
2031
+ int is = 0;
2032
+ uint8_t sc, m;
2033
+ uint8_t u1 = 1, u2 = 2;
2034
+ for (int j = 0; j < QK_K; j += 64) {
2035
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
2036
+ const float d1 = d * sc; const float m1 = min * m;
2037
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
2038
+ const float d2 = d * sc; const float m2 = min * m;
2039
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
2040
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
2041
+ ql += 32; is += 2;
2042
+ u1 <<= 2; u2 <<= 2;
2043
+ }
2044
+ #else
2045
+ float d = GGML_FP16_TO_FP32(x[i].d);
2046
+ const int8_t * restrict s = x[i].scales;
2047
+ for (int l = 0; l < 8; ++l) {
2048
+ y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
2049
+ y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
2050
+ y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
2051
+ y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
2052
+ y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
2053
+ y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
2054
+ y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
2055
+ y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
2056
+ }
2057
+ y += QK_K;
2058
+ #endif
2059
+ }
2060
+ }
2061
+
2062
+ void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
2063
+ assert(k % QK_K == 0);
2064
+ block_q5_K * restrict y = vy;
2065
+ quantize_row_q5_K_reference(x, y, k);
2066
+ }
2067
+
2068
+ size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
2069
+ assert(k % QK_K == 0);
2070
+ (void)hist; // TODO: collect histograms
2071
+
2072
+ for (int j = 0; j < n; j += k) {
2073
+ block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
2074
+ quantize_row_q5_K_reference(src + j, y, k);
2075
+ }
2076
+ return (n/QK_K*sizeof(block_q5_K));
2077
+ }
2078
+
2079
+ // ====================== 6-bit (de)-quantization
2080
+
2081
+ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
2082
+ assert(k % QK_K == 0);
2083
+ const int nb = k / QK_K;
2084
+
2085
+ int8_t L[QK_K];
2086
+ float scales[QK_K/16];
2087
+
2088
+ for (int i = 0; i < nb; i++) {
2089
+
2090
+ float max_scale = 0;
2091
+ float max_abs_scale = 0;
2092
+
2093
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2094
+
2095
+ const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
2096
+ scales[ib] = scale;
2097
+
2098
+ const float abs_scale = fabsf(scale);
2099
+ if (abs_scale > max_abs_scale) {
2100
+ max_abs_scale = abs_scale;
2101
+ max_scale = scale;
2102
+ }
2103
+
2104
+ }
2105
+
2106
+ if (!max_abs_scale) {
2107
+ memset(&y[i], 0, sizeof(block_q6_K));
2108
+ y[i].d = GGML_FP32_TO_FP16(0.f);
2109
+ x += QK_K;
2110
+ continue;
2111
+ }
2112
+
2113
+ float iscale = -128.f/max_scale;
2114
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
2115
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2116
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
2117
+ }
2118
+
2119
+ for (int j = 0; j < QK_K/16; ++j) {
2120
+ float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
2121
+ if (!d) {
2122
+ continue;
187
2123
  }
188
- if (!n_changed) {
189
- break;
2124
+ for (int ii = 0; ii < 16; ++ii) {
2125
+ int l = nearest_int(x[16*j + ii]/d);
2126
+ l = MAX(-32, MIN(31, l));
2127
+ L[16*j + ii] = l + 32;
190
2128
  }
191
2129
  }
192
- for (int i = 0; i < n; ++i) {
193
- L[i] += nmax;
194
- }
195
- return sumlx / suml2;
196
- }
197
- for (int i = 0; i < n; ++i) {
198
- int l = nearest_int(iscale * x[i]);
199
- l = MAX(-nmax, MIN(nmax-1, l));
200
- L[i] = l + nmax;
201
- }
202
- return 1/iscale;
203
- }
204
2130
 
205
- static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
206
- int ntry, float alpha) {
207
- float min = x[0];
208
- float max = x[0];
209
- for (int i = 1; i < n; ++i) {
210
- if (x[i] < min) min = x[i];
211
- if (x[i] > max) max = x[i];
212
- }
213
- if (max == min) {
214
- for (int i = 0; i < n; ++i) L[i] = 0;
215
- *the_min = 0;
216
- return 0.f;
217
- }
218
- if (min > 0) min = 0;
219
- float iscale = nmax/(max - min);
220
- float scale = 1/iscale;
221
- for (int itry = 0; itry < ntry; ++itry) {
222
- float sumlx = 0; int suml2 = 0;
223
- bool did_change = false;
224
- for (int i = 0; i < n; ++i) {
225
- int l = nearest_int(iscale*(x[i] - min));
226
- l = MAX(0, MIN(nmax, l));
227
- if (l != L[i]) {
228
- L[i] = l;
229
- did_change = true;
2131
+ uint8_t * restrict ql = y[i].ql;
2132
+ uint8_t * restrict qh = y[i].qh;
2133
+ #if QK_K == 256
2134
+ for (int j = 0; j < QK_K; j += 128) {
2135
+ for (int l = 0; l < 32; ++l) {
2136
+ const uint8_t q1 = L[j + l + 0] & 0xF;
2137
+ const uint8_t q2 = L[j + l + 32] & 0xF;
2138
+ const uint8_t q3 = L[j + l + 64] & 0xF;
2139
+ const uint8_t q4 = L[j + l + 96] & 0xF;
2140
+ ql[l+ 0] = q1 | (q3 << 4);
2141
+ ql[l+32] = q2 | (q4 << 4);
2142
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
230
2143
  }
231
- sumlx += (x[i] - min)*l;
232
- suml2 += l*l;
2144
+ ql += 64;
2145
+ qh += 32;
233
2146
  }
234
- scale = sumlx/suml2;
235
- float sum = 0;
236
- for (int i = 0; i < n; ++i) {
237
- sum += x[i] - scale*L[i];
2147
+ #else
2148
+ for (int l = 0; l < 32; ++l) {
2149
+ const uint8_t q1 = L[l + 0] & 0xF;
2150
+ const uint8_t q2 = L[l + 32] & 0xF;
2151
+ ql[l] = q1 | (q2 << 4);
238
2152
  }
239
- min = alpha*min + (1 - alpha)*sum/n;
240
- if (min > 0) min = 0;
241
- iscale = 1/scale;
242
- if (!did_change) break;
2153
+ for (int l = 0; l < 16; ++l) {
2154
+ qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
2155
+ }
2156
+ #endif
2157
+
2158
+ x += QK_K;
2159
+
243
2160
  }
244
- *the_min = -min;
245
- return scale;
246
2161
  }
247
2162
 
248
- static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
249
- uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
250
- float rmin, float rdelta, int nstep, bool use_mad) {
251
- float min = x[0];
252
- float max = x[0];
253
- float sum_w = weights[0];
254
- float sum_x = sum_w * x[0];
255
- for (int i = 1; i < n; ++i) {
256
- if (x[i] < min) min = x[i];
257
- if (x[i] > max) max = x[i];
258
- float w = weights[i];
259
- sum_w += w;
260
- sum_x += w * x[i];
261
- }
262
- if (min > 0) min = 0;
263
- if (max == min) {
264
- for (int i = 0; i < n; ++i) L[i] = 0;
265
- *the_min = -min;
266
- return 0.f;
267
- }
268
- float iscale = nmax/(max - min);
269
- float scale = 1/iscale;
270
- float best_mad = 0;
271
- for (int i = 0; i < n; ++i) {
272
- int l = nearest_int(iscale*(x[i] - min));
273
- L[i] = MAX(0, MIN(nmax, l));
274
- float diff = scale * L[i] + min - x[i];
275
- diff = use_mad ? fabsf(diff) : diff * diff;
276
- float w = weights[i];
277
- best_mad += w * diff;
278
- }
279
- if (nstep < 1) {
280
- *the_min = -min;
281
- return scale;
282
- }
283
- for (int is = 0; is <= nstep; ++is) {
284
- iscale = (rmin + rdelta*is + nmax)/(max - min);
285
- float sum_l = 0, sum_l2 = 0, sum_xl = 0;
286
- for (int i = 0; i < n; ++i) {
287
- int l = nearest_int(iscale*(x[i] - min));
288
- l = MAX(0, MIN(nmax, l));
289
- Laux[i] = l;
290
- float w = weights[i];
291
- sum_l += w*l;
292
- sum_l2 += w*l*l;
293
- sum_xl += w*l*x[i];
294
- }
295
- float D = sum_w * sum_l2 - sum_l * sum_l;
296
- if (D > 0) {
297
- float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
298
- float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
299
- if (this_min > 0) {
300
- this_min = 0;
301
- this_scale = sum_xl / sum_l2;
302
- }
303
- float mad = 0;
304
- for (int i = 0; i < n; ++i) {
305
- float diff = this_scale * Laux[i] + this_min - x[i];
306
- diff = use_mad ? fabsf(diff) : diff * diff;
307
- float w = weights[i];
308
- mad += w * diff;
309
- }
310
- if (mad < best_mad) {
311
- for (int i = 0; i < n; ++i) {
312
- L[i] = Laux[i];
313
- }
314
- best_mad = mad;
315
- scale = this_scale;
316
- min = this_min;
2163
+ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
2164
+ assert(k % QK_K == 0);
2165
+ const int nb = k / QK_K;
2166
+
2167
+ for (int i = 0; i < nb; i++) {
2168
+
2169
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2170
+
2171
+ const uint8_t * restrict ql = x[i].ql;
2172
+ const uint8_t * restrict qh = x[i].qh;
2173
+ const int8_t * restrict sc = x[i].scales;
2174
+
2175
+ #if QK_K == 256
2176
+ for (int n = 0; n < QK_K; n += 128) {
2177
+ for (int l = 0; l < 32; ++l) {
2178
+ int is = l/16;
2179
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2180
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2181
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2182
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2183
+ y[l + 0] = d * sc[is + 0] * q1;
2184
+ y[l + 32] = d * sc[is + 2] * q2;
2185
+ y[l + 64] = d * sc[is + 4] * q3;
2186
+ y[l + 96] = d * sc[is + 6] * q4;
317
2187
  }
2188
+ y += 128;
2189
+ ql += 64;
2190
+ qh += 32;
2191
+ sc += 8;
2192
+ }
2193
+ #else
2194
+ for (int l = 0; l < 16; ++l) {
2195
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2196
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2197
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2198
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2199
+ y[l+ 0] = d * sc[0] * q1;
2200
+ y[l+16] = d * sc[1] * q2;
2201
+ y[l+32] = d * sc[2] * q3;
2202
+ y[l+48] = d * sc[3] * q4;
318
2203
  }
2204
+ y += 64;
2205
+ #endif
2206
+
319
2207
  }
320
- *the_min = -min;
321
- return scale;
322
2208
  }
323
2209
 
324
- #if QK_K == 256
325
- static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
326
- if (j < 4) {
327
- *d = q[j] & 63; *m = q[j + 4] & 63;
328
- } else {
329
- *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
330
- *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2210
+ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
2211
+ assert(k % QK_K == 0);
2212
+ block_q6_K * restrict y = vy;
2213
+ quantize_row_q6_K_reference(x, y, k);
2214
+ }
2215
+
2216
+ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
2217
+ assert(k % QK_K == 0);
2218
+ (void)hist; // TODO: collect histograms
2219
+
2220
+ for (int j = 0; j < n; j += k) {
2221
+ block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
2222
+ quantize_row_q6_K_reference(src + j, y, k);
331
2223
  }
2224
+ return (n/QK_K*sizeof(block_q6_K));
332
2225
  }
333
- #endif
334
2226
 
335
- //========================- 2-bit (de)-quantization
2227
+ //===================================== Q8_K ==============================================
336
2228
 
337
- void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
2229
+ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
338
2230
  assert(k % QK_K == 0);
339
2231
  const int nb = k / QK_K;
340
2232
 
341
- uint8_t L[QK_K];
342
- uint8_t Laux[16];
343
- float weights[16];
344
- float mins[QK_K/16];
345
- float scales[QK_K/16];
346
-
347
- const float q4scale = 15.f;
348
-
349
2233
  for (int i = 0; i < nb; i++) {
350
- float max_scale = 0; // as we are deducting the min, scales are always positive
351
- float max_min = 0;
352
- for (int j = 0; j < QK_K/16; ++j) {
353
- for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
354
- scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
355
- float scale = scales[j];
356
- if (scale > max_scale) {
357
- max_scale = scale;
358
- }
359
- float min = mins[j];
360
- if (min > max_min) {
361
- max_min = min;
362
- }
363
- }
364
2234
 
365
- if (max_scale > 0) {
366
- float iscale = q4scale/max_scale;
367
- for (int j = 0; j < QK_K/16; ++j) {
368
- int l = nearest_int(iscale*scales[j]);
369
- y[i].scales[j] = l;
370
- }
371
- y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
372
- } else {
373
- for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
374
- y[i].d = ggml_fp32_to_fp16(0.f);
375
- }
376
- if (max_min > 0) {
377
- float iscale = q4scale/max_min;
378
- for (int j = 0; j < QK_K/16; ++j) {
379
- int l = nearest_int(iscale*mins[j]);
380
- y[i].scales[j] |= (l << 4);
2235
+ float max = 0;
2236
+ float amax = 0;
2237
+ for (int j = 0; j < QK_K; ++j) {
2238
+ float ax = fabsf(x[j]);
2239
+ if (ax > amax) {
2240
+ amax = ax; max = x[j];
381
2241
  }
382
- y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
383
- } else {
384
- y[i].dmin = ggml_fp32_to_fp16(0.f);
2242
+ }
2243
+ if (!amax) {
2244
+ y[i].d = 0;
2245
+ memset(y[i].qs, 0, QK_K);
2246
+ x += QK_K;
2247
+ continue;
2248
+ }
2249
+ const float iscale = -128.f/max;
2250
+ for (int j = 0; j < QK_K; ++j) {
2251
+ int v = nearest_int(iscale*x[j]);
2252
+ y[i].qs[j] = MIN(127, v);
385
2253
  }
386
2254
  for (int j = 0; j < QK_K/16; ++j) {
387
- const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
388
- if (!d) continue;
389
- const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
2255
+ int sum = 0;
390
2256
  for (int ii = 0; ii < 16; ++ii) {
391
- int l = nearest_int((x[16*j + ii] + dm)/d);
392
- l = MAX(0, MIN(3, l));
393
- L[16*j + ii] = l;
394
- }
395
- }
396
-
397
- #if QK_K == 256
398
- for (int j = 0; j < QK_K; j += 128) {
399
- for (int l = 0; l < 32; ++l) {
400
- y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
2257
+ sum += y[i].qs[j*16 + ii];
401
2258
  }
2259
+ y[i].bsums[j] = sum;
402
2260
  }
403
- #else
404
- for (int l = 0; l < 16; ++l) {
405
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
406
- }
407
- #endif
408
-
2261
+ y[i].d = 1/iscale;
409
2262
  x += QK_K;
410
-
411
2263
  }
412
2264
  }
413
2265
 
414
- void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
2266
+ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
415
2267
  assert(k % QK_K == 0);
416
2268
  const int nb = k / QK_K;
417
2269
 
418
2270
  for (int i = 0; i < nb; i++) {
2271
+ for (int j = 0; j < QK_K; ++j) {
2272
+ *y++ = x[i].d * x[i].qs[j];
2273
+ }
2274
+ }
2275
+ }
419
2276
 
420
- const float d = ggml_fp16_to_fp32(x[i].d);
421
- const float min = ggml_fp16_to_fp32(x[i].dmin);
2277
+ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
2278
+ quantize_row_q8_K_reference(x, y, k);
2279
+ }
422
2280
 
423
- const uint8_t * q = x[i].qs;
2281
+ //===================================== Dot ptoducts =================================
424
2282
 
425
- #if QK_K == 256
426
- int is = 0;
427
- float dl, ml;
428
- for (int n = 0; n < QK_K; n += 128) {
429
- int shift = 0;
430
- for (int j = 0; j < 4; ++j) {
2283
+ //
2284
+ // Helper functions
2285
+ //
2286
+ #if __AVX__ || __AVX2__ || __AVX512F__
431
2287
 
432
- uint8_t sc = x[i].scales[is++];
433
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
434
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
2288
+ // shuffles to pick the required scales in dot products
2289
+ static inline __m256i get_scale_shuffle_q3k(int i) {
2290
+ static const uint8_t k_shuffle[128] = {
2291
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2292
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
2293
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
2294
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
2295
+ };
2296
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
2297
+ }
2298
+ static inline __m256i get_scale_shuffle_k4(int i) {
2299
+ static const uint8_t k_shuffle[256] = {
2300
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
2301
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2302
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
2303
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
2304
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
2305
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
2306
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
2307
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
2308
+ };
2309
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
2310
+ }
2311
+ static inline __m128i get_scale_shuffle(int i) {
2312
+ static const uint8_t k_shuffle[128] = {
2313
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
2314
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
2315
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
2316
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
2317
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
2318
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
2319
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
2320
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
2321
+ };
2322
+ return _mm_loadu_si128((const __m128i*)k_shuffle + i);
2323
+ }
2324
+ #endif
435
2325
 
436
- sc = x[i].scales[is++];
437
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
438
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
2326
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2327
+ const int qk = QK8_0;
2328
+ const int nb = n / qk;
439
2329
 
440
- shift += 2;
441
- }
442
- q += 32;
443
- }
2330
+ assert(n % qk == 0);
2331
+
2332
+ const block_q4_0 * restrict x = vx;
2333
+ const block_q8_0 * restrict y = vy;
2334
+
2335
+ #if defined(__ARM_NEON)
2336
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2337
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2338
+
2339
+ assert(nb % 2 == 0); // TODO: handle odd nb
2340
+
2341
+ for (int i = 0; i < nb; i += 2) {
2342
+ const block_q4_0 * restrict x0 = &x[i + 0];
2343
+ const block_q4_0 * restrict x1 = &x[i + 1];
2344
+ const block_q8_0 * restrict y0 = &y[i + 0];
2345
+ const block_q8_0 * restrict y1 = &y[i + 1];
2346
+
2347
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2348
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2349
+
2350
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2351
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2352
+
2353
+ // 4-bit -> 8-bit
2354
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2355
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2356
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2357
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2358
+
2359
+ // sub 8
2360
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2361
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2362
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2363
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2364
+
2365
+ // load y
2366
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2367
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2368
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2369
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2370
+
2371
+ #if defined(__ARM_FEATURE_DOTPROD)
2372
+ // dot product into int32x4_t
2373
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2374
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2375
+
2376
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2377
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
444
2378
  #else
445
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
446
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
447
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
448
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
449
- for (int l = 0; l < 16; ++l) {
450
- y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
451
- y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
452
- y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
453
- y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
454
- }
455
- y += QK_K;
2379
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2380
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2381
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2382
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2383
+
2384
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2385
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2386
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2387
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2388
+
2389
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2390
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2391
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2392
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2393
+
2394
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2395
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
456
2396
  #endif
457
2397
  }
458
- }
459
2398
 
460
- void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
461
- quantize_row_q2_K_reference(x, vy, k);
462
- }
2399
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2400
+ #elif defined(__AVX2__)
2401
+ // Initialize accumulator with zeros
2402
+ __m256 acc = _mm256_setzero_ps();
463
2403
 
464
- size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
465
- (void)hist; // TODO: collect histograms
2404
+ // Main loop
2405
+ for (int i = 0; i < nb; ++i) {
2406
+ /* Compute combined scale for the block */
2407
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
466
2408
 
467
- for (int j = 0; j < n; j += k) {
468
- block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
469
- quantize_row_q2_K_reference(src + j, y, k);
2409
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2410
+
2411
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2412
+ const __m256i off = _mm256_set1_epi8( 8 );
2413
+ bx = _mm256_sub_epi8( bx, off );
2414
+
2415
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2416
+
2417
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2418
+
2419
+ /* Multiply q with scale and accumulate */
2420
+ acc = _mm256_fmadd_ps( d, q, acc );
470
2421
  }
471
- return (n/QK_K*sizeof(block_q2_K));
472
- }
473
2422
 
474
- //========================= 3-bit (de)-quantization
2423
+ *s = hsum_float_8(acc);
2424
+ #elif defined(__AVX__)
2425
+ // Initialize accumulator with zeros
2426
+ __m256 acc = _mm256_setzero_ps();
475
2427
 
476
- void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
477
- assert(k % QK_K == 0);
478
- const int nb = k / QK_K;
2428
+ // Main loop
2429
+ for (int i = 0; i < nb; ++i) {
2430
+ // Compute combined scale for the block
2431
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
479
2432
 
480
- int8_t L[QK_K];
481
- float scales[QK_K / 16];
2433
+ const __m128i lowMask = _mm_set1_epi8(0xF);
2434
+ const __m128i off = _mm_set1_epi8(8);
2435
+
2436
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
2437
+
2438
+ __m128i bx = _mm_and_si128(lowMask, tmp);
2439
+ __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
2440
+ bx = _mm_sub_epi8(bx, off);
2441
+ const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
2442
+
2443
+ bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
2444
+ by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2445
+ bx = _mm_sub_epi8(bx, off);
2446
+ const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2447
+
2448
+ // Convert int32_t to float
2449
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
2450
+
2451
+ // Apply the scale, and accumulate
2452
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2453
+ }
2454
+
2455
+ *s = hsum_float_8(acc);
2456
+ #elif defined(__SSSE3__)
2457
+ // set constants
2458
+ const __m128i lowMask = _mm_set1_epi8(0xF);
2459
+ const __m128i off = _mm_set1_epi8(8);
2460
+
2461
+ // Initialize accumulator with zeros
2462
+ __m128 acc_0 = _mm_setzero_ps();
2463
+ __m128 acc_1 = _mm_setzero_ps();
2464
+ __m128 acc_2 = _mm_setzero_ps();
2465
+ __m128 acc_3 = _mm_setzero_ps();
2466
+
2467
+ // First round without accumulation
2468
+ {
2469
+ _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
2470
+ _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
2471
+
2472
+ // Compute combined scale for the block 0 and 1
2473
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
2474
+
2475
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
2476
+
2477
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2478
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
2479
+ bx_0 = _mm_sub_epi8(bx_0, off);
2480
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2481
+
2482
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2483
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
2484
+ bx_1 = _mm_sub_epi8(bx_1, off);
2485
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2486
+
2487
+ _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
2488
+ _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
2489
+
2490
+ // Compute combined scale for the block 2 and 3
2491
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
2492
+
2493
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
2494
+
2495
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2496
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
2497
+ bx_2 = _mm_sub_epi8(bx_2, off);
2498
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2499
+
2500
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2501
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
2502
+ bx_3 = _mm_sub_epi8(bx_3, off);
2503
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2504
+
2505
+ // Convert int32_t to float
2506
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
2507
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
2508
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
2509
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
2510
+
2511
+ // Apply the scale
2512
+ acc_0 = _mm_mul_ps( d_0_1, p0 );
2513
+ acc_1 = _mm_mul_ps( d_0_1, p1 );
2514
+ acc_2 = _mm_mul_ps( d_2_3, p2 );
2515
+ acc_3 = _mm_mul_ps( d_2_3, p3 );
2516
+ }
2517
+
2518
+ assert(nb % 2 == 0); // TODO: handle odd nb
2519
+
2520
+ // Main loop
2521
+ for (int i = 2; i < nb; i+=2) {
2522
+ _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
2523
+ _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
2524
+
2525
+ // Compute combined scale for the block 0 and 1
2526
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2527
+
2528
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
2529
+
2530
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2531
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
2532
+ bx_0 = _mm_sub_epi8(bx_0, off);
2533
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2534
+
2535
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2536
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2537
+ bx_1 = _mm_sub_epi8(bx_1, off);
2538
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2539
+
2540
+ _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
2541
+ _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2542
+
2543
+ // Compute combined scale for the block 2 and 3
2544
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
2545
+
2546
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
2547
+
2548
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2549
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
2550
+ bx_2 = _mm_sub_epi8(bx_2, off);
2551
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2552
+
2553
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2554
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
2555
+ bx_3 = _mm_sub_epi8(bx_3, off);
2556
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2557
+
2558
+ // Convert int32_t to float
2559
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
2560
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
2561
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
2562
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
2563
+
2564
+ // Apply the scale
2565
+ __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
2566
+ __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
2567
+ __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
2568
+ __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
2569
+
2570
+ // Acummulate
2571
+ acc_0 = _mm_add_ps(p0_d, acc_0);
2572
+ acc_1 = _mm_add_ps(p1_d, acc_1);
2573
+ acc_2 = _mm_add_ps(p2_d, acc_2);
2574
+ acc_3 = _mm_add_ps(p3_d, acc_3);
2575
+ }
2576
+
2577
+ *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2578
+ #elif defined(__riscv_v_intrinsic)
2579
+ float sumf = 0.0;
2580
+
2581
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
482
2582
 
483
2583
  for (int i = 0; i < nb; i++) {
2584
+ // load elements
2585
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
484
2586
 
485
- float max_scale = 0;
486
- float amax = 0;
487
- for (int j = 0; j < QK_K/16; ++j) {
488
- scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
489
- float scale = fabsf(scales[j]);
490
- if (scale > amax) {
491
- amax = scale; max_scale = scales[j];
492
- }
493
- }
2587
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2588
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
494
2589
 
495
- #if QK_K == 256
496
- memset(y[i].scales, 0, 12);
497
- if (max_scale) {
498
- float iscale = -32.f/max_scale;
499
- for (int j = 0; j < QK_K/16; ++j) {
500
- int8_t l = nearest_int(iscale*scales[j]);
501
- l = MAX(-32, MIN(31, l)) + 32;
502
- if (j < 8) {
503
- y[i].scales[j] = l & 0xF;
504
- } else {
505
- y[i].scales[j-8] |= ((l & 0xF) << 4);
506
- }
507
- l >>= 4;
508
- y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
509
- }
510
- y[i].d = ggml_fp32_to_fp16(1/iscale);
511
- } else {
512
- y[i].d = ggml_fp32_to_fp16(0.f);
513
- }
2590
+ // mask and store lower part of x, and then upper part
2591
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2592
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
514
2593
 
515
- int8_t sc;
516
- for (int j = 0; j < QK_K/16; ++j) {
517
- sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
518
- sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
519
- float d = ggml_fp16_to_fp32(y[i].d) * sc;
520
- if (!d) {
521
- continue;
522
- }
523
- for (int ii = 0; ii < 16; ++ii) {
524
- int l = nearest_int(x[16*j + ii]/d);
525
- l = MAX(-4, MIN(3, l));
526
- L[16*j + ii] = l + 4;
527
- }
528
- }
529
- #else
530
- if (max_scale) {
531
- float iscale = -8.f/max_scale;
532
- for (int j = 0; j < QK_K/16; j+=2) {
533
- int l1 = nearest_int(iscale*scales[j]);
534
- l1 = 8 + MAX(-8, MIN(7, l1));
535
- int l2 = nearest_int(iscale*scales[j+1]);
536
- l2 = 8 + MAX(-8, MIN(7, l2));
537
- y[i].scales[j/2] = l1 | (l2 << 4);
538
- }
539
- y[i].d = ggml_fp32_to_fp16(1/iscale);
540
- } else {
541
- for (int j = 0; j < QK_K/16; j+=2) {
542
- y[i].scales[j/2] = 0;
543
- }
544
- y[i].d = ggml_fp32_to_fp16(0.f);
545
- }
546
- for (int j = 0; j < QK_K/16; ++j) {
547
- int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
548
- float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
549
- if (!d) {
550
- continue;
551
- }
552
- for (int ii = 0; ii < 16; ++ii) {
553
- int l = nearest_int(x[16*j + ii]/d);
554
- l = MAX(-4, MIN(3, l));
555
- L[16*j + ii] = l + 4;
556
- }
557
- }
558
- #endif
2594
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2595
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2596
+
2597
+ // subtract offset
2598
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2599
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2600
+
2601
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2602
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2603
+
2604
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2605
+
2606
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2607
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2608
+
2609
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2610
+
2611
+ sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2612
+ }
559
2613
 
560
- memset(y[i].hmask, 0, QK_K/8);
561
- // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
562
- int m = 0;
563
- uint8_t hm = 1;
564
- for (int j = 0; j < QK_K; ++j) {
565
- if (L[j] > 3) {
566
- y[i].hmask[m] |= hm;
567
- L[j] -= 4;
568
- }
569
- if (++m == QK_K/8) {
570
- m = 0; hm <<= 1;
571
- }
572
- }
573
- #if QK_K == 256
574
- for (int j = 0; j < QK_K; j += 128) {
575
- for (int l = 0; l < 32; ++l) {
576
- y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
577
- }
578
- }
2614
+ *s = sumf;
579
2615
  #else
580
- for (int l = 0; l < 16; ++l) {
581
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
2616
+ // scalar
2617
+ float sumf = 0.0;
2618
+
2619
+ for (int i = 0; i < nb; i++) {
2620
+ int sumi = 0;
2621
+
2622
+ for (int j = 0; j < qk/2; ++j) {
2623
+ const int v0 = (x[i].qs[j] & 0x0F) - 8;
2624
+ const int v1 = (x[i].qs[j] >> 4) - 8;
2625
+
2626
+ sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
582
2627
  }
583
- #endif
584
2628
 
585
- x += QK_K;
2629
+ sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
586
2630
  }
2631
+
2632
+ *s = sumf;
2633
+ #endif
587
2634
  }
588
2635
 
589
- #if QK_K == 256
590
- void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
591
- assert(k % QK_K == 0);
592
- const int nb = k / QK_K;
2636
+ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2637
+ const int qk = QK8_1;
2638
+ const int nb = n / qk;
593
2639
 
594
- const uint32_t kmask1 = 0x03030303;
595
- const uint32_t kmask2 = 0x0f0f0f0f;
2640
+ assert(n % qk == 0);
596
2641
 
597
- uint32_t aux[4];
598
- const int8_t * scales = (const int8_t*)aux;
2642
+ const block_q4_1 * restrict x = vx;
2643
+ const block_q8_1 * restrict y = vy;
599
2644
 
600
- for (int i = 0; i < nb; i++) {
2645
+ // TODO: add WASM SIMD
2646
+ #if defined(__ARM_NEON)
2647
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2648
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
601
2649
 
602
- const float d_all = ggml_fp16_to_fp32(x[i].d);
2650
+ float summs = 0;
603
2651
 
604
- const uint8_t * restrict q = x[i].qs;
605
- const uint8_t * restrict hm = x[i].hmask;
606
- uint8_t m = 1;
2652
+ assert(nb % 2 == 0); // TODO: handle odd nb
607
2653
 
608
- memcpy(aux, x[i].scales, 12);
609
- uint32_t tmp = aux[2];
610
- aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
611
- aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
612
- aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
613
- aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
2654
+ for (int i = 0; i < nb; i += 2) {
2655
+ const block_q4_1 * restrict x0 = &x[i + 0];
2656
+ const block_q4_1 * restrict x1 = &x[i + 1];
2657
+ const block_q8_1 * restrict y0 = &y[i + 0];
2658
+ const block_q8_1 * restrict y1 = &y[i + 1];
614
2659
 
615
- int is = 0;
616
- float dl;
617
- for (int n = 0; n < QK_K; n += 128) {
618
- int shift = 0;
619
- for (int j = 0; j < 4; ++j) {
2660
+ summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
620
2661
 
621
- dl = d_all * (scales[is++] - 32);
622
- for (int l = 0; l < 16; ++l) {
623
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
624
- }
2662
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
625
2663
 
626
- dl = d_all * (scales[is++] - 32);
627
- for (int l = 0; l < 16; ++l) {
628
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
629
- }
2664
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2665
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
630
2666
 
631
- shift += 2;
632
- m <<= 1;
633
- }
634
- q += 32;
635
- }
2667
+ // 4-bit -> 8-bit
2668
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2669
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2670
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2671
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2672
+
2673
+ // load y
2674
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2675
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2676
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2677
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2678
+
2679
+ #if defined(__ARM_FEATURE_DOTPROD)
2680
+ // dot product into int32x4_t
2681
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2682
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
636
2683
 
2684
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2685
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
2686
+ #else
2687
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2688
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
2689
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
2690
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
2691
+
2692
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
2693
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
2694
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
2695
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
2696
+
2697
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2698
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2699
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2700
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2701
+
2702
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2703
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2704
+ #endif
637
2705
  }
638
- }
2706
+
2707
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2708
+ #elif defined(__AVX2__) || defined(__AVX__)
2709
+ // Initialize accumulator with zeros
2710
+ __m256 acc = _mm256_setzero_ps();
2711
+
2712
+ float summs = 0;
2713
+
2714
+ // Main loop
2715
+ for (int i = 0; i < nb; ++i) {
2716
+ const float d0 = GGML_FP16_TO_FP32(x[i].d);
2717
+ const float d1 = y[i].d;
2718
+
2719
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2720
+
2721
+ const __m256 d0v = _mm256_set1_ps( d0 );
2722
+ const __m256 d1v = _mm256_set1_ps( d1 );
2723
+
2724
+ // Compute combined scales
2725
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2726
+
2727
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2728
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2729
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2730
+
2731
+ const __m256 xy = mul_sum_us8_pairs_float(bx, by);
2732
+
2733
+ // Accumulate d0*d1*x*y
2734
+ #if defined(__AVX2__)
2735
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
639
2736
  #else
640
- void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
641
- assert(k % QK_K == 0);
642
- assert(QK_K == 64);
643
- const int nb = k / QK_K;
2737
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
2738
+ #endif
2739
+ }
2740
+
2741
+ *s = hsum_float_8(acc) + summs;
2742
+ #elif defined(__riscv_v_intrinsic)
2743
+ float sumf = 0.0;
2744
+
2745
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
644
2746
 
645
2747
  for (int i = 0; i < nb; i++) {
2748
+ // load elements
2749
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
646
2750
 
647
- const float d_all = ggml_fp16_to_fp32(x[i].d);
2751
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2752
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
648
2753
 
649
- const uint8_t * restrict q = x[i].qs;
650
- const uint8_t * restrict hm = x[i].hmask;
2754
+ // mask and store lower part of x, and then upper part
2755
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2756
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
651
2757
 
652
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
653
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
654
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
655
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
2758
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2759
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
656
2760
 
657
- for (int l=0; l<8; ++l) {
658
- uint8_t h = hm[l];
659
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
660
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
661
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
662
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
663
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
664
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
665
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
666
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
2761
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2762
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2763
+
2764
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2765
+
2766
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2767
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2768
+
2769
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2770
+
2771
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2772
+ }
2773
+
2774
+ *s = sumf;
2775
+ #else
2776
+ // scalar
2777
+ float sumf = 0.0;
2778
+
2779
+ for (int i = 0; i < nb; i++) {
2780
+ int sumi = 0;
2781
+
2782
+ for (int j = 0; j < qk/2; ++j) {
2783
+ const int v0 = (x[i].qs[j] & 0x0F);
2784
+ const int v1 = (x[i].qs[j] >> 4);
2785
+
2786
+ sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
667
2787
  }
668
- y += QK_K;
2788
+
2789
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
669
2790
  }
670
- }
671
- #endif
672
2791
 
673
- void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
674
- quantize_row_q3_K_reference(x, vy, k);
2792
+ *s = sumf;
2793
+ #endif
675
2794
  }
676
2795
 
677
- size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
678
- (void)hist; // TODO: collect histograms
2796
+ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2797
+ const int qk = QK8_0;
2798
+ const int nb = n / qk;
679
2799
 
680
- for (int j = 0; j < n; j += k) {
681
- block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
682
- quantize_row_q3_K_reference(src + j, y, k);
2800
+ assert(n % qk == 0);
2801
+ assert(qk == QK5_0);
2802
+
2803
+ const block_q5_0 * restrict x = vx;
2804
+ const block_q8_0 * restrict y = vy;
2805
+
2806
+ #if defined(__ARM_NEON)
2807
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2808
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2809
+
2810
+ uint32_t qh0;
2811
+ uint32_t qh1;
2812
+
2813
+ uint64_t tmp0[4];
2814
+ uint64_t tmp1[4];
2815
+
2816
+ assert(nb % 2 == 0); // TODO: handle odd nb
2817
+
2818
+ for (int i = 0; i < nb; i += 2) {
2819
+ const block_q5_0 * restrict x0 = &x[i];
2820
+ const block_q5_0 * restrict x1 = &x[i + 1];
2821
+ const block_q8_0 * restrict y0 = &y[i];
2822
+ const block_q8_0 * restrict y1 = &y[i + 1];
2823
+
2824
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2825
+
2826
+ // extract the 5th bit via lookup table ((!b) << 4)
2827
+ memcpy(&qh0, x0->qh, sizeof(qh0));
2828
+ memcpy(&qh1, x1->qh, sizeof(qh1));
2829
+
2830
+ tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
2831
+ tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
2832
+ tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
2833
+ tmp0[3] = table_b2b_1[(qh0 >> 24) ];
2834
+
2835
+ tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
2836
+ tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
2837
+ tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
2838
+ tmp1[3] = table_b2b_1[(qh1 >> 24) ];
2839
+
2840
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
2841
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
2842
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
2843
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
2844
+
2845
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2846
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2847
+
2848
+ // 4-bit -> 8-bit
2849
+ int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2850
+ int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2851
+ int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2852
+ int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2853
+
2854
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2855
+ const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
2856
+ const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
2857
+ const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
2858
+ const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
2859
+
2860
+ // load y
2861
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2862
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2863
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2864
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2865
+
2866
+ #if defined(__ARM_FEATURE_DOTPROD)
2867
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2868
+ vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2869
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2870
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2871
+ vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2872
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2873
+ #else
2874
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2875
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2876
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2877
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2878
+
2879
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2880
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2881
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2882
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
2883
+
2884
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2885
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2886
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2887
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2888
+
2889
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2890
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2891
+ #endif
683
2892
  }
684
- return (n/QK_K*sizeof(block_q3_K));
685
- }
686
2893
 
687
- // ====================== 4-bit (de)-quantization
2894
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2895
+ #elif defined(__wasm_simd128__)
2896
+ v128_t sumv = wasm_f32x4_splat(0.0f);
688
2897
 
689
- void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
690
- assert(k % QK_K == 0);
691
- const int nb = k / QK_K;
2898
+ uint32_t qh;
2899
+ uint64_t tmp[4];
692
2900
 
693
- uint8_t L[QK_K];
694
- uint8_t Laux[32];
695
- float weights[32];
696
- float mins[QK_K/32];
697
- float scales[QK_K/32];
2901
+ // TODO: check if unrolling this is better
2902
+ for (int i = 0; i < nb; ++i) {
2903
+ const block_q5_0 * restrict x0 = &x[i];
2904
+ const block_q8_0 * restrict y0 = &y[i];
2905
+
2906
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2907
+
2908
+ // extract the 5th bit
2909
+ memcpy(&qh, x0->qh, sizeof(qh));
2910
+
2911
+ tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
2912
+ tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
2913
+ tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
2914
+ tmp[3] = table_b2b_1[(qh >> 24) ];
2915
+
2916
+ const v128_t qhl = wasm_v128_load(tmp + 0);
2917
+ const v128_t qhh = wasm_v128_load(tmp + 2);
2918
+
2919
+ const v128_t v0 = wasm_v128_load(x0->qs);
2920
+
2921
+ // 4-bit -> 8-bit
2922
+ const v128_t v0l = wasm_v128_and (v0, m4b);
2923
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
2924
+
2925
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2926
+ const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
2927
+ const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
2928
+
2929
+ // load y
2930
+ const v128_t v1l = wasm_v128_load(y0->qs);
2931
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
2932
+
2933
+ // int8x16 -> int16x8
2934
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
2935
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
2936
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
2937
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
2938
+
2939
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
2940
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
2941
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
2942
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
2943
+
2944
+ // dot product
2945
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
2946
+ wasm_i32x4_add(
2947
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
2948
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
2949
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
2950
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
2951
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
2952
+ }
2953
+
2954
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
2955
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
2956
+ #elif defined(__AVX2__)
2957
+ // Initialize accumulator with zeros
2958
+ __m256 acc = _mm256_setzero_ps();
698
2959
 
2960
+ // Main loop
699
2961
  for (int i = 0; i < nb; i++) {
2962
+ /* Compute combined scale for the block */
2963
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
700
2964
 
701
- float max_scale = 0; // as we are deducting the min, scales are always positive
702
- float max_min = 0;
703
- for (int j = 0; j < QK_K/32; ++j) {
704
- //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
705
- float sum_x2 = 0;
706
- for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
707
- float av_x = sqrtf(sum_x2/32);
708
- for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
709
- scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
710
- float scale = scales[j];
711
- if (scale > max_scale) {
712
- max_scale = scale;
713
- }
714
- float min = mins[j];
715
- if (min > max_min) {
716
- max_min = min;
717
- }
718
- }
2965
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2966
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
2967
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
2968
+ bx = _mm256_or_si256(bx, bxhi);
719
2969
 
720
- #if QK_K == 256
721
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
722
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
723
- for (int j = 0; j < QK_K/32; ++j) {
724
- uint8_t ls = nearest_int(inv_scale*scales[j]);
725
- uint8_t lm = nearest_int(inv_min*mins[j]);
726
- ls = MIN(63, ls);
727
- lm = MIN(63, lm);
728
- if (j < 4) {
729
- y[i].scales[j] = ls;
730
- y[i].scales[j+4] = lm;
731
- } else {
732
- y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
733
- y[i].scales[j-4] |= ((ls >> 4) << 6);
734
- y[i].scales[j-0] |= ((lm >> 4) << 6);
735
- }
736
- }
737
- y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
738
- y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
2970
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
739
2971
 
740
- uint8_t sc, m;
741
- for (int j = 0; j < QK_K/32; ++j) {
742
- get_scale_min_k4(j, y[i].scales, &sc, &m);
743
- const float d = ggml_fp16_to_fp32(y[i].d) * sc;
744
- if (!d) continue;
745
- const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
746
- for (int ii = 0; ii < 32; ++ii) {
747
- int l = nearest_int((x[32*j + ii] + dm)/d);
748
- l = MAX(0, MIN(15, l));
749
- L[32*j + ii] = l;
750
- }
751
- }
752
- #else
753
- const float s_factor = 15.f;
754
- float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
755
- float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
756
- int d1 = nearest_int(inv_scale*scales[0]);
757
- int m1 = nearest_int(inv_min*mins[0]);
758
- int d2 = nearest_int(inv_scale*scales[1]);
759
- int m2 = nearest_int(inv_min*mins[1]);
760
- y[i].scales[0] = d1 | (m1 << 4);
761
- y[i].scales[1] = d2 | (m2 << 4);
762
- y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
763
- y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
2972
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
764
2973
 
765
- float sumlx = 0;
766
- int suml2 = 0;
767
- for (int j = 0; j < QK_K/32; ++j) {
768
- const uint8_t sd = y[i].scales[j] & 0xF;
769
- const uint8_t sm = y[i].scales[j] >> 4;
770
- const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
771
- if (!d) continue;
772
- const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
773
- for (int ii = 0; ii < 32; ++ii) {
774
- int l = nearest_int((x[32*j + ii] + m)/d);
775
- l = MAX(0, MIN(15, l));
776
- L[32*j + ii] = l;
777
- sumlx += (x[32*j + ii] + m)*l*sd;
778
- suml2 += l*l*sd*sd;
779
- }
780
- }
781
- if (suml2) {
782
- y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
783
- }
784
- #endif
785
- uint8_t * q = y[i].qs;
786
- for (int j = 0; j < QK_K; j += 64) {
787
- for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
788
- q += 32;
789
- }
2974
+ /* Multiply q with scale and accumulate */
2975
+ acc = _mm256_fmadd_ps(d, q, acc);
2976
+ }
2977
+
2978
+ *s = hsum_float_8(acc);
2979
+ #elif defined(__AVX__)
2980
+ // Initialize accumulator with zeros
2981
+ __m256 acc = _mm256_setzero_ps();
2982
+ __m128i mask = _mm_set1_epi8((char)0xF0);
2983
+
2984
+ // Main loop
2985
+ for (int i = 0; i < nb; i++) {
2986
+ /* Compute combined scale for the block */
2987
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
2988
+
2989
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2990
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
2991
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
2992
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
2993
+ bxhil = _mm_andnot_si128(bxhil, mask);
2994
+ bxhih = _mm_andnot_si128(bxhih, mask);
2995
+ __m128i bxl = _mm256_castsi256_si128(bx);
2996
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
2997
+ bxl = _mm_or_si128(bxl, bxhil);
2998
+ bxh = _mm_or_si128(bxh, bxhih);
2999
+ bx = MM256_SET_M128I(bxh, bxl);
3000
+
3001
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3002
+
3003
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3004
+
3005
+ /* Multiply q with scale and accumulate */
3006
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
3007
+ }
790
3008
 
791
- x += QK_K;
3009
+ *s = hsum_float_8(acc);
3010
+ #elif defined(__riscv_v_intrinsic)
3011
+ float sumf = 0.0;
792
3012
 
793
- }
794
- }
3013
+ uint32_t qh;
795
3014
 
796
- void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
797
- assert(k % QK_K == 0);
798
- const int nb = k / QK_K;
3015
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
3016
+
3017
+ // These tempory registers are for masking and shift operations
3018
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3019
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3020
+
3021
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3022
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
799
3023
 
800
3024
  for (int i = 0; i < nb; i++) {
3025
+ memcpy(&qh, x[i].qh, sizeof(uint32_t));
801
3026
 
802
- const uint8_t * q = x[i].qs;
3027
+ // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3028
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3029
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3030
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
803
3031
 
804
- #if QK_K == 256
3032
+ // ((qh & (1u << (j + 16))) >> (j + 12));
3033
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3034
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
805
3035
 
806
- const float d = ggml_fp16_to_fp32(x[i].d);
807
- const float min = ggml_fp16_to_fp32(x[i].dmin);
3036
+ // narrowing
3037
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3038
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
808
3039
 
809
- int is = 0;
810
- uint8_t sc, m;
811
- for (int j = 0; j < QK_K; j += 64) {
812
- get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
813
- const float d1 = d * sc; const float m1 = min * m;
814
- get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
815
- const float d2 = d * sc; const float m2 = min * m;
816
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
817
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
818
- q += 32; is += 2;
819
- }
3040
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3041
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3042
+
3043
+ // load
3044
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3045
+
3046
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3047
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3048
+
3049
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3050
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3051
+
3052
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3053
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3054
+
3055
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3056
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3057
+
3058
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3059
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3060
+
3061
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3062
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3063
+
3064
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3065
+
3066
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3067
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3068
+
3069
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3070
+
3071
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3072
+ }
3073
+
3074
+ *s = sumf;
820
3075
  #else
821
- const float dall = ggml_fp16_to_fp32(x[i].d[0]);
822
- const float mall = ggml_fp16_to_fp32(x[i].d[1]);
823
- const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
824
- const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
825
- for (int l = 0; l < 32; ++l) {
826
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
827
- y[l+32] = d2 * (q[l] >> 4) - m2;
3076
+ // scalar
3077
+ float sumf = 0.0;
3078
+
3079
+ for (int i = 0; i < nb; i++) {
3080
+ uint32_t qh;
3081
+ memcpy(&qh, x[i].qh, sizeof(qh));
3082
+
3083
+ int sumi = 0;
3084
+
3085
+ for (int j = 0; j < qk/2; ++j) {
3086
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3087
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
3088
+
3089
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
3090
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
3091
+
3092
+ sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
828
3093
  }
829
- y += QK_K;
830
- #endif
831
3094
 
3095
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
832
3096
  }
833
- }
834
3097
 
835
- void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
836
- assert(k % QK_K == 0);
837
- block_q4_K * restrict y = vy;
838
- quantize_row_q4_K_reference(x, y, k);
3098
+ *s = sumf;
3099
+ #endif
839
3100
  }
840
3101
 
841
- size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
842
- assert(k % QK_K == 0);
843
- (void)hist; // TODO: collect histograms
3102
+ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3103
+ const int qk = QK8_1;
3104
+ const int nb = n / qk;
844
3105
 
845
- for (int j = 0; j < n; j += k) {
846
- block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
847
- quantize_row_q4_K_reference(src + j, y, k);
848
- }
849
- return (n/QK_K*sizeof(block_q4_K));
850
- }
3106
+ assert(n % qk == 0);
3107
+ assert(qk == QK5_1);
851
3108
 
852
- // ====================== 5-bit (de)-quantization
3109
+ const block_q5_1 * restrict x = vx;
3110
+ const block_q8_1 * restrict y = vy;
853
3111
 
854
- void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
855
- assert(k % QK_K == 0);
856
- const int nb = k / QK_K;
3112
+ #if defined(__ARM_NEON)
3113
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3114
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
857
3115
 
858
- #if QK_K == 256
859
- uint8_t L[QK_K];
860
- float mins[QK_K/32];
861
- float scales[QK_K/32];
862
- float weights[32];
863
- uint8_t Laux[32];
3116
+ float summs0 = 0.0f;
3117
+ float summs1 = 0.0f;
3118
+
3119
+ uint32_t qh0;
3120
+ uint32_t qh1;
3121
+
3122
+ uint64_t tmp0[4];
3123
+ uint64_t tmp1[4];
3124
+
3125
+ assert(nb % 2 == 0); // TODO: handle odd nb
3126
+
3127
+ for (int i = 0; i < nb; i += 2) {
3128
+ const block_q5_1 * restrict x0 = &x[i];
3129
+ const block_q5_1 * restrict x1 = &x[i + 1];
3130
+ const block_q8_1 * restrict y0 = &y[i];
3131
+ const block_q8_1 * restrict y1 = &y[i + 1];
3132
+
3133
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3134
+
3135
+ summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
3136
+ summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
3137
+
3138
+ // extract the 5th bit via lookup table ((b) << 4)
3139
+ memcpy(&qh0, x0->qh, sizeof(qh0));
3140
+ memcpy(&qh1, x1->qh, sizeof(qh1));
3141
+
3142
+ tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
3143
+ tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
3144
+ tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
3145
+ tmp0[3] = table_b2b_0[(qh0 >> 24) ];
3146
+
3147
+ tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
3148
+ tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
3149
+ tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
3150
+ tmp1[3] = table_b2b_0[(qh1 >> 24) ];
3151
+
3152
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
3153
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
3154
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
3155
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
3156
+
3157
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
3158
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
3159
+
3160
+ // 4-bit -> 8-bit
3161
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3162
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3163
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3164
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3165
+
3166
+ // add high bit
3167
+ const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
3168
+ const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
3169
+ const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
3170
+ const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
3171
+
3172
+ // load y
3173
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3174
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3175
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
3176
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3177
+
3178
+ #if defined(__ARM_FEATURE_DOTPROD)
3179
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3180
+ vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3181
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3182
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3183
+ vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3184
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
864
3185
  #else
865
- int8_t L[QK_K];
866
- float scales[QK_K/16];
3186
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
3187
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
3188
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
3189
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
3190
+
3191
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
3192
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
3193
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
3194
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
3195
+
3196
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3197
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3198
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3199
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3200
+
3201
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
3202
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
867
3203
  #endif
3204
+ }
868
3205
 
869
- for (int i = 0; i < nb; i++) {
3206
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3207
+ #elif defined(__wasm_simd128__)
3208
+ v128_t sumv = wasm_f32x4_splat(0.0f);
870
3209
 
871
- #if QK_K == 256
3210
+ float summs = 0.0f;
872
3211
 
873
- float max_scale = 0; // as we are deducting the min, scales are always positive
874
- float max_min = 0;
875
- for (int j = 0; j < QK_K/32; ++j) {
876
- //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
877
- float sum_x2 = 0;
878
- for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
879
- float av_x = sqrtf(sum_x2/32);
880
- for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
881
- scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
882
- float scale = scales[j];
883
- if (scale > max_scale) {
884
- max_scale = scale;
885
- }
886
- float min = mins[j];
887
- if (min > max_min) {
888
- max_min = min;
889
- }
890
- }
3212
+ uint32_t qh;
3213
+ uint64_t tmp[4];
891
3214
 
892
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
893
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
894
- for (int j = 0; j < QK_K/32; ++j) {
895
- uint8_t ls = nearest_int(inv_scale*scales[j]);
896
- uint8_t lm = nearest_int(inv_min*mins[j]);
897
- ls = MIN(63, ls);
898
- lm = MIN(63, lm);
899
- if (j < 4) {
900
- y[i].scales[j] = ls;
901
- y[i].scales[j+4] = lm;
902
- } else {
903
- y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
904
- y[i].scales[j-4] |= ((ls >> 4) << 6);
905
- y[i].scales[j-0] |= ((lm >> 4) << 6);
906
- }
907
- }
908
- y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
909
- y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
3215
+ // TODO: check if unrolling this is better
3216
+ for (int i = 0; i < nb; ++i) {
3217
+ const block_q5_1 * restrict x0 = &x[i];
3218
+ const block_q8_1 * restrict y0 = &y[i];
910
3219
 
911
- uint8_t sc, m;
912
- for (int j = 0; j < QK_K/32; ++j) {
913
- get_scale_min_k4(j, y[i].scales, &sc, &m);
914
- const float d = ggml_fp16_to_fp32(y[i].d) * sc;
915
- if (!d) continue;
916
- const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
917
- for (int ii = 0; ii < 32; ++ii) {
918
- int l = nearest_int((x[32*j + ii] + dm)/d);
919
- l = MAX(0, MIN(31, l));
920
- L[32*j + ii] = l;
921
- }
922
- }
3220
+ summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
923
3221
 
924
- uint8_t * restrict qh = y[i].qh;
925
- uint8_t * restrict ql = y[i].qs;
926
- memset(qh, 0, QK_K/8);
3222
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
927
3223
 
928
- uint8_t m1 = 1, m2 = 2;
929
- for (int n = 0; n < QK_K; n += 64) {
930
- for (int j = 0; j < 32; ++j) {
931
- int l1 = L[n + j];
932
- if (l1 > 15) {
933
- l1 -= 16; qh[j] |= m1;
934
- }
935
- int l2 = L[n + j + 32];
936
- if (l2 > 15) {
937
- l2 -= 16; qh[j] |= m2;
938
- }
939
- ql[j] = l1 | (l2 << 4);
940
- }
941
- m1 <<= 2; m2 <<= 2;
942
- ql += 32;
943
- }
944
- #else
945
- float max_scale = 0, amax = 0;
946
- for (int j = 0; j < QK_K/16; ++j) {
947
- scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
948
- float abs_scale = fabsf(scales[j]);
949
- if (abs_scale > amax) {
950
- amax = abs_scale;
951
- max_scale = scales[j];
952
- }
953
- }
3224
+ // extract the 5th bit
3225
+ memcpy(&qh, x0->qh, sizeof(qh));
954
3226
 
955
- float iscale = -128.f/max_scale;
956
- for (int j = 0; j < QK_K/16; ++j) {
957
- int l = nearest_int(iscale*scales[j]);
958
- y[i].scales[j] = MAX(-128, MIN(127, l));
959
- }
960
- y[i].d = ggml_fp32_to_fp16(1/iscale);
3227
+ tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3228
+ tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3229
+ tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3230
+ tmp[3] = table_b2b_0[(qh >> 24) ];
961
3231
 
962
- for (int j = 0; j < QK_K/16; ++j) {
963
- const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
964
- if (!d) continue;
965
- for (int ii = 0; ii < 16; ++ii) {
966
- int l = nearest_int(x[16*j + ii]/d);
967
- l = MAX(-16, MIN(15, l));
968
- L[16*j + ii] = l + 16;
969
- }
970
- }
3232
+ const v128_t qhl = wasm_v128_load(tmp + 0);
3233
+ const v128_t qhh = wasm_v128_load(tmp + 2);
971
3234
 
972
- uint8_t * restrict qh = y[i].qh;
973
- uint8_t * restrict ql = y[i].qs;
974
- memset(qh, 0, QK_K/8);
3235
+ const v128_t v0 = wasm_v128_load(x0->qs);
975
3236
 
976
- for (int j = 0; j < 32; ++j) {
977
- int jm = j%8;
978
- int is = j/8;
979
- int l1 = L[j];
980
- if (l1 > 15) {
981
- l1 -= 16; qh[jm] |= (1 << is);
982
- }
983
- int l2 = L[j + 32];
984
- if (l2 > 15) {
985
- l2 -= 16; qh[jm] |= (1 << (4 + is));
986
- }
987
- ql[j] = l1 | (l2 << 4);
988
- }
989
- #endif
3237
+ // 4-bit -> 8-bit
3238
+ const v128_t v0l = wasm_v128_and (v0, m4b);
3239
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
990
3240
 
991
- x += QK_K;
3241
+ // add high bit
3242
+ const v128_t v0lf = wasm_v128_or(v0l, qhl);
3243
+ const v128_t v0hf = wasm_v128_or(v0h, qhh);
3244
+
3245
+ // load y
3246
+ const v128_t v1l = wasm_v128_load(y0->qs);
3247
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
3248
+
3249
+ // int8x16 -> int16x8
3250
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3251
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3252
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3253
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
992
3254
 
3255
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3256
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3257
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3258
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3259
+
3260
+ // dot product
3261
+ sumv = wasm_f32x4_add(sumv,
3262
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
3263
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3264
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3265
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3266
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3267
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
993
3268
  }
994
- }
995
3269
 
996
- void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
997
- assert(k % QK_K == 0);
998
- const int nb = k / QK_K;
3270
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3271
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
3272
+ #elif defined(__AVX2__)
3273
+ // Initialize accumulator with zeros
3274
+ __m256 acc = _mm256_setzero_ps();
999
3275
 
3276
+ float summs = 0.0f;
3277
+
3278
+ // Main loop
1000
3279
  for (int i = 0; i < nb; i++) {
3280
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
1001
3281
 
1002
- const uint8_t * ql = x[i].qs;
1003
- const uint8_t * qh = x[i].qh;
3282
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
1004
3283
 
1005
- #if QK_K == 256
3284
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3285
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3286
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3287
+ bx = _mm256_or_si256(bx, bxhi);
1006
3288
 
1007
- const float d = ggml_fp16_to_fp32(x[i].d);
1008
- const float min = ggml_fp16_to_fp32(x[i].dmin);
3289
+ const __m256 dy = _mm256_set1_ps(y[i].d);
3290
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1009
3291
 
1010
- int is = 0;
1011
- uint8_t sc, m;
1012
- uint8_t u1 = 1, u2 = 2;
1013
- for (int j = 0; j < QK_K; j += 64) {
1014
- get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
1015
- const float d1 = d * sc; const float m1 = min * m;
1016
- get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
1017
- const float d2 = d * sc; const float m2 = min * m;
1018
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
1019
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
1020
- ql += 32; is += 2;
1021
- u1 <<= 2; u2 <<= 2;
1022
- }
1023
- #else
1024
- float d = ggml_fp16_to_fp32(x[i].d);
1025
- const int8_t * restrict s = x[i].scales;
1026
- for (int l = 0; l < 8; ++l) {
1027
- y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1028
- y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1029
- y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1030
- y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1031
- y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1032
- y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1033
- y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1034
- y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1035
- }
1036
- y += QK_K;
1037
- #endif
3292
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
3293
+
3294
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
1038
3295
  }
1039
- }
1040
3296
 
1041
- void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
1042
- assert(k % QK_K == 0);
1043
- block_q5_K * restrict y = vy;
1044
- quantize_row_q5_K_reference(x, y, k);
1045
- }
3297
+ *s = hsum_float_8(acc) + summs;
3298
+ #elif defined(__AVX__)
3299
+ // Initialize accumulator with zeros
3300
+ __m256 acc = _mm256_setzero_ps();
3301
+ __m128i mask = _mm_set1_epi8(0x10);
3302
+
3303
+ float summs = 0.0f;
3304
+
3305
+ // Main loop
3306
+ for (int i = 0; i < nb; i++) {
3307
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
3308
+
3309
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
3310
+
3311
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3312
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
3313
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
3314
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
3315
+ bxhil = _mm_and_si128(bxhil, mask);
3316
+ bxhih = _mm_and_si128(bxhih, mask);
3317
+ __m128i bxl = _mm256_castsi256_si128(bx);
3318
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
3319
+ bxl = _mm_or_si128(bxl, bxhil);
3320
+ bxh = _mm_or_si128(bxh, bxhih);
3321
+ bx = MM256_SET_M128I(bxh, bxl);
3322
+
3323
+ const __m256 dy = _mm256_set1_ps(y[i].d);
3324
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1046
3325
 
1047
- size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1048
- assert(k % QK_K == 0);
1049
- (void)hist; // TODO: collect histograms
3326
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
1050
3327
 
1051
- for (int j = 0; j < n; j += k) {
1052
- block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
1053
- quantize_row_q5_K_reference(src + j, y, k);
3328
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
1054
3329
  }
1055
- return (n/QK_K*sizeof(block_q5_K));
1056
- }
1057
3330
 
1058
- // ====================== 6-bit (de)-quantization
3331
+ *s = hsum_float_8(acc) + summs;
3332
+ #elif defined(__riscv_v_intrinsic)
3333
+ float sumf = 0.0;
1059
3334
 
1060
- void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
1061
- assert(k % QK_K == 0);
1062
- const int nb = k / QK_K;
3335
+ uint32_t qh;
1063
3336
 
1064
- int8_t L[QK_K];
1065
- float scales[QK_K/16];
3337
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
3338
+
3339
+ // temporary registers for shift operations
3340
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3341
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
1066
3342
 
1067
3343
  for (int i = 0; i < nb; i++) {
3344
+ memcpy(&qh, x[i].qh, sizeof(uint32_t));
1068
3345
 
1069
- float max_scale = 0;
1070
- float max_abs_scale = 0;
3346
+ // load qh
3347
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
1071
3348
 
1072
- for (int ib = 0; ib < QK_K/16; ++ib) {
3349
+ // ((qh >> (j + 0)) << 4) & 0x10;
3350
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3351
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3352
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
1073
3353
 
1074
- const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
1075
- scales[ib] = scale;
3354
+ // ((qh >> (j + 12)) ) & 0x10;
3355
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3356
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
1076
3357
 
1077
- const float abs_scale = fabsf(scale);
1078
- if (abs_scale > max_abs_scale) {
1079
- max_abs_scale = abs_scale;
1080
- max_scale = scale;
1081
- }
3358
+ // narrowing
3359
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3360
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
1082
3361
 
1083
- }
3362
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3363
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
1084
3364
 
1085
- if (!max_abs_scale) {
1086
- memset(&y[i], 0, sizeof(block_q6_K));
1087
- y[i].d = ggml_fp32_to_fp16(0.f);
1088
- x += QK_K;
1089
- continue;
1090
- }
3365
+ // load
3366
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
1091
3367
 
1092
- float iscale = -128.f/max_scale;
1093
- y[i].d = ggml_fp32_to_fp16(1/iscale);
1094
- for (int ib = 0; ib < QK_K/16; ++ib) {
1095
- y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
1096
- }
3368
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3369
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
1097
3370
 
1098
- for (int j = 0; j < QK_K/16; ++j) {
1099
- float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
1100
- if (!d) {
1101
- continue;
1102
- }
1103
- for (int ii = 0; ii < 16; ++ii) {
1104
- int l = nearest_int(x[16*j + ii]/d);
1105
- l = MAX(-32, MIN(31, l));
1106
- L[16*j + ii] = l + 32;
1107
- }
1108
- }
3371
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3372
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
1109
3373
 
1110
- uint8_t * restrict ql = y[i].ql;
1111
- uint8_t * restrict qh = y[i].qh;
1112
- #if QK_K == 256
1113
- for (int j = 0; j < QK_K; j += 128) {
1114
- for (int l = 0; l < 32; ++l) {
1115
- const uint8_t q1 = L[j + l + 0] & 0xF;
1116
- const uint8_t q2 = L[j + l + 32] & 0xF;
1117
- const uint8_t q3 = L[j + l + 64] & 0xF;
1118
- const uint8_t q4 = L[j + l + 96] & 0xF;
1119
- ql[l+ 0] = q1 | (q3 << 4);
1120
- ql[l+32] = q2 | (q4 << 4);
1121
- qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
1122
- }
1123
- ql += 64;
1124
- qh += 32;
1125
- }
1126
- #else
1127
- for (int l = 0; l < 32; ++l) {
1128
- const uint8_t q1 = L[l + 0] & 0xF;
1129
- const uint8_t q2 = L[l + 32] & 0xF;
1130
- ql[l] = q1 | (q2 << 4);
1131
- }
1132
- for (int l = 0; l < 16; ++l) {
1133
- qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
1134
- }
1135
- #endif
3374
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3375
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
1136
3376
 
1137
- x += QK_K;
3377
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3378
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3379
+
3380
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3381
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3382
+
3383
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
1138
3384
 
3385
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3386
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3387
+
3388
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3389
+
3390
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1139
3391
  }
1140
- }
1141
3392
 
1142
- void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
1143
- assert(k % QK_K == 0);
1144
- const int nb = k / QK_K;
3393
+ *s = sumf;
3394
+ #else
3395
+ // scalar
3396
+ float sumf = 0.0;
1145
3397
 
1146
3398
  for (int i = 0; i < nb; i++) {
3399
+ uint32_t qh;
3400
+ memcpy(&qh, x[i].qh, sizeof(qh));
1147
3401
 
1148
- const float d = ggml_fp16_to_fp32(x[i].d);
3402
+ int sumi = 0;
1149
3403
 
1150
- const uint8_t * restrict ql = x[i].ql;
1151
- const uint8_t * restrict qh = x[i].qh;
1152
- const int8_t * restrict sc = x[i].scales;
3404
+ for (int j = 0; j < qk/2; ++j) {
3405
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
3406
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1153
3407
 
1154
- #if QK_K == 256
1155
- for (int n = 0; n < QK_K; n += 128) {
1156
- for (int l = 0; l < 32; ++l) {
1157
- int is = l/16;
1158
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1159
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1160
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1161
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1162
- y[l + 0] = d * sc[is + 0] * q1;
1163
- y[l + 32] = d * sc[is + 2] * q2;
1164
- y[l + 64] = d * sc[is + 4] * q3;
1165
- y[l + 96] = d * sc[is + 6] * q4;
1166
- }
1167
- y += 128;
1168
- ql += 64;
1169
- qh += 32;
1170
- sc += 8;
1171
- }
1172
- #else
1173
- for (int l = 0; l < 16; ++l) {
1174
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1175
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1176
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1177
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1178
- y[l+ 0] = d * sc[0] * q1;
1179
- y[l+16] = d * sc[1] * q2;
1180
- y[l+32] = d * sc[2] * q3;
1181
- y[l+48] = d * sc[3] * q4;
3408
+ const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
3409
+ const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
3410
+
3411
+ sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
1182
3412
  }
1183
- y += 64;
1184
- #endif
1185
3413
 
3414
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1186
3415
  }
1187
- }
1188
3416
 
1189
- void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
1190
- assert(k % QK_K == 0);
1191
- block_q6_K * restrict y = vy;
1192
- quantize_row_q6_K_reference(x, y, k);
3417
+ *s = sumf;
3418
+ #endif
1193
3419
  }
1194
3420
 
1195
- size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
1196
- assert(k % QK_K == 0);
1197
- (void)hist; // TODO: collect histograms
3421
+ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3422
+ const int qk = QK8_0;
3423
+ const int nb = n / qk;
1198
3424
 
1199
- for (int j = 0; j < n; j += k) {
1200
- block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
1201
- quantize_row_q6_K_reference(src + j, y, k);
3425
+ assert(n % qk == 0);
3426
+
3427
+ const block_q8_0 * restrict x = vx;
3428
+ const block_q8_0 * restrict y = vy;
3429
+
3430
+ #if defined(__ARM_NEON)
3431
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3432
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3433
+
3434
+ assert(nb % 2 == 0); // TODO: handle odd nb
3435
+
3436
+ for (int i = 0; i < nb; i += 2) {
3437
+ const block_q8_0 * restrict x0 = &x[i + 0];
3438
+ const block_q8_0 * restrict x1 = &x[i + 1];
3439
+ const block_q8_0 * restrict y0 = &y[i + 0];
3440
+ const block_q8_0 * restrict y1 = &y[i + 1];
3441
+
3442
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3443
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3444
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3445
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
3446
+
3447
+ // load y
3448
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3449
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3450
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3451
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3452
+
3453
+ #if defined(__ARM_FEATURE_DOTPROD)
3454
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3455
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3456
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3457
+
3458
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3459
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3460
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3461
+
3462
+ #else
3463
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3464
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3465
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3466
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3467
+
3468
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3469
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3470
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3471
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3472
+
3473
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3474
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3475
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3476
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3477
+
3478
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3479
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3480
+ #endif
1202
3481
  }
1203
- return (n/QK_K*sizeof(block_q6_K));
1204
- }
1205
3482
 
1206
- //===================================== Q8_K ==============================================
3483
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3484
+ #elif defined(__AVX2__) || defined(__AVX__)
3485
+ // Initialize accumulator with zeros
3486
+ __m256 acc = _mm256_setzero_ps();
1207
3487
 
1208
- void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
1209
- assert(k % QK_K == 0);
1210
- const int nb = k / QK_K;
3488
+ // Main loop
3489
+ for (int i = 0; i < nb; ++i) {
3490
+ // Compute combined scale for the block
3491
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3492
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3493
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1211
3494
 
1212
- for (int i = 0; i < nb; i++) {
3495
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
1213
3496
 
1214
- float max = 0;
1215
- float amax = 0;
1216
- for (int j = 0; j < QK_K; ++j) {
1217
- float ax = fabsf(x[j]);
1218
- if (ax > amax) {
1219
- amax = ax; max = x[j];
1220
- }
1221
- }
1222
- if (!amax) {
1223
- y[i].d = 0;
1224
- memset(y[i].qs, 0, QK_K);
1225
- x += QK_K;
1226
- continue;
1227
- }
1228
- const float iscale = -128.f/max;
1229
- for (int j = 0; j < QK_K; ++j) {
1230
- int v = nearest_int(iscale*x[j]);
1231
- y[i].qs[j] = MIN(127, v);
1232
- }
1233
- for (int j = 0; j < QK_K/16; ++j) {
1234
- int sum = 0;
1235
- for (int ii = 0; ii < 16; ++ii) {
1236
- sum += y[i].qs[j*16 + ii];
1237
- }
1238
- y[i].bsums[j] = sum;
1239
- }
1240
- y[i].d = 1/iscale;
1241
- x += QK_K;
3497
+ // Multiply q with scale and accumulate
3498
+ #if defined(__AVX2__)
3499
+ acc = _mm256_fmadd_ps( d, q, acc );
3500
+ #else
3501
+ acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
3502
+ #endif
1242
3503
  }
1243
- }
1244
3504
 
1245
- void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
1246
- assert(k % QK_K == 0);
1247
- const int nb = k / QK_K;
3505
+ *s = hsum_float_8(acc);
3506
+ #elif defined(__riscv_v_intrinsic)
3507
+ float sumf = 0.0;
3508
+ size_t vl = __riscv_vsetvl_e8m1(qk);
1248
3509
 
1249
3510
  for (int i = 0; i < nb; i++) {
1250
- for (int j = 0; j < QK_K; ++j) {
1251
- *y++ = x[i].d * x[i].qs[j];
1252
- }
3511
+ // load elements
3512
+ vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
3513
+ vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
3514
+
3515
+ vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
3516
+
3517
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3518
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3519
+
3520
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
3521
+
3522
+ sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
1253
3523
  }
1254
- }
1255
3524
 
1256
- void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
1257
- quantize_row_q8_K_reference(x, y, k);
1258
- }
3525
+ *s = sumf;
3526
+ #else
3527
+ // scalar
3528
+ float sumf = 0.0;
1259
3529
 
1260
- //===================================== Dot ptoducts =================================
3530
+ for (int i = 0; i < nb; i++) {
3531
+ int sumi = 0;
1261
3532
 
1262
- //
1263
- // Helper functions
1264
- //
1265
- #if __AVX__ || __AVX2__ || __AVX512F__
3533
+ for (int j = 0; j < qk; j++) {
3534
+ sumi += x[i].qs[j]*y[i].qs[j];
3535
+ }
1266
3536
 
1267
- // horizontally add 8 floats
1268
- static inline float hsum_float_8(const __m256 x) {
1269
- __m128 res = _mm256_extractf128_ps(x, 1);
1270
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
1271
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
1272
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
1273
- return _mm_cvtss_f32(res);
1274
- }
3537
+ sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
3538
+ }
1275
3539
 
1276
- // shuffles to pick the required scales in dot products
1277
- static inline __m256i get_scale_shuffle_q3k(int i) {
1278
- static const uint8_t k_shuffle[128] = {
1279
- 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
1280
- 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
1281
- 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
1282
- 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
1283
- };
1284
- return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
1285
- }
1286
- static inline __m256i get_scale_shuffle_k4(int i) {
1287
- static const uint8_t k_shuffle[256] = {
1288
- 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1289
- 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
1290
- 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
1291
- 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
1292
- 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
1293
- 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
1294
- 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
1295
- 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
1296
- };
1297
- return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
1298
- }
1299
- static inline __m128i get_scale_shuffle(int i) {
1300
- static const uint8_t k_shuffle[128] = {
1301
- 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1302
- 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
1303
- 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
1304
- 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
1305
- 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
1306
- 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
1307
- 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
1308
- 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
1309
- };
1310
- return _mm_loadu_si128((const __m128i*)k_shuffle + i);
1311
- }
3540
+ *s = sumf;
1312
3541
  #endif
3542
+ }
1313
3543
 
1314
3544
  #if QK_K == 256
1315
3545
  void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -1334,8 +3564,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1334
3564
 
1335
3565
  for (int i = 0; i < nb; ++i) {
1336
3566
 
1337
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1338
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3567
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3568
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1339
3569
 
1340
3570
  const uint8_t * restrict q2 = x[i].qs;
1341
3571
  const int8_t * restrict q8 = y[i].qs;
@@ -1413,8 +3643,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1413
3643
 
1414
3644
  for (int i = 0; i < nb; ++i) {
1415
3645
 
1416
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1417
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3646
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3647
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1418
3648
 
1419
3649
  const uint8_t * restrict q2 = x[i].qs;
1420
3650
  const int8_t * restrict q8 = y[i].qs;
@@ -1480,8 +3710,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1480
3710
 
1481
3711
  for (int i = 0; i < nb; ++i) {
1482
3712
 
1483
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1484
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3713
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3714
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1485
3715
 
1486
3716
  const uint8_t * restrict q2 = x[i].qs;
1487
3717
  const int8_t * restrict q8 = y[i].qs;
@@ -1588,8 +3818,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1588
3818
  const int8_t * q8 = y[i].qs;
1589
3819
  const uint8_t * sc = x[i].scales;
1590
3820
 
1591
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1592
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3821
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3822
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1593
3823
 
1594
3824
  size_t vl = 16;
1595
3825
 
@@ -1675,8 +3905,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1675
3905
  summs += y[i].bsums[j] * (sc[j] >> 4);
1676
3906
  }
1677
3907
 
1678
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1679
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3908
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3909
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1680
3910
 
1681
3911
  int isum = 0;
1682
3912
  int is = 0;
@@ -1793,8 +4023,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1793
4023
 
1794
4024
  for (int i = 0; i < nb; ++i) {
1795
4025
 
1796
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1797
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
4026
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4027
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1798
4028
 
1799
4029
  const uint8_t * restrict q2 = x[i].qs;
1800
4030
  const int8_t * restrict q8 = y[i].qs;
@@ -1845,8 +4075,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1845
4075
 
1846
4076
  for (int i = 0; i < nb; ++i) {
1847
4077
 
1848
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1849
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
4078
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4079
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1850
4080
 
1851
4081
  const uint8_t * restrict q2 = x[i].qs;
1852
4082
  const int8_t * restrict q8 = y[i].qs;
@@ -1960,8 +4190,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1960
4190
  summs += y[i].bsums[j] * (sc[j] >> 4);
1961
4191
  }
1962
4192
 
1963
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1964
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
4193
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4194
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1965
4195
 
1966
4196
  isum[0] = isum[1] = isum[2] = isum[3] = 0;
1967
4197
  for (int l = 0; l < 16; ++l) {
@@ -2014,7 +4244,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2014
4244
 
2015
4245
  for (int i = 0; i < nb; ++i) {
2016
4246
 
2017
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4247
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2018
4248
 
2019
4249
  const uint8_t * restrict q3 = x[i].qs;
2020
4250
  const uint8_t * restrict qh = x[i].hmask;
@@ -2122,7 +4352,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2122
4352
 
2123
4353
  for (int i = 0; i < nb; ++i) {
2124
4354
 
2125
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4355
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2126
4356
 
2127
4357
  const uint8_t * restrict q3 = x[i].qs;
2128
4358
  const int8_t * restrict q8 = y[i].qs;
@@ -2227,7 +4457,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2227
4457
 
2228
4458
  for (int i = 0; i < nb; ++i) {
2229
4459
 
2230
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4460
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2231
4461
 
2232
4462
  const uint8_t * restrict q3 = x[i].qs;
2233
4463
  const int8_t * restrict q8 = y[i].qs;
@@ -2448,7 +4678,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2448
4678
 
2449
4679
  }
2450
4680
 
2451
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
4681
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2452
4682
 
2453
4683
  sumf += d*sum_t;
2454
4684
 
@@ -2513,7 +4743,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2513
4743
  for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
2514
4744
  q8 += 8; a += 8;
2515
4745
  }
2516
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
4746
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2517
4747
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2518
4748
  }
2519
4749
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -2615,7 +4845,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2615
4845
 
2616
4846
  for (int i = 0; i < nb; ++i) {
2617
4847
 
2618
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4848
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2619
4849
 
2620
4850
  const uint8_t * restrict q3 = x[i].qs;
2621
4851
  const int8_t * restrict q8 = y[i].qs;
@@ -2686,7 +4916,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2686
4916
 
2687
4917
  for (int i = 0; i < nb; ++i) {
2688
4918
 
2689
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4919
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2690
4920
 
2691
4921
  const uint8_t * restrict q3 = x[i].qs;
2692
4922
  const int8_t * restrict q8 = y[i].qs;
@@ -2871,7 +5101,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2871
5101
  q8 += 8; a += 8;
2872
5102
  for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
2873
5103
  }
2874
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
5104
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2875
5105
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2876
5106
  }
2877
5107
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -2911,8 +5141,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2911
5141
 
2912
5142
  for (int i = 0; i < nb; ++i) {
2913
5143
 
2914
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2915
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5144
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5145
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
2916
5146
 
2917
5147
  const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2918
5148
 
@@ -2994,8 +5224,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2994
5224
 
2995
5225
  for (int i = 0; i < nb; ++i) {
2996
5226
 
2997
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2998
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5227
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5228
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
2999
5229
 
3000
5230
  memcpy(utmp, x[i].scales, 12);
3001
5231
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -3060,8 +5290,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3060
5290
 
3061
5291
  for (int i = 0; i < nb; ++i) {
3062
5292
 
3063
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3064
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5293
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5294
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3065
5295
 
3066
5296
  const uint8_t * restrict q4 = x[i].qs;
3067
5297
  const int8_t * restrict q8 = y[i].qs;
@@ -3143,8 +5373,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3143
5373
 
3144
5374
  size_t vl = 8;
3145
5375
 
3146
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3147
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5376
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5377
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3148
5378
 
3149
5379
  vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
3150
5380
  vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
@@ -3254,9 +5484,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3254
5484
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3255
5485
  q8 += 8; a += 8;
3256
5486
  }
3257
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
5487
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
3258
5488
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3259
- const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
5489
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
3260
5490
  sumf -= dmin * sumi;
3261
5491
  }
3262
5492
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -3358,8 +5588,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3358
5588
 
3359
5589
  for (int i = 0; i < nb; ++i) {
3360
5590
 
3361
- const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
3362
- const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
5591
+ const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
5592
+ const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
3363
5593
  const __m256 vd = _mm256_set1_ps(d);
3364
5594
 
3365
5595
  const uint16_t * a = (const uint16_t *)x[i].scales;
@@ -3404,8 +5634,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3404
5634
 
3405
5635
  for (int i = 0; i < nb; ++i) {
3406
5636
 
3407
- const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
3408
- const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
5637
+ const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
5638
+ const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
3409
5639
  const __m256 vd = _mm256_set1_ps(d);
3410
5640
 
3411
5641
  const uint16_t * a = (const uint16_t *)x[i].scales;
@@ -3461,8 +5691,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3461
5691
  s16[0] = b[0] & 0x0f0f;
3462
5692
  s16[1] = (b[0] >> 4) & 0x0f0f;
3463
5693
 
3464
- sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
3465
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
5694
+ sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
5695
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
3466
5696
 
3467
5697
  size_t vl = 32;
3468
5698
 
@@ -3511,9 +5741,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3511
5741
  s16[0] = b[0] & 0x0f0f;
3512
5742
  s16[1] = (b[0] >> 4) & 0x0f0f;
3513
5743
 
3514
- sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
5744
+ sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
3515
5745
 
3516
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
5746
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
3517
5747
 
3518
5748
  for (int j = 0; j < QK_K/32; ++j) {
3519
5749
  for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
@@ -3561,8 +5791,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3561
5791
 
3562
5792
  for (int i = 0; i < nb; ++i) {
3563
5793
 
3564
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3565
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5794
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5795
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3566
5796
 
3567
5797
  const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
3568
5798
 
@@ -3650,8 +5880,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3650
5880
  const int8_t * restrict q8 = y[i].qs;
3651
5881
 
3652
5882
  #if QK_K == 256
3653
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3654
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5883
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5884
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3655
5885
 
3656
5886
  memcpy(utmp, x[i].scales, 12);
3657
5887
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -3732,8 +5962,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3732
5962
 
3733
5963
  for (int i = 0; i < nb; ++i) {
3734
5964
 
3735
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3736
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5965
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5966
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3737
5967
 
3738
5968
  const uint8_t * restrict q5 = x[i].qs;
3739
5969
  const int8_t * restrict q8 = y[i].qs;
@@ -3837,8 +6067,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3837
6067
  const uint8_t * restrict hm = x[i].qh;
3838
6068
  const int8_t * restrict q8 = y[i].qs;
3839
6069
 
3840
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
3841
- const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
6070
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6071
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
3842
6072
 
3843
6073
  vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
3844
6074
  vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
@@ -3960,9 +6190,9 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3960
6190
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3961
6191
  q8 += 8; a += 8;
3962
6192
  }
3963
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
6193
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
3964
6194
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3965
- const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
6195
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
3966
6196
  sumf -= dmin * sumi;
3967
6197
  }
3968
6198
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -4060,7 +6290,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
4060
6290
  const uint8_t * restrict q5 = x[i].qs;
4061
6291
  const int8_t * restrict q8 = y[i].qs;
4062
6292
 
4063
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6293
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4064
6294
 
4065
6295
  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
4066
6296
 
@@ -4106,7 +6336,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
4106
6336
  const uint8_t * restrict q5 = x[i].qs;
4107
6337
  const int8_t * restrict q8 = y[i].qs;
4108
6338
 
4109
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6339
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4110
6340
 
4111
6341
  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
4112
6342
 
@@ -4243,7 +6473,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
4243
6473
  for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
4244
6474
  }
4245
6475
 
4246
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6476
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4247
6477
  const int8_t * restrict sc = x[i].scales;
4248
6478
 
4249
6479
  for (int j = 0; j < QK_K/16; ++j) {
@@ -4286,7 +6516,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4286
6516
 
4287
6517
  for (int i = 0; i < nb; ++i) {
4288
6518
 
4289
- const float d_all = ggml_fp16_to_fp32(x[i].d);
6519
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
4290
6520
 
4291
6521
  const uint8_t * restrict q6 = x[i].ql;
4292
6522
  const uint8_t * restrict qh = x[i].qh;
@@ -4418,7 +6648,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4418
6648
 
4419
6649
  for (int i = 0; i < nb; ++i) {
4420
6650
 
4421
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6651
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4422
6652
 
4423
6653
  const uint8_t * restrict q4 = x[i].ql;
4424
6654
  const uint8_t * restrict qh = x[i].qh;
@@ -4498,7 +6728,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4498
6728
 
4499
6729
  for (int i = 0; i < nb; ++i) {
4500
6730
 
4501
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6731
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4502
6732
 
4503
6733
  const uint8_t * restrict q4 = x[i].ql;
4504
6734
  const uint8_t * restrict qh = x[i].qh;
@@ -4610,7 +6840,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4610
6840
  float sumf = 0;
4611
6841
  for (int i = 0; i < nb; ++i) {
4612
6842
 
4613
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
6843
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4614
6844
 
4615
6845
  const uint8_t * restrict q6 = x[i].ql;
4616
6846
  const uint8_t * restrict qh = x[i].qh;
@@ -4727,7 +6957,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4727
6957
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
4728
6958
  q8 += 8; a += 8;
4729
6959
  }
4730
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
6960
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4731
6961
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
4732
6962
  }
4733
6963
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -4825,7 +7055,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4825
7055
 
4826
7056
  for (int i = 0; i < nb; ++i) {
4827
7057
 
4828
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
7058
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4829
7059
 
4830
7060
  const uint8_t * restrict q4 = x[i].ql;
4831
7061
  const uint8_t * restrict qh = x[i].qh;
@@ -4882,7 +7112,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4882
7112
 
4883
7113
  for (int i = 0; i < nb; ++i) {
4884
7114
 
4885
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
7115
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4886
7116
 
4887
7117
  const uint8_t * restrict q4 = x[i].ql;
4888
7118
  const uint8_t * restrict qh = x[i].qh;
@@ -5041,7 +7271,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
5041
7271
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
5042
7272
  q8 += 8; a += 8;
5043
7273
  }
5044
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
7274
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5045
7275
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
5046
7276
  }
5047
7277
  for (int l = 0; l < 8; ++l) sumf += sums[l];