llama_cpp 0.9.0 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,9 +1,10 @@
1
- #include "k_quants.h"
2
- #include "ggml.h"
1
+ #include "ggml-quants.h"
2
+ #include "ggml-impl.h"
3
3
 
4
4
  #include <math.h>
5
5
  #include <string.h>
6
6
  #include <assert.h>
7
+ #include <float.h>
7
8
 
8
9
  #ifdef __ARM_NEON
9
10
 
@@ -65,1251 +66,3480 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
65
66
 
66
67
  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
67
68
 
68
- //
69
- // 2-6 bit quantization in super-blocks
70
- //
69
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
70
+ // multiply int8_t, add results pairwise twice
71
+ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
72
+ // Get absolute values of x vectors
73
+ const __m128i ax = _mm_sign_epi8(x, x);
74
+ // Sign the values of the y vectors
75
+ const __m128i sy = _mm_sign_epi8(y, x);
76
+ // Perform multiplication and create 16-bit values
77
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
78
+ const __m128i ones = _mm_set1_epi16(1);
79
+ return _mm_madd_epi16(ones, dot);
80
+ }
71
81
 
72
- //
73
- // ===================== Helper functions
74
- //
75
- static inline int nearest_int(float fval) {
76
- assert(fval <= 4194303.f);
77
- float val = fval + 12582912.f;
78
- int i; memcpy(&i, &val, sizeof(int));
79
- return (i & 0x007fffff) - 0x00400000;
82
+ #if __AVX__ || __AVX2__ || __AVX512F__
83
+ // horizontally add 8 floats
84
+ static inline float hsum_float_8(const __m256 x) {
85
+ __m128 res = _mm256_extractf128_ps(x, 1);
86
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
87
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
88
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
89
+ return _mm_cvtss_f32(res);
80
90
  }
81
91
 
82
- static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
83
- float max = 0;
84
- float amax = 0;
85
- for (int i = 0; i < n; ++i) {
86
- float ax = fabsf(x[i]);
87
- if (ax > amax) { amax = ax; max = x[i]; }
88
- }
89
- if (amax < 1e-30f) { // all zero
90
- for (int i = 0; i < n; ++i) {
91
- L[i] = 0;
92
+ // horizontally add 8 int32_t
93
+ static inline int hsum_i32_8(const __m256i a) {
94
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
95
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
96
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
97
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
98
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
99
+ }
100
+
101
+ // horizontally add 4 int32_t
102
+ static inline int hsum_i32_4(const __m128i a) {
103
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
104
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
105
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
106
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
107
+ }
108
+
109
+ #if defined(__AVX2__) || defined(__AVX512F__)
110
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
111
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
112
+ uint32_t x32;
113
+ memcpy(&x32, x, sizeof(uint32_t));
114
+ const __m256i shuf_mask = _mm256_set_epi64x(
115
+ 0x0303030303030303, 0x0202020202020202,
116
+ 0x0101010101010101, 0x0000000000000000);
117
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
118
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
119
+ bytes = _mm256_or_si256(bytes, bit_mask);
120
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
121
+ }
122
+
123
+ // Unpack 32 4-bit fields into 32 bytes
124
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
125
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
126
+ {
127
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
128
+ const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
129
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
130
+ return _mm256_and_si256(lowMask, bytes);
131
+ }
132
+
133
+ // add int16_t pairwise and return as float vector
134
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
135
+ const __m256i ones = _mm256_set1_epi16(1);
136
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
137
+ return _mm256_cvtepi32_ps(summed_pairs);
138
+ }
139
+
140
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
141
+ #if __AVXVNNI__
142
+ const __m256i zero = _mm256_setzero_si256();
143
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
144
+ return _mm256_cvtepi32_ps(summed_pairs);
145
+ #else
146
+ // Perform multiplication and create 16-bit values
147
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
148
+ return sum_i16_pairs_float(dot);
149
+ #endif
150
+ }
151
+
152
+ // multiply int8_t, add results pairwise twice and return as float vector
153
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
154
+ #if __AVXVNNIINT8__
155
+ const __m256i zero = _mm256_setzero_si256();
156
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
157
+ return _mm256_cvtepi32_ps(summed_pairs);
158
+ #else
159
+ // Get absolute values of x vectors
160
+ const __m256i ax = _mm256_sign_epi8(x, x);
161
+ // Sign the values of the y vectors
162
+ const __m256i sy = _mm256_sign_epi8(y, x);
163
+ return mul_sum_us8_pairs_float(ax, sy);
164
+ #endif
165
+ }
166
+
167
+ static inline __m128i packNibbles( __m256i bytes )
168
+ {
169
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
170
+ #if __AVX512F__
171
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
172
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
173
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
174
+ #else
175
+ const __m256i lowByte = _mm256_set1_epi16( 0xFF );
176
+ __m256i high = _mm256_andnot_si256( lowByte, bytes );
177
+ __m256i low = _mm256_and_si256( lowByte, bytes );
178
+ high = _mm256_srli_epi16( high, 4 );
179
+ bytes = _mm256_or_si256( low, high );
180
+
181
+ // Compress uint16_t lanes into bytes
182
+ __m128i r0 = _mm256_castsi256_si128( bytes );
183
+ __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
184
+ return _mm_packus_epi16( r0, r1 );
185
+ #endif
186
+ }
187
+ #elif defined(__AVX__)
188
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
189
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
190
+ uint32_t x32;
191
+ memcpy(&x32, x, sizeof(uint32_t));
192
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
193
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
194
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
195
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
196
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
197
+ bytesl = _mm_or_si128(bytesl, bit_mask);
198
+ bytesh = _mm_or_si128(bytesh, bit_mask);
199
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
200
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
201
+ return MM256_SET_M128I(bytesh, bytesl);
202
+ }
203
+
204
+ // Unpack 32 4-bit fields into 32 bytes
205
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
206
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
207
+ {
208
+ // Load 16 bytes from memory
209
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
210
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
211
+ const __m128i lowMask = _mm_set1_epi8(0xF);
212
+ tmpl = _mm_and_si128(lowMask, tmpl);
213
+ tmph = _mm_and_si128(lowMask, tmph);
214
+ return MM256_SET_M128I(tmph, tmpl);
215
+ }
216
+
217
+ // add int16_t pairwise and return as float vector
218
+ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
219
+ const __m128i ones = _mm_set1_epi16(1);
220
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
221
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
222
+ const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
223
+ return _mm256_cvtepi32_ps(summed_pairs);
224
+ }
225
+
226
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
227
+ const __m128i axl = _mm256_castsi256_si128(ax);
228
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
229
+ const __m128i syl = _mm256_castsi256_si128(sy);
230
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
231
+ // Perform multiplication and create 16-bit values
232
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
233
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
234
+ return sum_i16_pairs_float(doth, dotl);
235
+ }
236
+
237
+ // multiply int8_t, add results pairwise twice and return as float vector
238
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
239
+ const __m128i xl = _mm256_castsi256_si128(x);
240
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
241
+ const __m128i yl = _mm256_castsi256_si128(y);
242
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
243
+ // Get absolute values of x vectors
244
+ const __m128i axl = _mm_sign_epi8(xl, xl);
245
+ const __m128i axh = _mm_sign_epi8(xh, xh);
246
+ // Sign the values of the y vectors
247
+ const __m128i syl = _mm_sign_epi8(yl, xl);
248
+ const __m128i syh = _mm_sign_epi8(yh, xh);
249
+ // Perform multiplication and create 16-bit values
250
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
251
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
252
+ return sum_i16_pairs_float(doth, dotl);
253
+ }
254
+
255
+ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
256
+ {
257
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
258
+ const __m128i lowByte = _mm_set1_epi16( 0xFF );
259
+ __m128i high = _mm_andnot_si128( lowByte, bytes1 );
260
+ __m128i low = _mm_and_si128( lowByte, bytes1 );
261
+ high = _mm_srli_epi16( high, 4 );
262
+ bytes1 = _mm_or_si128( low, high );
263
+ high = _mm_andnot_si128( lowByte, bytes2 );
264
+ low = _mm_and_si128( lowByte, bytes2 );
265
+ high = _mm_srli_epi16( high, 4 );
266
+ bytes2 = _mm_or_si128( low, high );
267
+
268
+ return _mm_packus_epi16( bytes1, bytes2);
269
+ }
270
+ #endif
271
+ #elif defined(__SSSE3__)
272
+ // horizontally add 4x4 floats
273
+ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
274
+ __m128 res_0 =_mm_hadd_ps(a, b);
275
+ __m128 res_1 =_mm_hadd_ps(c, d);
276
+ __m128 res =_mm_hadd_ps(res_0, res_1);
277
+ res =_mm_hadd_ps(res, res);
278
+ res =_mm_hadd_ps(res, res);
279
+
280
+ return _mm_cvtss_f32(res);
281
+ }
282
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
283
+ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
284
+
285
+ #if defined(__ARM_NEON)
286
+
287
+ #if !defined(__aarch64__)
288
+
289
+ inline static int32_t vaddvq_s32(int32x4_t v) {
290
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
291
+ }
292
+
293
+ inline static float vaddvq_f32(float32x4_t v) {
294
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
295
+ }
296
+
297
+ inline static float vmaxvq_f32(float32x4_t v) {
298
+ return
299
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
300
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
301
+ }
302
+
303
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
304
+ int32x4_t res;
305
+
306
+ res[0] = roundf(vgetq_lane_f32(v, 0));
307
+ res[1] = roundf(vgetq_lane_f32(v, 1));
308
+ res[2] = roundf(vgetq_lane_f32(v, 2));
309
+ res[3] = roundf(vgetq_lane_f32(v, 3));
310
+
311
+ return res;
312
+ }
313
+
314
+ #endif
315
+ #endif
316
+
317
+ #if defined(__ARM_NEON) || defined(__wasm_simd128__)
318
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
319
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
320
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
321
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
322
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
323
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
324
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
325
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
326
+
327
+ // precomputed tables for expanding 8bits to 8 bytes:
328
+ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
329
+ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
330
+ #endif
331
+
332
+ // reference implementation for deterministic creation of model files
333
+ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
334
+ static const int qk = QK4_0;
335
+
336
+ assert(k % qk == 0);
337
+
338
+ const int nb = k / qk;
339
+
340
+ for (int i = 0; i < nb; i++) {
341
+ float amax = 0.0f; // absolute max
342
+ float max = 0.0f;
343
+
344
+ for (int j = 0; j < qk; j++) {
345
+ const float v = x[i*qk + j];
346
+ if (amax < fabsf(v)) {
347
+ amax = fabsf(v);
348
+ max = v;
349
+ }
92
350
  }
93
- return 0.f;
94
- }
95
- float iscale = -nmax / max;
96
- if (rmse_type == 0) {
97
- for (int i = 0; i < n; ++i) {
98
- int l = nearest_int(iscale * x[i]);
99
- L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
351
+
352
+ const float d = max / -8;
353
+ const float id = d ? 1.0f/d : 0.0f;
354
+
355
+ y[i].d = GGML_FP32_TO_FP16(d);
356
+
357
+ for (int j = 0; j < qk/2; ++j) {
358
+ const float x0 = x[i*qk + 0 + j]*id;
359
+ const float x1 = x[i*qk + qk/2 + j]*id;
360
+
361
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
362
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
363
+
364
+ y[i].qs[j] = xi0;
365
+ y[i].qs[j] |= xi1 << 4;
100
366
  }
101
- return 1/iscale;
102
- }
103
- bool return_early = false;
104
- if (rmse_type < 0) {
105
- rmse_type = -rmse_type;
106
- return_early = true;
107
- }
108
- int weight_type = rmse_type%2;
109
- float sumlx = 0;
110
- float suml2 = 0;
111
- for (int i = 0; i < n; ++i) {
112
- int l = nearest_int(iscale * x[i]);
113
- l = MAX(-nmax, MIN(nmax-1, l));
114
- L[i] = l + nmax;
115
- float w = weight_type == 1 ? x[i] * x[i] : 1;
116
- sumlx += w*x[i]*l;
117
- suml2 += w*l*l;
118
367
  }
119
- float scale = sumlx/suml2;
120
- if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
121
- float best = scale * sumlx;
122
- for (int is = -9; is <= 9; ++is) {
123
- if (is == 0) {
124
- continue;
368
+ }
369
+
370
+ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
371
+ quantize_row_q4_0_reference(x, y, k);
372
+ }
373
+
374
+ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
375
+ const int qk = QK4_1;
376
+
377
+ assert(k % qk == 0);
378
+
379
+ const int nb = k / qk;
380
+
381
+ for (int i = 0; i < nb; i++) {
382
+ float min = FLT_MAX;
383
+ float max = -FLT_MAX;
384
+
385
+ for (int j = 0; j < qk; j++) {
386
+ const float v = x[i*qk + j];
387
+
388
+ if (v < min) min = v;
389
+ if (v > max) max = v;
125
390
  }
126
- iscale = -(nmax + 0.1f*is) / max;
127
- sumlx = suml2 = 0;
128
- for (int i = 0; i < n; ++i) {
129
- int l = nearest_int(iscale * x[i]);
130
- l = MAX(-nmax, MIN(nmax-1, l));
131
- float w = weight_type == 1 ? x[i] * x[i] : 1;
132
- sumlx += w*x[i]*l;
133
- suml2 += w*l*l;
391
+
392
+ const float d = (max - min) / ((1 << 4) - 1);
393
+ const float id = d ? 1.0f/d : 0.0f;
394
+
395
+ y[i].d = GGML_FP32_TO_FP16(d);
396
+ y[i].m = GGML_FP32_TO_FP16(min);
397
+
398
+ for (int j = 0; j < qk/2; ++j) {
399
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
400
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
401
+
402
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
403
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
404
+
405
+ y[i].qs[j] = xi0;
406
+ y[i].qs[j] |= xi1 << 4;
134
407
  }
135
- if (suml2 > 0 && sumlx*sumlx > best*suml2) {
136
- for (int i = 0; i < n; ++i) {
137
- int l = nearest_int(iscale * x[i]);
138
- L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
408
+ }
409
+ }
410
+
411
+ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
412
+ quantize_row_q4_1_reference(x, y, k);
413
+ }
414
+
415
+ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
416
+ static const int qk = QK5_0;
417
+
418
+ assert(k % qk == 0);
419
+
420
+ const int nb = k / qk;
421
+
422
+ for (int i = 0; i < nb; i++) {
423
+ float amax = 0.0f; // absolute max
424
+ float max = 0.0f;
425
+
426
+ for (int j = 0; j < qk; j++) {
427
+ const float v = x[i*qk + j];
428
+ if (amax < fabsf(v)) {
429
+ amax = fabsf(v);
430
+ max = v;
139
431
  }
140
- scale = sumlx/suml2; best = scale*sumlx;
141
432
  }
433
+
434
+ const float d = max / -16;
435
+ const float id = d ? 1.0f/d : 0.0f;
436
+
437
+ y[i].d = GGML_FP32_TO_FP16(d);
438
+
439
+ uint32_t qh = 0;
440
+
441
+ for (int j = 0; j < qk/2; ++j) {
442
+ const float x0 = x[i*qk + 0 + j]*id;
443
+ const float x1 = x[i*qk + qk/2 + j]*id;
444
+
445
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
446
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
447
+
448
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
449
+
450
+ // get the 5-th bit and store it in qh at the right position
451
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
452
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
453
+ }
454
+
455
+ memcpy(&y[i].qh, &qh, sizeof(qh));
142
456
  }
143
- return scale;
144
457
  }
145
458
 
146
- static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
147
- float max = 0;
148
- float amax = 0;
149
- for (int i = 0; i < n; ++i) {
150
- float ax = fabsf(x[i]);
151
- if (ax > amax) { amax = ax; max = x[i]; }
459
+ void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
460
+ quantize_row_q5_0_reference(x, y, k);
461
+ }
462
+
463
+ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
464
+ const int qk = QK5_1;
465
+
466
+ assert(k % qk == 0);
467
+
468
+ const int nb = k / qk;
469
+
470
+ for (int i = 0; i < nb; i++) {
471
+ float min = FLT_MAX;
472
+ float max = -FLT_MAX;
473
+
474
+ for (int j = 0; j < qk; j++) {
475
+ const float v = x[i*qk + j];
476
+
477
+ if (v < min) min = v;
478
+ if (v > max) max = v;
479
+ }
480
+
481
+ const float d = (max - min) / ((1 << 5) - 1);
482
+ const float id = d ? 1.0f/d : 0.0f;
483
+
484
+ y[i].d = GGML_FP32_TO_FP16(d);
485
+ y[i].m = GGML_FP32_TO_FP16(min);
486
+
487
+ uint32_t qh = 0;
488
+
489
+ for (int j = 0; j < qk/2; ++j) {
490
+ const float x0 = (x[i*qk + 0 + j] - min)*id;
491
+ const float x1 = (x[i*qk + qk/2 + j] - min)*id;
492
+
493
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
494
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
495
+
496
+ y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
497
+
498
+ // get the 5-th bit and store it in qh at the right position
499
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
500
+ qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
501
+ }
502
+
503
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
152
504
  }
153
- if (!amax) { // all zero
154
- for (int i = 0; i < n; ++i) { L[i] = 0; }
155
- return 0.f;
505
+ }
506
+
507
+ void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
508
+ quantize_row_q5_1_reference(x, y, k);
509
+ }
510
+
511
+ // reference implementation for deterministic creation of model files
512
+ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
513
+ assert(k % QK8_0 == 0);
514
+ const int nb = k / QK8_0;
515
+
516
+ for (int i = 0; i < nb; i++) {
517
+ float amax = 0.0f; // absolute max
518
+
519
+ for (int j = 0; j < QK8_0; j++) {
520
+ const float v = x[i*QK8_0 + j];
521
+ amax = MAX(amax, fabsf(v));
522
+ }
523
+
524
+ const float d = amax / ((1 << 7) - 1);
525
+ const float id = d ? 1.0f/d : 0.0f;
526
+
527
+ y[i].d = GGML_FP32_TO_FP16(d);
528
+
529
+ for (int j = 0; j < QK8_0; ++j) {
530
+ const float x0 = x[i*QK8_0 + j]*id;
531
+
532
+ y[i].qs[j] = roundf(x0);
533
+ }
156
534
  }
