llama_cpp 0.9.0 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,10 +1,8 @@
1
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
2
+ #define _USE_MATH_DEFINES // For M_PI on MSVC
2
3
 
3
- #include "ggml.h"
4
-
5
- #ifdef GGML_USE_K_QUANTS
6
- #include "k_quants.h"
7
- #endif
4
+ #include "ggml-impl.h"
5
+ #include "ggml-quants.h"
8
6
 
9
7
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
8
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -30,18 +28,6 @@
30
28
  #include <unistd.h>
31
29
  #endif
32
30
 
33
- // static_assert should be a #define, but if it's not,
34
- // fall back to the _Static_assert C11 keyword.
35
- // if C99 - static_assert is noop
36
- // ref: https://stackoverflow.com/a/53923785/4039976
37
- #ifndef static_assert
38
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
39
- #define static_assert(cond, msg) _Static_assert(cond, msg)
40
- #else
41
- #define static_assert(cond, msg) struct global_scope_noop_trick
42
- #endif
43
- #endif
44
-
45
31
  #if defined(_MSC_VER)
46
32
  // disable "possible loss of data" to avoid hundreds of casts
47
33
  // we should just be careful :)
@@ -109,23 +95,11 @@ typedef void * thread_ret_t;
109
95
  #include <unistd.h>
110
96
 
111
97
  #endif
98
+
112
99
  #ifdef GGML_USE_CPU_HBM
113
100
  #include <hbwmalloc.h>
114
101
  #endif
115
102
 
116
- // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
117
- #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
118
- #ifndef __FMA__
119
- #define __FMA__
120
- #endif
121
- #ifndef __F16C__
122
- #define __F16C__
123
- #endif
124
- #ifndef __SSE3__
125
- #define __SSE3__
126
- #endif
127
- #endif
128
-
129
103
  /*#define GGML_PERF*/
130
104
  #define GGML_DEBUG 0
131
105
  #define GGML_GELU_FP16
@@ -251,228 +225,27 @@ inline static void * ggml_aligned_malloc(size_t size) {
251
225
  #include "ggml-opencl.h"
252
226
  #endif
253
227
 
254
- #undef MIN
255
- #undef MAX
256
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
257
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
258
-
259
228
  // floating point type used to accumulate sums
260
229
  typedef double ggml_float;
261
230
 
262
- // 16-bit float
263
- // on Arm, we use __fp16
264
- // on x86, we use uint16_t
265
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
266
-
267
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
268
- //
269
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
270
- //
271
- #include <arm_neon.h>
272
-
273
- #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
274
- #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
275
-
276
- #define GGML_FP16_TO_FP32(x) ((float) (x))
277
- #define GGML_FP32_TO_FP16(x) (x)
278
-
279
- #else
280
-
281
- #ifdef __wasm_simd128__
282
- #include <wasm_simd128.h>
283
- #else
284
- #ifdef __POWER9_VECTOR__
285
- #include <altivec.h>
286
- #undef bool
287
- #define bool _Bool
288
- #else
289
- #if defined(_MSC_VER) || defined(__MINGW32__)
290
- #include <intrin.h>
291
- #else
292
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
293
- #if !defined(__riscv)
294
- #include <immintrin.h>
295
- #endif
296
- #endif
297
- #endif
298
- #endif
299
- #endif
300
-
301
- #ifdef __riscv_v_intrinsic
302
- #include <riscv_vector.h>
303
- #endif
304
-
305
- #ifdef __F16C__
306
-
307
- #ifdef _MSC_VER
308
- #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
309
- #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
310
- #else
311
- #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
312
- #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
313
- #endif
314
-
315
- #elif defined(__POWER9_VECTOR__)
316
-
317
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
318
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
319
- /* the inline asm below is about 12% faster than the lookup method */
320
- #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
321
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
322
-
323
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
324
- register float f;
325
- register double d;
326
- __asm__(
327
- "mtfprd %0,%2\n"
328
- "xscvhpdp %0,%0\n"
329
- "frsp %1,%0\n" :
330
- /* temp */ "=d"(d),
331
- /* out */ "=f"(f):
332
- /* in */ "r"(h));
333
- return f;
334
- }
335
-
336
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
337
- register double d;
338
- register ggml_fp16_t r;
339
- __asm__( /* xscvdphp can work on double or single precision */
340
- "xscvdphp %0,%2\n"
341
- "mffprd %1,%0\n" :
342
- /* temp */ "=d"(d),
343
- /* out */ "=r"(r):
344
- /* in */ "f"(f));
345
- return r;
346
- }
347
-
348
- #else
349
-
350
- // FP16 <-> FP32
351
- // ref: https://github.com/Maratyszcza/FP16
352
-
353
- static inline float fp32_from_bits(uint32_t w) {
354
- union {
355
- uint32_t as_bits;
356
- float as_value;
357
- } fp32;
358
- fp32.as_bits = w;
359
- return fp32.as_value;
360
- }
361
-
362
- static inline uint32_t fp32_to_bits(float f) {
363
- union {
364
- float as_value;
365
- uint32_t as_bits;
366
- } fp32;
367
- fp32.as_value = f;
368
- return fp32.as_bits;
369
- }
370
-
371
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
372
- const uint32_t w = (uint32_t) h << 16;
373
- const uint32_t sign = w & UINT32_C(0x80000000);
374
- const uint32_t two_w = w + w;
375
-
376
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
377
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
378
- const float exp_scale = 0x1.0p-112f;
379
- #else
380
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
381
- #endif
382
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
383
-
384
- const uint32_t magic_mask = UINT32_C(126) << 23;
385
- const float magic_bias = 0.5f;
386
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
387
-
388
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
389
- const uint32_t result = sign |
390
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
391
- return fp32_from_bits(result);
392
- }
393
-
394
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
395
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
396
- const float scale_to_inf = 0x1.0p+112f;
397
- const float scale_to_zero = 0x1.0p-110f;
398
- #else
399
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
400
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
401
- #endif
402
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
403
-
404
- const uint32_t w = fp32_to_bits(f);
405
- const uint32_t shl1_w = w + w;
406
- const uint32_t sign = w & UINT32_C(0x80000000);
407
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
408
- if (bias < UINT32_C(0x71000000)) {
409
- bias = UINT32_C(0x71000000);
410
- }
411
-
412
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
413
- const uint32_t bits = fp32_to_bits(base);
414
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
415
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
416
- const uint32_t nonsign = exp_bits + mantissa_bits;
417
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
418
- }
419
-
420
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
421
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
422
-
423
- #endif // __F16C__
424
-
425
- #endif // __ARM_NEON
426
-
427
231
  //
428
232
  // global data
429
233
  //
430
234
 
431
235
  // precomputed gelu table for f16 (128 KB)
432
- static ggml_fp16_t table_gelu_f16[1 << 16];
236
+ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
433
237
 
434
238
  // precomputed quick gelu table for f16 (128 KB)
435
- static ggml_fp16_t table_gelu_quick_f16[1 << 16];
239
+ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
436
240
 
437
241
  // precomputed silu table for f16 (128 KB)
438
- static ggml_fp16_t table_silu_f16[1 << 16];
242
+ static ggml_fp16_t ggml_table_silu_f16[1 << 16];
439
243
 
440
244
  // precomputed exp table for f16 (128 KB)
441
- static ggml_fp16_t table_exp_f16[1 << 16];
442
-
443
- // precomputed f32 table for f16 (256 KB)
444
- static float table_f32_f16[1 << 16];
445
-
446
- #if defined(__ARM_NEON) || defined(__wasm_simd128__)
447
- #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
448
- #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
449
- #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
450
- #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
451
- #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
452
- #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
453
- #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
454
- #define B8(c,s ) B7(c,s, c), B7(c,s, s)
455
-
456
- // precomputed tables for expanding 8bits to 8 bytes:
457
- static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
458
- static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
459
- #endif
460
-
461
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
462
- // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
463
- // This is also true for POWER9.
464
- #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
465
-
466
- inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
467
- uint16_t s;
468
- memcpy(&s, &f, sizeof(uint16_t));
469
- return table_f32_f16[s];
470
- }
471
-
472
- #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
473
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
245
+ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
474
246
 
475
- #endif
247
+ // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
248
+ float ggml_table_f32_f16[1 << 16];
476
249
 
477
250
  // note: do not use these inside ggml.c
478
251
  // these are meant to be used via the ggml.h API
@@ -587,3071 +360,816 @@ int64_t ggml_cycles_per_ms(void) {
587
360
 
588
361
  static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
589
362
 
590
- //
591
- // quantization
592
- //
363
+ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
364
+ static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
593
365
 
594
- #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
595
-
596
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
597
- // multiply int8_t, add results pairwise twice
598
- static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
599
- // Get absolute values of x vectors
600
- const __m128i ax = _mm_sign_epi8(x, x);
601
- // Sign the values of the y vectors
602
- const __m128i sy = _mm_sign_epi8(y, x);
603
- // Perform multiplication and create 16-bit values
604
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
605
- const __m128i ones = _mm_set1_epi16(1);
606
- return _mm_madd_epi16(ones, dot);
607
- }
608
-
609
- #if __AVX__ || __AVX2__ || __AVX512F__
610
- // horizontally add 8 floats
611
- static inline float hsum_float_8(const __m256 x) {
612
- __m128 res = _mm256_extractf128_ps(x, 1);
613
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
614
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
615
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
616
- return _mm_cvtss_f32(res);
617
- }
618
-
619
- // horizontally add 8 int32_t
620
- static inline int hsum_i32_8(const __m256i a) {
621
- const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
622
- const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
623
- const __m128i sum64 = _mm_add_epi32(hi64, sum128);
624
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
625
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
626
- }
627
-
628
- // horizontally add 4 int32_t
629
- static inline int hsum_i32_4(const __m128i a) {
630
- const __m128i hi64 = _mm_unpackhi_epi64(a, a);
631
- const __m128i sum64 = _mm_add_epi32(hi64, a);
632
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
633
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
634
- }
635
-
636
- #if defined(__AVX2__) || defined(__AVX512F__)
637
- // spread 32 bits to 32 bytes { 0x00, 0xFF }
638
- static inline __m256i bytes_from_bits_32(const uint8_t * x) {
639
- uint32_t x32;
640
- memcpy(&x32, x, sizeof(uint32_t));
641
- const __m256i shuf_mask = _mm256_set_epi64x(
642
- 0x0303030303030303, 0x0202020202020202,
643
- 0x0101010101010101, 0x0000000000000000);
644
- __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
645
- const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
646
- bytes = _mm256_or_si256(bytes, bit_mask);
647
- return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
648
- }
649
-
650
- // Unpack 32 4-bit fields into 32 bytes
651
- // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
652
- static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
653
- {
654
- const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
655
- const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
656
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
657
- return _mm256_and_si256(lowMask, bytes);
658
- }
659
-
660
- // add int16_t pairwise and return as float vector
661
- static inline __m256 sum_i16_pairs_float(const __m256i x) {
662
- const __m256i ones = _mm256_set1_epi16(1);
663
- const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
664
- return _mm256_cvtepi32_ps(summed_pairs);
665
- }
666
-
667
- static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
668
- #if __AVXVNNI__
669
- const __m256i zero = _mm256_setzero_si256();
670
- const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
671
- return _mm256_cvtepi32_ps(summed_pairs);
672
- #else
673
- // Perform multiplication and create 16-bit values
674
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
675
- return sum_i16_pairs_float(dot);
676
- #endif
677
- }
366
+ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
367
+ [GGML_TYPE_I8] = {
368
+ .type_name = "i8",
369
+ .blck_size = 1,
370
+ .type_size = sizeof(int8_t),
371
+ .is_quantized = false,
372
+ },
373
+ [GGML_TYPE_I16] = {
374
+ .type_name = "i16",
375
+ .blck_size = 1,
376
+ .type_size = sizeof(int16_t),
377
+ .is_quantized = false,
378
+ },
379
+ [GGML_TYPE_I32] = {
380
+ .type_name = "i32",
381
+ .blck_size = 1,
382
+ .type_size = sizeof(int32_t),
383
+ .is_quantized = false,
384
+ },
385
+ [GGML_TYPE_F32] = {
386
+ .type_name = "f32",
387
+ .blck_size = 1,
388
+ .type_size = sizeof(float),
389
+ .is_quantized = false,
390
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
391
+ .vec_dot_type = GGML_TYPE_F32,
392
+ },
393
+ [GGML_TYPE_F16] = {
394
+ .type_name = "f16",
395
+ .blck_size = 1,
396
+ .type_size = sizeof(ggml_fp16_t),
397
+ .is_quantized = false,
398
+ .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
399
+ .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
400
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
401
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
402
+ .vec_dot_type = GGML_TYPE_F16,
403
+ },
404
+ [GGML_TYPE_Q4_0] = {
405
+ .type_name = "q4_0",
406
+ .blck_size = QK4_0,
407
+ .type_size = sizeof(block_q4_0),
408
+ .is_quantized = true,
409
+ .to_float = (ggml_to_float_t) dequantize_row_q4_0,
410
+ .from_float = quantize_row_q4_0,
411
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
412
+ .vec_dot = ggml_vec_dot_q4_0_q8_0,
413
+ .vec_dot_type = GGML_TYPE_Q8_0,
414
+ },
415
+ [GGML_TYPE_Q4_1] = {
416
+ .type_name = "q4_1",
417
+ .blck_size = QK4_1,
418
+ .type_size = sizeof(block_q4_1),
419
+ .is_quantized = true,
420
+ .to_float = (ggml_to_float_t) dequantize_row_q4_1,
421
+ .from_float = quantize_row_q4_1,
422
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
423
+ .vec_dot = ggml_vec_dot_q4_1_q8_1,
424
+ .vec_dot_type = GGML_TYPE_Q8_1,
425
+ },
426
+ [4] = { // GGML_TYPE_Q4_2
427
+ .type_name = "DEPRECATED",
428
+ .blck_size = 0,
429
+ .type_size = 0,
430
+ .is_quantized = false,
431
+ .to_float = NULL,
432
+ .from_float = NULL,
433
+ .from_float_reference = NULL,
434
+ .vec_dot = NULL,
435
+ .vec_dot_type = GGML_TYPE_COUNT,
436
+ },
437
+ [5] = { // GGML_TYPE_Q4_3
438
+ .type_name = "DEPRECATED",
439
+ .blck_size = 0,
440
+ .type_size = 0,
441
+ .is_quantized = false,
442
+ .to_float = NULL,
443
+ .from_float = NULL,
444
+ .from_float_reference = NULL,
445
+ .vec_dot = NULL,
446
+ .vec_dot_type = GGML_TYPE_COUNT,
447
+ },
448
+ [GGML_TYPE_Q5_0] = {
449
+ .type_name = "q5_0",
450
+ .blck_size = QK5_0,
451
+ .type_size = sizeof(block_q5_0),
452
+ .is_quantized = true,
453
+ .to_float = (ggml_to_float_t) dequantize_row_q5_0,
454
+ .from_float = quantize_row_q5_0,
455
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
456
+ .vec_dot = ggml_vec_dot_q5_0_q8_0,
457
+ .vec_dot_type = GGML_TYPE_Q8_0,
458
+ },
459
+ [GGML_TYPE_Q5_1] = {
460
+ .type_name = "q5_1",
461
+ .blck_size = QK5_1,
462
+ .type_size = sizeof(block_q5_1),
463
+ .is_quantized = true,
464
+ .to_float = (ggml_to_float_t) dequantize_row_q5_1,
465
+ .from_float = quantize_row_q5_1,
466
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
467
+ .vec_dot = ggml_vec_dot_q5_1_q8_1,
468
+ .vec_dot_type = GGML_TYPE_Q8_1,
469
+ },
470
+ [GGML_TYPE_Q8_0] = {
471
+ .type_name = "q8_0",
472
+ .blck_size = QK8_0,
473
+ .type_size = sizeof(block_q8_0),
474
+ .is_quantized = true,
475
+ .to_float = (ggml_to_float_t) dequantize_row_q8_0,
476
+ .from_float = quantize_row_q8_0,
477
+ .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
478
+ .vec_dot = ggml_vec_dot_q8_0_q8_0,
479
+ .vec_dot_type = GGML_TYPE_Q8_0,
480
+ },
481
+ [GGML_TYPE_Q8_1] = {
482
+ .type_name = "q8_1",
483
+ .blck_size = QK8_1,
484
+ .type_size = sizeof(block_q8_1),
485
+ .is_quantized = true,
486
+ .from_float = quantize_row_q8_1,
487
+ .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
488
+ .vec_dot_type = GGML_TYPE_Q8_1,
489
+ },
490
+ [GGML_TYPE_Q2_K] = {
491
+ .type_name = "q2_K",
492
+ .blck_size = QK_K,
493
+ .type_size = sizeof(block_q2_K),
494
+ .is_quantized = true,
495
+ .to_float = (ggml_to_float_t) dequantize_row_q2_K,
496
+ .from_float = quantize_row_q2_K,
497
+ .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
498
+ .vec_dot = ggml_vec_dot_q2_K_q8_K,
499
+ .vec_dot_type = GGML_TYPE_Q8_K,
500
+ },
501
+ [GGML_TYPE_Q3_K] = {
502
+ .type_name = "q3_K",
503
+ .blck_size = QK_K,
504
+ .type_size = sizeof(block_q3_K),
505
+ .is_quantized = true,
506
+ .to_float = (ggml_to_float_t) dequantize_row_q3_K,
507
+ .from_float = quantize_row_q3_K,
508
+ .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
509
+ .vec_dot = ggml_vec_dot_q3_K_q8_K,
510
+ .vec_dot_type = GGML_TYPE_Q8_K,
511
+ },
512
+ [GGML_TYPE_Q4_K] = {
513
+ .type_name = "q4_K",
514
+ .blck_size = QK_K,
515
+ .type_size = sizeof(block_q4_K),
516
+ .is_quantized = true,
517
+ .to_float = (ggml_to_float_t) dequantize_row_q4_K,
518
+ .from_float = quantize_row_q4_K,
519
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
520
+ .vec_dot = ggml_vec_dot_q4_K_q8_K,
521
+ .vec_dot_type = GGML_TYPE_Q8_K,
522
+ },
523
+ [GGML_TYPE_Q5_K] = {
524
+ .type_name = "q5_K",
525
+ .blck_size = QK_K,
526
+ .type_size = sizeof(block_q5_K),
527
+ .is_quantized = true,
528
+ .to_float = (ggml_to_float_t) dequantize_row_q5_K,
529
+ .from_float = quantize_row_q5_K,
530
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
531
+ .vec_dot = ggml_vec_dot_q5_K_q8_K,
532
+ .vec_dot_type = GGML_TYPE_Q8_K,
533
+ },
534
+ [GGML_TYPE_Q6_K] = {
535
+ .type_name = "q6_K",
536
+ .blck_size = QK_K,
537
+ .type_size = sizeof(block_q6_K),
538
+ .is_quantized = true,
539
+ .to_float = (ggml_to_float_t) dequantize_row_q6_K,
540
+ .from_float = quantize_row_q6_K,
541
+ .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
542
+ .vec_dot = ggml_vec_dot_q6_K_q8_K,
543
+ .vec_dot_type = GGML_TYPE_Q8_K,
544
+ },
545
+ [GGML_TYPE_Q8_K] = {
546
+ .type_name = "q8_K",
547
+ .blck_size = QK_K,
548
+ .type_size = sizeof(block_q8_K),
549
+ .is_quantized = true,
550
+ .from_float = quantize_row_q8_K,
551
+ }
552
+ };
678
553
 