157
- float iscale = -nmax / max;
158
- if (do_rmse) {
159
- float sumlx = 0;
160
- float suml2 = 0;
161
- for (int i = 0; i < n; ++i) {
162
- int l = nearest_int(iscale * x[i]);
163
- l = MAX(-nmax, MIN(nmax-1, l));
164
- L[i] = l;
165
- float w = x[i]*x[i];
166
- sumlx += w*x[i]*l;
167
- suml2 += w*l*l;
535
+ }
536
+
537
+ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
538
+ assert(QK8_0 == 32);
539
+ assert(k % QK8_0 == 0);
540
+ const int nb = k / QK8_0;
541
+
542
+ block_q8_0 * restrict y = vy;
543
+
544
+ #if defined(__ARM_NEON)
545
+ for (int i = 0; i < nb; i++) {
546
+ float32x4_t srcv [8];
547
+ float32x4_t asrcv[8];
548
+ float32x4_t amaxv[8];
549
+
550
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
551
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
552
+
553
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
554
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
555
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
556
+
557
+ const float amax = vmaxvq_f32(amaxv[0]);
558
+
559
+ const float d = amax / ((1 << 7) - 1);
560
+ const float id = d ? 1.0f/d : 0.0f;
561
+
562
+ y[i].d = GGML_FP32_TO_FP16(d);
563
+
564
+ for (int j = 0; j < 8; j++) {
565
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
566
+ const int32x4_t vi = vcvtnq_s32_f32(v);
567
+
568
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
569
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
570
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
571
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
168
572
  }
169
- for (int itry = 0; itry < 5; ++itry) {
170
- int n_changed = 0;
171
- for (int i = 0; i < n; ++i) {
172
- float w = x[i]*x[i];
173
- float slx = sumlx - w*x[i]*L[i];
174
- if (slx > 0) {
175
- float sl2 = suml2 - w*L[i]*L[i];
176
- int new_l = nearest_int(x[i] * sl2 / slx);
177
- new_l = MAX(-nmax, MIN(nmax-1, new_l));
178
- if (new_l != L[i]) {
179
- slx += w*x[i]*new_l;
180
- sl2 += w*new_l*new_l;
181
- if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
182
- L[i] = new_l; sumlx = slx; suml2 = sl2;
183
- ++n_changed;
184
- }
185
- }
186
- }
573
+ }
574
+ #elif defined(__wasm_simd128__)
575
+ for (int i = 0; i < nb; i++) {
576
+ v128_t srcv [8];
577
+ v128_t asrcv[8];
578
+ v128_t amaxv[8];
579
+
580
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
581
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
582
+
583
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
584
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
585
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
586
+
587
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
588
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
589
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
590
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
591
+
592
+ const float d = amax / ((1 << 7) - 1);
593
+ const float id = d ? 1.0f/d : 0.0f;
594
+
595
+ y[i].d = GGML_FP32_TO_FP16(d);
596
+
597
+ for (int j = 0; j < 8; j++) {
598
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
599
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
600
+
601
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
602
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
603
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
604
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
605
+ }
606
+ }
607
+ #elif defined(__AVX2__) || defined(__AVX__)
608
+ for (int i = 0; i < nb; i++) {
609
+ // Load elements into 4 AVX vectors
610
+ __m256 v0 = _mm256_loadu_ps( x );
611
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
612
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
613
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
614
+ x += 32;
615
+
616
+ // Compute max(abs(e)) for the block
617
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
618
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
619
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
620
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
621
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
622
+
623
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
624
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
625
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
626
+ const float maxScalar = _mm_cvtss_f32( max4 );
627
+
628
+ // Quantize these floats
629
+ const float d = maxScalar / 127.f;
630
+ y[i].d = GGML_FP32_TO_FP16(d);
631
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
632
+ const __m256 mul = _mm256_set1_ps( id );
633
+
634
+ // Apply the multiplier
635
+ v0 = _mm256_mul_ps( v0, mul );
636
+ v1 = _mm256_mul_ps( v1, mul );
637
+ v2 = _mm256_mul_ps( v2, mul );
638
+ v3 = _mm256_mul_ps( v3, mul );
639
+
640
+ // Round to nearest integer
641
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
642
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
643
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
644
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
645
+
646
+ // Convert floats to integers
647
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
648
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
649
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
650
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
651
+
652
+ #if defined(__AVX2__)
653
+ // Convert int32 to int16
654
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
655
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
656
+ // Convert int16 to int8
657
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
658
+
659
+ // We got our precious signed bytes, but the order is now wrong
660
+ // These AVX2 pack instructions process 16-byte pieces independently
661
+ // The following instruction is fixing the order
662
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
663
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
664
+
665
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
666
+ #else
667
+ // Since we don't have in AVX some necessary functions,
668
+ // we split the registers in half and call AVX2 analogs from SSE
669
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
670
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
671
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
672
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
673
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
674
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
675
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
676
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
677
+
678
+ // Convert int32 to int16
679
+ ni0 = _mm_packs_epi32( ni0, ni1 );
680
+ ni2 = _mm_packs_epi32( ni2, ni3 );
681
+ ni4 = _mm_packs_epi32( ni4, ni5 );
682
+ ni6 = _mm_packs_epi32( ni6, ni7 );
683
+ // Convert int16 to int8
684
+ ni0 = _mm_packs_epi16( ni0, ni2 );
685
+ ni4 = _mm_packs_epi16( ni4, ni6 );
686
+
687
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
688
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
689
+ #endif
690
+ }
691
+ #elif defined(__riscv_v_intrinsic)
692
+
693
+ size_t vl = __riscv_vsetvl_e32m4(QK8_0);
694
+
695
+ for (int i = 0; i < nb; i++) {
696
+ // load elements
697
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
698
+
699
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
700
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
701
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
702
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
703
+
704
+ const float d = amax / ((1 << 7) - 1);
705
+ const float id = d ? 1.0f/d : 0.0f;
706
+
707
+ y[i].d = GGML_FP32_TO_FP16(d);
708
+
709
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
710
+
711
+ // convert to integer
712
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
713
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
714
+
715
+ // store result
716
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
717
+ }
718
+ #else
719
+ GGML_UNUSED(nb);
720
+ // scalar
721
+ quantize_row_q8_0_reference(x, y, k);
722
+ #endif
723
+ }
724
+
725
+ // reference implementation for deterministic creation of model files
726
+ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
727
+ assert(QK8_1 == 32);
728
+ assert(k % QK8_1 == 0);
729
+ const int nb = k / QK8_1;
730
+
731
+ for (int i = 0; i < nb; i++) {
732
+ float amax = 0.0f; // absolute max
733
+
734
+ for (int j = 0; j < QK8_1; j++) {
735
+ const float v = x[i*QK8_1 + j];
736
+ amax = MAX(amax, fabsf(v));
737
+ }
738
+
739
+ const float d = amax / ((1 << 7) - 1);
740
+ const float id = d ? 1.0f/d : 0.0f;
741
+
742
+ y[i].d = d;
743
+
744
+ int sum = 0;
745
+
746
+ for (int j = 0; j < QK8_1/2; ++j) {
747
+ const float v0 = x[i*QK8_1 + j]*id;
748
+ const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
749
+
750
+ y[i].qs[ j] = roundf(v0);
751
+ y[i].qs[QK8_1/2 + j] = roundf(v1);
752
+
753
+ sum += y[i].qs[ j];
754
+ sum += y[i].qs[QK8_1/2 + j];
755
+ }
756
+
757
+ y[i].s = sum*d;
758
+ }
759
+ }
760
+
761
+ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
762
+ assert(k % QK8_1 == 0);
763
+ const int nb = k / QK8_1;
764
+
765
+ block_q8_1 * restrict y = vy;
766
+
767
+ #if defined(__ARM_NEON)
768
+ for (int i = 0; i < nb; i++) {
769
+ float32x4_t srcv [8];
770
+ float32x4_t asrcv[8];
771
+ float32x4_t amaxv[8];
772
+
773
+ for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
774
+ for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
775
+
776
+ for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
777
+ for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
778
+ for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
779
+
780
+ const float amax = vmaxvq_f32(amaxv[0]);
781
+
782
+ const float d = amax / ((1 << 7) - 1);
783
+ const float id = d ? 1.0f/d : 0.0f;
784
+
785
+ y[i].d = d;
786
+
787
+ int32x4_t accv = vdupq_n_s32(0);
788
+
789
+ for (int j = 0; j < 8; j++) {
790
+ const float32x4_t v = vmulq_n_f32(srcv[j], id);
791
+ const int32x4_t vi = vcvtnq_s32_f32(v);
792
+
793
+ y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
794
+ y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
795
+ y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
796
+ y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
797
+
798
+ accv = vaddq_s32(accv, vi);
799
+ }
800
+
801
+ y[i].s = d * vaddvq_s32(accv);
802
+ }
803
+ #elif defined(__wasm_simd128__)
804
+ for (int i = 0; i < nb; i++) {
805
+ v128_t srcv [8];
806
+ v128_t asrcv[8];
807
+ v128_t amaxv[8];
808
+
809
+ for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
810
+ for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
811
+
812
+ for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
813
+ for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
814
+ for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
815
+
816
+ const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
817
+ wasm_f32x4_extract_lane(amaxv[0], 1)),
818
+ MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
819
+ wasm_f32x4_extract_lane(amaxv[0], 3)));
820
+
821
+ const float d = amax / ((1 << 7) - 1);
822
+ const float id = d ? 1.0f/d : 0.0f;
823
+
824
+ y[i].d = d;
825
+
826
+ v128_t accv = wasm_i32x4_splat(0);
827
+
828
+ for (int j = 0; j < 8; j++) {
829
+ const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
830
+ const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
831
+
832
+ y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
833
+ y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
834
+ y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
835
+ y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
836
+
837
+ accv = wasm_i32x4_add(accv, vi);
838
+ }
839
+
840
+ y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
841
+ wasm_i32x4_extract_lane(accv, 1) +
842
+ wasm_i32x4_extract_lane(accv, 2) +
843
+ wasm_i32x4_extract_lane(accv, 3));
844
+ }
845
+ #elif defined(__AVX2__) || defined(__AVX__)
846
+ for (int i = 0; i < nb; i++) {
847
+ // Load elements into 4 AVX vectors
848
+ __m256 v0 = _mm256_loadu_ps( x );
849
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
850
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
851
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
852
+ x += 32;
853
+
854
+ // Compute max(abs(e)) for the block
855
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
856
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
857
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
858
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
859
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
860
+
861
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
862
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
863
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
864
+ const float maxScalar = _mm_cvtss_f32( max4 );
865
+
866
+ // Quantize these floats
867
+ const float d = maxScalar / 127.f;
868
+ y[i].d = d;
869
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
870
+ const __m256 mul = _mm256_set1_ps( id );
871
+
872
+ // Apply the multiplier
873
+ v0 = _mm256_mul_ps( v0, mul );
874
+ v1 = _mm256_mul_ps( v1, mul );
875
+ v2 = _mm256_mul_ps( v2, mul );
876
+ v3 = _mm256_mul_ps( v3, mul );
877
+
878
+ // Round to nearest integer
879
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
880
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
881
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
882
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
883
+
884
+ // Convert floats to integers
885
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
886
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
887
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
888
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
889
+
890
+ #if defined(__AVX2__)
891
+ // Compute the sum of the quants and set y[i].s
892
+ y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
893
+
894
+ // Convert int32 to int16
895
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
896
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
897
+ // Convert int16 to int8
898
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
899
+
900
+ // We got our precious signed bytes, but the order is now wrong
901
+ // These AVX2 pack instructions process 16-byte pieces independently
902
+ // The following instruction is fixing the order
903
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
904
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
905
+
906
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
907
+ #else
908
+ // Since we don't have in AVX some necessary functions,
909
+ // we split the registers in half and call AVX2 analogs from SSE
910
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
911
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
912
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
913
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
914
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
915
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
916
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
917
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
918
+
919
+ // Compute the sum of the quants and set y[i].s
920
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
921
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
922
+ y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
923
+
924
+ // Convert int32 to int16
925
+ ni0 = _mm_packs_epi32( ni0, ni1 );
926
+ ni2 = _mm_packs_epi32( ni2, ni3 );
927
+ ni4 = _mm_packs_epi32( ni4, ni5 );
928
+ ni6 = _mm_packs_epi32( ni6, ni7 );
929
+ // Convert int16 to int8
930
+ ni0 = _mm_packs_epi16( ni0, ni2 );
931
+ ni4 = _mm_packs_epi16( ni4, ni6 );
932
+
933
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
934
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
935
+ #endif
936
+ }
937
+ #elif defined(__riscv_v_intrinsic)
938
+
939
+ size_t vl = __riscv_vsetvl_e32m4(QK8_1);
940
+
941
+ for (int i = 0; i < nb; i++) {
942
+ // load elements
943
+ vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
944
+
945
+ vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
946
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
947
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
948
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
949
+
950
+ const float d = amax / ((1 << 7) - 1);
951
+ const float id = d ? 1.0f/d : 0.0f;
952
+
953
+ y[i].d = d;
954
+
955
+ vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
956
+
957
+ // convert to integer
958
+ vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
959
+ vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
960
+
961
+ // store result
962
+ __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
963
+
964
+ // compute sum for y[i].s
965
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
966
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
967
+
968
+ // set y[i].s
969
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
970
+ y[i].s = sum*d;
971
+ }
972
+ #else
973
+ GGML_UNUSED(nb);
974
+ // scalar
975
+ quantize_row_q8_1_reference(x, y, k);
976
+ #endif
977
+ }
978
+
979
+ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
980
+ static const int qk = QK4_0;
981
+
982
+ assert(k % qk == 0);
983
+
984
+ const int nb = k / qk;
985
+
986
+ for (int i = 0; i < nb; i++) {
987
+ const float d = GGML_FP16_TO_FP32(x[i].d);
988
+
989
+ for (int j = 0; j < qk/2; ++j) {
990
+ const int x0 = (x[i].qs[j] & 0x0F) - 8;
991
+ const int x1 = (x[i].qs[j] >> 4) - 8;
992
+
993
+ y[i*qk + j + 0 ] = x0*d;
994
+ y[i*qk + j + qk/2] = x1*d;
995
+ }
996
+ }
997
+ }
998
+
999
+ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
1000
+ static const int qk = QK4_1;
1001
+
1002
+ assert(k % qk == 0);
1003
+
1004
+ const int nb = k / qk;
1005
+
1006
+ for (int i = 0; i < nb; i++) {
1007
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1008
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1009
+
1010
+ for (int j = 0; j < qk/2; ++j) {
1011
+ const int x0 = (x[i].qs[j] & 0x0F);
1012
+ const int x1 = (x[i].qs[j] >> 4);
1013
+
1014
+ y[i*qk + j + 0 ] = x0*d + m;
1015
+ y[i*qk + j + qk/2] = x1*d + m;
1016
+ }
1017
+ }
1018
+ }
1019
+
1020
+ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
1021
+ static const int qk = QK5_0;
1022
+
1023
+ assert(k % qk == 0);
1024
+
1025
+ const int nb = k / qk;
1026
+
1027
+ for (int i = 0; i < nb; i++) {
1028
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1029
+
1030
+ uint32_t qh;
1031
+ memcpy(&qh, x[i].qh, sizeof(qh));
1032
+
1033
+ for (int j = 0; j < qk/2; ++j) {
1034
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1035
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1036
+
1037
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
1038
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
1039
+
1040
+ y[i*qk + j + 0 ] = x0*d;
1041
+ y[i*qk + j + qk/2] = x1*d;
1042
+ }
1043
+ }
1044
+ }
1045
+
1046
+ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
1047
+ static const int qk = QK5_1;
1048
+
1049
+ assert(k % qk == 0);
1050
+
1051
+ const int nb = k / qk;
1052
+
1053
+ for (int i = 0; i < nb; i++) {
1054
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1055
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1056
+
1057
+ uint32_t qh;
1058
+ memcpy(&qh, x[i].qh, sizeof(qh));
1059
+
1060
+ for (int j = 0; j < qk/2; ++j) {
1061
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1062
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1063
+
1064
+ const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
1065
+ const int x1 = (x[i].qs[j] >> 4) | xh_1;
1066
+
1067
+ y[i*qk + j + 0 ] = x0*d + m;
1068
+ y[i*qk + j + qk/2] = x1*d + m;
1069
+ }
1070
+ }
1071
+ }
1072
+
1073
+ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) {
1074
+ static const int qk = QK8_0;
1075
+
1076
+ assert(k % qk == 0);
1077
+
1078
+ const int nb = k / qk;
1079
+
1080
+ for (int i = 0; i < nb; i++) {
1081
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1082
+
1083
+ for (int j = 0; j < qk; ++j) {
1084
+ y[i*qk + j] = x[i].qs[j]*d;
1085
+ }
1086
+ }
1087
+ }
1088
+
1089
+ //
1090
+ // 2-6 bit quantization in super-blocks
1091
+ //
1092
+
1093
+ //
1094
+ // ===================== Helper functions
1095
+ //
1096
+ static inline int nearest_int(float fval) {
1097
+ assert(fval <= 4194303.f);
1098
+ float val = fval + 12582912.f;
1099
+ int i; memcpy(&i, &val, sizeof(int));
1100
+ return (i & 0x007fffff) - 0x00400000;
1101
+ }
1102
+
1103
+ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
1104
+ float max = 0;
1105
+ float amax = 0;
1106
+ for (int i = 0; i < n; ++i) {
1107
+ float ax = fabsf(x[i]);
1108
+ if (ax > amax) { amax = ax; max = x[i]; }
1109
+ }
1110
+ if (amax < 1e-30f) { // all zero
1111
+ for (int i = 0; i < n; ++i) {
1112
+ L[i] = 0;
1113
+ }
1114
+ return 0.f;
1115
+ }
1116
+ float iscale = -nmax / max;
1117
+ if (rmse_type == 0) {
1118
+ for (int i = 0; i < n; ++i) {
1119
+ int l = nearest_int(iscale * x[i]);
1120
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1121
+ }
1122
+ return 1/iscale;
1123
+ }
1124
+ bool return_early = false;
1125
+ if (rmse_type < 0) {
1126
+ rmse_type = -rmse_type;
1127
+ return_early = true;
1128
+ }
1129
+ int weight_type = rmse_type%2;
1130
+ float sumlx = 0;
1131
+ float suml2 = 0;
1132
+ for (int i = 0; i < n; ++i) {
1133
+ int l = nearest_int(iscale * x[i]);
1134
+ l = MAX(-nmax, MIN(nmax-1, l));
1135
+ L[i] = l + nmax;
1136
+ float w = weight_type == 1 ? x[i] * x[i] : 1;
1137
+ sumlx += w*x[i]*l;
1138
+ suml2 += w*l*l;
1139
+ }
1140
+ float scale = sumlx/suml2;
1141
+ if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
1142
+ float best = scale * sumlx;
1143
+ for (int is = -9; is <= 9; ++is) {
1144
+ if (is == 0) {
1145
+ continue;
1146
+ }
1147
+ iscale = -(nmax + 0.1f*is) / max;
1148
+ sumlx = suml2 = 0;
1149
+ for (int i = 0; i < n; ++i) {
1150
+ int l = nearest_int(iscale * x[i]);
1151
+ l = MAX(-nmax, MIN(nmax-1, l));
1152
+ float w = weight_type == 1 ? x[i] * x[i] : 1;
1153
+ sumlx += w*x[i]*l;
1154
+ suml2 += w*l*l;
1155
+ }
1156
+ if (suml2 > 0 && sumlx*sumlx > best*suml2) {
1157
+ for (int i = 0; i < n; ++i) {
1158
+ int l = nearest_int(iscale * x[i]);
1159
+ L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
1160
+ }
1161
+ scale = sumlx/suml2; best = scale*sumlx;
1162
+ }
1163
+ }
1164
+ return scale;
1165
+ }
1166
+
1167
+ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
1168
+ float max = 0;
1169
+ float amax = 0;
1170
+ for (int i = 0; i < n; ++i) {
1171
+ float ax = fabsf(x[i]);
1172
+ if (ax > amax) { amax = ax; max = x[i]; }
1173
+ }
1174
+ if (!amax) { // all zero
1175
+ for (int i = 0; i < n; ++i) { L[i] = 0; }
1176
+ return 0.f;
1177
+ }
1178
+ float iscale = -nmax / max;
1179
+ if (do_rmse) {
1180
+ float sumlx = 0;
1181
+ float suml2 = 0;
1182
+ for (int i = 0; i < n; ++i) {
1183
+ int l = nearest_int(iscale * x[i]);
1184
+ l = MAX(-nmax, MIN(nmax-1, l));
1185
+ L[i] = l;
1186
+ float w = x[i]*x[i];
1187
+ sumlx += w*x[i]*l;
1188
+ suml2 += w*l*l;
1189
+ }
1190
+ for (int itry = 0; itry < 5; ++itry) {
1191
+ int n_changed = 0;
1192
+ for (int i = 0; i < n; ++i) {
1193
+ float w = x[i]*x[i];
1194
+ float slx = sumlx - w*x[i]*L[i];
1195
+ if (slx > 0) {
1196
+ float sl2 = suml2 - w*L[i]*L[i];
1197
+ int new_l = nearest_int(x[i] * sl2 / slx);
1198
+ new_l = MAX(-nmax, MIN(nmax-1, new_l));
1199
+ if (new_l != L[i]) {
1200
+ slx += w*x[i]*new_l;
1201
+ sl2 += w*new_l*new_l;
1202
+ if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
1203
+ L[i] = new_l; sumlx = slx; suml2 = sl2;
1204
+ ++n_changed;
1205
+ }
1206
+ }
1207
+ }
1208
+ }
1209
+ if (!n_changed) {
1210
+ break;
1211
+ }
1212
+ }
1213
+ for (int i = 0; i < n; ++i) {
1214
+ L[i] += nmax;
1215
+ }
1216
+ return sumlx / suml2;
1217
+ }
1218
+ for (int i = 0; i < n; ++i) {
1219
+ int l = nearest_int(iscale * x[i]);
1220
+ l = MAX(-nmax, MIN(nmax-1, l));
1221
+ L[i] = l + nmax;
1222
+ }
1223
+ return 1/iscale;
1224
+ }
1225
+
1226
+ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
1227
+ int ntry, float alpha) {
1228
+ float min = x[0];
1229
+ float max = x[0];
1230
+ for (int i = 1; i < n; ++i) {
1231
+ if (x[i] < min) min = x[i];
1232
+ if (x[i] > max) max = x[i];
1233
+ }
1234
+ if (max == min) {
1235
+ for (int i = 0; i < n; ++i) L[i] = 0;
1236
+ *the_min = 0;
1237
+ return 0.f;
1238
+ }
1239
+ if (min > 0) min = 0;
1240
+ float iscale = nmax/(max - min);
1241
+ float scale = 1/iscale;
1242
+ for (int itry = 0; itry < ntry; ++itry) {
1243
+ float sumlx = 0; int suml2 = 0;
1244
+ bool did_change = false;
1245
+ for (int i = 0; i < n; ++i) {
1246
+ int l = nearest_int(iscale*(x[i] - min));
1247
+ l = MAX(0, MIN(nmax, l));
1248
+ if (l != L[i]) {
1249
+ L[i] = l;
1250
+ did_change = true;
1251
+ }
1252
+ sumlx += (x[i] - min)*l;
1253
+ suml2 += l*l;
1254
+ }
1255
+ scale = sumlx/suml2;
1256
+ float sum = 0;
1257
+ for (int i = 0; i < n; ++i) {
1258
+ sum += x[i] - scale*L[i];
1259
+ }
1260
+ min = alpha*min + (1 - alpha)*sum/n;
1261
+ if (min > 0) min = 0;
1262
+ iscale = 1/scale;
1263
+ if (!did_change) break;
1264
+ }
1265
+ *the_min = -min;
1266
+ return scale;
1267
+ }
1268
+
1269
+ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
1270
+ uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
1271
+ float rmin, float rdelta, int nstep, bool use_mad) {
1272
+ float min = x[0];
1273
+ float max = x[0];
1274
+ float sum_w = weights[0];
1275
+ float sum_x = sum_w * x[0];
1276
+ for (int i = 1; i < n; ++i) {
1277
+ if (x[i] < min) min = x[i];
1278
+ if (x[i] > max) max = x[i];
1279
+ float w = weights[i];
1280
+ sum_w += w;
1281
+ sum_x += w * x[i];
1282
+ }
1283
+ if (min > 0) min = 0;
1284
+ if (max == min) {
1285
+ for (int i = 0; i < n; ++i) L[i] = 0;
1286
+ *the_min = -min;
1287
+ return 0.f;
1288
+ }
1289
+ float iscale = nmax/(max - min);
1290
+ float scale = 1/iscale;
1291
+ float best_mad = 0;
1292
+ for (int i = 0; i < n; ++i) {
1293
+ int l = nearest_int(iscale*(x[i] - min));
1294
+ L[i] = MAX(0, MIN(nmax, l));
1295
+ float diff = scale * L[i] + min - x[i];
1296
+ diff = use_mad ? fabsf(diff) : diff * diff;
1297
+ float w = weights[i];
1298
+ best_mad += w * diff;
1299
+ }
1300
+ if (nstep < 1) {
1301
+ *the_min = -min;
1302
+ return scale;
1303
+ }
1304
+ for (int is = 0; is <= nstep; ++is) {
1305
+ iscale = (rmin + rdelta*is + nmax)/(max - min);
1306
+ float sum_l = 0, sum_l2 = 0, sum_xl = 0;
1307
+ for (int i = 0; i < n; ++i) {
1308
+ int l = nearest_int(iscale*(x[i] - min));
1309
+ l = MAX(0, MIN(nmax, l));
1310
+ Laux[i] = l;
1311
+ float w = weights[i];
1312
+ sum_l += w*l;
1313
+ sum_l2 += w*l*l;
1314
+ sum_xl += w*l*x[i];
1315
+ }
1316
+ float D = sum_w * sum_l2 - sum_l * sum_l;
1317
+ if (D > 0) {
1318
+ float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
1319
+ float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
1320
+ if (this_min > 0) {
1321
+ this_min = 0;
1322
+ this_scale = sum_xl / sum_l2;
1323
+ }
1324
+ float mad = 0;
1325
+ for (int i = 0; i < n; ++i) {
1326
+ float diff = this_scale * Laux[i] + this_min - x[i];
1327
+ diff = use_mad ? fabsf(diff) : diff * diff;
1328
+ float w = weights[i];
1329
+ mad += w * diff;
1330
+ }
1331
+ if (mad < best_mad) {
1332
+ for (int i = 0; i < n; ++i) {
1333
+ L[i] = Laux[i];
1334
+ }
1335
+ best_mad = mad;
1336
+ scale = this_scale;
1337
+ min = this_min;
1338
+ }
1339
+ }
1340
+ }
1341
+ *the_min = -min;
1342
+ return scale;
1343
+ }
1344
+
1345
+ #if QK_K == 256
1346
+ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
1347
+ if (j < 4) {
1348
+ *d = q[j] & 63; *m = q[j + 4] & 63;
1349
+ } else {
1350
+ *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
1351
+ *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
1352
+ }
1353
+ }
1354
+ #endif
1355
+
1356
+ //========================- 2-bit (de)-quantization
1357
+
1358
+ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
1359
+ assert(k % QK_K == 0);
1360
+ const int nb = k / QK_K;
1361
+
1362
+ uint8_t L[QK_K];
1363
+ uint8_t Laux[16];
1364
+ float weights[16];
1365
+ float mins[QK_K/16];
1366
+ float scales[QK_K/16];
1367
+
1368
+ const float q4scale = 15.f;
1369
+
1370
+ for (int i = 0; i < nb; i++) {
1371
+ float max_scale = 0; // as we are deducting the min, scales are always positive
1372
+ float max_min = 0;
1373
+ for (int j = 0; j < QK_K/16; ++j) {
1374
+ for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
1375
+ scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
1376
+ float scale = scales[j];
1377
+ if (scale > max_scale) {
1378
+ max_scale = scale;
1379
+ }
1380
+ float min = mins[j];
1381
+ if (min > max_min) {
1382
+ max_min = min;
1383
+ }
1384
+ }
1385
+
1386
+ if (max_scale > 0) {
1387
+ float iscale = q4scale/max_scale;
1388
+ for (int j = 0; j < QK_K/16; ++j) {
1389
+ int l = nearest_int(iscale*scales[j]);
1390
+ y[i].scales[j] = l;
1391
+ }
1392
+ y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
1393
+ } else {
1394
+ for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
1395
+ y[i].d = GGML_FP32_TO_FP16(0.f);
1396
+ }
1397
+ if (max_min > 0) {
1398
+ float iscale = q4scale/max_min;
1399
+ for (int j = 0; j < QK_K/16; ++j) {
1400
+ int l = nearest_int(iscale*mins[j]);
1401
+ y[i].scales[j] |= (l << 4);
1402
+ }
1403
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
1404
+ } else {
1405
+ y[i].dmin = GGML_FP32_TO_FP16(0.f);
1406
+ }
1407
+ for (int j = 0; j < QK_K/16; ++j) {
1408
+ const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
1409
+ if (!d) continue;
1410
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
1411
+ for (int ii = 0; ii < 16; ++ii) {
1412
+ int l = nearest_int((x[16*j + ii] + dm)/d);
1413
+ l = MAX(0, MIN(3, l));
1414
+ L[16*j + ii] = l;
1415
+ }
1416
+ }
1417
+
1418
+ #if QK_K == 256
1419
+ for (int j = 0; j < QK_K; j += 128) {
1420
+ for (int l = 0; l < 32; ++l) {
1421
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1422
+ }
1423
+ }
1424
+ #else
1425
+ for (int l = 0; l < 16; ++l) {
1426
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
1427
+ }
1428
+ #endif
1429
+
1430
+ x += QK_K;
1431
+
1432
+ }
1433
+ }
1434
+
1435
+ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
1436
+ assert(k % QK_K == 0);
1437
+ const int nb = k / QK_K;
1438
+
1439
+ for (int i = 0; i < nb; i++) {
1440
+
1441
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1442
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
1443
+
1444
+ const uint8_t * q = x[i].qs;
1445
+
1446
+ #if QK_K == 256
1447
+ int is = 0;
1448
+ float dl, ml;
1449
+ for (int n = 0; n < QK_K; n += 128) {
1450
+ int shift = 0;
1451
+ for (int j = 0; j < 4; ++j) {
1452
+
1453
+ uint8_t sc = x[i].scales[is++];
1454
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
1455
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
1456
+
1457
+ sc = x[i].scales[is++];
1458
+ dl = d * (sc & 0xF); ml = min * (sc >> 4);
1459
+ for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
1460
+
1461
+ shift += 2;
1462
+ }
1463
+ q += 32;
1464
+ }
1465
+ #else
1466
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
1467
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
1468
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
1469
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
1470
+ for (int l = 0; l < 16; ++l) {
1471
+ y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
1472
+ y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
1473
+ y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
1474
+ y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
1475
+ }
1476
+ y += QK_K;
1477
+ #endif
1478
+ }
1479
+ }
1480
+
1481
+ void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
1482
+ quantize_row_q2_K_reference(x, vy, k);
1483
+ }
1484
+
1485
+ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1486
+ (void)hist; // TODO: collect histograms
1487
+
1488
+ for (int j = 0; j < n; j += k) {
1489
+ block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
1490
+ quantize_row_q2_K_reference(src + j, y, k);
1491
+ }
1492
+ return (n/QK_K*sizeof(block_q2_K));
1493
+ }
1494
+
1495
+ //========================= 3-bit (de)-quantization
1496
+
1497
+ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
1498
+ assert(k % QK_K == 0);
1499
+ const int nb = k / QK_K;
1500
+
1501
+ int8_t L[QK_K];
1502
+ float scales[QK_K / 16];
1503
+
1504
+ for (int i = 0; i < nb; i++) {
1505
+
1506
+ float max_scale = 0;
1507
+ float amax = 0;
1508
+ for (int j = 0; j < QK_K/16; ++j) {
1509
+ scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
1510
+ float scale = fabsf(scales[j]);
1511
+ if (scale > amax) {
1512
+ amax = scale; max_scale = scales[j];
1513
+ }
1514
+ }
1515
+
1516
+ #if QK_K == 256
1517
+ memset(y[i].scales, 0, 12);
1518
+ if (max_scale) {
1519
+ float iscale = -32.f/max_scale;
1520
+ for (int j = 0; j < QK_K/16; ++j) {
1521
+ int8_t l = nearest_int(iscale*scales[j]);
1522
+ l = MAX(-32, MIN(31, l)) + 32;
1523
+ if (j < 8) {
1524
+ y[i].scales[j] = l & 0xF;
1525
+ } else {
1526
+ y[i].scales[j-8] |= ((l & 0xF) << 4);
1527
+ }
1528
+ l >>= 4;
1529
+ y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
1530
+ }
1531
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
1532
+ } else {
1533
+ y[i].d = GGML_FP32_TO_FP16(0.f);
1534
+ }
1535
+
1536
+ int8_t sc;
1537
+ for (int j = 0; j < QK_K/16; ++j) {
1538
+ sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
1539
+ sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
1540
+ float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1541
+ if (!d) {
1542
+ continue;
1543
+ }
1544
+ for (int ii = 0; ii < 16; ++ii) {
1545
+ int l = nearest_int(x[16*j + ii]/d);
1546
+ l = MAX(-4, MIN(3, l));
1547
+ L[16*j + ii] = l + 4;
1548
+ }
1549
+ }
1550
+ #else
1551
+ if (max_scale) {
1552
+ float iscale = -8.f/max_scale;
1553
+ for (int j = 0; j < QK_K/16; j+=2) {
1554
+ int l1 = nearest_int(iscale*scales[j]);
1555
+ l1 = 8 + MAX(-8, MIN(7, l1));
1556
+ int l2 = nearest_int(iscale*scales[j+1]);
1557
+ l2 = 8 + MAX(-8, MIN(7, l2));
1558
+ y[i].scales[j/2] = l1 | (l2 << 4);
1559
+ }
1560
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
1561
+ } else {
1562
+ for (int j = 0; j < QK_K/16; j+=2) {
1563
+ y[i].scales[j/2] = 0;
1564
+ }
1565
+ y[i].d = GGML_FP32_TO_FP16(0.f);
1566
+ }
1567
+ for (int j = 0; j < QK_K/16; ++j) {
1568
+ int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
1569
+ float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8);
1570
+ if (!d) {
1571
+ continue;
1572
+ }
1573
+ for (int ii = 0; ii < 16; ++ii) {
1574
+ int l = nearest_int(x[16*j + ii]/d);
1575
+ l = MAX(-4, MIN(3, l));
1576
+ L[16*j + ii] = l + 4;
1577
+ }
1578
+ }
1579
+ #endif
1580
+
1581
+ memset(y[i].hmask, 0, QK_K/8);
1582
+ // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
1583
+ int m = 0;
1584
+ uint8_t hm = 1;
1585
+ for (int j = 0; j < QK_K; ++j) {
1586
+ if (L[j] > 3) {
1587
+ y[i].hmask[m] |= hm;
1588
+ L[j] -= 4;
1589
+ }
1590
+ if (++m == QK_K/8) {
1591
+ m = 0; hm <<= 1;
1592
+ }
1593
+ }
1594
+ #if QK_K == 256
1595
+ for (int j = 0; j < QK_K; j += 128) {
1596
+ for (int l = 0; l < 32; ++l) {
1597
+ y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
1598
+ }
1599
+ }
1600
+ #else
1601
+ for (int l = 0; l < 16; ++l) {
1602
+ y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
1603
+ }
1604
+ #endif
1605
+
1606
+ x += QK_K;
1607
+ }
1608
+ }
1609
+
1610
+ #if QK_K == 256
1611
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
1612
+ assert(k % QK_K == 0);
1613
+ const int nb = k / QK_K;
1614
+
1615
+ const uint32_t kmask1 = 0x03030303;
1616
+ const uint32_t kmask2 = 0x0f0f0f0f;
1617
+
1618
+ uint32_t aux[4];
1619
+ const int8_t * scales = (const int8_t*)aux;
1620
+
1621
+ for (int i = 0; i < nb; i++) {
1622
+
1623
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
1624
+
1625
+ const uint8_t * restrict q = x[i].qs;
1626
+ const uint8_t * restrict hm = x[i].hmask;
1627
+ uint8_t m = 1;
1628
+
1629
+ memcpy(aux, x[i].scales, 12);
1630
+ uint32_t tmp = aux[2];
1631
+ aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1632
+ aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1633
+ aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1634
+ aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1635
+
1636
+ int is = 0;
1637
+ float dl;
1638
+ for (int n = 0; n < QK_K; n += 128) {
1639
+ int shift = 0;
1640
+ for (int j = 0; j < 4; ++j) {
1641
+
1642
+ dl = d_all * (scales[is++] - 32);
1643
+ for (int l = 0; l < 16; ++l) {
1644
+ *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
1645
+ }
1646
+
1647
+ dl = d_all * (scales[is++] - 32);
1648
+ for (int l = 0; l < 16; ++l) {
1649
+ *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
1650
+ }
1651
+
1652
+ shift += 2;
1653
+ m <<= 1;
1654
+ }
1655
+ q += 32;
1656
+ }
1657
+
1658
+ }
1659
+ }
1660
+ #else
1661
+ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
1662
+ assert(k % QK_K == 0);
1663
+ assert(QK_K == 64);
1664
+ const int nb = k / QK_K;
1665
+
1666
+ for (int i = 0; i < nb; i++) {
1667
+
1668
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
1669
+
1670
+ const uint8_t * restrict q = x[i].qs;
1671
+ const uint8_t * restrict hm = x[i].hmask;
1672
+
1673
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1674
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1675
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1676
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1677
+
1678
+ for (int l=0; l<8; ++l) {
1679
+ uint8_t h = hm[l];
1680
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
1681
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
1682
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
1683
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
1684
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
1685
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
1686
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
1687
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
1688
+ }
1689
+ y += QK_K;
1690
+ }
1691
+ }
1692
+ #endif
1693
+
1694
+ void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
1695
+ quantize_row_q3_K_reference(x, vy, k);
1696
+ }
1697
+
1698
+ size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1699
+ (void)hist; // TODO: collect histograms
1700
+
1701
+ for (int j = 0; j < n; j += k) {
1702
+ block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
1703
+ quantize_row_q3_K_reference(src + j, y, k);
1704
+ }
1705
+ return (n/QK_K*sizeof(block_q3_K));
1706
+ }
1707
+
1708
+ // ====================== 4-bit (de)-quantization
1709
+
1710
+ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
1711
+ assert(k % QK_K == 0);
1712
+ const int nb = k / QK_K;
1713
+
1714
+ uint8_t L[QK_K];
1715
+ uint8_t Laux[32];
1716
+ float weights[32];
1717
+ float mins[QK_K/32];
1718
+ float scales[QK_K/32];
1719
+
1720
+ for (int i = 0; i < nb; i++) {
1721
+
1722
+ float max_scale = 0; // as we are deducting the min, scales are always positive
1723
+ float max_min = 0;
1724
+ for (int j = 0; j < QK_K/32; ++j) {
1725
+ //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
1726
+ float sum_x2 = 0;
1727
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
1728
+ float av_x = sqrtf(sum_x2/32);
1729
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1730
+ scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
1731
+ float scale = scales[j];
1732
+ if (scale > max_scale) {
1733
+ max_scale = scale;
1734
+ }
1735
+ float min = mins[j];
1736
+ if (min > max_min) {
1737
+ max_min = min;
1738
+ }
1739
+ }
1740
+
1741
+ #if QK_K == 256
1742
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
1743
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
1744
+ for (int j = 0; j < QK_K/32; ++j) {
1745
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
1746
+ uint8_t lm = nearest_int(inv_min*mins[j]);
1747
+ ls = MIN(63, ls);
1748
+ lm = MIN(63, lm);
1749
+ if (j < 4) {
1750
+ y[i].scales[j] = ls;
1751
+ y[i].scales[j+4] = lm;
1752
+ } else {
1753
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1754
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
1755
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
1756
+ }
1757
+ }
1758
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
1759
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
1760
+
1761
+ uint8_t sc, m;
1762
+ for (int j = 0; j < QK_K/32; ++j) {
1763
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
1764
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1765
+ if (!d) continue;
1766
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1767
+ for (int ii = 0; ii < 32; ++ii) {
1768
+ int l = nearest_int((x[32*j + ii] + dm)/d);
1769
+ l = MAX(0, MIN(15, l));
1770
+ L[32*j + ii] = l;
1771
+ }
1772
+ }
1773
+ #else
1774
+ const float s_factor = 15.f;
1775
+ float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
1776
+ float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
1777
+ int d1 = nearest_int(inv_scale*scales[0]);
1778
+ int m1 = nearest_int(inv_min*mins[0]);
1779
+ int d2 = nearest_int(inv_scale*scales[1]);
1780
+ int m2 = nearest_int(inv_min*mins[1]);
1781
+ y[i].scales[0] = d1 | (m1 << 4);
1782
+ y[i].scales[1] = d2 | (m2 << 4);
1783
+ y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor);
1784
+ y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor);
1785
+
1786
+ float sumlx = 0;
1787
+ int suml2 = 0;
1788
+ for (int j = 0; j < QK_K/32; ++j) {
1789
+ const uint8_t sd = y[i].scales[j] & 0xF;
1790
+ const uint8_t sm = y[i].scales[j] >> 4;
1791
+ const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd;
1792
+ if (!d) continue;
1793
+ const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm;
1794
+ for (int ii = 0; ii < 32; ++ii) {
1795
+ int l = nearest_int((x[32*j + ii] + m)/d);
1796
+ l = MAX(0, MIN(15, l));
1797
+ L[32*j + ii] = l;
1798
+ sumlx += (x[32*j + ii] + m)*l*sd;
1799
+ suml2 += l*l*sd*sd;
1800
+ }
1801
+ }
1802
+ if (suml2) {
1803
+ y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2);
1804
+ }
1805
+ #endif
1806
+ uint8_t * q = y[i].qs;
1807
+ for (int j = 0; j < QK_K; j += 64) {
1808
+ for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
1809
+ q += 32;
1810
+ }
1811
+
1812
+ x += QK_K;
1813
+
1814
+ }
1815
+ }
1816
+
1817
+ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
1818
+ assert(k % QK_K == 0);
1819
+ const int nb = k / QK_K;
1820
+
1821
+ for (int i = 0; i < nb; i++) {
1822
+
1823
+ const uint8_t * q = x[i].qs;
1824
+
1825
+ #if QK_K == 256
1826
+
1827
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1828
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
1829
+
1830
+ int is = 0;
1831
+ uint8_t sc, m;
1832
+ for (int j = 0; j < QK_K; j += 64) {
1833
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
1834
+ const float d1 = d * sc; const float m1 = min * m;
1835
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
1836
+ const float d2 = d * sc; const float m2 = min * m;
1837
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
1838
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
1839
+ q += 32; is += 2;
1840
+ }
1841
+ #else
1842
+ const float dall = GGML_FP16_TO_FP32(x[i].d[0]);
1843
+ const float mall = GGML_FP16_TO_FP32(x[i].d[1]);
1844
+ const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
1845
+ const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
1846
+ for (int l = 0; l < 32; ++l) {
1847
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1848
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1849
+ }
1850
+ y += QK_K;
1851
+ #endif
1852
+
1853
+ }
1854
+ }
1855
+
1856
+ void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
1857
+ assert(k % QK_K == 0);
1858
+ block_q4_K * restrict y = vy;
1859
+ quantize_row_q4_K_reference(x, y, k);
1860
+ }
1861
+
1862
+ size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1863
+ assert(k % QK_K == 0);
1864
+ (void)hist; // TODO: collect histograms
1865
+
1866
+ for (int j = 0; j < n; j += k) {
1867
+ block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
1868
+ quantize_row_q4_K_reference(src + j, y, k);
1869
+ }
1870
+ return (n/QK_K*sizeof(block_q4_K));
1871
+ }
1872
+
1873
+ // ====================== 5-bit (de)-quantization
1874
+
1875
+ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
1876
+ assert(k % QK_K == 0);
1877
+ const int nb = k / QK_K;
1878
+
1879
+ #if QK_K == 256
1880
+ uint8_t L[QK_K];
1881
+ float mins[QK_K/32];
1882
+ float scales[QK_K/32];
1883
+ float weights[32];
1884
+ uint8_t Laux[32];
1885
+ #else
1886
+ int8_t L[QK_K];
1887
+ float scales[QK_K/16];
1888
+ #endif
1889
+
1890
+ for (int i = 0; i < nb; i++) {
1891
+
1892
+ #if QK_K == 256
1893
+
1894
+ float max_scale = 0; // as we are deducting the min, scales are always positive
1895
+ float max_min = 0;
1896
+ for (int j = 0; j < QK_K/32; ++j) {
1897
+ //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
1898
+ float sum_x2 = 0;
1899
+ for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
1900
+ float av_x = sqrtf(sum_x2/32);
1901
+ for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
1902
+ scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
1903
+ float scale = scales[j];
1904
+ if (scale > max_scale) {
1905
+ max_scale = scale;
1906
+ }
1907
+ float min = mins[j];
1908
+ if (min > max_min) {
1909
+ max_min = min;
1910
+ }
1911
+ }
1912
+
1913
+ float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
1914
+ float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
1915
+ for (int j = 0; j < QK_K/32; ++j) {
1916
+ uint8_t ls = nearest_int(inv_scale*scales[j]);
1917
+ uint8_t lm = nearest_int(inv_min*mins[j]);
1918
+ ls = MIN(63, ls);
1919
+ lm = MIN(63, lm);
1920
+ if (j < 4) {
1921
+ y[i].scales[j] = ls;
1922
+ y[i].scales[j+4] = lm;
1923
+ } else {
1924
+ y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
1925
+ y[i].scales[j-4] |= ((ls >> 4) << 6);
1926
+ y[i].scales[j-0] |= ((lm >> 4) << 6);
1927
+ }
1928
+ }
1929
+ y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
1930
+ y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
1931
+
1932
+ uint8_t sc, m;
1933
+ for (int j = 0; j < QK_K/32; ++j) {
1934
+ get_scale_min_k4(j, y[i].scales, &sc, &m);
1935
+ const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
1936
+ if (!d) continue;
1937
+ const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
1938
+ for (int ii = 0; ii < 32; ++ii) {
1939
+ int l = nearest_int((x[32*j + ii] + dm)/d);
1940
+ l = MAX(0, MIN(31, l));
1941
+ L[32*j + ii] = l;
1942
+ }
1943
+ }
1944
+
1945
+ uint8_t * restrict qh = y[i].qh;
1946
+ uint8_t * restrict ql = y[i].qs;
1947
+ memset(qh, 0, QK_K/8);
1948
+
1949
+ uint8_t m1 = 1, m2 = 2;
1950
+ for (int n = 0; n < QK_K; n += 64) {
1951
+ for (int j = 0; j < 32; ++j) {
1952
+ int l1 = L[n + j];
1953
+ if (l1 > 15) {
1954
+ l1 -= 16; qh[j] |= m1;
1955
+ }
1956
+ int l2 = L[n + j + 32];
1957
+ if (l2 > 15) {
1958
+ l2 -= 16; qh[j] |= m2;
1959
+ }
1960
+ ql[j] = l1 | (l2 << 4);
1961
+ }
1962
+ m1 <<= 2; m2 <<= 2;
1963
+ ql += 32;
1964
+ }
1965
+ #else
1966
+ float max_scale = 0, amax = 0;
1967
+ for (int j = 0; j < QK_K/16; ++j) {
1968
+ scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
1969
+ float abs_scale = fabsf(scales[j]);
1970
+ if (abs_scale > amax) {
1971
+ amax = abs_scale;
1972
+ max_scale = scales[j];
1973
+ }
1974
+ }
1975
+
1976
+ float iscale = -128.f/max_scale;
1977
+ for (int j = 0; j < QK_K/16; ++j) {
1978
+ int l = nearest_int(iscale*scales[j]);
1979
+ y[i].scales[j] = MAX(-128, MIN(127, l));
1980
+ }
1981
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
1982
+
1983
+ for (int j = 0; j < QK_K/16; ++j) {
1984
+ const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
1985
+ if (!d) continue;
1986
+ for (int ii = 0; ii < 16; ++ii) {
1987
+ int l = nearest_int(x[16*j + ii]/d);
1988
+ l = MAX(-16, MIN(15, l));
1989
+ L[16*j + ii] = l + 16;
1990
+ }
1991
+ }
1992
+
1993
+ uint8_t * restrict qh = y[i].qh;
1994
+ uint8_t * restrict ql = y[i].qs;
1995
+ memset(qh, 0, QK_K/8);
1996
+
1997
+ for (int j = 0; j < 32; ++j) {
1998
+ int jm = j%8;
1999
+ int is = j/8;
2000
+ int l1 = L[j];
2001
+ if (l1 > 15) {
2002
+ l1 -= 16; qh[jm] |= (1 << is);
2003
+ }
2004
+ int l2 = L[j + 32];
2005
+ if (l2 > 15) {
2006
+ l2 -= 16; qh[jm] |= (1 << (4 + is));
2007
+ }
2008
+ ql[j] = l1 | (l2 << 4);
2009
+ }
2010
+ #endif
2011
+
2012
+ x += QK_K;
2013
+
2014
+ }
2015
+ }
2016
+
2017
+ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
2018
+ assert(k % QK_K == 0);
2019
+ const int nb = k / QK_K;
2020
+
2021
+ for (int i = 0; i < nb; i++) {
2022
+
2023
+ const uint8_t * ql = x[i].qs;
2024
+ const uint8_t * qh = x[i].qh;
2025
+
2026
+ #if QK_K == 256
2027
+
2028
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2029
+ const float min = GGML_FP16_TO_FP32(x[i].dmin);
2030
+
2031
+ int is = 0;
2032
+ uint8_t sc, m;
2033
+ uint8_t u1 = 1, u2 = 2;
2034
+ for (int j = 0; j < QK_K; j += 64) {
2035
+ get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
2036
+ const float d1 = d * sc; const float m1 = min * m;
2037
+ get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
2038
+ const float d2 = d * sc; const float m2 = min * m;
2039
+ for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
2040
+ for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
2041
+ ql += 32; is += 2;
2042
+ u1 <<= 2; u2 <<= 2;
2043
+ }
2044
+ #else
2045
+ float d = GGML_FP16_TO_FP32(x[i].d);
2046
+ const int8_t * restrict s = x[i].scales;
2047
+ for (int l = 0; l < 8; ++l) {
2048
+ y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
2049
+ y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
2050
+ y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
2051
+ y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
2052
+ y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
2053
+ y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
2054
+ y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
2055
+ y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
2056
+ }
2057
+ y += QK_K;
2058
+ #endif
2059
+ }
2060
+ }
2061
+
2062
+ void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
2063
+ assert(k % QK_K == 0);
2064
+ block_q5_K * restrict y = vy;
2065
+ quantize_row_q5_K_reference(x, y, k);
2066
+ }
2067
+
2068
+ size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
2069
+ assert(k % QK_K == 0);
2070
+ (void)hist; // TODO: collect histograms
2071
+
2072
+ for (int j = 0; j < n; j += k) {
2073
+ block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
2074
+ quantize_row_q5_K_reference(src + j, y, k);
2075
+ }
2076
+ return (n/QK_K*sizeof(block_q5_K));
2077
+ }
2078
+
2079
+ // ====================== 6-bit (de)-quantization
2080
+
2081
+ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
2082
+ assert(k % QK_K == 0);
2083
+ const int nb = k / QK_K;
2084
+
2085
+ int8_t L[QK_K];
2086
+ float scales[QK_K/16];
2087
+
2088
+ for (int i = 0; i < nb; i++) {
2089
+
2090
+ float max_scale = 0;
2091
+ float max_abs_scale = 0;
2092
+
2093
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2094
+
2095
+ const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
2096
+ scales[ib] = scale;
2097
+
2098
+ const float abs_scale = fabsf(scale);
2099
+ if (abs_scale > max_abs_scale) {
2100
+ max_abs_scale = abs_scale;
2101
+ max_scale = scale;
2102
+ }
2103
+
2104
+ }
2105
+
2106
+ if (!max_abs_scale) {
2107
+ memset(&y[i], 0, sizeof(block_q6_K));
2108
+ y[i].d = GGML_FP32_TO_FP16(0.f);
2109
+ x += QK_K;
2110
+ continue;
2111
+ }
2112
+
2113
+ float iscale = -128.f/max_scale;
2114
+ y[i].d = GGML_FP32_TO_FP16(1/iscale);
2115
+ for (int ib = 0; ib < QK_K/16; ++ib) {
2116
+ y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
2117
+ }
2118
+
2119
+ for (int j = 0; j < QK_K/16; ++j) {
2120
+ float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
2121
+ if (!d) {
2122
+ continue;
187
2123
  }
188
- if (!n_changed) {
189
- break;
2124
+ for (int ii = 0; ii < 16; ++ii) {
2125
+ int l = nearest_int(x[16*j + ii]/d);
2126
+ l = MAX(-32, MIN(31, l));
2127
+ L[16*j + ii] = l + 32;
190
2128
  }
191
2129
  }
192
- for (int i = 0; i < n; ++i) {
193
- L[i] += nmax;
194
- }
195
- return sumlx / suml2;
196
- }
197
- for (int i = 0; i < n; ++i) {
198
- int l = nearest_int(iscale * x[i]);
199
- l = MAX(-nmax, MIN(nmax-1, l));
200
- L[i] = l + nmax;
201
- }
202
- return 1/iscale;
203
- }
204
2130
 
205
- static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
206
- int ntry, float alpha) {
207
- float min = x[0];
208
- float max = x[0];
209
- for (int i = 1; i < n; ++i) {
210
- if (x[i] < min) min = x[i];
211
- if (x[i] > max) max = x[i];
212
- }
213
- if (max == min) {
214
- for (int i = 0; i < n; ++i) L[i] = 0;
215
- *the_min = 0;
216
- return 0.f;
217
- }
218
- if (min > 0) min = 0;
219
- float iscale = nmax/(max - min);
220
- float scale = 1/iscale;
221
- for (int itry = 0; itry < ntry; ++itry) {
222
- float sumlx = 0; int suml2 = 0;
223
- bool did_change = false;
224
- for (int i = 0; i < n; ++i) {
225
- int l = nearest_int(iscale*(x[i] - min));
226
- l = MAX(0, MIN(nmax, l));
227
- if (l != L[i]) {
228
- L[i] = l;
229
- did_change = true;
2131
+ uint8_t * restrict ql = y[i].ql;
2132
+ uint8_t * restrict qh = y[i].qh;
2133
+ #if QK_K == 256
2134
+ for (int j = 0; j < QK_K; j += 128) {
2135
+ for (int l = 0; l < 32; ++l) {
2136
+ const uint8_t q1 = L[j + l + 0] & 0xF;
2137
+ const uint8_t q2 = L[j + l + 32] & 0xF;
2138
+ const uint8_t q3 = L[j + l + 64] & 0xF;
2139
+ const uint8_t q4 = L[j + l + 96] & 0xF;
2140
+ ql[l+ 0] = q1 | (q3 << 4);
2141
+ ql[l+32] = q2 | (q4 << 4);
2142
+ qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
230
2143
  }
231
- sumlx += (x[i] - min)*l;
232
- suml2 += l*l;
2144
+ ql += 64;
2145
+ qh += 32;
233
2146
  }
234
- scale = sumlx/suml2;
235
- float sum = 0;
236
- for (int i = 0; i < n; ++i) {
237
- sum += x[i] - scale*L[i];
2147
+ #else
2148
+ for (int l = 0; l < 32; ++l) {
2149
+ const uint8_t q1 = L[l + 0] & 0xF;
2150
+ const uint8_t q2 = L[l + 32] & 0xF;
2151
+ ql[l] = q1 | (q2 << 4);
238
2152
  }
239
- min = alpha*min + (1 - alpha)*sum/n;
240
- if (min > 0) min = 0;
241
- iscale = 1/scale;
242
- if (!did_change) break;
2153
+ for (int l = 0; l < 16; ++l) {
2154
+ qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
2155
+ }
2156
+ #endif
2157
+
2158
+ x += QK_K;
2159
+
243
2160
  }
244
- *the_min = -min;
245
- return scale;
246
2161
  }
247
2162
 
248
- static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
249
- uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
250
- float rmin, float rdelta, int nstep, bool use_mad) {
251
- float min = x[0];
252
- float max = x[0];
253
- float sum_w = weights[0];
254
- float sum_x = sum_w * x[0];
255
- for (int i = 1; i < n; ++i) {
256
- if (x[i] < min) min = x[i];
257
- if (x[i] > max) max = x[i];
258
- float w = weights[i];
259
- sum_w += w;
260
- sum_x += w * x[i];
261
- }
262
- if (min > 0) min = 0;
263
- if (max == min) {
264
- for (int i = 0; i < n; ++i) L[i] = 0;
265
- *the_min = -min;
266
- return 0.f;
267
- }
268
- float iscale = nmax/(max - min);
269
- float scale = 1/iscale;
270
- float best_mad = 0;
271
- for (int i = 0; i < n; ++i) {
272
- int l = nearest_int(iscale*(x[i] - min));
273
- L[i] = MAX(0, MIN(nmax, l));
274
- float diff = scale * L[i] + min - x[i];
275
- diff = use_mad ? fabsf(diff) : diff * diff;
276
- float w = weights[i];
277
- best_mad += w * diff;
278
- }
279
- if (nstep < 1) {
280
- *the_min = -min;
281
- return scale;
282
- }
283
- for (int is = 0; is <= nstep; ++is) {
284
- iscale = (rmin + rdelta*is + nmax)/(max - min);
285
- float sum_l = 0, sum_l2 = 0, sum_xl = 0;
286
- for (int i = 0; i < n; ++i) {
287
- int l = nearest_int(iscale*(x[i] - min));
288
- l = MAX(0, MIN(nmax, l));
289
- Laux[i] = l;
290
- float w = weights[i];
291
- sum_l += w*l;
292
- sum_l2 += w*l*l;
293
- sum_xl += w*l*x[i];
294
- }
295
- float D = sum_w * sum_l2 - sum_l * sum_l;
296
- if (D > 0) {
297
- float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
298
- float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
299
- if (this_min > 0) {
300
- this_min = 0;
301
- this_scale = sum_xl / sum_l2;
302
- }
303
- float mad = 0;
304
- for (int i = 0; i < n; ++i) {
305
- float diff = this_scale * Laux[i] + this_min - x[i];
306
- diff = use_mad ? fabsf(diff) : diff * diff;
307
- float w = weights[i];
308
- mad += w * diff;
309
- }
310
- if (mad < best_mad) {
311
- for (int i = 0; i < n; ++i) {
312
- L[i] = Laux[i];
313
- }
314
- best_mad = mad;
315
- scale = this_scale;
316
- min = this_min;
2163
+ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
2164
+ assert(k % QK_K == 0);
2165
+ const int nb = k / QK_K;
2166
+
2167
+ for (int i = 0; i < nb; i++) {
2168
+
2169
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2170
+
2171
+ const uint8_t * restrict ql = x[i].ql;
2172
+ const uint8_t * restrict qh = x[i].qh;
2173
+ const int8_t * restrict sc = x[i].scales;
2174
+
2175
+ #if QK_K == 256
2176
+ for (int n = 0; n < QK_K; n += 128) {
2177
+ for (int l = 0; l < 32; ++l) {
2178
+ int is = l/16;
2179
+ const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2180
+ const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2181
+ const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2182
+ const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2183
+ y[l + 0] = d * sc[is + 0] * q1;
2184
+ y[l + 32] = d * sc[is + 2] * q2;
2185
+ y[l + 64] = d * sc[is + 4] * q3;
2186
+ y[l + 96] = d * sc[is + 6] * q4;
317
2187
  }
2188
+ y += 128;
2189
+ ql += 64;
2190
+ qh += 32;
2191
+ sc += 8;
2192
+ }
2193
+ #else
2194
+ for (int l = 0; l < 16; ++l) {
2195
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2196
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2197
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2198
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2199
+ y[l+ 0] = d * sc[0] * q1;
2200
+ y[l+16] = d * sc[1] * q2;
2201
+ y[l+32] = d * sc[2] * q3;
2202
+ y[l+48] = d * sc[3] * q4;
318
2203
  }
2204
+ y += 64;
2205
+ #endif
2206
+
319
2207
  }
320
- *the_min = -min;
321
- return scale;
322
2208
  }
323
2209
 
324
- #if QK_K == 256
325
- static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
326
- if (j < 4) {
327
- *d = q[j] & 63; *m = q[j + 4] & 63;
328
- } else {
329
- *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
330
- *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
2210
+ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
2211
+ assert(k % QK_K == 0);
2212
+ block_q6_K * restrict y = vy;
2213
+ quantize_row_q6_K_reference(x, y, k);
2214
+ }
2215
+
2216
+ size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
2217
+ assert(k % QK_K == 0);
2218
+ (void)hist; // TODO: collect histograms
2219
+
2220
+ for (int j = 0; j < n; j += k) {
2221
+ block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
2222
+ quantize_row_q6_K_reference(src + j, y, k);
331
2223
  }
2224
+ return (n/QK_K*sizeof(block_q6_K));
332
2225
  }
333
- #endif
334
2226
 
335
- //========================- 2-bit (de)-quantization
2227
+ //===================================== Q8_K ==============================================
336
2228
 
337
- void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
2229
+ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
338
2230
  assert(k % QK_K == 0);
339
2231
  const int nb = k / QK_K;
340
2232
 
341
- uint8_t L[QK_K];
342
- uint8_t Laux[16];
343
- float weights[16];
344
- float mins[QK_K/16];
345
- float scales[QK_K/16];
346
-
347
- const float q4scale = 15.f;
348
-
349
2233
  for (int i = 0; i < nb; i++) {
350
- float max_scale = 0; // as we are deducting the min, scales are always positive
351
- float max_min = 0;
352
- for (int j = 0; j < QK_K/16; ++j) {
353
- for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
354
- scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
355
- float scale = scales[j];
356
- if (scale > max_scale) {
357
- max_scale = scale;
358
- }
359
- float min = mins[j];
360
- if (min > max_min) {
361
- max_min = min;
362
- }
363
- }
364
2234
 
365
- if (max_scale > 0) {
366
- float iscale = q4scale/max_scale;
367
- for (int j = 0; j < QK_K/16; ++j) {
368
- int l = nearest_int(iscale*scales[j]);
369
- y[i].scales[j] = l;
370
- }
371
- y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
372
- } else {
373
- for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
374
- y[i].d = ggml_fp32_to_fp16(0.f);
375
- }
376
- if (max_min > 0) {
377
- float iscale = q4scale/max_min;
378
- for (int j = 0; j < QK_K/16; ++j) {
379
- int l = nearest_int(iscale*mins[j]);
380
- y[i].scales[j] |= (l << 4);
2235
+ float max = 0;
2236
+ float amax = 0;
2237
+ for (int j = 0; j < QK_K; ++j) {
2238
+ float ax = fabsf(x[j]);
2239
+ if (ax > amax) {
2240
+ amax = ax; max = x[j];
381
2241
  }
382
- y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
383
- } else {
384
- y[i].dmin = ggml_fp32_to_fp16(0.f);
2242
+ }
2243
+ if (!amax) {
2244
+ y[i].d = 0;
2245
+ memset(y[i].qs, 0, QK_K);
2246
+ x += QK_K;
2247
+ continue;
2248
+ }
2249
+ const float iscale = -128.f/max;
2250
+ for (int j = 0; j < QK_K; ++j) {
2251
+ int v = nearest_int(iscale*x[j]);
2252
+ y[i].qs[j] = MIN(127, v);
385
2253
  }
386
2254
  for (int j = 0; j < QK_K/16; ++j) {
387
- const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
388
- if (!d) continue;
389
- const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
2255
+ int sum = 0;
390
2256
  for (int ii = 0; ii < 16; ++ii) {
391
- int l = nearest_int((x[16*j + ii] + dm)/d);
392
- l = MAX(0, MIN(3, l));
393
- L[16*j + ii] = l;
394
- }
395
- }
396
-
397
- #if QK_K == 256
398
- for (int j = 0; j < QK_K; j += 128) {
399
- for (int l = 0; l < 32; ++l) {
400
- y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
2257
+ sum += y[i].qs[j*16 + ii];
401
2258
  }
2259
+ y[i].bsums[j] = sum;
402
2260
  }
403
- #else
404
- for (int l = 0; l < 16; ++l) {
405
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
406
- }
407
- #endif
408
-
2261
+ y[i].d = 1/iscale;
409
2262
  x += QK_K;
410
-
411
2263
  }
412
2264
  }
413
2265
 
414
- void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
2266
+ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
415
2267
  assert(k % QK_K == 0);
416
2268
  const int nb = k / QK_K;
417
2269
 
418
2270
  for (int i = 0; i < nb; i++) {
2271
+ for (int j = 0; j < QK_K; ++j) {
2272
+ *y++ = x[i].d * x[i].qs[j];
2273
+ }
2274
+ }
2275
+ }
419
2276
 
420
- const float d = ggml_fp16_to_fp32(x[i].d);
421
- const float min = ggml_fp16_to_fp32(x[i].dmin);
2277
+ void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
2278
+ quantize_row_q8_K_reference(x, y, k);
2279
+ }
422
2280
 
423
- const uint8_t * q = x[i].qs;
2281
+ //===================================== Dot ptoducts =================================
424
2282
 
425
- #if QK_K == 256
426
- int is = 0;
427
- float dl, ml;
428
- for (int n = 0; n < QK_K; n += 128) {
429
- int shift = 0;
430
- for (int j = 0; j < 4; ++j) {
2283
+ //
2284
+ // Helper functions
2285
+ //
2286
+ #if __AVX__ || __AVX2__ || __AVX512F__
431
2287
 
432
- uint8_t sc = x[i].scales[is++];
433
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
434
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
2288
+ // shuffles to pick the required scales in dot products
2289
+ static inline __m256i get_scale_shuffle_q3k(int i) {
2290
+ static const uint8_t k_shuffle[128] = {
2291
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2292
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
2293
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
2294
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
2295
+ };
2296
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
2297
+ }
2298
+ static inline __m256i get_scale_shuffle_k4(int i) {
2299
+ static const uint8_t k_shuffle[256] = {
2300
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
2301
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
2302
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
2303
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
2304
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
2305
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
2306
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
2307
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
2308
+ };
2309
+ return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
2310
+ }
2311
+ static inline __m128i get_scale_shuffle(int i) {
2312
+ static const uint8_t k_shuffle[128] = {
2313
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
2314
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
2315
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
2316
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
2317
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
2318
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
2319
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
2320
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
2321
+ };
2322
+ return _mm_loadu_si128((const __m128i*)k_shuffle + i);
2323
+ }
2324
+ #endif
435
2325
 
436
- sc = x[i].scales[is++];
437
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
438
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
2326
+ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2327
+ const int qk = QK8_0;
2328
+ const int nb = n / qk;
439
2329
 
440
- shift += 2;
441
- }
442
- q += 32;
443
- }
2330
+ assert(n % qk == 0);
2331
+
2332
+ const block_q4_0 * restrict x = vx;
2333
+ const block_q8_0 * restrict y = vy;
2334
+
2335
+ #if defined(__ARM_NEON)
2336
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2337
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2338
+
2339
+ assert(nb % 2 == 0); // TODO: handle odd nb
2340
+
2341
+ for (int i = 0; i < nb; i += 2) {
2342
+ const block_q4_0 * restrict x0 = &x[i + 0];
2343
+ const block_q4_0 * restrict x1 = &x[i + 1];
2344
+ const block_q8_0 * restrict y0 = &y[i + 0];
2345
+ const block_q8_0 * restrict y1 = &y[i + 1];
2346
+
2347
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2348
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2349
+
2350
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2351
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2352
+
2353
+ // 4-bit -> 8-bit
2354
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2355
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2356
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2357
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2358
+
2359
+ // sub 8
2360
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2361
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2362
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2363
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2364
+
2365
+ // load y
2366
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2367
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2368
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2369
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2370
+
2371
+ #if defined(__ARM_FEATURE_DOTPROD)
2372
+ // dot product into int32x4_t
2373
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2374
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2375
+
2376
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2377
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
444
2378
  #else
445
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
446
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
447
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
448
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
449
- for (int l = 0; l < 16; ++l) {
450
- y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
451
- y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
452
- y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
453
- y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
454
- }
455
- y += QK_K;
2379
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2380
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2381
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2382
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2383
+
2384
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2385
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2386
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2387
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2388
+
2389
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2390
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2391
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2392
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2393
+
2394
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2395
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
456
2396
  #endif
457
2397
  }
458
- }
459
2398
 
460
- void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
461
- quantize_row_q2_K_reference(x, vy, k);
462
- }
2399
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2400
+ #elif defined(__AVX2__)
2401
+ // Initialize accumulator with zeros
2402
+ __m256 acc = _mm256_setzero_ps();
463
2403
 
464
- size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
465
- (void)hist; // TODO: collect histograms
2404
+ // Main loop
2405
+ for (int i = 0; i < nb; ++i) {
2406
+ /* Compute combined scale for the block */
2407
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
466
2408
 
467
- for (int j = 0; j < n; j += k) {
468
- block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
469
- quantize_row_q2_K_reference(src + j, y, k);
2409
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2410
+
2411
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2412
+ const __m256i off = _mm256_set1_epi8( 8 );
2413
+ bx = _mm256_sub_epi8( bx, off );
2414
+
2415
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2416
+
2417
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2418
+
2419
+ /* Multiply q with scale and accumulate */
2420
+ acc = _mm256_fmadd_ps( d, q, acc );
470
2421
  }
471
- return (n/QK_K*sizeof(block_q2_K));
472
- }
473
2422
 
474
- //========================= 3-bit (de)-quantization
2423
+ *s = hsum_float_8(acc);
2424
+ #elif defined(__AVX__)
2425
+ // Initialize accumulator with zeros
2426
+ __m256 acc = _mm256_setzero_ps();
475
2427
 
476
- void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
477
- assert(k % QK_K == 0);
478
- const int nb = k / QK_K;
2428
+ // Main loop
2429
+ for (int i = 0; i < nb; ++i) {
2430
+ // Compute combined scale for the block
2431
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
479
2432
 
480
- int8_t L[QK_K];
481
- float scales[QK_K / 16];
2433
+ const __m128i lowMask = _mm_set1_epi8(0xF);
2434
+ const __m128i off = _mm_set1_epi8(8);
2435
+
2436
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
2437
+
2438
+ __m128i bx = _mm_and_si128(lowMask, tmp);
2439
+ __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
2440
+ bx = _mm_sub_epi8(bx, off);
2441
+ const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
2442
+
2443
+ bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
2444
+ by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2445
+ bx = _mm_sub_epi8(bx, off);
2446
+ const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2447
+
2448
+ // Convert int32_t to float
2449
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
2450
+
2451
+ // Apply the scale, and accumulate
2452
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2453
+ }
2454
+
2455
+ *s = hsum_float_8(acc);
2456
+ #elif defined(__SSSE3__)
2457
+ // set constants
2458
+ const __m128i lowMask = _mm_set1_epi8(0xF);
2459
+ const __m128i off = _mm_set1_epi8(8);
2460
+
2461
+ // Initialize accumulator with zeros
2462
+ __m128 acc_0 = _mm_setzero_ps();
2463
+ __m128 acc_1 = _mm_setzero_ps();
2464
+ __m128 acc_2 = _mm_setzero_ps();
2465
+ __m128 acc_3 = _mm_setzero_ps();
2466
+
2467
+ // First round without accumulation
2468
+ {
2469
+ _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
2470
+ _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
2471
+
2472
+ // Compute combined scale for the block 0 and 1
2473
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
2474
+
2475
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
2476
+
2477
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2478
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
2479
+ bx_0 = _mm_sub_epi8(bx_0, off);
2480
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2481
+
2482
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2483
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
2484
+ bx_1 = _mm_sub_epi8(bx_1, off);
2485
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2486
+
2487
+ _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
2488
+ _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
2489
+
2490
+ // Compute combined scale for the block 2 and 3
2491
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
2492
+
2493
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
2494
+
2495
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2496
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
2497
+ bx_2 = _mm_sub_epi8(bx_2, off);
2498
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2499
+
2500
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2501
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
2502
+ bx_3 = _mm_sub_epi8(bx_3, off);
2503
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2504
+
2505
+ // Convert int32_t to float
2506
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
2507
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
2508
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
2509
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
2510
+
2511
+ // Apply the scale
2512
+ acc_0 = _mm_mul_ps( d_0_1, p0 );
2513
+ acc_1 = _mm_mul_ps( d_0_1, p1 );
2514
+ acc_2 = _mm_mul_ps( d_2_3, p2 );
2515
+ acc_3 = _mm_mul_ps( d_2_3, p3 );
2516
+ }
2517
+
2518
+ assert(nb % 2 == 0); // TODO: handle odd nb
2519
+
2520
+ // Main loop
2521
+ for (int i = 2; i < nb; i+=2) {
2522
+ _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
2523
+ _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
2524
+
2525
+ // Compute combined scale for the block 0 and 1
2526
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2527
+
2528
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
2529
+
2530
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2531
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
2532
+ bx_0 = _mm_sub_epi8(bx_0, off);
2533
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2534
+
2535
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2536
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2537
+ bx_1 = _mm_sub_epi8(bx_1, off);
2538
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2539
+
2540
+ _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
2541
+ _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2542
+
2543
+ // Compute combined scale for the block 2 and 3
2544
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
2545
+
2546
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
2547
+
2548
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2549
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
2550
+ bx_2 = _mm_sub_epi8(bx_2, off);
2551
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2552
+
2553
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2554
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
2555
+ bx_3 = _mm_sub_epi8(bx_3, off);
2556
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2557
+
2558
+ // Convert int32_t to float
2559
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
2560
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
2561
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
2562
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
2563
+
2564
+ // Apply the scale
2565
+ __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
2566
+ __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
2567
+ __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
2568
+ __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
2569
+
2570
+ // Acummulate
2571
+ acc_0 = _mm_add_ps(p0_d, acc_0);
2572
+ acc_1 = _mm_add_ps(p1_d, acc_1);
2573
+ acc_2 = _mm_add_ps(p2_d, acc_2);
2574
+ acc_3 = _mm_add_ps(p3_d, acc_3);
2575
+ }
2576
+
2577
+ *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2578
+ #elif defined(__riscv_v_intrinsic)
2579
+ float sumf = 0.0;
2580
+
2581
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
482
2582
 
483
2583
  for (int i = 0; i < nb; i++) {
2584
+ // load elements
2585
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
484
2586
 
485
- float max_scale = 0;
486
- float amax = 0;
487
- for (int j = 0; j < QK_K/16; ++j) {
488
- scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
489
- float scale = fabsf(scales[j]);
490
- if (scale > amax) {
491
- amax = scale; max_scale = scales[j];
492
- }
493
- }
2587
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2588
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
494
2589
 
495
- #if QK_K == 256
496
- memset(y[i].scales, 0, 12);
497
- if (max_scale) {
498
- float iscale = -32.f/max_scale;
499
- for (int j = 0; j < QK_K/16; ++j) {
500
- int8_t l = nearest_int(iscale*scales[j]);
501
- l = MAX(-32, MIN(31, l)) + 32;
502
- if (j < 8) {
503
- y[i].scales[j] = l & 0xF;
504
- } else {
505
- y[i].scales[j-8] |= ((l & 0xF) << 4);
506
- }
507
- l >>= 4;
508
- y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
509
- }
510
- y[i].d = ggml_fp32_to_fp16(1/iscale);
511
- } else {
512
- y[i].d = ggml_fp32_to_fp16(0.f);
513
- }
2590
+ // mask and store lower part of x, and then upper part
2591
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2592
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
514
2593
 
515
- int8_t sc;
516
- for (int j = 0; j < QK_K/16; ++j) {
517
- sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
518
- sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
519
- float d = ggml_fp16_to_fp32(y[i].d) * sc;
520
- if (!d) {
521
- continue;
522
- }
523
- for (int ii = 0; ii < 16; ++ii) {
524
- int l = nearest_int(x[16*j + ii]/d);
525
- l = MAX(-4, MIN(3, l));
526
- L[16*j + ii] = l + 4;
527
- }
528
- }
529
- #else
530
- if (max_scale) {
531
- float iscale = -8.f/max_scale;
532
- for (int j = 0; j < QK_K/16; j+=2) {
533
- int l1 = nearest_int(iscale*scales[j]);
534
- l1 = 8 + MAX(-8, MIN(7, l1));
535
- int l2 = nearest_int(iscale*scales[j+1]);
536
- l2 = 8 + MAX(-8, MIN(7, l2));
537
- y[i].scales[j/2] = l1 | (l2 << 4);
538
- }
539
- y[i].d = ggml_fp32_to_fp16(1/iscale);
540
- } else {
541
- for (int j = 0; j < QK_K/16; j+=2) {
542
- y[i].scales[j/2] = 0;
543
- }
544
- y[i].d = ggml_fp32_to_fp16(0.f);
545
- }
546
- for (int j = 0; j < QK_K/16; ++j) {
547
- int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4;
548
- float d = ggml_fp16_to_fp32(y[i].d) * (s - 8);
549
- if (!d) {
550
- continue;
551
- }
552
- for (int ii = 0; ii < 16; ++ii) {
553
- int l = nearest_int(x[16*j + ii]/d);
554
- l = MAX(-4, MIN(3, l));
555
- L[16*j + ii] = l + 4;
556
- }
557
- }
558
- #endif
2594
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2595
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2596
+
2597
+ // subtract offset
2598
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2599
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2600
+
2601
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2602
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2603
+
2604
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2605
+
2606
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2607
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2608
+
2609
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2610
+
2611
+ sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2612
+ }
559
2613
 
560
- memset(y[i].hmask, 0, QK_K/8);
561
- // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
562
- int m = 0;
563
- uint8_t hm = 1;
564
- for (int j = 0; j < QK_K; ++j) {
565
- if (L[j] > 3) {
566
- y[i].hmask[m] |= hm;
567
- L[j] -= 4;
568
- }
569
- if (++m == QK_K/8) {
570
- m = 0; hm <<= 1;
571
- }
572
- }
573
- #if QK_K == 256
574
- for (int j = 0; j < QK_K; j += 128) {
575
- for (int l = 0; l < 32; ++l) {
576
- y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
577
- }
578
- }
2614
+ *s = sumf;
579
2615
  #else
580
- for (int l = 0; l < 16; ++l) {
581
- y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
2616
+ // scalar
2617
+ float sumf = 0.0;
2618
+
2619
+ for (int i = 0; i < nb; i++) {
2620
+ int sumi = 0;
2621
+
2622
+ for (int j = 0; j < qk/2; ++j) {
2623
+ const int v0 = (x[i].qs[j] & 0x0F) - 8;
2624
+ const int v1 = (x[i].qs[j] >> 4) - 8;
2625
+
2626
+ sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
582
2627
  }
583
- #endif
584
2628
 
585
- x += QK_K;
2629
+ sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
586
2630
  }
2631
+
2632
+ *s = sumf;
2633
+ #endif
587
2634
  }
588
2635
 
589
- #if QK_K == 256
590
- void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
591
- assert(k % QK_K == 0);
592
- const int nb = k / QK_K;
2636
+ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2637
+ const int qk = QK8_1;
2638
+ const int nb = n / qk;
593
2639
 
594
- const uint32_t kmask1 = 0x03030303;
595
- const uint32_t kmask2 = 0x0f0f0f0f;
2640
+ assert(n % qk == 0);
596
2641
 
597
- uint32_t aux[4];
598
- const int8_t * scales = (const int8_t*)aux;
2642
+ const block_q4_1 * restrict x = vx;
2643
+ const block_q8_1 * restrict y = vy;
599
2644
 
600
- for (int i = 0; i < nb; i++) {
2645
+ // TODO: add WASM SIMD
2646
+ #if defined(__ARM_NEON)
2647
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2648
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
601
2649
 
602
- const float d_all = ggml_fp16_to_fp32(x[i].d);
2650
+ float summs = 0;
603
2651
 
604
- const uint8_t * restrict q = x[i].qs;
605
- const uint8_t * restrict hm = x[i].hmask;
606
- uint8_t m = 1;
2652
+ assert(nb % 2 == 0); // TODO: handle odd nb
607
2653
 
608
- memcpy(aux, x[i].scales, 12);
609
- uint32_t tmp = aux[2];
610
- aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
611
- aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
612
- aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
613
- aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
2654
+ for (int i = 0; i < nb; i += 2) {
2655
+ const block_q4_1 * restrict x0 = &x[i + 0];
2656
+ const block_q4_1 * restrict x1 = &x[i + 1];
2657
+ const block_q8_1 * restrict y0 = &y[i + 0];
2658
+ const block_q8_1 * restrict y1 = &y[i + 1];
614
2659
 
615
- int is = 0;
616
- float dl;
617
- for (int n = 0; n < QK_K; n += 128) {
618
- int shift = 0;
619
- for (int j = 0; j < 4; ++j) {
2660
+ summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
620
2661
 
621
- dl = d_all * (scales[is++] - 32);
622
- for (int l = 0; l < 16; ++l) {
623
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
624
- }
2662
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
625
2663
 
626
- dl = d_all * (scales[is++] - 32);
627
- for (int l = 0; l < 16; ++l) {
628
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
629
- }
2664
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2665
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
630
2666
 
631
- shift += 2;
632
- m <<= 1;
633
- }
634
- q += 32;
635
- }
2667
+ // 4-bit -> 8-bit
2668
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2669
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2670
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2671
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2672
+
2673
+ // load y
2674
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2675
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2676
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2677
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2678
+
2679
+ #if defined(__ARM_FEATURE_DOTPROD)
2680
+ // dot product into int32x4_t
2681
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2682
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
636
2683
 
2684
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2685
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
2686
+ #else
2687
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2688
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
2689
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
2690
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
2691
+
2692
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
2693
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
2694
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
2695
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
2696
+
2697
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2698
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2699
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2700
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2701
+
2702
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2703
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2704
+ #endif
637
2705
  }
638
- }
2706
+
2707
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2708
+ #elif defined(__AVX2__) || defined(__AVX__)
2709
+ // Initialize accumulator with zeros
2710
+ __m256 acc = _mm256_setzero_ps();
2711
+
2712
+ float summs = 0;
2713
+
2714
+ // Main loop
2715
+ for (int i = 0; i < nb; ++i) {
2716
+ const float d0 = GGML_FP16_TO_FP32(x[i].d);
2717
+ const float d1 = y[i].d;
2718
+
2719
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2720
+
2721
+ const __m256 d0v = _mm256_set1_ps( d0 );
2722
+ const __m256 d1v = _mm256_set1_ps( d1 );
2723
+
2724
+ // Compute combined scales
2725
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2726
+
2727
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2728
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2729
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2730
+
2731
+ const __m256 xy = mul_sum_us8_pairs_float(bx, by);
2732
+
2733
+ // Accumulate d0*d1*x*y
2734
+ #if defined(__AVX2__)
2735
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
639
2736
  #else
640
- void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
641
- assert(k % QK_K == 0);
642
- assert(QK_K == 64);
643
- const int nb = k / QK_K;
2737
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
2738
+ #endif
2739
+ }
2740
+
2741
+ *s = hsum_float_8(acc) + summs;
2742
+ #elif defined(__riscv_v_intrinsic)
2743
+ float sumf = 0.0;
2744
+
2745
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
644
2746
 
645
2747
  for (int i = 0; i < nb; i++) {
2748
+ // load elements
2749
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
646
2750
 
647
- const float d_all = ggml_fp16_to_fp32(x[i].d);
2751
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2752
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
648
2753
 
649
- const uint8_t * restrict q = x[i].qs;
650
- const uint8_t * restrict hm = x[i].hmask;
2754
+ // mask and store lower part of x, and then upper part
2755
+ vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2756
+ vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
651
2757
 
652
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
653
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
654
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
655
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
2758
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2759
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
656
2760
 
657
- for (int l=0; l<8; ++l) {
658
- uint8_t h = hm[l];
659
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
660
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
661
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
662
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
663
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
664
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
665
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
666
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
2761
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2762
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2763
+
2764
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2765
+
2766
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2767
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2768
+
2769
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2770
+
2771
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2772
+ }
2773
+
2774
+ *s = sumf;
2775
+ #else
2776
+ // scalar
2777
+ float sumf = 0.0;
2778
+
2779
+ for (int i = 0; i < nb; i++) {
2780
+ int sumi = 0;
2781
+
2782
+ for (int j = 0; j < qk/2; ++j) {
2783
+ const int v0 = (x[i].qs[j] & 0x0F);
2784
+ const int v1 = (x[i].qs[j] >> 4);
2785
+
2786
+ sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
667
2787
  }
668
- y += QK_K;
2788
+
2789
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
669
2790
  }
670
- }
671
- #endif
672
2791
 
673
- void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
674
- quantize_row_q3_K_reference(x, vy, k);
2792
+ *s = sumf;
2793
+ #endif
675
2794
  }