679
- // multiply int8_t, add results pairwise twice and return as float vector
680
- static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
681
- #if __AVXVNNIINT8__
682
- const __m256i zero = _mm256_setzero_si256();
683
- const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
684
- return _mm256_cvtepi32_ps(summed_pairs);
685
- #else
686
- // Get absolute values of x vectors
687
- const __m256i ax = _mm256_sign_epi8(x, x);
688
- // Sign the values of the y vectors
689
- const __m256i sy = _mm256_sign_epi8(y, x);
690
- return mul_sum_us8_pairs_float(ax, sy);
691
- #endif
554
+ // For internal test use
555
+ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
556
+ GGML_ASSERT(type < GGML_TYPE_COUNT);
557
+ return type_traits[type];
692
558
  }
693
559
 
694
- static inline __m128i packNibbles( __m256i bytes )
695
- {
696
- // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
697
- #if __AVX512F__
698
- const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
699
- bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
700
- return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
701
- #else
702
- const __m256i lowByte = _mm256_set1_epi16( 0xFF );
703
- __m256i high = _mm256_andnot_si256( lowByte, bytes );
704
- __m256i low = _mm256_and_si256( lowByte, bytes );
705
- high = _mm256_srli_epi16( high, 4 );
706
- bytes = _mm256_or_si256( low, high );
707
-
708
- // Compress uint16_t lanes into bytes
709
- __m128i r0 = _mm256_castsi256_si128( bytes );
710
- __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
711
- return _mm_packus_epi16( r0, r1 );
712
- #endif
713
- }
714
- #elif defined(__AVX__)
715
- // spread 32 bits to 32 bytes { 0x00, 0xFF }
716
- static inline __m256i bytes_from_bits_32(const uint8_t * x) {
717
- uint32_t x32;
718
- memcpy(&x32, x, sizeof(uint32_t));
719
- const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
720
- const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
721
- __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
722
- __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
723
- const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
724
- bytesl = _mm_or_si128(bytesl, bit_mask);
725
- bytesh = _mm_or_si128(bytesh, bit_mask);
726
- bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
727
- bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
728
- return MM256_SET_M128I(bytesh, bytesl);
729
- }
730
-
731
- // Unpack 32 4-bit fields into 32 bytes
732
- // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
733
- static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
734
- {
735
- // Load 16 bytes from memory
736
- __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
737
- __m128i tmph = _mm_srli_epi16(tmpl, 4);
738
- const __m128i lowMask = _mm_set1_epi8(0xF);
739
- tmpl = _mm_and_si128(lowMask, tmpl);
740
- tmph = _mm_and_si128(lowMask, tmph);
741
- return MM256_SET_M128I(tmph, tmpl);
742
- }
743
-
744
- // add int16_t pairwise and return as float vector
745
- static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
746
- const __m128i ones = _mm_set1_epi16(1);
747
- const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
748
- const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
749
- const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
750
- return _mm256_cvtepi32_ps(summed_pairs);
751
- }
752
-
753
- static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
754
- const __m128i axl = _mm256_castsi256_si128(ax);
755
- const __m128i axh = _mm256_extractf128_si256(ax, 1);
756
- const __m128i syl = _mm256_castsi256_si128(sy);
757
- const __m128i syh = _mm256_extractf128_si256(sy, 1);
758
- // Perform multiplication and create 16-bit values
759
- const __m128i dotl = _mm_maddubs_epi16(axl, syl);
760
- const __m128i doth = _mm_maddubs_epi16(axh, syh);
761
- return sum_i16_pairs_float(doth, dotl);
762
- }
763
-
764
- // multiply int8_t, add results pairwise twice and return as float vector
765
- static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
766
- const __m128i xl = _mm256_castsi256_si128(x);
767
- const __m128i xh = _mm256_extractf128_si256(x, 1);
768
- const __m128i yl = _mm256_castsi256_si128(y);
769
- const __m128i yh = _mm256_extractf128_si256(y, 1);
770
- // Get absolute values of x vectors
771
- const __m128i axl = _mm_sign_epi8(xl, xl);
772
- const __m128i axh = _mm_sign_epi8(xh, xh);
773
- // Sign the values of the y vectors
774
- const __m128i syl = _mm_sign_epi8(yl, xl);
775
- const __m128i syh = _mm_sign_epi8(yh, xh);
776
- // Perform multiplication and create 16-bit values
777
- const __m128i dotl = _mm_maddubs_epi16(axl, syl);
778
- const __m128i doth = _mm_maddubs_epi16(axh, syh);
779
- return sum_i16_pairs_float(doth, dotl);
780
- }
781
-
782
- static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
783
- {
784
- // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
785
- const __m128i lowByte = _mm_set1_epi16( 0xFF );
786
- __m128i high = _mm_andnot_si128( lowByte, bytes1 );
787
- __m128i low = _mm_and_si128( lowByte, bytes1 );
788
- high = _mm_srli_epi16( high, 4 );
789
- bytes1 = _mm_or_si128( low, high );
790
- high = _mm_andnot_si128( lowByte, bytes2 );
791
- low = _mm_and_si128( lowByte, bytes2 );
792
- high = _mm_srli_epi16( high, 4 );
793
- bytes2 = _mm_or_si128( low, high );
794
-
795
- return _mm_packus_epi16( bytes1, bytes2);
796
- }
797
- #endif
798
- #elif defined(__SSSE3__)
799
- // horizontally add 4x4 floats
800
- static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
801
- __m128 res_0 =_mm_hadd_ps(a, b);
802
- __m128 res_1 =_mm_hadd_ps(c, d);
803
- __m128 res =_mm_hadd_ps(res_0, res_1);
804
- res =_mm_hadd_ps(res, res);
805
- res =_mm_hadd_ps(res, res);
806
-
807
- return _mm_cvtss_f32(res);
808
- }
809
- #endif // __AVX__ || __AVX2__ || __AVX512F__
810
- #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
811
-
812
- #if defined(__ARM_NEON)
813
-
814
- #if !defined(__aarch64__)
815
-
816
- inline static int32_t vaddvq_s32(int32x4_t v) {
817
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
818
- }
819
-
820
- inline static float vaddvq_f32(float32x4_t v) {
821
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
822
- }
823
-
824
- inline static float vmaxvq_f32(float32x4_t v) {
825
- return
826
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
827
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
828
- }
829
-
830
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
831
- int32x4_t res;
832
-
833
- res[0] = roundf(vgetq_lane_f32(v, 0));
834
- res[1] = roundf(vgetq_lane_f32(v, 1));
835
- res[2] = roundf(vgetq_lane_f32(v, 2));
836
- res[3] = roundf(vgetq_lane_f32(v, 3));
837
-
838
- return res;
839
- }
840
-
841
- #endif
842
- #endif
843
-
844
- #define QK4_0 32
845
- typedef struct {
846
- ggml_fp16_t d; // delta
847
- uint8_t qs[QK4_0 / 2]; // nibbles / quants
848
- } block_q4_0;
849
- static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
850
-
851
- #define QK4_1 32
852
- typedef struct {
853
- ggml_fp16_t d; // delta
854
- ggml_fp16_t m; // min
855
- uint8_t qs[QK4_1 / 2]; // nibbles / quants
856
- } block_q4_1;
857
- static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
858
-
859
- #define QK5_0 32
860
- typedef struct {
861
- ggml_fp16_t d; // delta
862
- uint8_t qh[4]; // 5-th bit of quants
863
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
864
- } block_q5_0;
865
- static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
866
-
867
- #define QK5_1 32
868
- typedef struct {
869
- ggml_fp16_t d; // delta
870
- ggml_fp16_t m; // min
871
- uint8_t qh[4]; // 5-th bit of quants
872
- uint8_t qs[QK5_1 / 2]; // nibbles / quants
873
- } block_q5_1;
874
- static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
875
-
876
- #define QK8_0 32
877
- typedef struct {
878
- ggml_fp16_t d; // delta
879
- int8_t qs[QK8_0]; // quants
880
- } block_q8_0;
881
- static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
882
-
883
- #define QK8_1 32
884
- typedef struct {
885
- float d; // delta
886
- float s; // d * sum(qs[i])
887
- int8_t qs[QK8_1]; // quants
888
- } block_q8_1;
889
- static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
890
-
891
- // reference implementation for deterministic creation of model files
892
- static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
893
- static const int qk = QK4_0;
894
-
895
- assert(k % qk == 0);
896
-
897
- const int nb = k / qk;
898
-
899
- for (int i = 0; i < nb; i++) {
900
- float amax = 0.0f; // absolute max
901
- float max = 0.0f;
902
-
903
- for (int j = 0; j < qk; j++) {
904
- const float v = x[i*qk + j];
905
- if (amax < fabsf(v)) {
906
- amax = fabsf(v);
907
- max = v;
908
- }
909
- }
910
-
911
- const float d = max / -8;
912
- const float id = d ? 1.0f/d : 0.0f;
913
-
914
- y[i].d = GGML_FP32_TO_FP16(d);
915
-
916
- for (int j = 0; j < qk/2; ++j) {
917
- const float x0 = x[i*qk + 0 + j]*id;
918
- const float x1 = x[i*qk + qk/2 + j]*id;
919
-
920
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
921
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
922
-
923
- y[i].qs[j] = xi0;
924
- y[i].qs[j] |= xi1 << 4;
925
- }
926
- }
927
- }
928
-
929
- static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
930
- quantize_row_q4_0_reference(x, y, k);
931
- }
932
-
933
- static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
934
- const int qk = QK4_1;
935
-
936
- assert(k % qk == 0);
937
-
938
- const int nb = k / qk;
939
-
940
- for (int i = 0; i < nb; i++) {
941
- float min = FLT_MAX;
942
- float max = -FLT_MAX;
943
-
944
- for (int j = 0; j < qk; j++) {
945
- const float v = x[i*qk + j];
946
-
947
- if (v < min) min = v;
948
- if (v > max) max = v;
949
- }
950
-
951
- const float d = (max - min) / ((1 << 4) - 1);
952
- const float id = d ? 1.0f/d : 0.0f;
953
-
954
- y[i].d = GGML_FP32_TO_FP16(d);
955
- y[i].m = GGML_FP32_TO_FP16(min);
956
-
957
- for (int j = 0; j < qk/2; ++j) {
958
- const float x0 = (x[i*qk + 0 + j] - min)*id;
959
- const float x1 = (x[i*qk + qk/2 + j] - min)*id;
960
-
961
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
962
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
963
-
964
- y[i].qs[j] = xi0;
965
- y[i].qs[j] |= xi1 << 4;
966
- }
967
- }
968
- }
969
-
970
- static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
971
- quantize_row_q4_1_reference(x, y, k);
972
- }
973
-
974
- static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
975
- static const int qk = QK5_0;
976
-
977
- assert(k % qk == 0);
978
-
979
- const int nb = k / qk;
980
-
981
- for (int i = 0; i < nb; i++) {
982
- float amax = 0.0f; // absolute max
983
- float max = 0.0f;
984
-
985
- for (int j = 0; j < qk; j++) {
986
- const float v = x[i*qk + j];
987
- if (amax < fabsf(v)) {
988
- amax = fabsf(v);
989
- max = v;
990
- }
991
- }
992
-
993
- const float d = max / -16;
994
- const float id = d ? 1.0f/d : 0.0f;
995
-
996
- y[i].d = GGML_FP32_TO_FP16(d);
997
-
998
- uint32_t qh = 0;
999
-
1000
- for (int j = 0; j < qk/2; ++j) {
1001
- const float x0 = x[i*qk + 0 + j]*id;
1002
- const float x1 = x[i*qk + qk/2 + j]*id;
1003
-
1004
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1005
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
1006
-
1007
- y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1008
-
1009
- // get the 5-th bit and store it in qh at the right position
1010
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1011
- qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1012
- }
1013
-
1014
- memcpy(&y[i].qh, &qh, sizeof(qh));
1015
- }
1016
- }
1017
-
1018
- static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
1019
- quantize_row_q5_0_reference(x, y, k);
1020
- }
1021
-
1022
- static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1023
- const int qk = QK5_1;
1024
-
1025
- assert(k % qk == 0);
1026
-
1027
- const int nb = k / qk;
1028
-
1029
- for (int i = 0; i < nb; i++) {
1030
- float min = FLT_MAX;
1031
- float max = -FLT_MAX;
1032
-
1033
- for (int j = 0; j < qk; j++) {
1034
- const float v = x[i*qk + j];
1035
-
1036
- if (v < min) min = v;
1037
- if (v > max) max = v;
1038
- }
1039
-
1040
- const float d = (max - min) / ((1 << 5) - 1);
1041
- const float id = d ? 1.0f/d : 0.0f;
1042
-
1043
- y[i].d = GGML_FP32_TO_FP16(d);
1044
- y[i].m = GGML_FP32_TO_FP16(min);
1045
-
1046
- uint32_t qh = 0;
1047
-
1048
- for (int j = 0; j < qk/2; ++j) {
1049
- const float x0 = (x[i*qk + 0 + j] - min)*id;
1050
- const float x1 = (x[i*qk + qk/2 + j] - min)*id;
1051
-
1052
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
1053
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
1054
-
1055
- y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1056
-
1057
- // get the 5-th bit and store it in qh at the right position
1058
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1059
- qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1060
- }
1061
-
1062
- memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1063
- }
1064
- }
1065
-
1066
- static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
1067
- quantize_row_q5_1_reference(x, y, k);
1068
- }
1069
-
1070
- // reference implementation for deterministic creation of model files
1071
- static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1072
- assert(k % QK8_0 == 0);
1073
- const int nb = k / QK8_0;
1074
-
1075
- for (int i = 0; i < nb; i++) {
1076
- float amax = 0.0f; // absolute max
1077
-
1078
- for (int j = 0; j < QK8_0; j++) {
1079
- const float v = x[i*QK8_0 + j];
1080
- amax = MAX(amax, fabsf(v));
1081
- }
1082
-
1083
- const float d = amax / ((1 << 7) - 1);
1084
- const float id = d ? 1.0f/d : 0.0f;
1085
-
1086
- y[i].d = GGML_FP32_TO_FP16(d);
1087
-
1088
- for (int j = 0; j < QK8_0; ++j) {
1089
- const float x0 = x[i*QK8_0 + j]*id;
1090
-
1091
- y[i].qs[j] = roundf(x0);
1092
- }
1093
- }
1094
- }
1095
-
1096
- static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1097
- assert(QK8_0 == 32);
1098
- assert(k % QK8_0 == 0);
1099
- const int nb = k / QK8_0;
1100
-
1101
- block_q8_0 * restrict y = vy;
1102
-
1103
- #if defined(__ARM_NEON)
1104
- for (int i = 0; i < nb; i++) {
1105
- float32x4_t srcv [8];
1106
- float32x4_t asrcv[8];
1107
- float32x4_t amaxv[8];
1108
-
1109
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
1110
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
1111
-
1112
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
1113
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
1114
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
1115
-
1116
- const float amax = vmaxvq_f32(amaxv[0]);
1117
-
1118
- const float d = amax / ((1 << 7) - 1);
1119
- const float id = d ? 1.0f/d : 0.0f;
1120
-
1121
- y[i].d = GGML_FP32_TO_FP16(d);
1122
-
1123
- for (int j = 0; j < 8; j++) {
1124
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
1125
- const int32x4_t vi = vcvtnq_s32_f32(v);
1126
-
1127
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
1128
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
1129
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
1130
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
1131
- }
1132
- }
1133
- #elif defined(__wasm_simd128__)
1134
- for (int i = 0; i < nb; i++) {
1135
- v128_t srcv [8];
1136
- v128_t asrcv[8];
1137
- v128_t amaxv[8];
1138
-
1139
- for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
1140
- for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
1141
-
1142
- for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
1143
- for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
1144
- for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
1145
-
1146
- const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
1147
- wasm_f32x4_extract_lane(amaxv[0], 1)),
1148
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
1149
- wasm_f32x4_extract_lane(amaxv[0], 3)));
1150
-
1151
- const float d = amax / ((1 << 7) - 1);
1152
- const float id = d ? 1.0f/d : 0.0f;
1153
-
1154
- y[i].d = GGML_FP32_TO_FP16(d);
1155
-
1156
- for (int j = 0; j < 8; j++) {
1157
- const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
1158
- const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
1159
-
1160
- y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
1161
- y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
1162
- y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
1163
- y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
1164
- }
1165
- }
1166
- #elif defined(__AVX2__) || defined(__AVX__)
1167
- for (int i = 0; i < nb; i++) {
1168
- // Load elements into 4 AVX vectors
1169
- __m256 v0 = _mm256_loadu_ps( x );
1170
- __m256 v1 = _mm256_loadu_ps( x + 8 );
1171
- __m256 v2 = _mm256_loadu_ps( x + 16 );
1172
- __m256 v3 = _mm256_loadu_ps( x + 24 );
1173
- x += 32;
1174
-
1175
- // Compute max(abs(e)) for the block
1176
- const __m256 signBit = _mm256_set1_ps( -0.0f );
1177
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1178
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1179
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1180
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1181
-
1182
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1183
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1184
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1185
- const float maxScalar = _mm_cvtss_f32( max4 );
1186
-
1187
- // Quantize these floats
1188
- const float d = maxScalar / 127.f;
1189
- y[i].d = GGML_FP32_TO_FP16(d);
1190
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1191
- const __m256 mul = _mm256_set1_ps( id );
1192
-
1193
- // Apply the multiplier
1194
- v0 = _mm256_mul_ps( v0, mul );
1195
- v1 = _mm256_mul_ps( v1, mul );
1196
- v2 = _mm256_mul_ps( v2, mul );
1197
- v3 = _mm256_mul_ps( v3, mul );
1198
-
1199
- // Round to nearest integer
1200
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1201
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1202
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1203
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1204
-
1205
- // Convert floats to integers
1206
- __m256i i0 = _mm256_cvtps_epi32( v0 );
1207
- __m256i i1 = _mm256_cvtps_epi32( v1 );
1208
- __m256i i2 = _mm256_cvtps_epi32( v2 );
1209
- __m256i i3 = _mm256_cvtps_epi32( v3 );
1210
-
1211
- #if defined(__AVX2__)
1212
- // Convert int32 to int16
1213
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1214
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1215
- // Convert int16 to int8
1216
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1217
-
1218
- // We got our precious signed bytes, but the order is now wrong
1219
- // These AVX2 pack instructions process 16-byte pieces independently
1220
- // The following instruction is fixing the order
1221
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1222
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
1223
-
1224
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1225
- #else
1226
- // Since we don't have in AVX some necessary functions,
1227
- // we split the registers in half and call AVX2 analogs from SSE
1228
- __m128i ni0 = _mm256_castsi256_si128( i0 );
1229
- __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1230
- __m128i ni2 = _mm256_castsi256_si128( i1 );
1231
- __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1232
- __m128i ni4 = _mm256_castsi256_si128( i2 );
1233
- __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1234
- __m128i ni6 = _mm256_castsi256_si128( i3 );
1235
- __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1236
-
1237
- // Convert int32 to int16
1238
- ni0 = _mm_packs_epi32( ni0, ni1 );
1239
- ni2 = _mm_packs_epi32( ni2, ni3 );
1240
- ni4 = _mm_packs_epi32( ni4, ni5 );
1241
- ni6 = _mm_packs_epi32( ni6, ni7 );
1242
- // Convert int16 to int8
1243
- ni0 = _mm_packs_epi16( ni0, ni2 );
1244
- ni4 = _mm_packs_epi16( ni4, ni6 );
1245
-
1246
- _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1247
- _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1248
- #endif
1249
- }
1250
- #elif defined(__riscv_v_intrinsic)
1251
-
1252
- size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1253
-
1254
- for (int i = 0; i < nb; i++) {
1255
- // load elements
1256
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1257
-
1258
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1259
- vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1260
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1261
- float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1262
-
1263
- const float d = amax / ((1 << 7) - 1);
1264
- const float id = d ? 1.0f/d : 0.0f;
1265
-
1266
- y[i].d = GGML_FP32_TO_FP16(d);
1267
-
1268
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1269
-
1270
- // convert to integer
1271
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1272
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1273
-
1274
- // store result
1275
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1276
- }
1277
- #else
1278
- // scalar
1279
- quantize_row_q8_0_reference(x, y, k);
1280
- #endif
1281
- }
1282
-
1283
- // reference implementation for deterministic creation of model files
1284
- static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1285
- assert(QK8_1 == 32);
1286
- assert(k % QK8_1 == 0);
1287
- const int nb = k / QK8_1;
1288
-
1289
- for (int i = 0; i < nb; i++) {
1290
- float amax = 0.0f; // absolute max
1291
-
1292
- for (int j = 0; j < QK8_1; j++) {
1293
- const float v = x[i*QK8_1 + j];
1294
- amax = MAX(amax, fabsf(v));
1295
- }
1296
-
1297
- const float d = amax / ((1 << 7) - 1);
1298
- const float id = d ? 1.0f/d : 0.0f;
1299
-
1300
- y[i].d = d;
1301
-
1302
- int sum = 0;
1303
-
1304
- for (int j = 0; j < QK8_1/2; ++j) {
1305
- const float v0 = x[i*QK8_1 + j]*id;
1306
- const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
1307
-
1308
- y[i].qs[ j] = roundf(v0);
1309
- y[i].qs[QK8_1/2 + j] = roundf(v1);
1310
-
1311
- sum += y[i].qs[ j];
1312
- sum += y[i].qs[QK8_1/2 + j];
1313
- }
1314
-
1315
- y[i].s = sum*d;
1316
- }
1317
- }
1318
-
1319
- static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1320
- assert(k % QK8_1 == 0);
1321
- const int nb = k / QK8_1;
1322
-
1323
- block_q8_1 * restrict y = vy;
1324
-
1325
- #if defined(__ARM_NEON)
1326
- for (int i = 0; i < nb; i++) {
1327
- float32x4_t srcv [8];
1328
- float32x4_t asrcv[8];
1329
- float32x4_t amaxv[8];
1330
-
1331
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
1332
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
1333
-
1334
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
1335
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
1336
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
1337
-
1338
- const float amax = vmaxvq_f32(amaxv[0]);
1339
-
1340
- const float d = amax / ((1 << 7) - 1);
1341
- const float id = d ? 1.0f/d : 0.0f;
1342
-
1343
- y[i].d = d;
1344
-
1345
- int32x4_t accv = vdupq_n_s32(0);
1346
-
1347
- for (int j = 0; j < 8; j++) {
1348
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
1349
- const int32x4_t vi = vcvtnq_s32_f32(v);
1350
-
1351
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
1352
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
1353
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
1354
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
1355
-
1356
- accv = vaddq_s32(accv, vi);
1357
- }
1358
-
1359
- y[i].s = d * vaddvq_s32(accv);
1360
- }
1361
- #elif defined(__wasm_simd128__)
1362
- for (int i = 0; i < nb; i++) {
1363
- v128_t srcv [8];
1364
- v128_t asrcv[8];
1365
- v128_t amaxv[8];
1366
-
1367
- for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
1368
- for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
1369
-
1370
- for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
1371
- for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
1372
- for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
1373
-
1374
- const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
1375
- wasm_f32x4_extract_lane(amaxv[0], 1)),
1376
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
1377
- wasm_f32x4_extract_lane(amaxv[0], 3)));
1378
-
1379
- const float d = amax / ((1 << 7) - 1);
1380
- const float id = d ? 1.0f/d : 0.0f;
1381
-
1382
- y[i].d = d;
1383
-
1384
- v128_t accv = wasm_i32x4_splat(0);
1385
-
1386
- for (int j = 0; j < 8; j++) {
1387
- const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
1388
- const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
1389
-
1390
- y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
1391
- y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
1392
- y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
1393
- y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
1394
-
1395
- accv = wasm_i32x4_add(accv, vi);
1396
- }
1397
-
1398
- y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
1399
- wasm_i32x4_extract_lane(accv, 1) +
1400
- wasm_i32x4_extract_lane(accv, 2) +
1401
- wasm_i32x4_extract_lane(accv, 3));
1402
- }
1403
- #elif defined(__AVX2__) || defined(__AVX__)
1404
- for (int i = 0; i < nb; i++) {
1405
- // Load elements into 4 AVX vectors
1406
- __m256 v0 = _mm256_loadu_ps( x );
1407
- __m256 v1 = _mm256_loadu_ps( x + 8 );
1408
- __m256 v2 = _mm256_loadu_ps( x + 16 );
1409
- __m256 v3 = _mm256_loadu_ps( x + 24 );
1410
- x += 32;
1411
-
1412
- // Compute max(abs(e)) for the block
1413
- const __m256 signBit = _mm256_set1_ps( -0.0f );
1414
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1415
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1416
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1417
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1418
-
1419
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1420
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1421
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1422
- const float maxScalar = _mm_cvtss_f32( max4 );
1423
-
1424
- // Quantize these floats
1425
- const float d = maxScalar / 127.f;
1426
- y[i].d = d;
1427
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1428
- const __m256 mul = _mm256_set1_ps( id );
1429
-
1430
- // Apply the multiplier
1431
- v0 = _mm256_mul_ps( v0, mul );
1432
- v1 = _mm256_mul_ps( v1, mul );
1433
- v2 = _mm256_mul_ps( v2, mul );
1434
- v3 = _mm256_mul_ps( v3, mul );
1435
-
1436
- // Round to nearest integer
1437
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1438
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1439
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1440
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1441
-
1442
- // Convert floats to integers
1443
- __m256i i0 = _mm256_cvtps_epi32( v0 );
1444
- __m256i i1 = _mm256_cvtps_epi32( v1 );
1445
- __m256i i2 = _mm256_cvtps_epi32( v2 );
1446
- __m256i i3 = _mm256_cvtps_epi32( v3 );
1447
-
1448
- #if defined(__AVX2__)
1449
- // Compute the sum of the quants and set y[i].s
1450
- y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1451
-
1452
- // Convert int32 to int16
1453
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1454
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1455
- // Convert int16 to int8
1456
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1457
-
1458
- // We got our precious signed bytes, but the order is now wrong
1459
- // These AVX2 pack instructions process 16-byte pieces independently
1460
- // The following instruction is fixing the order
1461
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1462
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
1463
-
1464
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1465
- #else
1466
- // Since we don't have in AVX some necessary functions,
1467
- // we split the registers in half and call AVX2 analogs from SSE
1468
- __m128i ni0 = _mm256_castsi256_si128( i0 );
1469
- __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1470
- __m128i ni2 = _mm256_castsi256_si128( i1 );
1471
- __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1472
- __m128i ni4 = _mm256_castsi256_si128( i2 );
1473
- __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1474
- __m128i ni6 = _mm256_castsi256_si128( i3 );
1475
- __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1476
-
1477
- // Compute the sum of the quants and set y[i].s
1478
- const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1479
- const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1480
- y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
1481
-
1482
- // Convert int32 to int16
1483
- ni0 = _mm_packs_epi32( ni0, ni1 );
1484
- ni2 = _mm_packs_epi32( ni2, ni3 );
1485
- ni4 = _mm_packs_epi32( ni4, ni5 );
1486
- ni6 = _mm_packs_epi32( ni6, ni7 );
1487
- // Convert int16 to int8
1488
- ni0 = _mm_packs_epi16( ni0, ni2 );
1489
- ni4 = _mm_packs_epi16( ni4, ni6 );
1490
-
1491
- _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1492
- _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1493
- #endif
1494
- }
1495
- #elif defined(__riscv_v_intrinsic)
1496
-
1497
- size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1498
-
1499
- for (int i = 0; i < nb; i++) {
1500
- // load elements
1501
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1502
-
1503
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1504
- vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1505
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1506
- float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1507
-
1508
- const float d = amax / ((1 << 7) - 1);
1509
- const float id = d ? 1.0f/d : 0.0f;
1510
-
1511
- y[i].d = d;
1512
-
1513
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1514
-
1515
- // convert to integer
1516
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1517
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1518
-
1519
- // store result
1520
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1521
-
1522
- // compute sum for y[i].s
1523
- vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1524
- vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1525
-
1526
- // set y[i].s
1527
- int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1528
- y[i].s = sum*d;
1529
- }
1530
- #else
1531
- // scalar
1532
- quantize_row_q8_1_reference(x, y, k);
1533
- #endif
1534
- }
1535
-
1536
- static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
1537
- static const int qk = QK4_0;
1538
-
1539
- assert(k % qk == 0);
1540
-
1541
- const int nb = k / qk;
1542
-
1543
- for (int i = 0; i < nb; i++) {
1544
- const float d = GGML_FP16_TO_FP32(x[i].d);
1545
-
1546
- for (int j = 0; j < qk/2; ++j) {
1547
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
1548
- const int x1 = (x[i].qs[j] >> 4) - 8;
1549
-
1550
- y[i*qk + j + 0 ] = x0*d;
1551
- y[i*qk + j + qk/2] = x1*d;
1552
- }
1553
- }
1554
- }
1555
-
1556
- static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
1557
- static const int qk = QK4_1;
1558
-
1559
- assert(k % qk == 0);
1560
-
1561
- const int nb = k / qk;
1562
-
1563
- for (int i = 0; i < nb; i++) {
1564
- const float d = GGML_FP16_TO_FP32(x[i].d);
1565
- const float m = GGML_FP16_TO_FP32(x[i].m);
1566
-
1567
- for (int j = 0; j < qk/2; ++j) {
1568
- const int x0 = (x[i].qs[j] & 0x0F);
1569
- const int x1 = (x[i].qs[j] >> 4);
1570
-
1571
- y[i*qk + j + 0 ] = x0*d + m;
1572
- y[i*qk + j + qk/2] = x1*d + m;
1573
- }
1574
- }
1575
- }
1576
-
1577
- static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
1578
- static const int qk = QK5_0;
1579
-
1580
- assert(k % qk == 0);
1581
-
1582
- const int nb = k / qk;
1583
-
1584
- for (int i = 0; i < nb; i++) {
1585
- const float d = GGML_FP16_TO_FP32(x[i].d);
1586
-
1587
- uint32_t qh;
1588
- memcpy(&qh, x[i].qh, sizeof(qh));
1589
-
1590
- for (int j = 0; j < qk/2; ++j) {
1591
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1592
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1593
-
1594
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
1595
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
1596
-
1597
- y[i*qk + j + 0 ] = x0*d;
1598
- y[i*qk + j + qk/2] = x1*d;
1599
- }
1600
- }
1601
- }
1602
-
1603
- static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
1604
- static const int qk = QK5_1;
1605
-
1606
- assert(k % qk == 0);
1607
-
1608
- const int nb = k / qk;
1609
-
1610
- for (int i = 0; i < nb; i++) {
1611
- const float d = GGML_FP16_TO_FP32(x[i].d);
1612
- const float m = GGML_FP16_TO_FP32(x[i].m);
1613
-
1614
- uint32_t qh;
1615
- memcpy(&qh, x[i].qh, sizeof(qh));
1616
-
1617
- for (int j = 0; j < qk/2; ++j) {
1618
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1619
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1620
-
1621
- const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
1622
- const int x1 = (x[i].qs[j] >> 4) | xh_1;
1623
-
1624
- y[i*qk + j + 0 ] = x0*d + m;
1625
- y[i*qk + j + qk/2] = x1*d + m;
1626
- }
1627
- }
1628
- }
1629
-
1630
- static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
1631
- static const int qk = QK8_0;
1632
-
1633
- assert(k % qk == 0);
1634
-
1635
- const int nb = k / qk;
1636
-
1637
- const block_q8_0 * restrict x = vx;
1638
-
1639
- for (int i = 0; i < nb; i++) {
1640
- const float d = GGML_FP16_TO_FP32(x[i].d);
1641
-
1642
- for (int j = 0; j < qk; ++j) {
1643
- y[i*qk + j] = x[i].qs[j]*d;
1644
- }
1645
- }
1646
- }
1647
-
1648
- static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
1649
- static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
1650
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1651
- static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1652
- static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1653
- static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1654
- static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1655
-
1656
- static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
1657
- [GGML_TYPE_I8] = {
1658
- .type_name = "i8",
1659
- .blck_size = 1,
1660
- .type_size = sizeof(int8_t),
1661
- .is_quantized = false,
1662
- },
1663
- [GGML_TYPE_I16] = {
1664
- .type_name = "i16",
1665
- .blck_size = 1,
1666
- .type_size = sizeof(int16_t),
1667
- .is_quantized = false,
1668
- },
1669
- [GGML_TYPE_I32] = {
1670
- .type_name = "i32",
1671
- .blck_size = 1,
1672
- .type_size = sizeof(int32_t),
1673
- .is_quantized = false,
1674
- },
1675
- [GGML_TYPE_F32] = {
1676
- .type_name = "f32",
1677
- .blck_size = 1,
1678
- .type_size = sizeof(float),
1679
- .is_quantized = false,
1680
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
1681
- .vec_dot_type = GGML_TYPE_F32,
1682
- },
1683
- [GGML_TYPE_F16] = {
1684
- .type_name = "f16",
1685
- .blck_size = 1,
1686
- .type_size = sizeof(ggml_fp16_t),
1687
- .is_quantized = false,
1688
- .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
1689
- .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
1690
- .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
1691
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
1692
- .vec_dot_type = GGML_TYPE_F16,
1693
- },
1694
- [GGML_TYPE_Q4_0] = {
1695
- .type_name = "q4_0",
1696
- .blck_size = QK4_0,
1697
- .type_size = sizeof(block_q4_0),
1698
- .is_quantized = true,
1699
- .to_float = (ggml_to_float_t) dequantize_row_q4_0,
1700
- .from_float = quantize_row_q4_0,
1701
- .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
1702
- .vec_dot = ggml_vec_dot_q4_0_q8_0,
1703
- .vec_dot_type = GGML_TYPE_Q8_0,
1704
- },
1705
- [GGML_TYPE_Q4_1] = {
1706
- .type_name = "q4_1",
1707
- .blck_size = QK4_1,
1708
- .type_size = sizeof(block_q4_1),
1709
- .is_quantized = true,
1710
- .to_float = (ggml_to_float_t) dequantize_row_q4_1,
1711
- .from_float = quantize_row_q4_1,
1712
- .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
1713
- .vec_dot = ggml_vec_dot_q4_1_q8_1,
1714
- .vec_dot_type = GGML_TYPE_Q8_1,
1715
- },
1716
- [GGML_TYPE_Q5_0] = {
1717
- .type_name = "q5_0",
1718
- .blck_size = QK5_0,
1719
- .type_size = sizeof(block_q5_0),
1720
- .is_quantized = true,
1721
- .to_float = (ggml_to_float_t) dequantize_row_q5_0,
1722
- .from_float = quantize_row_q5_0,
1723
- .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
1724
- .vec_dot = ggml_vec_dot_q5_0_q8_0,
1725
- .vec_dot_type = GGML_TYPE_Q8_0,
1726
- },
1727
- [GGML_TYPE_Q5_1] = {
1728
- .type_name = "q5_1",
1729
- .blck_size = QK5_1,
1730
- .type_size = sizeof(block_q5_1),
1731
- .is_quantized = true,
1732
- .to_float = (ggml_to_float_t) dequantize_row_q5_1,
1733
- .from_float = quantize_row_q5_1,
1734
- .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
1735
- .vec_dot = ggml_vec_dot_q5_1_q8_1,
1736
- .vec_dot_type = GGML_TYPE_Q8_1,
1737
- },
1738
- [GGML_TYPE_Q8_0] = {
1739
- .type_name = "q8_0",
1740
- .blck_size = QK8_0,
1741
- .type_size = sizeof(block_q8_0),
1742
- .is_quantized = true,
1743
- .to_float = dequantize_row_q8_0,
1744
- .from_float = quantize_row_q8_0,
1745
- .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
1746
- .vec_dot = ggml_vec_dot_q8_0_q8_0,
1747
- .vec_dot_type = GGML_TYPE_Q8_0,
1748
- },
1749
- [GGML_TYPE_Q8_1] = {
1750
- .type_name = "q8_1",
1751
- .blck_size = QK8_1,
1752
- .type_size = sizeof(block_q8_1),
1753
- .is_quantized = true,
1754
- .from_float = quantize_row_q8_1,
1755
- .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
1756
- .vec_dot_type = GGML_TYPE_Q8_1,
1757
- },
1758
- #ifdef GGML_USE_K_QUANTS
1759
- [GGML_TYPE_Q2_K] = {
1760
- .type_name = "q2_K",
1761
- .blck_size = QK_K,
1762
- .type_size = sizeof(block_q2_K),
1763
- .is_quantized = true,
1764
- .to_float = (ggml_to_float_t) dequantize_row_q2_K,
1765
- .from_float = quantize_row_q2_K,
1766
- .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
1767
- .vec_dot = ggml_vec_dot_q2_K_q8_K,
1768
- .vec_dot_type = GGML_TYPE_Q8_K,
1769
- },
1770
- [GGML_TYPE_Q3_K] = {
1771
- .type_name = "q3_K",
1772
- .blck_size = QK_K,
1773
- .type_size = sizeof(block_q3_K),
1774
- .is_quantized = true,
1775
- .to_float = (ggml_to_float_t) dequantize_row_q3_K,
1776
- .from_float = quantize_row_q3_K,
1777
- .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
1778
- .vec_dot = ggml_vec_dot_q3_K_q8_K,
1779
- .vec_dot_type = GGML_TYPE_Q8_K,
1780
- },
1781
- [GGML_TYPE_Q4_K] = {
1782
- .type_name = "q4_K",
1783
- .blck_size = QK_K,
1784
- .type_size = sizeof(block_q4_K),
1785
- .is_quantized = true,
1786
- .to_float = (ggml_to_float_t) dequantize_row_q4_K,
1787
- .from_float = quantize_row_q4_K,
1788
- .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
1789
- .vec_dot = ggml_vec_dot_q4_K_q8_K,
1790
- .vec_dot_type = GGML_TYPE_Q8_K,
1791
- },
1792
- [GGML_TYPE_Q5_K] = {
1793
- .type_name = "q5_K",
1794
- .blck_size = QK_K,
1795
- .type_size = sizeof(block_q5_K),
1796
- .is_quantized = true,
1797
- .to_float = (ggml_to_float_t) dequantize_row_q5_K,
1798
- .from_float = quantize_row_q5_K,
1799
- .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
1800
- .vec_dot = ggml_vec_dot_q5_K_q8_K,
1801
- .vec_dot_type = GGML_TYPE_Q8_K,
1802
- },
1803
- [GGML_TYPE_Q6_K] = {
1804
- .type_name = "q6_K",
1805
- .blck_size = QK_K,
1806
- .type_size = sizeof(block_q6_K),
1807
- .is_quantized = true,
1808
- .to_float = (ggml_to_float_t) dequantize_row_q6_K,
1809
- .from_float = quantize_row_q6_K,
1810
- .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
1811
- .vec_dot = ggml_vec_dot_q6_K_q8_K,
1812
- .vec_dot_type = GGML_TYPE_Q8_K,
1813
- },
1814
- [GGML_TYPE_Q8_K] = {
1815
- .type_name = "q8_K",
1816
- .blck_size = QK_K,
1817
- .type_size = sizeof(block_q8_K),
1818
- .is_quantized = true,
1819
- .from_float = quantize_row_q8_K,
1820
- }
1821
- #endif
1822
- };
1823
-
1824
- // For internal test use
1825
- ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1826
- GGML_ASSERT(type < GGML_TYPE_COUNT);
1827
- return type_traits[type];
1828
- }
1829
-
1830
- //
1831
- // simd mappings
1832
- //
1833
-
1834
- // we define a common set of C macros which map to specific intrinsics based on the current architecture
1835
- // we then implement the fundamental computation operations below using only these macros
1836
- // adding support for new architectures requires to define the corresponding SIMD macros
1837
- //
1838
- // GGML_F32_STEP / GGML_F16_STEP
1839
- // number of elements to process in a single step
1840
- //
1841
- // GGML_F32_EPR / GGML_F16_EPR
1842
- // number of elements to fit in a single register
1843
- //
1844
-
1845
- #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
1846
-
1847
- #define GGML_SIMD
1848
-
1849
- // F32 NEON
1850
-
1851
- #define GGML_F32_STEP 16
1852
- #define GGML_F32_EPR 4
1853
-
1854
- #define GGML_F32x4 float32x4_t
1855
- #define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
1856
- #define GGML_F32x4_SET1(x) vdupq_n_f32(x)
1857
- #define GGML_F32x4_LOAD vld1q_f32
1858
- #define GGML_F32x4_STORE vst1q_f32
1859
- #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1860
- #define GGML_F32x4_ADD vaddq_f32
1861
- #define GGML_F32x4_MUL vmulq_f32
1862
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1863
- #define GGML_F32x4_REDUCE(res, x) \
1864
- { \
1865
- int offset = GGML_F32_ARR >> 1; \
1866
- for (int i = 0; i < offset; ++i) { \
1867
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1868
- } \
1869
- offset >>= 1; \
1870
- for (int i = 0; i < offset; ++i) { \
1871
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1872
- } \
1873
- offset >>= 1; \
1874
- for (int i = 0; i < offset; ++i) { \
1875
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1876
- } \
1877
- res = GGML_F32x4_REDUCE_ONE(x[0]); \
1878
- }
1879
-
1880
- #define GGML_F32_VEC GGML_F32x4
1881
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1882
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1883
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1884
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1885
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1886
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1887
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1888
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1889
-
1890
- // F16 NEON
1891
-
1892
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1893
- #define GGML_F16_STEP 32
1894
- #define GGML_F16_EPR 8
1895
-
1896
- #define GGML_F16x8 float16x8_t
1897
- #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
1898
- #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
1899
- #define GGML_F16x8_LOAD vld1q_f16
1900
- #define GGML_F16x8_STORE vst1q_f16
1901
- #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
1902
- #define GGML_F16x8_ADD vaddq_f16
1903
- #define GGML_F16x8_MUL vmulq_f16
1904
- #define GGML_F16x8_REDUCE(res, x) \
1905
- do { \
1906
- int offset = GGML_F16_ARR >> 1; \
1907
- for (int i = 0; i < offset; ++i) { \
1908
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1909
- } \
1910
- offset >>= 1; \
1911
- for (int i = 0; i < offset; ++i) { \
1912
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1913
- } \
1914
- offset >>= 1; \
1915
- for (int i = 0; i < offset; ++i) { \
1916
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1917
- } \
1918
- const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1919
- const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1920
- res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1921
- } while (0)
1922
-
1923
- #define GGML_F16_VEC GGML_F16x8
1924
- #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
1925
- #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
1926
- #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
1927
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
1928
- #define GGML_F16_VEC_FMA GGML_F16x8_FMA
1929
- #define GGML_F16_VEC_ADD GGML_F16x8_ADD
1930
- #define GGML_F16_VEC_MUL GGML_F16x8_MUL
1931
- #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
1932
- #else
1933
- // if FP16 vector arithmetic is not supported, we use FP32 instead
1934
- // and take advantage of the vcvt_ functions to convert to/from FP16
1935
-
1936
- #define GGML_F16_STEP 16
1937
- #define GGML_F16_EPR 4
1938
-
1939
- #define GGML_F32Cx4 float32x4_t
1940
- #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
1941
- #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
1942
- #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
1943
- #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
1944
- #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
1945
- #define GGML_F32Cx4_ADD vaddq_f32
1946
- #define GGML_F32Cx4_MUL vmulq_f32
1947
- #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1948
-
1949
- #define GGML_F16_VEC GGML_F32Cx4
1950
- #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1951
- #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1952
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1953
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1954
- #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1955
- #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1956
- #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1957
- #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1958
- #endif
1959
-
1960
- #elif defined(__AVX__)
1961
-
1962
- #define GGML_SIMD
1963
-
1964
- // F32 AVX
1965
-
1966
- #define GGML_F32_STEP 32
1967
- #define GGML_F32_EPR 8
1968
-
1969
- #define GGML_F32x8 __m256
1970
- #define GGML_F32x8_ZERO _mm256_setzero_ps()
1971
- #define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
1972
- #define GGML_F32x8_LOAD _mm256_loadu_ps
1973
- #define GGML_F32x8_STORE _mm256_storeu_ps
1974
- #if defined(__FMA__)
1975
- #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
1976
- #else
1977
- #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
1978
- #endif
1979
- #define GGML_F32x8_ADD _mm256_add_ps
1980
- #define GGML_F32x8_MUL _mm256_mul_ps
1981
- #define GGML_F32x8_REDUCE(res, x) \
1982
- do { \
1983
- int offset = GGML_F32_ARR >> 1; \
1984
- for (int i = 0; i < offset; ++i) { \
1985
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1986
- } \
1987
- offset >>= 1; \
1988
- for (int i = 0; i < offset; ++i) { \
1989
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1990
- } \
1991
- offset >>= 1; \
1992
- for (int i = 0; i < offset; ++i) { \
1993
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1994
- } \
1995
- const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
1996
- _mm256_extractf128_ps(x[0], 1)); \
1997
- const __m128 t1 = _mm_hadd_ps(t0, t0); \
1998
- res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
1999
- } while (0)
2000
- // TODO: is this optimal ?
2001
-
2002
- #define GGML_F32_VEC GGML_F32x8
2003
- #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
2004
- #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
2005
- #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
2006
- #define GGML_F32_VEC_STORE GGML_F32x8_STORE
2007
- #define GGML_F32_VEC_FMA GGML_F32x8_FMA
2008
- #define GGML_F32_VEC_ADD GGML_F32x8_ADD
2009
- #define GGML_F32_VEC_MUL GGML_F32x8_MUL
2010
- #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
2011
-
2012
- // F16 AVX
2013
-
2014
- #define GGML_F16_STEP 32
2015
- #define GGML_F16_EPR 8
2016
-
2017
- // F16 arithmetic is not supported by AVX, so we use F32 instead
2018
-
2019
- #define GGML_F32Cx8 __m256
2020
- #define GGML_F32Cx8_ZERO _mm256_setzero_ps()
2021
- #define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
2022
-
2023
- #if defined(__F16C__)
2024
- // the _mm256_cvt intrinsics require F16C
2025
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
2026
- #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
2027
- #else
2028
- static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
2029
- float tmp[8];
2030
-
2031
- for (int i = 0; i < 8; i++) {
2032
- tmp[i] = GGML_FP16_TO_FP32(x[i]);
2033
- }
2034
-
2035
- return _mm256_loadu_ps(tmp);
2036
- }
2037
- static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
2038
- float arr[8];
2039
-
2040
- _mm256_storeu_ps(arr, y);
2041
-
2042
- for (int i = 0; i < 8; i++)
2043
- x[i] = GGML_FP32_TO_FP16(arr[i]);
2044
- }
2045
- #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
2046
- #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
2047
- #endif
2048
-
2049
- #define GGML_F32Cx8_FMA GGML_F32x8_FMA
2050
- #define GGML_F32Cx8_ADD _mm256_add_ps
2051
- #define GGML_F32Cx8_MUL _mm256_mul_ps
2052
- #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
2053
-
2054
- #define GGML_F16_VEC GGML_F32Cx8
2055
- #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
2056
- #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
2057
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
2058
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
2059
- #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
2060
- #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
2061
- #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
2062
- #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
2063
-
2064
- #elif defined(__POWER9_VECTOR__)
2065
-
2066
- #define GGML_SIMD
2067
-
2068
- // F32 POWER9
2069
-
2070
- #define GGML_F32_STEP 32
2071
- #define GGML_F32_EPR 4
2072
-
2073
- #define GGML_F32x4 vector float
2074
- #define GGML_F32x4_ZERO 0.0f
2075
- #define GGML_F32x4_SET1 vec_splats
2076
- #define GGML_F32x4_LOAD(p) vec_xl(0, p)
2077
- #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
2078
- #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
2079
- #define GGML_F32x4_ADD vec_add
2080
- #define GGML_F32x4_MUL vec_mul
2081
- #define GGML_F32x4_REDUCE(res, x) \
2082
- { \
2083
- int offset = GGML_F32_ARR >> 1; \
2084
- for (int i = 0; i < offset; ++i) { \
2085
- x[i] = vec_add(x[i], x[offset+i]); \
2086
- } \
2087
- offset >>= 1; \
2088
- for (int i = 0; i < offset; ++i) { \
2089
- x[i] = vec_add(x[i], x[offset+i]); \
2090
- } \
2091
- offset >>= 1; \
2092
- for (int i = 0; i < offset; ++i) { \
2093
- x[i] = vec_add(x[i], x[offset+i]); \
2094
- } \
2095
- res = vec_extract(x[0], 0) + \
2096
- vec_extract(x[0], 1) + \
2097
- vec_extract(x[0], 2) + \
2098
- vec_extract(x[0], 3); \
2099
- }
2100
-
2101
- #define GGML_F32_VEC GGML_F32x4
2102
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
2103
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
2104
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
2105
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
2106
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
2107
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
2108
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
2109
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2110
-
2111
- // F16 POWER9
2112
- #define GGML_F16_STEP GGML_F32_STEP
2113
- #define GGML_F16_EPR GGML_F32_EPR
2114
- #define GGML_F16_VEC GGML_F32x4
2115
- #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
2116
- #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
2117
- #define GGML_F16_VEC_FMA GGML_F32x4_FMA
2118
- #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
2119
- // Use vec_xl, not vec_ld, in case the load address is not aligned.
2120
- #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
2121
- vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
2122
- vec_extract_fp32_from_shortl(vec_xl(0, p))
2123
- #define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
2124
- #define GGML_F16_VEC_STORE(p, r, i) \
2125
- if (i & 0x1) \
2126
- vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
2127
- r[i - GGML_ENDIAN_BYTE(0)]), \
2128
- 0, p - GGML_F16_EPR)
2129
-
2130
- #elif defined(__wasm_simd128__)
2131
-
2132
- #define GGML_SIMD
2133
-
2134
- // F32 WASM
2135
-
2136
- #define GGML_F32_STEP 16
2137
- #define GGML_F32_EPR 4
2138
-
2139
- #define GGML_F32x4 v128_t
2140
- #define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
2141
- #define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
2142
- #define GGML_F32x4_LOAD wasm_v128_load
2143
- #define GGML_F32x4_STORE wasm_v128_store
2144
- #define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
2145
- #define GGML_F32x4_ADD wasm_f32x4_add
2146
- #define GGML_F32x4_MUL wasm_f32x4_mul
2147
- #define GGML_F32x4_REDUCE(res, x) \
2148
- { \
2149
- int offset = GGML_F32_ARR >> 1; \
2150
- for (int i = 0; i < offset; ++i) { \
2151
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2152
- } \
2153
- offset >>= 1; \
2154
- for (int i = 0; i < offset; ++i) { \
2155
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2156
- } \
2157
- offset >>= 1; \
2158
- for (int i = 0; i < offset; ++i) { \
2159
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2160
- } \
2161
- res = wasm_f32x4_extract_lane(x[0], 0) + \
2162
- wasm_f32x4_extract_lane(x[0], 1) + \
2163
- wasm_f32x4_extract_lane(x[0], 2) + \
2164
- wasm_f32x4_extract_lane(x[0], 3); \
2165
- }
2166
-
2167
- #define GGML_F32_VEC GGML_F32x4
2168
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
2169
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
2170
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
2171
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
2172
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
2173
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
2174
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
2175
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2176
-
2177
- // F16 WASM
2178
-
2179
- #define GGML_F16_STEP 16
2180
- #define GGML_F16_EPR 4
2181
-
2182
- inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
2183
- float tmp[4];
2184
-
2185
- tmp[0] = GGML_FP16_TO_FP32(p[0]);
2186
- tmp[1] = GGML_FP16_TO_FP32(p[1]);
2187
- tmp[2] = GGML_FP16_TO_FP32(p[2]);
2188
- tmp[3] = GGML_FP16_TO_FP32(p[3]);
2189
-
2190
- return wasm_v128_load(tmp);
2191
- }
2192
-
2193
- inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
2194
- float tmp[4];
2195
-
2196
- wasm_v128_store(tmp, x);
2197
-
2198
- p[0] = GGML_FP32_TO_FP16(tmp[0]);
2199
- p[1] = GGML_FP32_TO_FP16(tmp[1]);
2200
- p[2] = GGML_FP32_TO_FP16(tmp[2]);
2201
- p[3] = GGML_FP32_TO_FP16(tmp[3]);
2202
- }
2203
-
2204
- #define GGML_F16x4 v128_t
2205
- #define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
2206
- #define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
2207
- #define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
2208
- #define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
2209
- #define GGML_F16x4_FMA GGML_F32x4_FMA
2210
- #define GGML_F16x4_ADD wasm_f32x4_add
2211
- #define GGML_F16x4_MUL wasm_f32x4_mul
2212
- #define GGML_F16x4_REDUCE(res, x) \
2213
- { \
2214
- int offset = GGML_F16_ARR >> 1; \
2215
- for (int i = 0; i < offset; ++i) { \
2216
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2217
- } \
2218
- offset >>= 1; \
2219
- for (int i = 0; i < offset; ++i) { \
2220
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2221
- } \
2222
- offset >>= 1; \
2223
- for (int i = 0; i < offset; ++i) { \
2224
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2225
- } \
2226
- res = wasm_f32x4_extract_lane(x[0], 0) + \
2227
- wasm_f32x4_extract_lane(x[0], 1) + \
2228
- wasm_f32x4_extract_lane(x[0], 2) + \
2229
- wasm_f32x4_extract_lane(x[0], 3); \
2230
- }
2231
-
2232
- #define GGML_F16_VEC GGML_F16x4
2233
- #define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
2234
- #define GGML_F16_VEC_SET1 GGML_F16x4_SET1
2235
- #define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
2236
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
2237
- #define GGML_F16_VEC_FMA GGML_F16x4_FMA
2238
- #define GGML_F16_VEC_ADD GGML_F16x4_ADD
2239
- #define GGML_F16_VEC_MUL GGML_F16x4_MUL
2240
- #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
2241
-
2242
- #elif defined(__SSE3__)
2243
-
2244
- #define GGML_SIMD
2245
-
2246
- // F32 SSE
2247
-
2248
- #define GGML_F32_STEP 32
2249
- #define GGML_F32_EPR 4
2250
-
2251
- #define GGML_F32x4 __m128
2252
- #define GGML_F32x4_ZERO _mm_setzero_ps()
2253
- #define GGML_F32x4_SET1(x) _mm_set1_ps(x)
2254
- #define GGML_F32x4_LOAD _mm_loadu_ps
2255
- #define GGML_F32x4_STORE _mm_storeu_ps
2256
- #if defined(__FMA__)
2257
- // TODO: Does this work?
2258
- #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
2259
- #else
2260
- #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
2261
- #endif
2262
- #define GGML_F32x4_ADD _mm_add_ps
2263
- #define GGML_F32x4_MUL _mm_mul_ps
2264
- #define GGML_F32x4_REDUCE(res, x) \
2265
- { \
2266
- int offset = GGML_F32_ARR >> 1; \
2267
- for (int i = 0; i < offset; ++i) { \
2268
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
2269
- } \
2270
- offset >>= 1; \
2271
- for (int i = 0; i < offset; ++i) { \
2272
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
2273
- } \
2274
- offset >>= 1; \
2275
- for (int i = 0; i < offset; ++i) { \
2276
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
2277
- } \
2278
- const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
2279
- res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
2280
- }
2281
- // TODO: is this optimal ?
2282
-
2283
- #define GGML_F32_VEC GGML_F32x4
2284
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
2285
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
2286
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
2287
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
2288
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
2289
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
2290
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
2291
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2292
-
2293
- // F16 SSE
2294
-
2295
- #define GGML_F16_STEP 32
2296
- #define GGML_F16_EPR 4
2297
-
2298
- static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
2299
- float tmp[4];
2300
-
2301
- tmp[0] = GGML_FP16_TO_FP32(x[0]);
2302
- tmp[1] = GGML_FP16_TO_FP32(x[1]);
2303
- tmp[2] = GGML_FP16_TO_FP32(x[2]);
2304
- tmp[3] = GGML_FP16_TO_FP32(x[3]);
2305
-
2306
- return _mm_loadu_ps(tmp);
2307
- }
2308
-
2309
- static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
2310
- float arr[4];
2311
-
2312
- _mm_storeu_ps(arr, y);
2313
-
2314
- x[0] = GGML_FP32_TO_FP16(arr[0]);
2315
- x[1] = GGML_FP32_TO_FP16(arr[1]);
2316
- x[2] = GGML_FP32_TO_FP16(arr[2]);
2317
- x[3] = GGML_FP32_TO_FP16(arr[3]);
2318
- }
2319
-
2320
- #define GGML_F32Cx4 __m128
2321
- #define GGML_F32Cx4_ZERO _mm_setzero_ps()
2322
- #define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
2323
- #define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
2324
- #define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
2325
- #define GGML_F32Cx4_FMA GGML_F32x4_FMA
2326
- #define GGML_F32Cx4_ADD _mm_add_ps
2327
- #define GGML_F32Cx4_MUL _mm_mul_ps
2328
- #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
2329
-
2330
- #define GGML_F16_VEC GGML_F32Cx4
2331
- #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
2332
- #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
2333
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
2334
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
2335
- #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
2336
- #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
2337
- #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
2338
- #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
2339
-
2340
- #endif
2341
-
2342
- // GGML_F32_ARR / GGML_F16_ARR
2343
- // number of registers to use per step
2344
- #ifdef GGML_SIMD
2345
- #define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
2346
- #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
2347
- #endif
2348
-
2349
- //
2350
- // fundamental operations
2351
- //
2352
-
2353
- inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2354
-
2355
- inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2356
-
2357
- inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2358
-
2359
- inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2360
-
2361
- inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
2362
- inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
2363
- inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
2364
- inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
2365
- inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
2366
- inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
2367
- inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
2368
- inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
2369
- inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
2370
- inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
2371
-
2372
- static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
2373
- #ifdef GGML_SIMD
2374
- float sumf = 0.0f;
2375
- const int np = (n & ~(GGML_F32_STEP - 1));
2376
-
2377
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
2378
-
2379
- GGML_F32_VEC ax[GGML_F32_ARR];
2380
- GGML_F32_VEC ay[GGML_F32_ARR];
2381
-
2382
- for (int i = 0; i < np; i += GGML_F32_STEP) {
2383
- for (int j = 0; j < GGML_F32_ARR; j++) {
2384
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
2385
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
2386
-
2387
- sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
2388
- }
2389
- }
2390
-
2391
- // reduce sum0..sum3 to sum0
2392
- GGML_F32_VEC_REDUCE(sumf, sum);
2393
-
2394
- // leftovers
2395
- for (int i = np; i < n; ++i) {
2396
- sumf += x[i]*y[i];
2397
- }
2398
- #else
2399
- // scalar
2400
- ggml_float sumf = 0.0;
2401
- for (int i = 0; i < n; ++i) {
2402
- sumf += (ggml_float)(x[i]*y[i]);
2403
- }
2404
- #endif
2405
-
2406
- *s = sumf;
2407
- }
2408
-
2409
- static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
2410
- ggml_float sumf = 0.0;
2411
-
2412
- #if defined(GGML_SIMD)
2413
- const int np = (n & ~(GGML_F16_STEP - 1));
2414
-
2415
- GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
2416
-
2417
- GGML_F16_VEC ax[GGML_F16_ARR];
2418
- GGML_F16_VEC ay[GGML_F16_ARR];
2419
-
2420
- for (int i = 0; i < np; i += GGML_F16_STEP) {
2421
- for (int j = 0; j < GGML_F16_ARR; j++) {
2422
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
2423
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2424
-
2425
- sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
2426
- }
2427
- }
2428
-
2429
- // reduce sum0..sum3 to sum0
2430
- GGML_F16_VEC_REDUCE(sumf, sum);
2431
-
2432
- // leftovers
2433
- for (int i = np; i < n; ++i) {
2434
- sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
2435
- }
2436
- #else
2437
- for (int i = 0; i < n; ++i) {
2438
- sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
2439
- }
2440
- #endif
2441
-
2442
- *s = sumf;
2443
- }
2444
-
2445
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2446
- const int qk = QK8_0;
2447
- const int nb = n / qk;
2448
-
2449
- assert(n % qk == 0);
2450
-
2451
- const block_q4_0 * restrict x = vx;
2452
- const block_q8_0 * restrict y = vy;
2453
-
2454
- #if defined(__ARM_NEON)
2455
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2456
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
2457
-
2458
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2459
- for (int i = 0; i < nb; i += 2) {
2460
- const block_q4_0 * restrict x0 = &x[i + 0];
2461
- const block_q4_0 * restrict x1 = &x[i + 1];
2462
- const block_q8_0 * restrict y0 = &y[i + 0];
2463
- const block_q8_0 * restrict y1 = &y[i + 1];
2464
-
2465
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2466
- const int8x16_t s8b = vdupq_n_s8(0x8);
2467
-
2468
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2469
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2470
-
2471
- // 4-bit -> 8-bit
2472
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2473
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2474
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2475
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2476
-
2477
- // sub 8
2478
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2479
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2480
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2481
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2482
-
2483
- // load y
2484
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2485
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2486
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2487
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2488
-
2489
- #if defined(__ARM_FEATURE_DOTPROD)
2490
- // dot product into int32x4_t
2491
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2492
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2493
-
2494
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2495
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2496
- #else
2497
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2498
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2499
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2500
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2501
-
2502
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2503
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2504
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2505
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2506
-
2507
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2508
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2509
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2510
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2511
-
2512
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2513
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2514
- #endif
2515
- }
2516
-
2517
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2518
- #elif defined(__AVX2__)
2519
- // Initialize accumulator with zeros
2520
- __m256 acc = _mm256_setzero_ps();
2521
-
2522
- // Main loop
2523
- for (int i = 0; i < nb; ++i) {
2524
- /* Compute combined scale for the block */
2525
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2526
-
2527
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
2528
-
2529
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2530
- const __m256i off = _mm256_set1_epi8( 8 );
2531
- bx = _mm256_sub_epi8( bx, off );
2532
-
2533
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2534
-
2535
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
2536
-
2537
- /* Multiply q with scale and accumulate */
2538
- acc = _mm256_fmadd_ps( d, q, acc );
2539
- }
2540
-
2541
- *s = hsum_float_8(acc);
2542
- #elif defined(__AVX__)
2543
- // Initialize accumulator with zeros
2544
- __m256 acc = _mm256_setzero_ps();
2545
-
2546
- // Main loop
2547
- for (int i = 0; i < nb; ++i) {
2548
- // Compute combined scale for the block
2549
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2550
-
2551
- const __m128i lowMask = _mm_set1_epi8(0xF);
2552
- const __m128i off = _mm_set1_epi8(8);
2553
-
2554
- const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
2555
-
2556
- __m128i bx = _mm_and_si128(lowMask, tmp);
2557
- __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
2558
- bx = _mm_sub_epi8(bx, off);
2559
- const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
2560
-
2561
- bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
2562
- by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2563
- bx = _mm_sub_epi8(bx, off);
2564
- const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2565
-
2566
- // Convert int32_t to float
2567
- __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
2568
-
2569
- // Apply the scale, and accumulate
2570
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2571
- }
2572
-
2573
- *s = hsum_float_8(acc);
2574
- #elif defined(__SSSE3__)
2575
- // set constants
2576
- const __m128i lowMask = _mm_set1_epi8(0xF);
2577
- const __m128i off = _mm_set1_epi8(8);
2578
-
2579
- // Initialize accumulator with zeros
2580
- __m128 acc_0 = _mm_setzero_ps();
2581
- __m128 acc_1 = _mm_setzero_ps();
2582
- __m128 acc_2 = _mm_setzero_ps();
2583
- __m128 acc_3 = _mm_setzero_ps();
2584
-
2585
- // First round without accumulation
2586
- {
2587
- _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
2588
- _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
2589
-
2590
- // Compute combined scale for the block 0 and 1
2591
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
2592
-
2593
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
2594
-
2595
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2596
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
2597
- bx_0 = _mm_sub_epi8(bx_0, off);
2598
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2599
-
2600
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2601
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
2602
- bx_1 = _mm_sub_epi8(bx_1, off);
2603
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2604
-
2605
- _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
2606
- _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
2607
-
2608
- // Compute combined scale for the block 2 and 3
2609
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
2610
-
2611
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
2612
-
2613
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2614
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
2615
- bx_2 = _mm_sub_epi8(bx_2, off);
2616
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2617
-
2618
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2619
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
2620
- bx_3 = _mm_sub_epi8(bx_3, off);
2621
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2622
-
2623
- // Convert int32_t to float
2624
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
2625
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
2626
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
2627
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
2628
-
2629
- // Apply the scale
2630
- acc_0 = _mm_mul_ps( d_0_1, p0 );
2631
- acc_1 = _mm_mul_ps( d_0_1, p1 );
2632
- acc_2 = _mm_mul_ps( d_2_3, p2 );
2633
- acc_3 = _mm_mul_ps( d_2_3, p3 );
2634
- }
2635
-
2636
- // Main loop
2637
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2638
- for (int i = 2; i < nb; i+=2) {
2639
- _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
2640
- _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
2641
-
2642
- // Compute combined scale for the block 0 and 1
2643
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2644
-
2645
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
2646
-
2647
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2648
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
2649
- bx_0 = _mm_sub_epi8(bx_0, off);
2650
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2651
-
2652
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2653
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2654
- bx_1 = _mm_sub_epi8(bx_1, off);
2655
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2656
-
2657
- _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
2658
- _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2659
-
2660
- // Compute combined scale for the block 2 and 3
2661
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
2662
-
2663
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
2664
-
2665
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2666
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
2667
- bx_2 = _mm_sub_epi8(bx_2, off);
2668
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2669
-
2670
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2671
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
2672
- bx_3 = _mm_sub_epi8(bx_3, off);
2673
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2674
-
2675
- // Convert int32_t to float
2676
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
2677
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
2678
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
2679
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
2680
-
2681
- // Apply the scale
2682
- __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
2683
- __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
2684
- __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
2685
- __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
2686
-
2687
- // Acummulate
2688
- acc_0 = _mm_add_ps(p0_d, acc_0);
2689
- acc_1 = _mm_add_ps(p1_d, acc_1);
2690
- acc_2 = _mm_add_ps(p2_d, acc_2);
2691
- acc_3 = _mm_add_ps(p3_d, acc_3);
2692
- }
2693
-
2694
- *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2695
- #elif defined(__riscv_v_intrinsic)
2696
- float sumf = 0.0;
2697
-
2698
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2699
-
2700
- for (int i = 0; i < nb; i++) {
2701
- // load elements
2702
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2703
-
2704
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2705
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2706
-
2707
- // mask and store lower part of x, and then upper part
2708
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2709
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2710
-
2711
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2712
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2713
-
2714
- // subtract offset
2715
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2716
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2717
-
2718
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2719
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2720
-
2721
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2722
-
2723
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2724
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2725
-
2726
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2727
-
2728
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2729
- }
560
+ //
561
+ // simd mappings
562
+ //
2730
563
 