676
2795
 
677
- size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
678
- (void)hist; // TODO: collect histograms
2796
+ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2797
+ const int qk = QK8_0;
2798
+ const int nb = n / qk;
679
2799
 
680
- for (int j = 0; j < n; j += k) {
681
- block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
682
- quantize_row_q3_K_reference(src + j, y, k);
2800
+ assert(n % qk == 0);
2801
+ assert(qk == QK5_0);
2802
+
2803
+ const block_q5_0 * restrict x = vx;
2804
+ const block_q8_0 * restrict y = vy;
2805
+
2806
+ #if defined(__ARM_NEON)
2807
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2808
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2809
+
2810
+ uint32_t qh0;
2811
+ uint32_t qh1;
2812
+
2813
+ uint64_t tmp0[4];
2814
+ uint64_t tmp1[4];
2815
+
2816
+ assert(nb % 2 == 0); // TODO: handle odd nb
2817
+
2818
+ for (int i = 0; i < nb; i += 2) {
2819
+ const block_q5_0 * restrict x0 = &x[i];
2820
+ const block_q5_0 * restrict x1 = &x[i + 1];
2821
+ const block_q8_0 * restrict y0 = &y[i];
2822
+ const block_q8_0 * restrict y1 = &y[i + 1];
2823
+
2824
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2825
+
2826
+ // extract the 5th bit via lookup table ((!b) << 4)
2827
+ memcpy(&qh0, x0->qh, sizeof(qh0));
2828
+ memcpy(&qh1, x1->qh, sizeof(qh1));
2829
+
2830
+ tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
2831
+ tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
2832
+ tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
2833
+ tmp0[3] = table_b2b_1[(qh0 >> 24) ];
2834
+
2835
+ tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
2836
+ tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
2837
+ tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
2838
+ tmp1[3] = table_b2b_1[(qh1 >> 24) ];
2839
+
2840
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
2841
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
2842
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
2843
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
2844
+
2845
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2846
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2847
+
2848
+ // 4-bit -> 8-bit
2849
+ int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2850
+ int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2851
+ int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2852
+ int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2853
+
2854
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2855
+ const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
2856
+ const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
2857
+ const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
2858
+ const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
2859
+
2860
+ // load y
2861
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2862
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2863
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2864
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2865
+
2866
+ #if defined(__ARM_FEATURE_DOTPROD)
2867
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2868
+ vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2869
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2870
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2871
+ vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2872
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2873
+ #else
2874
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2875
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2876
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2877
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2878
+
2879
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2880
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2881
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2882
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
2883
+
2884
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2885
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2886
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2887
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2888
+
2889
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2890
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2891
+ #endif
683
2892
  }
684
- return (n/QK_K*sizeof(block_q3_K));
685
- }
686
2893
 
687
- // ====================== 4-bit (de)-quantization
2894
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2895
+ #elif defined(__wasm_simd128__)
2896
+ v128_t sumv = wasm_f32x4_splat(0.0f);
688
2897
 
689
- void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
690
- assert(k % QK_K == 0);
691
- const int nb = k / QK_K;
2898
+ uint32_t qh;
2899
+ uint64_t tmp[4];
692
2900
 
693
- uint8_t L[QK_K];
694
- uint8_t Laux[32];
695
- float weights[32];
696
- float mins[QK_K/32];
697
- float scales[QK_K/32];
2901
+ // TODO: check if unrolling this is better
2902
+ for (int i = 0; i < nb; ++i) {
2903
+ const block_q5_0 * restrict x0 = &x[i];
2904
+ const block_q8_0 * restrict y0 = &y[i];
2905
+
2906
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2907
+
2908
+ // extract the 5th bit
2909
+ memcpy(&qh, x0->qh, sizeof(qh));
2910
+
2911
+ tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
2912
+ tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
2913
+ tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
2914
+ tmp[3] = table_b2b_1[(qh >> 24) ];
2915
+
2916
+ const v128_t qhl = wasm_v128_load(tmp + 0);
2917
+ const v128_t qhh = wasm_v128_load(tmp + 2);
2918
+
2919
+ const v128_t v0 = wasm_v128_load(x0->qs);
2920
+
2921
+ // 4-bit -> 8-bit
2922
+ const v128_t v0l = wasm_v128_and (v0, m4b);
2923
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
2924
+
2925
+ // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2926
+ const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
2927
+ const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
2928
+
2929
+ // load y
2930
+ const v128_t v1l = wasm_v128_load(y0->qs);
2931
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
2932
+
2933
+ // int8x16 -> int16x8
2934
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
2935
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
2936
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
2937
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
2938
+
2939
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
2940
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
2941
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
2942
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
2943
+
2944
+ // dot product
2945
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
2946
+ wasm_i32x4_add(
2947
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
2948
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
2949
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
2950
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
2951
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
2952
+ }
2953
+
2954
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
2955
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
2956
+ #elif defined(__AVX2__)
2957
+ // Initialize accumulator with zeros
2958
+ __m256 acc = _mm256_setzero_ps();
698
2959
 
2960
+ // Main loop
699
2961
  for (int i = 0; i < nb; i++) {
2962
+ /* Compute combined scale for the block */
2963
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
700
2964
 
701
- float max_scale = 0; // as we are deducting the min, scales are always positive
702
- float max_min = 0;
703
- for (int j = 0; j < QK_K/32; ++j) {
704
- //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
705
- float sum_x2 = 0;
706
- for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
707
- float av_x = sqrtf(sum_x2/32);
708
- for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
709
- scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
710
- float scale = scales[j];
711
- if (scale > max_scale) {
712
- max_scale = scale;
713
- }
714
- float min = mins[j];
715
- if (min > max_min) {
716
- max_min = min;
717
- }
718
- }
2965
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2966
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
2967
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
2968
+ bx = _mm256_or_si256(bx, bxhi);
719
2969
 
720
- #if QK_K == 256
721
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
722
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
723
- for (int j = 0; j < QK_K/32; ++j) {
724
- uint8_t ls = nearest_int(inv_scale*scales[j]);
725
- uint8_t lm = nearest_int(inv_min*mins[j]);
726
- ls = MIN(63, ls);
727
- lm = MIN(63, lm);
728
- if (j < 4) {
729
- y[i].scales[j] = ls;
730
- y[i].scales[j+4] = lm;
731
- } else {
732
- y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
733
- y[i].scales[j-4] |= ((ls >> 4) << 6);
734
- y[i].scales[j-0] |= ((lm >> 4) << 6);
735
- }
736
- }
737
- y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
738
- y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
2970
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
739
2971
 
740
- uint8_t sc, m;
741
- for (int j = 0; j < QK_K/32; ++j) {
742
- get_scale_min_k4(j, y[i].scales, &sc, &m);
743
- const float d = ggml_fp16_to_fp32(y[i].d) * sc;
744
- if (!d) continue;
745
- const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
746
- for (int ii = 0; ii < 32; ++ii) {
747
- int l = nearest_int((x[32*j + ii] + dm)/d);
748
- l = MAX(0, MIN(15, l));
749
- L[32*j + ii] = l;
750
- }
751
- }
752
- #else
753
- const float s_factor = 15.f;
754
- float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f;
755
- float inv_min = max_min > 0 ? s_factor/max_min : 0.f;
756
- int d1 = nearest_int(inv_scale*scales[0]);
757
- int m1 = nearest_int(inv_min*mins[0]);
758
- int d2 = nearest_int(inv_scale*scales[1]);
759
- int m2 = nearest_int(inv_min*mins[1]);
760
- y[i].scales[0] = d1 | (m1 << 4);
761
- y[i].scales[1] = d2 | (m2 << 4);
762
- y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor);
763
- y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor);
2972
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
764
2973
 