2731
- *s = sumf;
2732
- #else
2733
- // scalar
2734
- float sumf = 0.0;
564
+ // we define a common set of C macros which map to specific intrinsics based on the current architecture
565
+ // we then implement the fundamental computation operations below using only these macros
566
+ // adding support for new architectures requires to define the corresponding SIMD macros
567
+ //
568
+ // GGML_F32_STEP / GGML_F16_STEP
569
+ // number of elements to process in a single step
570
+ //
571
+ // GGML_F32_EPR / GGML_F16_EPR
572
+ // number of elements to fit in a single register
573
+ //
2735
574
 
2736
- for (int i = 0; i < nb; i++) {
2737
- int sumi = 0;
575
+ #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
2738
576
 
2739
- for (int j = 0; j < qk/2; ++j) {
2740
- const int v0 = (x[i].qs[j] & 0x0F) - 8;
2741
- const int v1 = (x[i].qs[j] >> 4) - 8;
577
+ #define GGML_SIMD
2742
578
 
2743
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2744
- }
579
+ // F32 NEON
2745
580
 
2746
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2747
- }
581
+ #define GGML_F32_STEP 16
582
+ #define GGML_F32_EPR 4
2748
583
 
2749
- *s = sumf;
2750
- #endif
584
+ #define GGML_F32x4 float32x4_t
585
+ #define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
586
+ #define GGML_F32x4_SET1(x) vdupq_n_f32(x)
587
+ #define GGML_F32x4_LOAD vld1q_f32
588
+ #define GGML_F32x4_STORE vst1q_f32
589
+ #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
590
+ #define GGML_F32x4_ADD vaddq_f32
591
+ #define GGML_F32x4_MUL vmulq_f32
592
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
593
+ #define GGML_F32x4_REDUCE(res, x) \
594
+ { \
595
+ int offset = GGML_F32_ARR >> 1; \
596
+ for (int i = 0; i < offset; ++i) { \
597
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
598
+ } \
599
+ offset >>= 1; \
600
+ for (int i = 0; i < offset; ++i) { \
601
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
602
+ } \
603
+ offset >>= 1; \
604
+ for (int i = 0; i < offset; ++i) { \
605
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
606
+ } \
607
+ res = GGML_F32x4_REDUCE_ONE(x[0]); \
2751
608
  }