765
- float sumlx = 0;
766
- int suml2 = 0;
767
- for (int j = 0; j < QK_K/32; ++j) {
768
- const uint8_t sd = y[i].scales[j] & 0xF;
769
- const uint8_t sm = y[i].scales[j] >> 4;
770
- const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd;
771
- if (!d) continue;
772
- const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm;
773
- for (int ii = 0; ii < 32; ++ii) {
774
- int l = nearest_int((x[32*j + ii] + m)/d);
775
- l = MAX(0, MIN(15, l));
776
- L[32*j + ii] = l;
777
- sumlx += (x[32*j + ii] + m)*l*sd;
778
- suml2 += l*l*sd*sd;
779
- }
780
- }
781
- if (suml2) {
782
- y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2);
783
- }
784
- #endif
785
- uint8_t * q = y[i].qs;
786
- for (int j = 0; j < QK_K; j += 64) {
787
- for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
788
- q += 32;
789
- }
2974
+ /* Multiply q with scale and accumulate */
2975
+ acc = _mm256_fmadd_ps(d, q, acc);
2976
+ }
2977
+
2978
+ *s = hsum_float_8(acc);
2979
+ #elif defined(__AVX__)
2980
+ // Initialize accumulator with zeros
2981
+ __m256 acc = _mm256_setzero_ps();
2982
+ __m128i mask = _mm_set1_epi8((char)0xF0);
2983
+
2984
+ // Main loop
2985
+ for (int i = 0; i < nb; i++) {
2986
+ /* Compute combined scale for the block */
2987
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
2988
+
2989
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2990
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
2991
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
2992
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
2993
+ bxhil = _mm_andnot_si128(bxhil, mask);
2994
+ bxhih = _mm_andnot_si128(bxhih, mask);
2995
+ __m128i bxl = _mm256_castsi256_si128(bx);
2996
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
2997
+ bxl = _mm_or_si128(bxl, bxhil);
2998
+ bxh = _mm_or_si128(bxh, bxhih);
2999
+ bx = MM256_SET_M128I(bxh, bxl);
3000
+
3001
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3002
+
3003
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3004
+
3005
+ /* Multiply q with scale and accumulate */
3006
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
3007
+ }
790
3008
 
791
- x += QK_K;
3009
+ *s = hsum_float_8(acc);
3010
+ #elif defined(__riscv_v_intrinsic)
3011
+ float sumf = 0.0;
792
3012
 
793
- }
794
- }
3013
+ uint32_t qh;
795
3014
 
796
- void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
797
- assert(k % QK_K == 0);
798
- const int nb = k / QK_K;
3015
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
3016
+
3017
+ // These tempory registers are for masking and shift operations
3018
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3019
+ vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3020
+
3021
+ vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3022
+ vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
799
3023
 
800
3024
  for (int i = 0; i < nb; i++) {
3025
+ memcpy(&qh, x[i].qh, sizeof(uint32_t));
801
3026
 
802
- const uint8_t * q = x[i].qs;
3027
+ // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3028
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3029
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3030
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
803
3031
 
804
- #if QK_K == 256
3032
+ // ((qh & (1u << (j + 16))) >> (j + 12));
3033
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3034
+ vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
805
3035
 
806
- const float d = ggml_fp16_to_fp32(x[i].d);
807
- const float min = ggml_fp16_to_fp32(x[i].dmin);
3036
+ // narrowing
3037
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3038
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
808
3039
 
809
- int is = 0;
810
- uint8_t sc, m;
811
- for (int j = 0; j < QK_K; j += 64) {
812
- get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
813
- const float d1 = d * sc; const float m1 = min * m;
814
- get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
815
- const float d2 = d * sc; const float m2 = min * m;
816
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
817
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
818
- q += 32; is += 2;
819
- }
3040
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3041
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3042
+
3043
+ // load
3044
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3045
+
3046
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3047
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3048
+
3049
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3050
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3051
+
3052
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3053
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3054
+
3055
+ vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3056
+ vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3057
+
3058
+ vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3059
+ vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3060
+
3061
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3062
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3063
+
3064
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3065
+
3066
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3067
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3068
+
3069
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3070
+
3071
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3072
+ }
3073
+
3074
+ *s = sumf;
820
3075
  #else
821
- const float dall = ggml_fp16_to_fp32(x[i].d[0]);
822
- const float mall = ggml_fp16_to_fp32(x[i].d[1]);
823
- const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4);
824
- const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4);
825
- for (int l = 0; l < 32; ++l) {
826
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
827
- y[l+32] = d2 * (q[l] >> 4) - m2;
3076
+ // scalar
3077
+ float sumf = 0.0;
3078
+
3079
+ for (int i = 0; i < nb; i++) {
3080
+ uint32_t qh;
3081
+ memcpy(&qh, x[i].qh, sizeof(qh));
3082
+
3083
+ int sumi = 0;
3084
+
3085
+ for (int j = 0; j < qk/2; ++j) {
3086
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3087
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
3088
+
3089
+ const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
3090
+ const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
3091
+
3092
+ sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
828
3093
  }
829
- y += QK_K;
830
- #endif
831
3094
 
3095
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
832
3096
  }
833
- }
834
3097
 
835
- void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
836
- assert(k % QK_K == 0);
837
- block_q4_K * restrict y = vy;
838
- quantize_row_q4_K_reference(x, y, k);
3098
+ *s = sumf;
3099
+ #endif
839
3100
  }
840
3101
 
841
- size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
842
- assert(k % QK_K == 0);
843
- (void)hist; // TODO: collect histograms
3102
+ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3103
+ const int qk = QK8_1;
3104
+ const int nb = n / qk;
844
3105
 
845
- for (int j = 0; j < n; j += k) {
846
- block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
847
- quantize_row_q4_K_reference(src + j, y, k);
848
- }
849
- return (n/QK_K*sizeof(block_q4_K));
850
- }
3106
+ assert(n % qk == 0);
3107
+ assert(qk == QK5_1);
851
3108
 
852
- // ====================== 5-bit (de)-quantization
3109
+ const block_q5_1 * restrict x = vx;
3110
+ const block_q8_1 * restrict y = vy;
853
3111
 
854
- void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
855
- assert(k % QK_K == 0);
856
- const int nb = k / QK_K;
3112
+ #if defined(__ARM_NEON)
3113
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3114
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
857
3115
 
858
- #if QK_K == 256
859
- uint8_t L[QK_K];
860
- float mins[QK_K/32];
861
- float scales[QK_K/32];
862
- float weights[32];
863
- uint8_t Laux[32];
3116
+ float summs0 = 0.0f;
3117
+ float summs1 = 0.0f;
3118
+
3119
+ uint32_t qh0;
3120
+ uint32_t qh1;
3121
+
3122
+ uint64_t tmp0[4];
3123
+ uint64_t tmp1[4];
3124
+
3125
+ assert(nb % 2 == 0); // TODO: handle odd nb
3126
+
3127
+ for (int i = 0; i < nb; i += 2) {
3128
+ const block_q5_1 * restrict x0 = &x[i];
3129
+ const block_q5_1 * restrict x1 = &x[i + 1];
3130
+ const block_q8_1 * restrict y0 = &y[i];
3131
+ const block_q8_1 * restrict y1 = &y[i + 1];
3132
+
3133
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3134
+
3135
+ summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
3136
+ summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
3137
+
3138
+ // extract the 5th bit via lookup table ((b) << 4)
3139
+ memcpy(&qh0, x0->qh, sizeof(qh0));
3140
+ memcpy(&qh1, x1->qh, sizeof(qh1));
3141
+
3142
+ tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
3143
+ tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
3144
+ tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
3145
+ tmp0[3] = table_b2b_0[(qh0 >> 24) ];
3146
+
3147
+ tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
3148
+ tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
3149
+ tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
3150
+ tmp1[3] = table_b2b_0[(qh1 >> 24) ];
3151
+
3152
+ const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
3153
+ const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
3154
+ const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
3155
+ const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
3156
+
3157
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
3158
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
3159
+
3160
+ // 4-bit -> 8-bit
3161
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3162
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3163
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3164
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3165
+
3166
+ // add high bit
3167
+ const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
3168
+ const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
3169
+ const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
3170
+ const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
3171
+
3172
+ // load y
3173
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3174
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3175
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
3176
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3177
+
3178
+ #if defined(__ARM_FEATURE_DOTPROD)
3179
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3180
+ vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3181
+ vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3182
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3183
+ vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3184
+ vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
864
3185
  #else
865
- int8_t L[QK_K];
866
- float scales[QK_K/16];
3186
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
3187
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
3188
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
3189
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
3190
+
3191
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
3192
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
3193
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
3194
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
3195
+
3196
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3197
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3198
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3199
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3200
+
3201
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
3202
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
867
3203
  #endif
3204
+ }
868
3205
 
869
- for (int i = 0; i < nb; i++) {
3206
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3207
+ #elif defined(__wasm_simd128__)
3208
+ v128_t sumv = wasm_f32x4_splat(0.0f);
870
3209
 
871
- #if QK_K == 256
3210
+ float summs = 0.0f;
872
3211
 
873
- float max_scale = 0; // as we are deducting the min, scales are always positive
874
- float max_min = 0;
875
- for (int j = 0; j < QK_K/32; ++j) {
876
- //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
877
- float sum_x2 = 0;
878
- for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
879
- float av_x = sqrtf(sum_x2/32);
880
- for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
881
- scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
882
- float scale = scales[j];
883
- if (scale > max_scale) {
884
- max_scale = scale;
885
- }
886
- float min = mins[j];
887
- if (min > max_min) {
888
- max_min = min;
889
- }
890
- }
3212
+ uint32_t qh;
3213
+ uint64_t tmp[4];
891
3214
 
892
- float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
893
- float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
894
- for (int j = 0; j < QK_K/32; ++j) {
895
- uint8_t ls = nearest_int(inv_scale*scales[j]);
896
- uint8_t lm = nearest_int(inv_min*mins[j]);
897
- ls = MIN(63, ls);
898
- lm = MIN(63, lm);
899
- if (j < 4) {
900
- y[i].scales[j] = ls;
901
- y[i].scales[j+4] = lm;
902
- } else {
903
- y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
904
- y[i].scales[j-4] |= ((ls >> 4) << 6);
905
- y[i].scales[j-0] |= ((lm >> 4) << 6);
906
- }
907
- }
908
- y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
909
- y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
3215
+ // TODO: check if unrolling this is better
3216
+ for (int i = 0; i < nb; ++i) {
3217
+ const block_q5_1 * restrict x0 = &x[i];
3218
+ const block_q8_1 * restrict y0 = &y[i];
910
3219
 
911
- uint8_t sc, m;
912
- for (int j = 0; j < QK_K/32; ++j) {
913
- get_scale_min_k4(j, y[i].scales, &sc, &m);
914
- const float d = ggml_fp16_to_fp32(y[i].d) * sc;
915
- if (!d) continue;
916
- const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
917
- for (int ii = 0; ii < 32; ++ii) {
918
- int l = nearest_int((x[32*j + ii] + dm)/d);
919
- l = MAX(0, MIN(31, l));
920
- L[32*j + ii] = l;
921
- }
922
- }
3220
+ summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
923
3221
 
924
- uint8_t * restrict qh = y[i].qh;
925
- uint8_t * restrict ql = y[i].qs;
926
- memset(qh, 0, QK_K/8);
3222
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
927
3223
 
928
- uint8_t m1 = 1, m2 = 2;
929
- for (int n = 0; n < QK_K; n += 64) {
930
- for (int j = 0; j < 32; ++j) {
931
- int l1 = L[n + j];
932
- if (l1 > 15) {
933
- l1 -= 16; qh[j] |= m1;
934
- }
935
- int l2 = L[n + j + 32];
936
- if (l2 > 15) {
937
- l2 -= 16; qh[j] |= m2;
938
- }
939
- ql[j] = l1 | (l2 << 4);
940
- }
941
- m1 <<= 2; m2 <<= 2;
942
- ql += 32;
943
- }
944
- #else
945
- float max_scale = 0, amax = 0;
946
- for (int j = 0; j < QK_K/16; ++j) {
947
- scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1);
948
- float abs_scale = fabsf(scales[j]);
949
- if (abs_scale > amax) {
950
- amax = abs_scale;
951
- max_scale = scales[j];
952
- }
953
- }
3224
+ // extract the 5th bit
3225
+ memcpy(&qh, x0->qh, sizeof(qh));
954
3226
 
955
- float iscale = -128.f/max_scale;
956
- for (int j = 0; j < QK_K/16; ++j) {
957
- int l = nearest_int(iscale*scales[j]);
958
- y[i].scales[j] = MAX(-128, MIN(127, l));
959
- }
960
- y[i].d = ggml_fp32_to_fp16(1/iscale);
3227
+ tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3228
+ tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3229
+ tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3230
+ tmp[3] = table_b2b_0[(qh >> 24) ];
961
3231
 
962
- for (int j = 0; j < QK_K/16; ++j) {
963
- const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
964
- if (!d) continue;
965
- for (int ii = 0; ii < 16; ++ii) {
966
- int l = nearest_int(x[16*j + ii]/d);
967
- l = MAX(-16, MIN(15, l));
968
- L[16*j + ii] = l + 16;
969
- }
970
- }
3232
+ const v128_t qhl = wasm_v128_load(tmp + 0);
3233
+ const v128_t qhh = wasm_v128_load(tmp + 2);
971
3234
 
972
- uint8_t * restrict qh = y[i].qh;
973
- uint8_t * restrict ql = y[i].qs;
974
- memset(qh, 0, QK_K/8);
3235
+ const v128_t v0 = wasm_v128_load(x0->qs);
975
3236
 
976
- for (int j = 0; j < 32; ++j) {
977
- int jm = j%8;
978
- int is = j/8;
979
- int l1 = L[j];
980
- if (l1 > 15) {
981
- l1 -= 16; qh[jm] |= (1 << is);
982
- }
983
- int l2 = L[j + 32];
984
- if (l2 > 15) {
985
- l2 -= 16; qh[jm] |= (1 << (4 + is));
986
- }
987
- ql[j] = l1 | (l2 << 4);
988
- }
989
- #endif
3237
+ // 4-bit -> 8-bit
3238
+ const v128_t v0l = wasm_v128_and (v0, m4b);
3239
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
990
3240
 
991
- x += QK_K;
3241
+ // add high bit
3242
+ const v128_t v0lf = wasm_v128_or(v0l, qhl);
3243
+ const v128_t v0hf = wasm_v128_or(v0h, qhh);
3244
+
3245
+ // load y
3246
+ const v128_t v1l = wasm_v128_load(y0->qs);
3247
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
3248
+
3249
+ // int8x16 -> int16x8
3250
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3251
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3252
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3253
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
992
3254
 
3255
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3256
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3257
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3258
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3259
+
3260
+ // dot product
3261
+ sumv = wasm_f32x4_add(sumv,
3262
+ wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
3263
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3264
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3265
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3266
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3267
+ wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
993
3268
  }
994
- }
995
3269
 
996
- void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
997
- assert(k % QK_K == 0);
998
- const int nb = k / QK_K;
3270
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3271
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
3272
+ #elif defined(__AVX2__)
3273
+ // Initialize accumulator with zeros
3274
+ __m256 acc = _mm256_setzero_ps();
999
3275
 
3276
+ float summs = 0.0f;
3277
+
3278
+ // Main loop
1000
3279
  for (int i = 0; i < nb; i++) {
3280
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
1001
3281
 
1002
- const uint8_t * ql = x[i].qs;
1003
- const uint8_t * qh = x[i].qh;
3282
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
1004
3283
 
1005
- #if QK_K == 256
3284
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3285
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3286
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3287
+ bx = _mm256_or_si256(bx, bxhi);
1006
3288
 
1007
- const float d = ggml_fp16_to_fp32(x[i].d);
1008
- const float min = ggml_fp16_to_fp32(x[i].dmin);
3289
+ const __m256 dy = _mm256_set1_ps(y[i].d);
3290
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1009
3291
 
1010
- int is = 0;
1011
- uint8_t sc, m;
1012
- uint8_t u1 = 1, u2 = 2;
1013
- for (int j = 0; j < QK_K; j += 64) {
1014
- get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
1015
- const float d1 = d * sc; const float m1 = min * m;
1016
- get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
1017
- const float d2 = d * sc; const float m2 = min * m;
1018
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
1019
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
1020
- ql += 32; is += 2;
1021
- u1 <<= 2; u2 <<= 2;
1022
- }
1023
- #else
1024
- float d = ggml_fp16_to_fp32(x[i].d);
1025
- const int8_t * restrict s = x[i].scales;
1026
- for (int l = 0; l < 8; ++l) {
1027
- y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1028
- y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1029
- y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1030
- y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1031
- y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1032
- y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1033
- y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1034
- y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1035
- }
1036
- y += QK_K;
1037
- #endif
3292
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
3293
+
3294
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
1038
3295
  }
1039
- }
1040
3296
 
1041
- void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
1042
- assert(k % QK_K == 0);
1043
- block_q5_K * restrict y = vy;
1044
- quantize_row_q5_K_reference(x, y, k);
1045
- }
3297
+ *s = hsum_float_8(acc) + summs;
3298
+ #elif defined(__AVX__)
3299
+ // Initialize accumulator with zeros
3300
+ __m256 acc = _mm256_setzero_ps();
3301
+ __m128i mask = _mm_set1_epi8(0x10);
3302
+
3303
+ float summs = 0.0f;
3304
+
3305
+ // Main loop
3306
+ for (int i = 0; i < nb; i++) {
3307
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
3308
+
3309
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
3310
+
3311
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3312
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
3313
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
3314
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
3315
+ bxhil = _mm_and_si128(bxhil, mask);
3316
+ bxhih = _mm_and_si128(bxhih, mask);
3317
+ __m128i bxl = _mm256_castsi256_si128(bx);
3318
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
3319
+ bxl = _mm_or_si128(bxl, bxhil);
3320
+ bxh = _mm_or_si128(bxh, bxhih);
3321
+ bx = MM256_SET_M128I(bxh, bxl);
3322
+
3323
+ const __m256 dy = _mm256_set1_ps(y[i].d);
3324
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1046
3325
 
1047
- size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
1048
- assert(k % QK_K == 0);
1049
- (void)hist; // TODO: collect histograms
3326
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
1050
3327
 
1051
- for (int j = 0; j < n; j += k) {
1052
- block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
1053
- quantize_row_q5_K_reference(src + j, y, k);
3328
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
1054
3329
  }
1055
- return (n/QK_K*sizeof(block_q5_K));
1056
- }
1057
3330
 
1058
- // ====================== 6-bit (de)-quantization
3331
+ *s = hsum_float_8(acc) + summs;
3332
+ #elif defined(__riscv_v_intrinsic)
3333
+ float sumf = 0.0;
1059
3334
 
1060
- void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
1061
- assert(k % QK_K == 0);
1062
- const int nb = k / QK_K;
3335
+ uint32_t qh;
1063
3336
 
1064
- int8_t L[QK_K];
1065
- float scales[QK_K/16];
3337
+ size_t vl = __riscv_vsetvl_e8m1(qk/2);
3338
+
3339
+ // temporary registers for shift operations
3340
+ vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3341
+ vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
1066
3342
 
1067
3343
  for (int i = 0; i < nb; i++) {
3344
+ memcpy(&qh, x[i].qh, sizeof(uint32_t));
1068
3345
 
1069
- float max_scale = 0;
1070
- float max_abs_scale = 0;
3346
+ // load qh
3347
+ vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
1071
3348
 
1072
- for (int ib = 0; ib < QK_K/16; ++ib) {
3349
+ // ((qh >> (j + 0)) << 4) & 0x10;
3350
+ vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3351
+ vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3352
+ vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
1073
3353
 
1074
- const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
1075
- scales[ib] = scale;
3354
+ // ((qh >> (j + 12)) ) & 0x10;
3355
+ vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3356
+ vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
1076
3357
 
1077
- const float abs_scale = fabsf(scale);
1078
- if (abs_scale > max_abs_scale) {
1079
- max_abs_scale = abs_scale;
1080
- max_scale = scale;
1081
- }
3358
+ // narrowing
3359
+ vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3360
+ vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
1082
3361
 
1083
- }
3362
+ vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3363
+ vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
1084
3364
 
1085
- if (!max_abs_scale) {
1086
- memset(&y[i], 0, sizeof(block_q6_K));
1087
- y[i].d = ggml_fp32_to_fp16(0.f);
1088
- x += QK_K;
1089
- continue;
1090
- }
3365
+ // load
3366
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
1091
3367
 
1092
- float iscale = -128.f/max_scale;
1093
- y[i].d = ggml_fp32_to_fp16(1/iscale);
1094
- for (int ib = 0; ib < QK_K/16; ++ib) {
1095
- y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
1096
- }
3368
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3369
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
1097
3370
 
1098
- for (int j = 0; j < QK_K/16; ++j) {
1099
- float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
1100
- if (!d) {
1101
- continue;
1102
- }
1103
- for (int ii = 0; ii < 16; ++ii) {
1104
- int l = nearest_int(x[16*j + ii]/d);
1105
- l = MAX(-32, MIN(31, l));
1106
- L[16*j + ii] = l + 32;
1107
- }
1108
- }
3371
+ vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3372
+ vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
1109
3373
 
1110
- uint8_t * restrict ql = y[i].ql;
1111
- uint8_t * restrict qh = y[i].qh;
1112
- #if QK_K == 256
1113
- for (int j = 0; j < QK_K; j += 128) {
1114
- for (int l = 0; l < 32; ++l) {
1115
- const uint8_t q1 = L[j + l + 0] & 0xF;
1116
- const uint8_t q2 = L[j + l + 32] & 0xF;
1117
- const uint8_t q3 = L[j + l + 64] & 0xF;
1118
- const uint8_t q4 = L[j + l + 96] & 0xF;
1119
- ql[l+ 0] = q1 | (q3 << 4);
1120
- ql[l+32] = q2 | (q4 << 4);
1121
- qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
1122
- }
1123
- ql += 64;
1124
- qh += 32;
1125
- }
1126
- #else
1127
- for (int l = 0; l < 32; ++l) {
1128
- const uint8_t q1 = L[l + 0] & 0xF;
1129
- const uint8_t q2 = L[l + 32] & 0xF;
1130
- ql[l] = q1 | (q2 << 4);
1131
- }
1132
- for (int l = 0; l < 16; ++l) {
1133
- qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
1134
- }
1135
- #endif
3374
+ vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3375
+ vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
1136
3376
 
1137
- x += QK_K;
3377
+ vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3378
+ vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3379
+
3380
+ vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3381
+ vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3382
+
3383
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
1138
3384
 
3385
+ vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3386
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3387
+
3388
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3389
+
3390
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1139
3391
  }
1140
- }
1141
3392
 
1142
- void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
1143
- assert(k % QK_K == 0);
1144
- const int nb = k / QK_K;
3393
+ *s = sumf;
3394
+ #else
3395
+ // scalar
3396
+ float sumf = 0.0;
1145
3397
 
1146
3398
  for (int i = 0; i < nb; i++) {
3399
+ uint32_t qh;
3400
+ memcpy(&qh, x[i].qh, sizeof(qh));
1147
3401
 
1148
- const float d = ggml_fp16_to_fp32(x[i].d);
3402
+ int sumi = 0;
1149
3403
 
1150
- const uint8_t * restrict ql = x[i].ql;
1151
- const uint8_t * restrict qh = x[i].qh;
1152
- const int8_t * restrict sc = x[i].scales;
3404
+ for (int j = 0; j < qk/2; ++j) {
3405
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
3406
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1153
3407
 
1154
- #if QK_K == 256
1155
- for (int n = 0; n < QK_K; n += 128) {
1156
- for (int l = 0; l < 32; ++l) {
1157
- int is = l/16;
1158
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1159
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1160
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1161
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1162
- y[l + 0] = d * sc[is + 0] * q1;
1163
- y[l + 32] = d * sc[is + 2] * q2;
1164
- y[l + 64] = d * sc[is + 4] * q3;
1165
- y[l + 96] = d * sc[is + 6] * q4;
1166
- }
1167
- y += 128;
1168
- ql += 64;
1169
- qh += 32;
1170
- sc += 8;
1171
- }
1172
- #else
1173
- for (int l = 0; l < 16; ++l) {
1174
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1175
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1176
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1177
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1178
- y[l+ 0] = d * sc[0] * q1;
1179
- y[l+16] = d * sc[1] * q2;
1180
- y[l+32] = d * sc[2] * q3;
1181
- y[l+48] = d * sc[3] * q4;
3408
+ const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
3409
+ const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
3410
+
3411
+ sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
1182
3412
  }
1183
- y += 64;
1184
- #endif
1185
3413
 
3414
+ sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1186
3415
  }
1187
- }
1188
3416
 