2752
609
 
2753
- static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2754
- const int qk = QK8_1;
2755
- const int nb = n / qk;
2756
-
2757
- assert(n % qk == 0);
2758
-
2759
- const block_q4_1 * restrict x = vx;
2760
- const block_q8_1 * restrict y = vy;
2761
-
2762
- // TODO: add WASM SIMD
2763
- #if defined(__ARM_NEON)
2764
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2765
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
2766
-
2767
- float summs = 0;
2768
-
2769
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2770
- for (int i = 0; i < nb; i += 2) {
2771
- const block_q4_1 * restrict x0 = &x[i + 0];
2772
- const block_q4_1 * restrict x1 = &x[i + 1];
2773
- const block_q8_1 * restrict y0 = &y[i + 0];
2774
- const block_q8_1 * restrict y1 = &y[i + 1];
2775
-
2776
- summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
2777
-
2778
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2779
-
2780
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2781
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
610
+ #define GGML_F32_VEC GGML_F32x4
611
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
612
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
613
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
614
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
615
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
616
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
617
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
618
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2782
619
 
2783
- // 4-bit -> 8-bit
2784
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2785
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2786
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2787
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
620
+ // F16 NEON
2788
621
 
2789
- // load y
2790
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2791
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2792
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2793
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
622
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
623
+ #define GGML_F16_STEP 32
624
+ #define GGML_F16_EPR 8
2794
625
 
2795
- #if defined(__ARM_FEATURE_DOTPROD)
2796
- // dot product into int32x4_t
2797
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2798
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
626
+ #define GGML_F16x8 float16x8_t
627
+ #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
628
+ #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
629
+ #define GGML_F16x8_LOAD vld1q_f16
630
+ #define GGML_F16x8_STORE vst1q_f16
631
+ #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
632
+ #define GGML_F16x8_ADD vaddq_f16
633
+ #define GGML_F16x8_MUL vmulq_f16
634
+ #define GGML_F16x8_REDUCE(res, x) \
635
+ do { \
636
+ int offset = GGML_F16_ARR >> 1; \
637
+ for (int i = 0; i < offset; ++i) { \
638
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
639
+ } \
640
+ offset >>= 1; \
641
+ for (int i = 0; i < offset; ++i) { \
642
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
643
+ } \
644
+ offset >>= 1; \
645
+ for (int i = 0; i < offset; ++i) { \
646
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
647
+ } \
648
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
649
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
650
+ res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
651
+ } while (0)
2799
652
 
2800
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2801
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
653
+ #define GGML_F16_VEC GGML_F16x8
654
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
655
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
656
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
657
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
658
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
659
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
660
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
661
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
2802
662
  #else
2803
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2804
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
2805
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
2806
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
2807
-
2808
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
2809
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
2810
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
2811
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
2812
-
2813
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2814
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2815
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2816
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2817
-
2818
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2819
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2820
- #endif
2821
- }
2822
-
2823
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2824
- #elif defined(__AVX2__) || defined(__AVX__)
2825
- // Initialize accumulator with zeros
2826
- __m256 acc = _mm256_setzero_ps();
2827
-
2828
- float summs = 0;
2829
-
2830
- // Main loop
2831
- for (int i = 0; i < nb; ++i) {
2832
- const float d0 = GGML_FP16_TO_FP32(x[i].d);
2833
- const float d1 = y[i].d;
2834
-
2835
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2836
-
2837
- const __m256 d0v = _mm256_set1_ps( d0 );
2838
- const __m256 d1v = _mm256_set1_ps( d1 );
2839
-
2840
- // Compute combined scales
2841
- const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
663
+ // if FP16 vector arithmetic is not supported, we use FP32 instead
664
+ // and take advantage of the vcvt_ functions to convert to/from FP16
2842
665
 
2843
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2844
- const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2845
- const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
666
+ #define GGML_F16_STEP 16
667
+ #define GGML_F16_EPR 4
2846
668
 
2847
- const __m256 xy = mul_sum_us8_pairs_float(bx, by);
669
+ #define GGML_F32Cx4 float32x4_t
670
+ #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
671
+ #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
672
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
673
+ #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
674
+ #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
675
+ #define GGML_F32Cx4_ADD vaddq_f32
676
+ #define GGML_F32Cx4_MUL vmulq_f32
677
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
2848
678
 
2849
- // Accumulate d0*d1*x*y
2850
- #if defined(__AVX2__)
2851
- acc = _mm256_fmadd_ps( d0d1, xy, acc );
2852
- #else
2853
- acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
679
+ #define GGML_F16_VEC GGML_F32Cx4
680
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
681
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
682
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
683
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
684
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
685
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
686
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
687
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
2854
688
  #endif
2855
- }
2856
-
2857
- *s = hsum_float_8(acc) + summs;
2858
- #elif defined(__riscv_v_intrinsic)
2859
- float sumf = 0.0;
2860
-
2861
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2862
-
2863
- for (int i = 0; i < nb; i++) {
2864
- // load elements
2865
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2866
-
2867
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2868
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2869
-
2870
- // mask and store lower part of x, and then upper part
2871
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2872
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2873
-
2874
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2875
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2876
-
2877
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2878
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2879
-
2880
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2881
-
2882
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2883
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2884
-
2885
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2886
-
2887
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2888
- }
2889
-
2890
- *s = sumf;
2891
- #else
2892
- // scalar
2893
- float sumf = 0.0;
2894
-
2895
- for (int i = 0; i < nb; i++) {
2896
- int sumi = 0;
2897
-
2898
- for (int j = 0; j < qk/2; ++j) {
2899
- const int v0 = (x[i].qs[j] & 0x0F);
2900
- const int v1 = (x[i].qs[j] >> 4);
2901
-
2902
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2903
- }
2904
-
2905
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2906
- }
2907
689
 