1189
- void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
1190
- assert(k % QK_K == 0);
1191
- block_q6_K * restrict y = vy;
1192
- quantize_row_q6_K_reference(x, y, k);
3417
+ *s = sumf;
3418
+ #endif
1193
3419
  }
1194
3420
 
1195
- size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
1196
- assert(k % QK_K == 0);
1197
- (void)hist; // TODO: collect histograms
3421
+ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3422
+ const int qk = QK8_0;
3423
+ const int nb = n / qk;
1198
3424
 
1199
- for (int j = 0; j < n; j += k) {
1200
- block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
1201
- quantize_row_q6_K_reference(src + j, y, k);
3425
+ assert(n % qk == 0);
3426
+
3427
+ const block_q8_0 * restrict x = vx;
3428
+ const block_q8_0 * restrict y = vy;
3429
+
3430
+ #if defined(__ARM_NEON)
3431
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3432
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3433
+
3434
+ assert(nb % 2 == 0); // TODO: handle odd nb
3435
+
3436
+ for (int i = 0; i < nb; i += 2) {
3437
+ const block_q8_0 * restrict x0 = &x[i + 0];
3438
+ const block_q8_0 * restrict x1 = &x[i + 1];
3439
+ const block_q8_0 * restrict y0 = &y[i + 0];
3440
+ const block_q8_0 * restrict y1 = &y[i + 1];
3441
+
3442
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3443
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3444
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3445
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
3446
+
3447
+ // load y
3448
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3449
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3450
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3451
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3452
+
3453
+ #if defined(__ARM_FEATURE_DOTPROD)
3454
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3455
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3456
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3457
+
3458
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3459
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3460
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3461
+
3462
+ #else
3463
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3464
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3465
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3466
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3467
+
3468
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3469
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3470
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3471
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3472
+
3473
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3474
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3475
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3476
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3477
+
3478
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3479
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3480
+ #endif
1202
3481
  }
1203
- return (n/QK_K*sizeof(block_q6_K));
1204
- }
1205
3482
 
1206
- //===================================== Q8_K ==============================================
3483
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3484
+ #elif defined(__AVX2__) || defined(__AVX__)
3485
+ // Initialize accumulator with zeros
3486
+ __m256 acc = _mm256_setzero_ps();
1207
3487
 
1208
- void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
1209
- assert(k % QK_K == 0);
1210
- const int nb = k / QK_K;
3488
+ // Main loop
3489
+ for (int i = 0; i < nb; ++i) {
3490
+ // Compute combined scale for the block
3491
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3492
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3493
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1211
3494
 
1212
- for (int i = 0; i < nb; i++) {
3495
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
1213
3496
 
1214
- float max = 0;
1215
- float amax = 0;
1216
- for (int j = 0; j < QK_K; ++j) {
1217
- float ax = fabsf(x[j]);
1218
- if (ax > amax) {
1219
- amax = ax; max = x[j];
1220
- }
1221
- }
1222
- if (!amax) {
1223
- y[i].d = 0;
1224
- memset(y[i].qs, 0, QK_K);
1225
- x += QK_K;
1226
- continue;
1227
- }
1228
- const float iscale = -128.f/max;
1229
- for (int j = 0; j < QK_K; ++j) {
1230
- int v = nearest_int(iscale*x[j]);
1231
- y[i].qs[j] = MIN(127, v);
1232
- }
1233
- for (int j = 0; j < QK_K/16; ++j) {
1234
- int sum = 0;
1235
- for (int ii = 0; ii < 16; ++ii) {
1236
- sum += y[i].qs[j*16 + ii];
1237
- }
1238
- y[i].bsums[j] = sum;
1239
- }
1240
- y[i].d = 1/iscale;
1241
- x += QK_K;
3497
+ // Multiply q with scale and accumulate
3498
+ #if defined(__AVX2__)
3499
+ acc = _mm256_fmadd_ps( d, q, acc );
3500
+ #else
3501
+ acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
3502
+ #endif
1242
3503
  }
1243
- }
1244
3504
 
1245
- void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
1246
- assert(k % QK_K == 0);
1247
- const int nb = k / QK_K;
3505
+ *s = hsum_float_8(acc);
3506
+ #elif defined(__riscv_v_intrinsic)
3507
+ float sumf = 0.0;
3508
+ size_t vl = __riscv_vsetvl_e8m1(qk);
1248
3509
 
1249
3510
  for (int i = 0; i < nb; i++) {
1250
- for (int j = 0; j < QK_K; ++j) {
1251
- *y++ = x[i].d * x[i].qs[j];
1252
- }
3511
+ // load elements
3512
+ vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
3513
+ vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
3514
+
3515
+ vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
3516
+
3517
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3518
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3519
+
3520
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
3521
+
3522
+ sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
1253
3523
  }
1254
- }
1255
3524
 
1256
- void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
1257
- quantize_row_q8_K_reference(x, y, k);
1258
- }
3525
+ *s = sumf;
3526
+ #else
3527
+ // scalar
3528
+ float sumf = 0.0;
1259
3529
 
1260
- //===================================== Dot ptoducts =================================
3530
+ for (int i = 0; i < nb; i++) {
3531
+ int sumi = 0;
1261
3532
 
1262
- //
1263
- // Helper functions
1264
- //
1265
- #if __AVX__ || __AVX2__ || __AVX512F__
3533
+ for (int j = 0; j < qk; j++) {
3534
+ sumi += x[i].qs[j]*y[i].qs[j];
3535
+ }
1266
3536
 
1267
- // horizontally add 8 floats
1268
- static inline float hsum_float_8(const __m256 x) {
1269
- __m128 res = _mm256_extractf128_ps(x, 1);
1270
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
1271
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
1272
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
1273
- return _mm_cvtss_f32(res);
1274
- }
3537
+ sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
3538
+ }
1275
3539
 
1276
- // shuffles to pick the required scales in dot products
1277
- static inline __m256i get_scale_shuffle_q3k(int i) {
1278
- static const uint8_t k_shuffle[128] = {
1279
- 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
1280
- 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
1281
- 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
1282
- 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
1283
- };
1284
- return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
1285
- }
1286
- static inline __m256i get_scale_shuffle_k4(int i) {
1287
- static const uint8_t k_shuffle[256] = {
1288
- 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
1289
- 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
1290
- 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
1291
- 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
1292
- 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
1293
- 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
1294
- 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
1295
- 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
1296
- };
1297
- return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
1298
- }
1299
- static inline __m128i get_scale_shuffle(int i) {
1300
- static const uint8_t k_shuffle[128] = {
1301
- 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1302
- 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
1303
- 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
1304
- 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
1305
- 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
1306
- 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
1307
- 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
1308
- 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
1309
- };
1310
- return _mm_loadu_si128((const __m128i*)k_shuffle + i);
1311
- }
3540
+ *s = sumf;
1312
3541
  #endif
3542
+ }
1313
3543
 
1314
3544
  #if QK_K == 256
1315
3545
  void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -1334,8 +3564,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1334
3564
 
1335
3565
  for (int i = 0; i < nb; ++i) {
1336
3566
 
1337
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1338
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3567
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3568
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1339
3569
 
1340
3570
  const uint8_t * restrict q2 = x[i].qs;
1341
3571
  const int8_t * restrict q8 = y[i].qs;
@@ -1413,8 +3643,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1413
3643
 
1414
3644
  for (int i = 0; i < nb; ++i) {
1415
3645
 
1416
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1417
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3646
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3647
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1418
3648
 
1419
3649
  const uint8_t * restrict q2 = x[i].qs;
1420
3650
  const int8_t * restrict q8 = y[i].qs;
@@ -1480,8 +3710,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1480
3710
 
1481
3711
  for (int i = 0; i < nb; ++i) {
1482
3712
 
1483
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1484
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3713
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3714
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1485
3715
 
1486
3716
  const uint8_t * restrict q2 = x[i].qs;
1487
3717
  const int8_t * restrict q8 = y[i].qs;
@@ -1588,8 +3818,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1588
3818
  const int8_t * q8 = y[i].qs;
1589
3819
  const uint8_t * sc = x[i].scales;
1590
3820
 
1591
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1592
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3821
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3822
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1593
3823
 
1594
3824
  size_t vl = 16;
1595
3825
 
@@ -1675,8 +3905,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1675
3905
  summs += y[i].bsums[j] * (sc[j] >> 4);
1676
3906
  }
1677
3907
 
1678
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1679
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3908
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
3909
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1680
3910
 
1681
3911
  int isum = 0;
1682
3912
  int is = 0;
@@ -1793,8 +4023,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1793
4023
 
1794
4024
  for (int i = 0; i < nb; ++i) {
1795
4025
 
1796
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1797
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
4026
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4027
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1798
4028
 
1799
4029
  const uint8_t * restrict q2 = x[i].qs;
1800
4030
  const int8_t * restrict q8 = y[i].qs;
@@ -1845,8 +4075,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1845
4075
 
1846
4076
  for (int i = 0; i < nb; ++i) {
1847
4077
 
1848
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1849
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
4078
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4079
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1850
4080
 
1851
4081
  const uint8_t * restrict q2 = x[i].qs;
1852
4082
  const int8_t * restrict q8 = y[i].qs;
@@ -1960,8 +4190,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1960
4190
  summs += y[i].bsums[j] * (sc[j] >> 4);
1961
4191
  }
1962
4192
 
1963
- const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1964
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
4193
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4194
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
1965
4195
 
1966
4196
  isum[0] = isum[1] = isum[2] = isum[3] = 0;
1967
4197
  for (int l = 0; l < 16; ++l) {
@@ -2014,7 +4244,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2014
4244
 
2015
4245
  for (int i = 0; i < nb; ++i) {
2016
4246
 
2017
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4247
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2018
4248
 
2019
4249
  const uint8_t * restrict q3 = x[i].qs;
2020
4250
  const uint8_t * restrict qh = x[i].hmask;
@@ -2122,7 +4352,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2122
4352
 
2123
4353
  for (int i = 0; i < nb; ++i) {
2124
4354
 
2125
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4355
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2126
4356
 
2127
4357
  const uint8_t * restrict q3 = x[i].qs;
2128
4358
  const int8_t * restrict q8 = y[i].qs;
@@ -2227,7 +4457,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2227
4457
 
2228
4458
  for (int i = 0; i < nb; ++i) {
2229
4459
 
2230
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4460
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2231
4461
 
2232
4462
  const uint8_t * restrict q3 = x[i].qs;
2233
4463
  const int8_t * restrict q8 = y[i].qs;
@@ -2448,7 +4678,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2448
4678
 
2449
4679
  }
2450
4680
 
2451
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
4681
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2452
4682
 
2453
4683
  sumf += d*sum_t;
2454
4684
 
@@ -2513,7 +4743,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2513
4743
  for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
2514
4744
  q8 += 8; a += 8;
2515
4745
  }
2516
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
4746
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2517
4747
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2518
4748
  }
2519
4749
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -2615,7 +4845,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2615
4845
 
2616
4846
  for (int i = 0; i < nb; ++i) {
2617
4847
 
2618
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4848
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2619
4849
 
2620
4850
  const uint8_t * restrict q3 = x[i].qs;
2621
4851
  const int8_t * restrict q8 = y[i].qs;
@@ -2686,7 +4916,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2686
4916
 
2687
4917
  for (int i = 0; i < nb; ++i) {
2688
4918
 
2689
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4919
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
2690
4920
 
2691
4921
  const uint8_t * restrict q3 = x[i].qs;
2692
4922
  const int8_t * restrict q8 = y[i].qs;
@@ -2871,7 +5101,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2871
5101
  q8 += 8; a += 8;
2872
5102
  for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l];
2873
5103
  }
2874
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
5104
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
2875
5105
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2876
5106
  }
2877
5107
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -2911,8 +5141,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2911
5141
 
2912
5142
  for (int i = 0; i < nb; ++i) {
2913
5143
 
2914
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2915
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5144
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5145
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
2916
5146
 
2917
5147
  const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
2918
5148
 
@@ -2994,8 +5224,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2994
5224
 
2995
5225
  for (int i = 0; i < nb; ++i) {
2996
5226
 
2997
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2998
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5227
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5228
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
2999
5229
 
3000
5230
  memcpy(utmp, x[i].scales, 12);
3001
5231
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -3060,8 +5290,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3060
5290
 
3061
5291
  for (int i = 0; i < nb; ++i) {
3062
5292
 
3063
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3064
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5293
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5294
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3065
5295
 
3066
5296
  const uint8_t * restrict q4 = x[i].qs;
3067
5297
  const int8_t * restrict q8 = y[i].qs;
@@ -3143,8 +5373,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3143
5373
 
3144
5374
  size_t vl = 8;
3145
5375
 
3146
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3147
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5376
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5377
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3148
5378
 
3149
5379
  vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
3150
5380
  vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
@@ -3254,9 +5484,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3254
5484
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3255
5485
  q8 += 8; a += 8;
3256
5486
  }
3257
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
5487
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
3258
5488
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3259
- const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
5489
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
3260
5490
  sumf -= dmin * sumi;
3261
5491
  }
3262
5492
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -3358,8 +5588,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3358
5588
 
3359
5589
  for (int i = 0; i < nb; ++i) {
3360
5590
 
3361
- const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
3362
- const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
5591
+ const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
5592
+ const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
3363
5593
  const __m256 vd = _mm256_set1_ps(d);
3364
5594
 
3365
5595
  const uint16_t * a = (const uint16_t *)x[i].scales;
@@ -3404,8 +5634,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3404
5634
 
3405
5635
  for (int i = 0; i < nb; ++i) {
3406
5636
 
3407
- const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
3408
- const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
5637
+ const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d;
5638
+ const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d;
3409
5639
  const __m256 vd = _mm256_set1_ps(d);
3410
5640
 
3411
5641
  const uint16_t * a = (const uint16_t *)x[i].scales;
@@ -3461,8 +5691,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3461
5691
  s16[0] = b[0] & 0x0f0f;
3462
5692
  s16[1] = (b[0] >> 4) & 0x0f0f;
3463
5693
 
3464
- sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
3465
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
5694
+ sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
5695
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
3466
5696
 
3467
5697
  size_t vl = 32;
3468
5698
 
@@ -3511,9 +5741,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3511
5741
  s16[0] = b[0] & 0x0f0f;
3512
5742
  s16[1] = (b[0] >> 4) & 0x0f0f;
3513
5743
 
3514
- sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
5744
+ sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
3515
5745
 
3516
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
5746
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]);
3517
5747
 
3518
5748
  for (int j = 0; j < QK_K/32; ++j) {
3519
5749
  for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l];
@@ -3561,8 +5791,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3561
5791
 
3562
5792
  for (int i = 0; i < nb; ++i) {
3563
5793
 
3564
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3565
- const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5794
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5795
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3566
5796
 
3567
5797
  const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
3568
5798
 
@@ -3650,8 +5880,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3650
5880
  const int8_t * restrict q8 = y[i].qs;
3651
5881
 
3652
5882
  #if QK_K == 256
3653
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3654
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5883
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5884
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3655
5885
 
3656
5886
  memcpy(utmp, x[i].scales, 12);
3657
5887
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -3732,8 +5962,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3732
5962
 
3733
5963
  for (int i = 0; i < nb; ++i) {
3734
5964
 
3735
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3736
- const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
5965
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5966
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
3737
5967
 
3738
5968
  const uint8_t * restrict q5 = x[i].qs;
3739
5969
  const int8_t * restrict q8 = y[i].qs;
@@ -3837,8 +6067,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3837
6067
  const uint8_t * restrict hm = x[i].qh;
3838
6068
  const int8_t * restrict q8 = y[i].qs;
3839
6069
 
3840
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
3841
- const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
6070
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6071
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
3842
6072
 
3843
6073
  vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
3844
6074
  vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
@@ -3960,9 +6190,9 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3960
6190
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
3961
6191
  q8 += 8; a += 8;
3962
6192
  }
3963
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
6193
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
3964
6194
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
3965
- const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
6195
+ const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
3966
6196
  sumf -= dmin * sumi;
3967
6197
  }
3968
6198
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -4060,7 +6290,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
4060
6290
  const uint8_t * restrict q5 = x[i].qs;
4061
6291
  const int8_t * restrict q8 = y[i].qs;
4062
6292
 
4063
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6293
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4064
6294
 
4065
6295
  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
4066
6296
 
@@ -4106,7 +6336,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
4106
6336
  const uint8_t * restrict q5 = x[i].qs;
4107
6337
  const int8_t * restrict q8 = y[i].qs;
4108
6338
 
4109
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6339
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4110
6340
 
4111
6341
  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
4112
6342
 
@@ -4243,7 +6473,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
4243
6473
  for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16);
4244
6474
  }
4245
6475
 
4246
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6476
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4247
6477
  const int8_t * restrict sc = x[i].scales;
4248
6478
 
4249
6479
  for (int j = 0; j < QK_K/16; ++j) {
@@ -4286,7 +6516,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4286
6516
 
4287
6517
  for (int i = 0; i < nb; ++i) {
4288
6518
 
4289
- const float d_all = ggml_fp16_to_fp32(x[i].d);
6519
+ const float d_all = GGML_FP16_TO_FP32(x[i].d);
4290
6520
 
4291
6521
  const uint8_t * restrict q6 = x[i].ql;
4292
6522
  const uint8_t * restrict qh = x[i].qh;
@@ -4418,7 +6648,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4418
6648
 
4419
6649
  for (int i = 0; i < nb; ++i) {
4420
6650
 
4421
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6651
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4422
6652
 
4423
6653
  const uint8_t * restrict q4 = x[i].ql;
4424
6654
  const uint8_t * restrict qh = x[i].qh;
@@ -4498,7 +6728,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4498
6728
 
4499
6729
  for (int i = 0; i < nb; ++i) {
4500
6730
 
4501
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
6731
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4502
6732
 
4503
6733
  const uint8_t * restrict q4 = x[i].ql;
4504
6734
  const uint8_t * restrict qh = x[i].qh;
@@ -4610,7 +6840,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4610
6840
  float sumf = 0;
4611
6841
  for (int i = 0; i < nb; ++i) {
4612
6842
 
4613
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
6843
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4614
6844
 
4615
6845
  const uint8_t * restrict q6 = x[i].ql;
4616
6846
  const uint8_t * restrict qh = x[i].qh;
@@ -4727,7 +6957,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4727
6957
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
4728
6958
  q8 += 8; a += 8;
4729
6959
  }
4730
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
6960
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
4731
6961
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
4732
6962
  }
4733
6963
  for (int l = 0; l < 8; ++l) sumf += sums[l];
@@ -4825,7 +7055,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4825
7055
 
4826
7056
  for (int i = 0; i < nb; ++i) {
4827
7057
 
4828
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
7058
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4829
7059
 
4830
7060
  const uint8_t * restrict q4 = x[i].ql;
4831
7061
  const uint8_t * restrict qh = x[i].qh;
@@ -4882,7 +7112,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4882
7112
 
4883
7113
  for (int i = 0; i < nb; ++i) {
4884
7114
 
4885
- const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
7115
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
4886
7116
 
4887
7117
  const uint8_t * restrict q4 = x[i].ql;
4888
7118
  const uint8_t * restrict qh = x[i].qh;
@@ -5041,7 +7271,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
5041
7271
  for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
5042
7272
  q8 += 8; a += 8;
5043
7273
  }
5044
- const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
7274
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
5045
7275
  for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
5046
7276
  }
5047
7277
  for (int l = 0; l < 8; ++l) sumf += sums[l];