2908
- *s = sumf;
2909
- #endif
2910
- }
690
+ #elif defined(__AVX__)
2911
691
 
2912
- static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2913
- const int qk = QK8_0;
2914
- const int nb = n / qk;
692
+ #define GGML_SIMD
2915
693
 
2916
- assert(n % qk == 0);
2917
- assert(qk == QK5_0);
694
+ // F32 AVX
2918
695
 
2919
- const block_q5_0 * restrict x = vx;
2920
- const block_q8_0 * restrict y = vy;
696
+ #define GGML_F32_STEP 32
697
+ #define GGML_F32_EPR 8
2921
698
 
2922
- #if defined(__ARM_NEON)
2923
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2924
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
2925
-
2926
- uint32_t qh0;
2927
- uint32_t qh1;
2928
-
2929
- uint64_t tmp0[4];
2930
- uint64_t tmp1[4];
2931
-
2932
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2933
- for (int i = 0; i < nb; i += 2) {
2934
- const block_q5_0 * restrict x0 = &x[i];
2935
- const block_q5_0 * restrict x1 = &x[i + 1];
2936
- const block_q8_0 * restrict y0 = &y[i];
2937
- const block_q8_0 * restrict y1 = &y[i + 1];
2938
-
2939
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2940
-
2941
- // extract the 5th bit via lookup table ((!b) << 4)
2942
- memcpy(&qh0, x0->qh, sizeof(qh0));
2943
- memcpy(&qh1, x1->qh, sizeof(qh1));
2944
-
2945
- tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
2946
- tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
2947
- tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
2948
- tmp0[3] = table_b2b_1[(qh0 >> 24) ];
2949
-
2950
- tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
2951
- tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
2952
- tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
2953
- tmp1[3] = table_b2b_1[(qh1 >> 24) ];
2954
-
2955
- const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
2956
- const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
2957
- const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
2958
- const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
2959
-
2960
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2961
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2962
-
2963
- // 4-bit -> 8-bit
2964
- int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2965
- int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2966
- int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2967
- int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2968
-
2969
- // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2970
- const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
2971
- const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
2972
- const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
2973
- const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
2974
-
2975
- // load y
2976
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2977
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2978
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2979
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2980
-
2981
- #if defined(__ARM_FEATURE_DOTPROD)
2982
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2983
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2984
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2985
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2986
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2987
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
699
+ #define GGML_F32x8 __m256
700
+ #define GGML_F32x8_ZERO _mm256_setzero_ps()
701
+ #define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
702
+ #define GGML_F32x8_LOAD _mm256_loadu_ps
703
+ #define GGML_F32x8_STORE _mm256_storeu_ps
704
+ #if defined(__FMA__)
705
+ #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
2988
706
  #else
2989
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2990
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2991
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2992
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2993
-
2994
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2995
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2996
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2997
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
2998
-
2999
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3000
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3001
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3002
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3003
-
3004
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3005
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
707
+ #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
3006
708
  #endif
3007
- }
3008
-
3009
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3010
- #elif defined(__wasm_simd128__)
3011
- v128_t sumv = wasm_f32x4_splat(0.0f);
3012
-
3013
- uint32_t qh;
3014
- uint64_t tmp[4];
3015
-
3016
- // TODO: check if unrolling this is better
3017
- for (int i = 0; i < nb; ++i) {
3018
- const block_q5_0 * restrict x0 = &x[i];
3019
- const block_q8_0 * restrict y0 = &y[i];
3020
-
3021
- const v128_t m4b = wasm_i8x16_splat(0x0F);
3022
-
3023
- // extract the 5th bit
3024
- memcpy(&qh, x0->qh, sizeof(qh));
3025
-
3026
- tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
3027
- tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
3028
- tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
3029
- tmp[3] = table_b2b_1[(qh >> 24) ];
3030
-
3031
- const v128_t qhl = wasm_v128_load(tmp + 0);
3032
- const v128_t qhh = wasm_v128_load(tmp + 2);
3033
-
3034
- const v128_t v0 = wasm_v128_load(x0->qs);
3035
-
3036
- // 4-bit -> 8-bit
3037
- const v128_t v0l = wasm_v128_and (v0, m4b);
3038
- const v128_t v0h = wasm_u8x16_shr(v0, 4);
3039
-
3040
- // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
3041
- const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
3042
- const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
3043
-
3044
- // load y
3045
- const v128_t v1l = wasm_v128_load(y0->qs);
3046
- const v128_t v1h = wasm_v128_load(y0->qs + 16);
3047
-
3048
- // int8x16 -> int16x8
3049
- const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3050
- const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3051
- const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3052
- const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
3053
-
3054
- const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3055
- const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3056
- const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3057
- const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3058
-
3059
- // dot product
3060
- sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3061
- wasm_i32x4_add(
3062
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3063
- wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3064
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3065
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3066
- wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
3067
- }
3068
-
3069
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3070
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3071
- #elif defined(__AVX2__)
3072
- // Initialize accumulator with zeros
3073
- __m256 acc = _mm256_setzero_ps();
3074
-
3075
- // Main loop
3076
- for (int i = 0; i < nb; i++) {
3077
- /* Compute combined scale for the block */
3078
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3079
-
3080
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3081
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
3082
- bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3083
- bx = _mm256_or_si256(bx, bxhi);
3084
-
3085
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3086
-
3087
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
709
+ #define GGML_F32x8_ADD _mm256_add_ps
710
+ #define GGML_F32x8_MUL _mm256_mul_ps
711
+ #define GGML_F32x8_REDUCE(res, x) \
712
+ do { \
713
+ int offset = GGML_F32_ARR >> 1; \
714
+ for (int i = 0; i < offset; ++i) { \
715
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
716
+ } \
717
+ offset >>= 1; \
718
+ for (int i = 0; i < offset; ++i) { \
719
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
720
+ } \
721
+ offset >>= 1; \
722
+ for (int i = 0; i < offset; ++i) { \
723
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
724
+ } \
725
+ const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
726
+ _mm256_extractf128_ps(x[0], 1)); \
727
+ const __m128 t1 = _mm_hadd_ps(t0, t0); \
728
+ res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
729
+ } while (0)
730
+ // TODO: is this optimal ?
3088
731
 
3089
- /* Multiply q with scale and accumulate */
3090
- acc = _mm256_fmadd_ps(d, q, acc);
3091
- }
732
+ #define GGML_F32_VEC GGML_F32x8
733
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
734
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
735
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
736
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
737
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
738
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
739
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
740
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
3092
741
 
3093
- *s = hsum_float_8(acc);
3094
- #elif defined(__AVX__)
3095
- // Initialize accumulator with zeros
3096
- __m256 acc = _mm256_setzero_ps();
3097
- __m128i mask = _mm_set1_epi8((char)0xF0);
742
+ // F16 AVX
3098
743
 
3099
- // Main loop
3100
- for (int i = 0; i < nb; i++) {
3101
- /* Compute combined scale for the block */
3102
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
744
+ #define GGML_F16_STEP 32
745
+ #define GGML_F16_EPR 8
3103
746
 
3104
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3105
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
3106
- __m128i bxhil = _mm256_castsi256_si128(bxhi);
3107
- __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
3108
- bxhil = _mm_andnot_si128(bxhil, mask);
3109
- bxhih = _mm_andnot_si128(bxhih, mask);
3110
- __m128i bxl = _mm256_castsi256_si128(bx);
3111
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
3112
- bxl = _mm_or_si128(bxl, bxhil);
3113
- bxh = _mm_or_si128(bxh, bxhih);
3114
- bx = MM256_SET_M128I(bxh, bxl);
747
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
3115
748
 
3116
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
749
+ #define GGML_F32Cx8 __m256
750
+ #define GGML_F32Cx8_ZERO _mm256_setzero_ps()
751
+ #define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
3117
752
 
3118
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
753
+ #if defined(__F16C__)
754
+ // the _mm256_cvt intrinsics require F16C
755
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
756
+ #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
757
+ #else
758
+ static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
759
+ float tmp[8];
3119
760
 
3120
- /* Multiply q with scale and accumulate */
3121
- acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
761
+ for (int i = 0; i < 8; i++) {
762
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
3122
763
  }
3123
764
 
3124
- *s = hsum_float_8(acc);
3125
- #elif defined(__riscv_v_intrinsic)
3126
- float sumf = 0.0;
765
+ return _mm256_loadu_ps(tmp);
766
+ }
767
+ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
768
+ float arr[8];
3127
769
 
3128
- uint32_t qh;
770
+ _mm256_storeu_ps(arr, y);
3129
771
 
3130
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
772
+ for (int i = 0; i < 8; i++)
773
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
774
+ }
775
+ #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
776
+ #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
777
+ #endif
3131
778
 
3132
- // These tempory registers are for masking and shift operations
3133
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3134
- vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
779
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
780
+ #define GGML_F32Cx8_ADD _mm256_add_ps
781
+ #define GGML_F32Cx8_MUL _mm256_mul_ps
782
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
3135
783
 
3136
- vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3137
- vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
784
+ #define GGML_F16_VEC GGML_F32Cx8
785
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
786
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
787
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
788
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
789
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
790
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
791
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
792
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
3138
793
 
3139
- for (int i = 0; i < nb; i++) {
3140
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
794
+ #elif defined(__POWER9_VECTOR__)
3141
795
 
3142
- // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3143
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3144
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3145
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
796
+ #define GGML_SIMD
3146
797
 
3147
- // ((qh & (1u << (j + 16))) >> (j + 12));
3148
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3149
- vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
798
+ // F32 POWER9
3150
799
 
3151
- // narrowing
3152
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3153
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
800
+ #define GGML_F32_STEP 32
801
+ #define GGML_F32_EPR 4
3154
802
 
3155
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3156
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
803
+ #define GGML_F32x4 vector float
804
+ #define GGML_F32x4_ZERO 0.0f
805
+ #define GGML_F32x4_SET1 vec_splats
806
+ #define GGML_F32x4_LOAD(p) vec_xl(0, p)
807
+ #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
808
+ #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
809
+ #define GGML_F32x4_ADD vec_add
810
+ #define GGML_F32x4_MUL vec_mul
811
+ #define GGML_F32x4_REDUCE(res, x) \
812
+ { \
813
+ int offset = GGML_F32_ARR >> 1; \
814
+ for (int i = 0; i < offset; ++i) { \
815
+ x[i] = vec_add(x[i], x[offset+i]); \
816
+ } \
817
+ offset >>= 1; \
818
+ for (int i = 0; i < offset; ++i) { \
819
+ x[i] = vec_add(x[i], x[offset+i]); \
820
+ } \
821
+ offset >>= 1; \
822
+ for (int i = 0; i < offset; ++i) { \
823
+ x[i] = vec_add(x[i], x[offset+i]); \
824
+ } \
825
+ res = vec_extract(x[0], 0) + \
826
+ vec_extract(x[0], 1) + \
827
+ vec_extract(x[0], 2) + \
828
+ vec_extract(x[0], 3); \
829
+ }
3157
830
 
3158
- // load
3159
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
831
+ #define GGML_F32_VEC GGML_F32x4
832
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
833
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
834
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
835
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
836
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
837
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
838
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
839
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
3160
840
 
3161
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3162
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
841
+ // F16 POWER9
842
+ #define GGML_F16_STEP GGML_F32_STEP
843
+ #define GGML_F16_EPR GGML_F32_EPR
844
+ #define GGML_F16_VEC GGML_F32x4
845
+ #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
846
+ #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
847
+ #define GGML_F16_VEC_FMA GGML_F32x4_FMA
848
+ #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
849
+ // Use vec_xl, not vec_ld, in case the load address is not aligned.
850
+ #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
851
+ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
852
+ vec_extract_fp32_from_shortl(vec_xl(0, p))
853
+ #define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
854
+ #define GGML_F16_VEC_STORE(p, r, i) \
855
+ if (i & 0x1) \
856
+ vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
857
+ r[i - GGML_ENDIAN_BYTE(0)]), \
858
+ 0, p - GGML_F16_EPR)
3163
859
 
3164
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3165
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
860
+ #elif defined(__wasm_simd128__)
3166
861
 
3167
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3168
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
862
+ #define GGML_SIMD
3169
863
 
3170
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3171
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
864
+ // F32 WASM
3172
865
 
3173
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3174
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
866
+ #define GGML_F32_STEP 16
867
+ #define GGML_F32_EPR 4
3175
868
 
3176
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3177
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
869
+ #define GGML_F32x4 v128_t
870
+ #define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
871
+ #define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
872
+ #define GGML_F32x4_LOAD wasm_v128_load
873
+ #define GGML_F32x4_STORE wasm_v128_store
874
+ #define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
875
+ #define GGML_F32x4_ADD wasm_f32x4_add
876
+ #define GGML_F32x4_MUL wasm_f32x4_mul
877
+ #define GGML_F32x4_REDUCE(res, x) \
878
+ { \
879
+ int offset = GGML_F32_ARR >> 1; \
880
+ for (int i = 0; i < offset; ++i) { \
881
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
882
+ } \
883
+ offset >>= 1; \
884
+ for (int i = 0; i < offset; ++i) { \
885
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
886
+ } \
887
+ offset >>= 1; \
888
+ for (int i = 0; i < offset; ++i) { \
889
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
890
+ } \
891
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
892
+ wasm_f32x4_extract_lane(x[0], 1) + \
893
+ wasm_f32x4_extract_lane(x[0], 2) + \
894
+ wasm_f32x4_extract_lane(x[0], 3); \
895
+ }
3178
896
 
3179
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
897
+ #define GGML_F32_VEC GGML_F32x4
898
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
899
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
900
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
901
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
902
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
903
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
904
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
905
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
3180
906
 
3181
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3182
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
907
+ // F16 WASM
3183
908
 
3184
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
909
+ #define GGML_F16_STEP 16
910
+ #define GGML_F16_EPR 4
3185
911
 
3186
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3187
- }
912
+ inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
913
+ float tmp[4];
3188
914
 
3189
- *s = sumf;
3190
- #else
3191
- // scalar
3192
- float sumf = 0.0;
915
+ tmp[0] = GGML_FP16_TO_FP32(p[0]);
916
+ tmp[1] = GGML_FP16_TO_FP32(p[1]);
917
+ tmp[2] = GGML_FP16_TO_FP32(p[2]);
918
+ tmp[3] = GGML_FP16_TO_FP32(p[3]);
3193
919
 
3194
- for (int i = 0; i < nb; i++) {
3195
- uint32_t qh;
3196
- memcpy(&qh, x[i].qh, sizeof(qh));
920
+ return wasm_v128_load(tmp);
921
+ }
3197
922
 
3198
- int sumi = 0;
923
+ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
924
+ float tmp[4];
3199
925
 
3200
- for (int j = 0; j < qk/2; ++j) {
3201
- const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3202
- const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
926
+ wasm_v128_store(tmp, x);
3203
927
 
3204
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
3205
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
928
+ p[0] = GGML_FP32_TO_FP16(tmp[0]);
929
+ p[1] = GGML_FP32_TO_FP16(tmp[1]);
930
+ p[2] = GGML_FP32_TO_FP16(tmp[2]);
931
+ p[3] = GGML_FP32_TO_FP16(tmp[3]);
932
+ }
3206
933
 
3207
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
3208
- }
934
+ #define GGML_F16x4 v128_t
935
+ #define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
936
+ #define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
937
+ #define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
938
+ #define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
939
+ #define GGML_F16x4_FMA GGML_F32x4_FMA
940
+ #define GGML_F16x4_ADD wasm_f32x4_add
941
+ #define GGML_F16x4_MUL wasm_f32x4_mul
942
+ #define GGML_F16x4_REDUCE(res, x) \
943
+ { \
944
+ int offset = GGML_F16_ARR >> 1; \
945
+ for (int i = 0; i < offset; ++i) { \
946
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
947
+ } \
948
+ offset >>= 1; \
949
+ for (int i = 0; i < offset; ++i) { \
950
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
951
+ } \
952
+ offset >>= 1; \
953
+ for (int i = 0; i < offset; ++i) { \
954
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
955
+ } \
956
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
957
+ wasm_f32x4_extract_lane(x[0], 1) + \
958
+ wasm_f32x4_extract_lane(x[0], 2) + \
959
+ wasm_f32x4_extract_lane(x[0], 3); \
960
+ }
3209
961
 
3210
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3211
- }
962
+ #define GGML_F16_VEC GGML_F16x4
963
+ #define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
964
+ #define GGML_F16_VEC_SET1 GGML_F16x4_SET1
965
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
966
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
967
+ #define GGML_F16_VEC_FMA GGML_F16x4_FMA
968
+ #define GGML_F16_VEC_ADD GGML_F16x4_ADD
969
+ #define GGML_F16_VEC_MUL GGML_F16x4_MUL
970
+ #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
3212
971
 
3213
- *s = sumf;
3214
- #endif
3215
- }
972
+ #elif defined(__SSE3__)
3216
973
 
3217
- static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3218
- const int qk = QK8_1;
3219
- const int nb = n / qk;
974
+ #define GGML_SIMD
3220
975
 
3221
- assert(n % qk == 0);
3222
- assert(qk == QK5_1);
976
+ // F32 SSE
3223
977
 
3224
- const block_q5_1 * restrict x = vx;
3225
- const block_q8_1 * restrict y = vy;
978
+ #define GGML_F32_STEP 32
979
+ #define GGML_F32_EPR 4
3226
980
 
3227
- #if defined(__ARM_NEON)
3228
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
3229
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
3230
-
3231
- float summs0 = 0.0f;
3232
- float summs1 = 0.0f;
3233
-
3234
- uint32_t qh0;
3235
- uint32_t qh1;
3236
-
3237
- uint64_t tmp0[4];
3238
- uint64_t tmp1[4];
3239
-
3240
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
3241
- for (int i = 0; i < nb; i += 2) {
3242
- const block_q5_1 * restrict x0 = &x[i];
3243
- const block_q5_1 * restrict x1 = &x[i + 1];
3244
- const block_q8_1 * restrict y0 = &y[i];
3245
- const block_q8_1 * restrict y1 = &y[i + 1];
3246
-
3247
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
3248
-
3249
- summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
3250
- summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
3251
-
3252
- // extract the 5th bit via lookup table ((b) << 4)
3253
- memcpy(&qh0, x0->qh, sizeof(qh0));
3254
- memcpy(&qh1, x1->qh, sizeof(qh1));
3255
-
3256
- tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
3257
- tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
3258
- tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
3259
- tmp0[3] = table_b2b_0[(qh0 >> 24) ];
3260
-
3261
- tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
3262
- tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
3263
- tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
3264
- tmp1[3] = table_b2b_0[(qh1 >> 24) ];
3265
-
3266
- const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
3267
- const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
3268
- const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
3269
- const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
3270
-
3271
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
3272
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
3273
-
3274
- // 4-bit -> 8-bit
3275
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3276
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3277
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3278
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3279
-
3280
- // add high bit
3281
- const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
3282
- const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
3283
- const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
3284
- const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
3285
-
3286
- // load y
3287
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
3288
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3289
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
3290
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3291
-
3292
- #if defined(__ARM_FEATURE_DOTPROD)
3293
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3294
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3295
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3296
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3297
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3298
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
981
+ #define GGML_F32x4 __m128
982
+ #define GGML_F32x4_ZERO _mm_setzero_ps()
983
+ #define GGML_F32x4_SET1(x) _mm_set1_ps(x)
984
+ #define GGML_F32x4_LOAD _mm_loadu_ps
985
+ #define GGML_F32x4_STORE _mm_storeu_ps
986
+ #if defined(__FMA__)
987
+ // TODO: Does this work?
988
+ #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
3299
989
  #else
3300
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
3301
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
3302
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
3303
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
3304
-
3305
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
3306
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
3307
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
3308
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
3309
-
3310
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3311
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3312
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3313
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3314
-
3315
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
3316
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
990
+ #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
3317
991
  #endif
3318
- }
3319
-
3320
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3321
- #elif defined(__wasm_simd128__)
3322
- v128_t sumv = wasm_f32x4_splat(0.0f);
3323
-
3324
- float summs = 0.0f;
3325
-
3326
- uint32_t qh;
3327
- uint64_t tmp[4];
3328
-
3329
- // TODO: check if unrolling this is better
3330
- for (int i = 0; i < nb; ++i) {
3331
- const block_q5_1 * restrict x0 = &x[i];
3332
- const block_q8_1 * restrict y0 = &y[i];
3333
-
3334
- summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
3335
-
3336
- const v128_t m4b = wasm_i8x16_splat(0x0F);
3337
-
3338
- // extract the 5th bit
3339
- memcpy(&qh, x0->qh, sizeof(qh));
3340
-
3341
- tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3342
- tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3343
- tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3344
- tmp[3] = table_b2b_0[(qh >> 24) ];
992
+ #define GGML_F32x4_ADD _mm_add_ps
993
+ #define GGML_F32x4_MUL _mm_mul_ps
994
+ #define GGML_F32x4_REDUCE(res, x) \
995
+ { \
996
+ int offset = GGML_F32_ARR >> 1; \
997
+ for (int i = 0; i < offset; ++i) { \
998
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
999
+ } \
1000
+ offset >>= 1; \
1001
+ for (int i = 0; i < offset; ++i) { \
1002
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
1003
+ } \
1004
+ offset >>= 1; \
1005
+ for (int i = 0; i < offset; ++i) { \
1006
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
1007
+ } \
1008
+ const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
1009
+ res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
1010
+ }
1011
+ // TODO: is this optimal ?
3345
1012
 
3346
- const v128_t qhl = wasm_v128_load(tmp + 0);
3347
- const v128_t qhh = wasm_v128_load(tmp + 2);
1013
+ #define GGML_F32_VEC GGML_F32x4
1014
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1015
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1016
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1017
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1018
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1019
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1020
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1021
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
3348
1022
 
3349
- const v128_t v0 = wasm_v128_load(x0->qs);
1023
+ // F16 SSE
3350
1024
 
3351
- // 4-bit -> 8-bit
3352
- const v128_t v0l = wasm_v128_and (v0, m4b);
3353
- const v128_t v0h = wasm_u8x16_shr(v0, 4);
1025
+ #define GGML_F16_STEP 32
1026
+ #define GGML_F16_EPR 4
3354
1027
 
3355
- // add high bit
3356
- const v128_t v0lf = wasm_v128_or(v0l, qhl);
3357
- const v128_t v0hf = wasm_v128_or(v0h, qhh);
1028
+ static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
1029
+ float tmp[4];
3358
1030
 
3359
- // load y
3360
- const v128_t v1l = wasm_v128_load(y0->qs);
3361
- const v128_t v1h = wasm_v128_load(y0->qs + 16);
1031
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1032
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1033
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1034
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
3362
1035
 
3363
- // int8x16 -> int16x8
3364
- const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3365
- const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3366
- const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3367
- const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
1036
+ return _mm_loadu_ps(tmp);
1037
+ }
3368
1038
 
3369
- const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3370
- const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3371
- const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3372
- const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
1039
+ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1040
+ float arr[4];
3373
1041
 
3374
- // dot product
3375
- sumv = wasm_f32x4_add(sumv,
3376
- wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
3377
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3378
- wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3379
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3380
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3381
- wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
3382
- }
1042
+ _mm_storeu_ps(arr, y);
3383
1043
 
3384
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3385
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
3386
- #elif defined(__AVX2__)
3387
- // Initialize accumulator with zeros
3388
- __m256 acc = _mm256_setzero_ps();
1044
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1045
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1046
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1047
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1048
+ }
3389
1049
 
3390
- float summs = 0.0f;
1050
+ #define GGML_F32Cx4 __m128
1051
+ #define GGML_F32Cx4_ZERO _mm_setzero_ps()
1052
+ #define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
1053
+ #define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
1054
+ #define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
1055
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1056
+ #define GGML_F32Cx4_ADD _mm_add_ps
1057
+ #define GGML_F32Cx4_MUL _mm_mul_ps
1058
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
3391
1059
 
3392
- // Main loop
3393
- for (int i = 0; i < nb; i++) {
3394
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
1060
+ #define GGML_F16_VEC GGML_F32Cx4
1061
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1062
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1063
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1064
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1065
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1066
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1067
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1068
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
3395
1069
 
3396
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
1070
+ #endif
3397
1071
 
3398
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3399
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
3400
- bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3401
- bx = _mm256_or_si256(bx, bxhi);
1072
+ // GGML_F32_ARR / GGML_F16_ARR
1073
+ // number of registers to use per step
1074
+ #ifdef GGML_SIMD
1075
+ #define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
1076
+ #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1077
+ #endif
3402
1078
 
3403
- const __m256 dy = _mm256_set1_ps(y[i].d);
3404
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1079
+ //
1080
+ // fundamental operations
1081
+ //
3405
1082
 
3406
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
1083
+ inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3407
1084
 
3408
- acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
3409
- }
1085
+ inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3410
1086
 
3411
- *s = hsum_float_8(acc) + summs;
3412
- #elif defined(__AVX__)
3413
- // Initialize accumulator with zeros
3414
- __m256 acc = _mm256_setzero_ps();
3415
- __m128i mask = _mm_set1_epi8(0x10);
1087
+ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3416
1088
 
3417
- float summs = 0.0f;
1089
+ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3418
1090
 
3419
- // Main loop
3420
- for (int i = 0; i < nb; i++) {
3421
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
1091
+ inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1092
+ inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1093
+ inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
1094
+ inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
1095
+ inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
1096
+ inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
1097
+ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1098
+ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
1099
+ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1100
+ inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
3422
1101
 
3423
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
1102
+ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
1103
+ #ifdef GGML_SIMD
1104
+ float sumf = 0.0f;
1105
+ const int np = (n & ~(GGML_F32_STEP - 1));
3424
1106
 
3425
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3426
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
3427
- __m128i bxhil = _mm256_castsi256_si128(bxhi);
3428
- __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
3429
- bxhil = _mm_and_si128(bxhil, mask);
3430
- bxhih = _mm_and_si128(bxhih, mask);
3431
- __m128i bxl = _mm256_castsi256_si128(bx);
3432
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
3433
- bxl = _mm_or_si128(bxl, bxhil);
3434
- bxh = _mm_or_si128(bxh, bxhih);
3435
- bx = MM256_SET_M128I(bxh, bxl);
1107
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
3436
1108
 
3437
- const __m256 dy = _mm256_set1_ps(y[i].d);
3438
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1109
+ GGML_F32_VEC ax[GGML_F32_ARR];
1110
+ GGML_F32_VEC ay[GGML_F32_ARR];
3439
1111
 
3440
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
1112
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
1113
+ for (int j = 0; j < GGML_F32_ARR; j++) {
1114
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
1115
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3441
1116
 
3442
- acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
1117
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
1118
+ }
3443
1119
  }
3444
1120
 
3445
- *s = hsum_float_8(acc) + summs;
3446
- #elif defined(__riscv_v_intrinsic)
3447
- float sumf = 0.0;
3448
-
3449
- uint32_t qh;
3450
-
3451
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
3452
-
3453
- // temporary registers for shift operations
3454
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3455
- vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3456
-
3457
- for (int i = 0; i < nb; i++) {
3458
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
3459
-
3460
- // load qh
3461
- vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3462
-
3463
- // ((qh >> (j + 0)) << 4) & 0x10;
3464
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3465
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3466
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3467
-
3468
- // ((qh >> (j + 12)) ) & 0x10;
3469
- vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3470
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3471
-
3472
- // narrowing
3473
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3474
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3475
-
3476
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3477
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3478
-
3479
- // load
3480
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3481
-
3482
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3483
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3484
-
3485
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3486
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3487
-
3488
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3489
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3490
-
3491
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3492
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3493
-
3494
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3495
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3496
-
3497
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3498
-
3499
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3500
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3501
-
3502
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
1121
+ // reduce sum0..sum3 to sum0
1122
+ GGML_F32_VEC_REDUCE(sumf, sum);
3503
1123
 
3504
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1124
+ // leftovers
1125
+ for (int i = np; i < n; ++i) {
1126
+ sumf += x[i]*y[i];
3505
1127
  }
3506
-
3507
- *s = sumf;
3508
1128
  #else
3509
1129
  // scalar
3510
- float sumf = 0.0;
3511
-
3512
- for (int i = 0; i < nb; i++) {
3513
- uint32_t qh;
3514
- memcpy(&qh, x[i].qh, sizeof(qh));
3515
-
3516
- int sumi = 0;
3517
-
3518
- for (int j = 0; j < qk/2; ++j) {
3519
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
3520
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
3521
-
3522
- const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
3523
- const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
3524
-
3525
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
3526
- }
3527
-
3528
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1130
+ ggml_float sumf = 0.0;
1131
+ for (int i = 0; i < n; ++i) {
1132
+ sumf += (ggml_float)(x[i]*y[i]);
3529
1133
  }
1134
+ #endif
3530
1135
 
3531
1136
  *s = sumf;
3532
- #endif
3533
1137
  }
3534
1138
 
3535
- static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3536
- const int qk = QK8_0;
3537
- const int nb = n / qk;
3538
-
3539
- assert(n % qk == 0);
3540
-
3541
- const block_q8_0 * restrict x = vx;
3542
- const block_q8_0 * restrict y = vy;
3543
-
3544
- #if defined(__ARM_NEON)
3545
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
3546
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
3547
-
3548
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
3549
- for (int i = 0; i < nb; i += 2) {
3550
- const block_q8_0 * restrict x0 = &x[i + 0];
3551
- const block_q8_0 * restrict x1 = &x[i + 1];
3552
- const block_q8_0 * restrict y0 = &y[i + 0];
3553
- const block_q8_0 * restrict y1 = &y[i + 1];
3554
-
3555
- const int8x16_t x0_0 = vld1q_s8(x0->qs);
3556
- const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3557
- const int8x16_t x1_0 = vld1q_s8(x1->qs);
3558
- const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
3559
-
3560
- // load y
3561
- const int8x16_t y0_0 = vld1q_s8(y0->qs);
3562
- const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3563
- const int8x16_t y1_0 = vld1q_s8(y1->qs);
3564
- const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3565
-
3566
- #if defined(__ARM_FEATURE_DOTPROD)
3567
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3568
- vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3569
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3570
-
3571
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3572
- vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3573
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
1139
+ static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
1140
+ ggml_float sumf = 0.0;
3574
1141
 
3575
- #else
3576
- const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3577
- const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3578
- const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3579
- const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3580
-
3581
- const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3582
- const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3583
- const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3584
- const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3585
-
3586
- const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3587
- const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3588
- const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3589
- const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3590
-
3591
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3592
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3593
- #endif
3594
- }
1142
+ #if defined(GGML_SIMD)
1143
+ const int np = (n & ~(GGML_F16_STEP - 1));
3595
1144
 
3596
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3597
- #elif defined(__AVX2__) || defined(__AVX__)
3598
- // Initialize accumulator with zeros
3599
- __m256 acc = _mm256_setzero_ps();
1145
+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
3600
1146
 
3601
- // Main loop
3602
- for (int i = 0; i < nb; ++i) {
3603
- // Compute combined scale for the block
3604
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3605
- __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3606
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1147
+ GGML_F16_VEC ax[GGML_F16_ARR];
1148
+ GGML_F16_VEC ay[GGML_F16_ARR];
3607
1149
 
3608
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
1150
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1151
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1152
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1153
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
3609
1154
 
3610
- // Multiply q with scale and accumulate
3611
- #if defined(__AVX2__)
3612
- acc = _mm256_fmadd_ps( d, q, acc );
3613
- #else
3614
- acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
3615
- #endif
1155
+ sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
1156
+ }
3616
1157
  }
3617
1158
 
3618
- *s = hsum_float_8(acc);
3619
- #elif defined(__riscv_v_intrinsic)
3620
- float sumf = 0.0;
3621
- size_t vl = __riscv_vsetvl_e8m1(qk);
3622
-
3623
- for (int i = 0; i < nb; i++) {
3624
- // load elements
3625
- vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
3626
- vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
3627
-
3628
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
3629
-
3630
- vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3631
- vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3632
-
3633
- int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
1159
+ // reduce sum0..sum3 to sum0
1160
+ GGML_F16_VEC_REDUCE(sumf, sum);
3634
1161
 
3635
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
1162
+ // leftovers
1163
+ for (int i = np; i < n; ++i) {
1164
+ sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
3636
1165
  }
3637
-
3638
- *s = sumf;
3639
1166
  #else
3640
- // scalar
3641
- float sumf = 0.0;
3642
-
3643
- for (int i = 0; i < nb; i++) {
3644
- int sumi = 0;
3645
-
3646
- for (int j = 0; j < qk; j++) {
3647
- sumi += x[i].qs[j]*y[i].qs[j];
3648
- }
3649
-
3650
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
1167
+ for (int i = 0; i < n; ++i) {
1168
+ sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
3651
1169
  }
1170
+ #endif
3652
1171
 
3653
1172
  *s = sumf;
3654
- #endif
3655
1173
  }
3656
1174
 
3657
1175
  // compute GGML_VEC_DOT_UNROLL dot products at once
@@ -3846,7 +1364,7 @@ inline static float ggml_gelu_f32(float x) {
3846
1364
  inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3847
1365
  const uint16_t * i16 = (const uint16_t *) x;
3848
1366
  for (int i = 0; i < n; ++i) {
3849
- y[i] = table_gelu_f16[i16[i]];
1367
+ y[i] = ggml_table_gelu_f16[i16[i]];
3850
1368
  }
3851
1369
  }
3852
1370
 
@@ -3856,7 +1374,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
3856
1374
  for (int i = 0; i < n; ++i) {
3857
1375
  ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3858
1376
  memcpy(&t, &fp16, sizeof(uint16_t));
3859
- y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]);
1377
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
3860
1378
  }
3861
1379
  }
3862
1380
  #else
@@ -3874,7 +1392,7 @@ inline static float ggml_gelu_quick_f32(float x) {
3874
1392
  //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3875
1393
  // const uint16_t * i16 = (const uint16_t *) x;
3876
1394
  // for (int i = 0; i < n; ++i) {
3877
- // y[i] = table_gelu_quick_f16[i16[i]];
1395
+ // y[i] = ggml_table_gelu_quick_f16[i16[i]];
3878
1396
  // }
3879
1397
  //}
3880
1398
 
@@ -3884,7 +1402,7 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float *
3884
1402
  for (int i = 0; i < n; ++i) {
3885
1403
  ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3886
1404
  memcpy(&t, &fp16, sizeof(uint16_t));
3887
- y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
1405
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
3888
1406
  }
3889
1407
  }
3890
1408
  #else
@@ -3903,7 +1421,7 @@ inline static float ggml_silu_f32(float x) {
3903
1421
  //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3904
1422
  // const uint16_t * i16 = (const uint16_t *) x;
3905
1423
  // for (int i = 0; i < n; ++i) {
3906
- // y[i] = table_silu_f16[i16[i]];
1424
+ // y[i] = ggml_table_silu_f16[i16[i]];
3907
1425
  // }
3908
1426
  //}
3909
1427
 
@@ -3913,7 +1431,7 @@ inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
3913
1431
  for (int i = 0; i < n; ++i) {
3914
1432
  ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3915
1433
  memcpy(&t, &fp16, sizeof(uint16_t));
3916
- y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]);
1434
+ y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
3917
1435
  }
3918
1436
  }
3919
1437
  #else
@@ -4629,11 +2147,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
4629
2147
  for (int i = 0; i < (1 << 16); ++i) {
4630
2148
  uint16_t ui = i;
4631
2149
  memcpy(&ii, &ui, sizeof(ii));
4632
- const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
4633
- table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
4634
- table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
4635
- table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
4636
- table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2150
+ const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2151
+ ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2152
+ ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2153
+ ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2154
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
4637
2155
  }
4638
2156
 
4639
2157
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -5636,7 +3154,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
5636
3154
  // TODO: support less-strict constraint
5637
3155
  // GGML_ASSERT(ggml_can_repeat(b, a));
5638
3156
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
5639
- GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
3157
+ GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
5640
3158
 
5641
3159
  bool is_node = false;
5642
3160
 
@@ -7328,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
7328
4846
  int n_dims,
7329
4847
  int mode,
7330
4848
  int n_ctx,
4849
+ int n_orig_ctx,
7331
4850
  float freq_base,
7332
4851
  float freq_scale,
4852
+ float ext_factor,
4853
+ float attn_factor,
4854
+ float beta_fast,
4855
+ float beta_slow,
7333
4856
  float xpos_base,
7334
4857
  bool xpos_down,
7335
4858
  bool inplace) {
@@ -7345,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
7345
4868
 
7346
4869
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
7347
4870
 
7348
- int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
7349
- memcpy(params + 4, &freq_base, sizeof(float));
7350
- memcpy(params + 5, &freq_scale, sizeof(float));
7351
- memcpy(params + 6, &xpos_base, sizeof(float));
7352
- memcpy(params + 7, &xpos_down, sizeof(bool));
4871
+ int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
4872
+ memcpy(params + 5, &freq_base, sizeof(float));
4873
+ memcpy(params + 6, &freq_scale, sizeof(float));
4874
+ memcpy(params + 7, &ext_factor, sizeof(float));
4875
+ memcpy(params + 8, &attn_factor, sizeof(float));
4876
+ memcpy(params + 9, &beta_fast, sizeof(float));
4877
+ memcpy(params + 10, &beta_slow, sizeof(float));
4878
+ memcpy(params + 11, &xpos_base, sizeof(float));
4879
+ memcpy(params + 12, &xpos_down, sizeof(bool));
7353
4880
  ggml_set_op_params(result, params, sizeof(params));
7354
4881
 
7355
4882
  result->op = GGML_OP_ROPE;
@@ -7367,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
7367
4894
  int n_dims,
7368
4895
  int mode,
7369
4896
  int n_ctx) {
7370
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
4897
+ return ggml_rope_impl(
4898
+ ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
4899
+ );
7371
4900
  }
7372
4901
 
7373
4902
  struct ggml_tensor * ggml_rope_inplace(
@@ -7377,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
7377
4906
  int n_dims,
7378
4907
  int mode,
7379
4908
  int n_ctx) {
7380
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
4909
+ return ggml_rope_impl(
4910
+ ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
4911
+ );
7381
4912
  }
7382
4913
 
7383
4914
  struct ggml_tensor * ggml_rope_custom(
@@ -7387,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
7387
4918
  int n_dims,
7388
4919
  int mode,
7389
4920
  int n_ctx,
4921
+ int n_orig_ctx,
7390
4922
  float freq_base,
7391
- float freq_scale) {
7392
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
4923
+ float freq_scale,
4924
+ float ext_factor,
4925
+ float attn_factor,
4926
+ float beta_fast,
4927
+ float beta_slow) {
4928
+ return ggml_rope_impl(
4929
+ ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
4930
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
4931
+ );
7393
4932
  }
7394
4933
 
7395
4934
  struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7399,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
7399
4938
  int n_dims,
7400
4939
  int mode,
7401
4940
  int n_ctx,
4941
+ int n_orig_ctx,
7402
4942
  float freq_base,
7403
- float freq_scale) {
7404
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
4943
+ float freq_scale,
4944
+ float ext_factor,
4945
+ float attn_factor,
4946
+ float beta_fast,
4947
+ float beta_slow) {
4948
+ return ggml_rope_impl(
4949
+ ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
4950
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
4951
+ );
7405
4952
  }
7406
4953
 
7407
4954
  struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -7411,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
7411
4958
  int n_dims,
7412
4959
  float base,
7413
4960
  bool down) {
7414
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
4961
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
7415
4962
  }
7416
4963
 
7417
4964
  // ggml_rope_back
@@ -9410,9 +6957,15 @@ static void ggml_compute_forward_add_f16_f32(
9410
6957
 
9411
6958
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9412
6959
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
9413
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
9414
6960
 
9415
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6961
+ if (dst->type == GGML_TYPE_F32) {
6962
+ GGML_ASSERT( nb0 == sizeof(float));
6963
+ }
6964
+ else {
6965
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6966
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6967
+ }
6968
+
9416
6969
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
9417
6970
 
9418
6971
  // rows per thread
@@ -9423,18 +6976,35 @@ static void ggml_compute_forward_add_f16_f32(
9423
6976
  const int ir1 = MIN(ir0 + dr, nr);
9424
6977
 
9425
6978
  if (nb10 == sizeof(float)) {
9426
- for (int ir = ir0; ir < ir1; ++ir) {
9427
- // src0, src1 and dst are same shape => same indices
9428
- const int i3 = ir/(ne2*ne1);
9429
- const int i2 = (ir - i3*ne2*ne1)/ne1;
9430
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9431
-
9432
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9433
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9434
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9435
-
9436
- for (int i = 0; i < ne0; i++) {
9437
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
6979
+ if (dst->type == GGML_TYPE_F16) {
6980
+ for (int ir = ir0; ir < ir1; ++ir) {
6981
+ // src0, src1 and dst are same shape => same indices
6982
+ const int i3 = ir/(ne2*ne1);
6983
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
6984
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
6985
+
6986
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
6987
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
6988
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
6989
+
6990
+ for (int i = 0; i < ne0; i++) {
6991
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
6992
+ }
6993
+ }
6994
+ } else {
6995
+ for (int ir = ir0; ir < ir1; ++ir) {
6996
+ // src0, src1 and dst are same shape => same indices
6997
+ const int i3 = ir/(ne2*ne1);
6998
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
6999
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
7000
+
7001
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
7002
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
7003
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
7004
+
7005
+ for (int i = 0; i < ne0; i++) {
7006
+ dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
7007
+ }
9438
7008
  }
9439
7009
  }
9440
7010
  }
@@ -12996,7 +10566,7 @@ static void ggml_compute_forward_soft_max_f32(
12996
10566
  // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
12997
10567
  ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
12998
10568
  memcpy(&scvt, &s, sizeof(scvt));
12999
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
10569
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
13000
10570
  sum += (ggml_float)val;
13001
10571
  dp[i] = val;
13002
10572
  }
@@ -13361,6 +10931,45 @@ static void ggml_compute_forward_clamp(
13361
10931
 
13362
10932
  // ggml_compute_forward_rope
13363
10933
 
10934
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
10935
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
10936
+ return 1 - MIN(1, MAX(0, y));
10937
+ }
10938
+
10939
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
10940
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
10941
+ static void rope_yarn(
10942
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
10943
+ float * cos_theta, float * sin_theta
10944
+ ) {
10945
+ // Get n-d rotational scaling corrected for extrapolation
10946
+ float theta_interp = freq_scale * theta_extrap;
10947
+ float theta = theta_interp;
10948
+ if (ext_factor != 0.0f) {
10949
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
10950
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
10951
+
10952
+ // Get n-d magnitude scaling corrected for interpolation
10953
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
10954
+ }
10955
+ *cos_theta = cosf(theta) * mscale;
10956
+ *sin_theta = sinf(theta) * mscale;
10957
+ }
10958
+
10959
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
10960
+ // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
10961
+ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
10962
+ return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
10963
+ }
10964
+
10965
+ void ggml_rope_yarn_corr_dims(
10966
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
10967
+ ) {
10968
+ // start and end correction dims
10969
+ dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
10970
+ dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
10971
+ }
10972
+
13364
10973
  static void ggml_compute_forward_rope_f32(
13365
10974
  const struct ggml_compute_params * params,
13366
10975
  const struct ggml_tensor * src0,
@@ -13370,21 +10979,26 @@ static void ggml_compute_forward_rope_f32(
13370
10979
  return;
13371
10980
  }
13372
10981
 
13373
- float freq_base;
13374
- float freq_scale;
10982
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
13375
10983
 
13376
10984
  // these two only relevant for xPos RoPE:
13377
10985
  float xpos_base;
13378
10986
  bool xpos_down;
13379
10987
 
13380
- //const int n_past = ((int32_t *) dst->op_params)[0];
13381
- const int n_dims = ((int32_t *) dst->op_params)[1];
13382
- const int mode = ((int32_t *) dst->op_params)[2];
13383
- const int n_ctx = ((int32_t *) dst->op_params)[3];
13384
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
13385
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
13386
- memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
13387
- memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
10988
+ //const int n_past = ((int32_t *) dst->op_params)[0];
10989
+ const int n_dims = ((int32_t *) dst->op_params)[1];
10990
+ const int mode = ((int32_t *) dst->op_params)[2];
10991
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
10992
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
10993
+
10994
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
10995
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
10996
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
10997
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
10998
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
10999
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
11000
+ memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
11001
+ memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
13388
11002
 
13389
11003
  GGML_TENSOR_UNARY_OP_LOCALS
13390
11004
 
@@ -13412,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
13412
11026
  int ir = 0;
13413
11027
 
13414
11028
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
11029
+ const float inv_ndims = -1.f/n_dims;
11030
+ float corr_dims[2];
11031
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
13415
11032
 
13416
11033
  const bool is_neox = mode & 2;
13417
11034
  const bool is_glm = mode & 4;
@@ -13425,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
13425
11042
  if (ir++ < ir0) continue;
13426
11043
  if (ir > ir1) break;
13427
11044
 
13428
- float theta = freq_scale * (float)p;
11045
+ float theta_base = (float)p;
13429
11046
 
13430
11047
  if (is_glm) {
13431
- theta = MIN(p, n_ctx - 2);
11048
+ theta_base = MIN(p, n_ctx - 2);
13432
11049
  float block_theta = MAX(p - (n_ctx - 2), 0);
13433
11050
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
13434
- const float cos_theta = cosf(theta);
13435
- const float sin_theta = sinf(theta);
11051
+ const float cos_theta = cosf(theta_base);
11052
+ const float sin_theta = sinf(theta_base);
13436
11053
  const float cos_block_theta = cosf(block_theta);
13437
11054
  const float sin_block_theta = sinf(block_theta);
13438
11055
 
13439
- theta *= theta_scale;
11056
+ theta_base *= theta_scale;
13440
11057
  block_theta *= theta_scale;
13441
11058
 
13442
11059
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -13454,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
13454
11071
  }
13455
11072
  } else if (!is_neox) {
13456
11073
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13457
- const float cos_theta = cosf(theta);
13458
- const float sin_theta = sinf(theta);
11074
+ float cos_theta, sin_theta;
11075
+ rope_yarn(
11076
+ theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11077
+ );
11078
+
13459
11079
  // zeta scaling for xPos only:
13460
11080
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
13461
11081
  if (xpos_down) zeta = 1.0f / zeta;
13462
11082
 
13463
- theta *= theta_scale;
11083
+ theta_base *= theta_scale;
13464
11084
 
13465
11085
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13466
11086
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13474,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
13474
11094
  } else {
13475
11095
  // TODO: this might be wrong for ne0 != n_dims - need double check
13476
11096
  // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11097
+ theta_base *= freq_scale;
13477
11098
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13478
11099
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13479
- const float cos_theta = cosf(theta);
13480
- const float sin_theta = sinf(theta);
11100
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
11101
+ float cur_rot = inv_ndims * ic - ib;
11102
+
11103
+ float cos_theta, sin_theta;
11104
+ rope_yarn(
11105
+ theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11106
+ &cos_theta, &sin_theta
11107
+ );
13481
11108
 
13482
- theta *= theta_scale;
11109
+ theta_base *= theta_scale;
13483
11110
 
13484
11111
  const int64_t i0 = ib*n_dims + ic/2;
13485
11112
 
@@ -13508,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
13508
11135
  return;
13509
11136
  }
13510
11137
 
13511
- float freq_base;
13512
- float freq_scale;
11138
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
13513
11139
 
13514
- //const int n_past = ((int32_t *) dst->op_params)[0];
13515
- const int n_dims = ((int32_t *) dst->op_params)[1];
13516
- const int mode = ((int32_t *) dst->op_params)[2];
13517
- const int n_ctx = ((int32_t *) dst->op_params)[3];
13518
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
13519
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
11140
+ //const int n_past = ((int32_t *) dst->op_params)[0];
11141
+ const int n_dims = ((int32_t *) dst->op_params)[1];
11142
+ const int mode = ((int32_t *) dst->op_params)[2];
11143
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
11144
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
11145
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
11146
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
11147
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
11148
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
11149
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
11150
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
13520
11151
 
13521
11152
  GGML_TENSOR_UNARY_OP_LOCALS
13522
11153
 
@@ -13544,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
13544
11175
  int ir = 0;
13545
11176
 
13546
11177
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
11178
+ const float inv_ndims = -1.f/n_dims;
11179
+ float corr_dims[2];
11180
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
13547
11181
 
13548
11182
  const bool is_neox = mode & 2;
13549
11183
  const bool is_glm = mode & 4;
@@ -13557,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
13557
11191
  if (ir++ < ir0) continue;
13558
11192
  if (ir > ir1) break;
13559
11193
 
13560
- float theta = freq_scale * (float)p;
11194
+ float theta_base = (float)p;
13561
11195
 
13562
11196
  if (is_glm) {
13563
- theta = MIN(p, n_ctx - 2);
11197
+ theta_base = MIN(p, n_ctx - 2);
13564
11198
  float block_theta = MAX(p - (n_ctx - 2), 0);
13565
11199
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
13566
- const float cos_theta = cosf(theta);
13567
- const float sin_theta = sinf(theta);
11200
+ const float cos_theta = cosf(theta_base);
11201
+ const float sin_theta = sinf(theta_base);
13568
11202
  const float cos_block_theta = cosf(block_theta);
13569
11203
  const float sin_block_theta = sinf(block_theta);
13570
11204
 
13571
- theta *= theta_scale;
11205
+ theta_base *= theta_scale;
13572
11206
  block_theta *= theta_scale;
13573
11207
 
13574
11208
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -13586,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
13586
11220
  }
13587
11221
  } else if (!is_neox) {
13588
11222
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13589
- const float cos_theta = cosf(theta);
13590
- const float sin_theta = sinf(theta);
11223
+ float cos_theta, sin_theta;
11224
+ rope_yarn(
11225
+ theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11226
+ );
13591
11227
 
13592
- theta *= theta_scale;
11228
+ theta_base *= theta_scale;
13593
11229
 
13594
11230
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13595
11231
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13603,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
13603
11239
  } else {
13604
11240
  // TODO: this might be wrong for ne0 != n_dims - need double check
13605
11241
  // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11242
+ theta_base *= freq_scale;
13606
11243
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13607
11244
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13608
- const float cos_theta = cosf(theta);
13609
- const float sin_theta = sinf(theta);
11245
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
11246
+ float cur_rot = inv_ndims * ic - ib;
11247
+
11248
+ float cos_theta, sin_theta;
11249
+ rope_yarn(
11250
+ theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11251
+ &cos_theta, &sin_theta
11252
+ );
13610
11253
 
13611
- theta *= theta_scale;
11254
+ theta_base *= theta_scale;
13612
11255
 
13613
11256
  const int64_t i0 = ib*n_dims + ic/2;
13614
11257
 
@@ -13716,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
13716
11359
  if (ir++ < ir0) continue;
13717
11360
  if (ir > ir1) break;
13718
11361
 
13719
- float theta = freq_scale * (float)p;
11362
+ float theta_base = freq_scale * (float)p;
13720
11363
 
13721
11364
  if (!is_neox) {
13722
11365
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13723
- const float cos_theta = cosf(theta);
13724
- const float sin_theta = sinf(theta);
11366
+ const float cos_theta = cosf(theta_base);
11367
+ const float sin_theta = sinf(theta_base);
11368
+
13725
11369
  // zeta scaling for xPos only:
13726
11370
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
13727
11371
  if (xpos_down) zeta = 1.0f / zeta;
13728
11372
 
13729
- theta *= theta_scale;
11373
+ theta_base *= theta_scale;
13730
11374
 
13731
11375
  const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13732
11376
  float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13740,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
13740
11384
  } else {
13741
11385
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13742
11386
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13743
- const float cos_theta = cosf(theta);
13744
- const float sin_theta = sinf(theta);
11387
+ const float cos_theta = cosf(theta_base);
11388
+ const float sin_theta = sinf(theta_base);
13745
11389
 
13746
- theta *= theta_scale;
11390
+ theta_base *= theta_scale;
13747
11391
 
13748
11392
  const int64_t i0 = ib*n_dims + ic/2;
13749
11393
 
@@ -13816,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
13816
11460
  if (ir++ < ir0) continue;
13817
11461
  if (ir > ir1) break;
13818
11462
 
13819
- float theta = (float)p;
11463
+ float theta_base = (float)p;
13820
11464
 
13821
11465
  if (!is_neox) {
13822
11466
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13823
- const float cos_theta = cosf(theta);
13824
- const float sin_theta = sinf(theta);
11467
+ const float cos_theta = cosf(theta_base);
11468
+ const float sin_theta = sinf(theta_base);
13825
11469
 
13826
- theta *= theta_scale;
11470
+ theta_base *= theta_scale;
13827
11471
 
13828
11472
  const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13829
11473
  ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13837,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
13837
11481
  } else {
13838
11482
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13839
11483
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13840
- const float cos_theta = cosf(theta);
13841
- const float sin_theta = sinf(theta);
11484
+ const float cos_theta = cosf(theta_base);
11485
+ const float sin_theta = sinf(theta_base);
13842
11486
 
13843
- theta *= theta_scale;
11487
+ theta_base *= theta_scale;
13844
11488
 
13845
11489
  const int64_t i0 = ib*n_dims + ic/2;
13846
11490
 
@@ -15285,7 +12929,7 @@ static void ggml_compute_forward_flash_attn_f32(
15285
12929
  #else
15286
12930
  ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15287
12931
  memcpy(&scvt[j], &s, sizeof(uint16_t));
15288
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
12932
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15289
12933
  #endif
15290
12934
  sump[j] += (ggml_float)val;
15291
12935
  SS[j] = val;
@@ -15487,7 +13131,7 @@ static void ggml_compute_forward_flash_attn_f16(
15487
13131
  } else {
15488
13132
  ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15489
13133
  memcpy(&scvt[j], &s, sizeof(uint16_t));
15490
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13134
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15491
13135
  sump[j] += (ggml_float)val;
15492
13136
  SS[j] = val;
15493
13137
  }
@@ -15938,7 +13582,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
15938
13582
  #else
15939
13583
  ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15940
13584
  memcpy(&scvt[j], &s, sizeof(uint16_t));
15941
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13585
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15942
13586
  #endif
15943
13587
  sump[j] += (ggml_float)val;
15944
13588
  SW[j] = val;
@@ -16688,7 +14332,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
16688
14332
  #else
16689
14333
  ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16690
14334
  memcpy(&scvt, &s, sizeof(scvt));
16691
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14335
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16692
14336
  #endif
16693
14337
  sum += (ggml_float)val;
16694
14338
  st[i] = val;
@@ -16802,7 +14446,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
16802
14446
  #else
16803
14447
  ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16804
14448
  memcpy(&scvt, &s, sizeof(scvt));
16805
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14449
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16806
14450
  #endif
16807
14451
  sum += (ggml_float)val;
16808
14452
  ds0[i] = val;
@@ -17965,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17965
15609
  src1,
17966
15610
  n_dims,
17967
15611
  mode,
15612
+ 0,
17968
15613
  n_ctx,
17969
15614
  freq_base,
17970
15615
  freq_scale,
15616
+ 0.0f,
15617
+ 1.0f,
15618
+ 0.0f,
15619
+ 0.0f,
17971
15620
  xpos_base,
17972
15621
  xpos_down,
17973
15622
  false),
@@ -21001,7 +18650,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
21001
18650
  block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
21002
18651
  result = ggml_quantize_q8_0(src + start, block, n, n, hist);
21003
18652
  } break;
21004
- #ifdef GGML_USE_K_QUANTS
21005
18653
  case GGML_TYPE_Q2_K:
21006
18654
  {
21007
18655
  GGML_ASSERT(start % QK_K == 0);
@@ -21032,7 +18680,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
21032
18680
  block_q6_K * block = (block_q6_K*)dst + start / QK_K;
21033
18681
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
21034
18682
  } break;
21035
- #endif
21036
18683
  case GGML_TYPE_F16:
21037
18684
  {
21038
18685
  int elemsize = sizeof(ggml_fp16_t);
@@ -21164,8 +18811,7 @@ static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset)
21164
18811
  return n == size;
21165
18812
  }
21166
18813
 
21167
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21168
- static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) {
18814
+ static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
21169
18815
  p->n = 0;
21170
18816
  p->data = NULL;
21171
18817
 
@@ -21177,19 +18823,6 @@ static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset
21177
18823
  return ok;
21178
18824
  }
21179
18825
 
21180
- static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) {
21181
- p->n = 0;
21182
- p->data = NULL;
21183
-
21184
- bool ok = true;
21185
-
21186
- uint32_t n = 0;
21187
- ok = ok && gguf_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n;
21188
- ok = ok && gguf_fread_el(file, p->data, p->n, offset);
21189
-
21190
- return ok;
21191
- }
21192
-
21193
18826
  struct gguf_context * gguf_init_empty(void) {
21194
18827
  struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
21195
18828
 
@@ -21248,20 +18881,14 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21248
18881
  ctx->data = NULL;
21249
18882
 
21250
18883
  ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
18884
+ ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
18885
+ ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
21251
18886
 
21252
18887
  if (ctx->header.version == 1) {
21253
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21254
- uint32_t n_tensors = 0;
21255
- uint32_t n_kv = 0;
21256
-
21257
- ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset);
21258
- ok = ok && gguf_fread_el(file, &n_kv, sizeof(n_kv), &offset);
21259
-
21260
- ctx->header.n_tensors = n_tensors;
21261
- ctx->header.n_kv = n_kv;
21262
- } else {
21263
- ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
21264
- ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
18888
+ fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
18889
+ fclose(file);
18890
+ gguf_free(ctx);
18891
+ return NULL;
21265
18892
  }
21266
18893
 
21267
18894
  if (!ok) {
@@ -21272,12 +18899,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21272
18899
  }
21273
18900
  }
21274
18901
 
21275
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21276
- bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur;
21277
- if (ctx->header.version == 1) {
21278
- gguf_fread_str = gguf_fread_str_v1;
21279
- }
21280
-
21281
18902
  // read the kv pairs
21282
18903
  {
21283
18904
  ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv));
@@ -21308,15 +18929,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21308
18929
  case GGUF_TYPE_ARRAY:
21309
18930
  {
21310
18931
  ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
21311
-
21312
- if (ctx->header.version == 1) {
21313
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21314
- uint32_t n = 0;
21315
- ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset);
21316
- kv->value.arr.n = n;
21317
- } else {
21318
- ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
21319
- }
18932
+ ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
21320
18933
 
21321
18934
  switch (kv->value.arr.type) {
21322
18935
  case GGUF_TYPE_UINT8:
@@ -21375,14 +18988,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21375
18988
  ok = ok && gguf_fread_str(file, &info->name, &offset);
21376
18989
  ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
21377
18990
  for (uint32_t j = 0; j < info->n_dims; ++j) {
21378
- if (ctx->header.version == 1) {
21379
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21380
- uint32_t t = 0;
21381
- ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset);
21382
- info->ne[j] = t;
21383
- } else {
21384
- ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
21385
- }
18991
+ ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
21386
18992
  }
21387
18993
  ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
21388
18994
  ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);