llama_cpp 0.9.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,8 @@
1
1
  #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
2
+ #define _USE_MATH_DEFINES // For M_PI on MSVC
2
3
 
3
- #include "ggml.h"
4
-
5
- #ifdef GGML_USE_K_QUANTS
6
- #include "k_quants.h"
7
- #endif
4
+ #include "ggml-impl.h"
5
+ #include "ggml-quants.h"
8
6
 
9
7
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
8
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -30,18 +28,6 @@
30
28
  #include <unistd.h>
31
29
  #endif
32
30
 
33
- // static_assert should be a #define, but if it's not,
34
- // fall back to the _Static_assert C11 keyword.
35
- // if C99 - static_assert is noop
36
- // ref: https://stackoverflow.com/a/53923785/4039976
37
- #ifndef static_assert
38
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
39
- #define static_assert(cond, msg) _Static_assert(cond, msg)
40
- #else
41
- #define static_assert(cond, msg) struct global_scope_noop_trick
42
- #endif
43
- #endif
44
-
45
31
  #if defined(_MSC_VER)
46
32
  // disable "possible loss of data" to avoid hundreds of casts
47
33
  // we should just be careful :)
@@ -109,23 +95,11 @@ typedef void * thread_ret_t;
109
95
  #include <unistd.h>
110
96
 
111
97
  #endif
98
+
112
99
  #ifdef GGML_USE_CPU_HBM
113
100
  #include <hbwmalloc.h>
114
101
  #endif
115
102
 
116
- // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
117
- #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
118
- #ifndef __FMA__
119
- #define __FMA__
120
- #endif
121
- #ifndef __F16C__
122
- #define __F16C__
123
- #endif
124
- #ifndef __SSE3__
125
- #define __SSE3__
126
- #endif
127
- #endif
128
-
129
103
  /*#define GGML_PERF*/
130
104
  #define GGML_DEBUG 0
131
105
  #define GGML_GELU_FP16
@@ -251,228 +225,27 @@ inline static void * ggml_aligned_malloc(size_t size) {
251
225
  #include "ggml-opencl.h"
252
226
  #endif
253
227
 
254
- #undef MIN
255
- #undef MAX
256
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
257
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
258
-
259
228
  // floating point type used to accumulate sums
260
229
  typedef double ggml_float;
261
230
 
262
- // 16-bit float
263
- // on Arm, we use __fp16
264
- // on x86, we use uint16_t
265
- #if defined(__ARM_NEON) && !defined(_MSC_VER)
266
-
267
- // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
268
- //
269
- // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
270
- //
271
- #include <arm_neon.h>
272
-
273
- #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
274
- #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
275
-
276
- #define GGML_FP16_TO_FP32(x) ((float) (x))
277
- #define GGML_FP32_TO_FP16(x) (x)
278
-
279
- #else
280
-
281
- #ifdef __wasm_simd128__
282
- #include <wasm_simd128.h>
283
- #else
284
- #ifdef __POWER9_VECTOR__
285
- #include <altivec.h>
286
- #undef bool
287
- #define bool _Bool
288
- #else
289
- #if defined(_MSC_VER) || defined(__MINGW32__)
290
- #include <intrin.h>
291
- #else
292
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
293
- #if !defined(__riscv)
294
- #include <immintrin.h>
295
- #endif
296
- #endif
297
- #endif
298
- #endif
299
- #endif
300
-
301
- #ifdef __riscv_v_intrinsic
302
- #include <riscv_vector.h>
303
- #endif
304
-
305
- #ifdef __F16C__
306
-
307
- #ifdef _MSC_VER
308
- #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
309
- #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
310
- #else
311
- #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
312
- #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
313
- #endif
314
-
315
- #elif defined(__POWER9_VECTOR__)
316
-
317
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
318
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
319
- /* the inline asm below is about 12% faster than the lookup method */
320
- #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
321
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
322
-
323
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
324
- register float f;
325
- register double d;
326
- __asm__(
327
- "mtfprd %0,%2\n"
328
- "xscvhpdp %0,%0\n"
329
- "frsp %1,%0\n" :
330
- /* temp */ "=d"(d),
331
- /* out */ "=f"(f):
332
- /* in */ "r"(h));
333
- return f;
334
- }
335
-
336
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
337
- register double d;
338
- register ggml_fp16_t r;
339
- __asm__( /* xscvdphp can work on double or single precision */
340
- "xscvdphp %0,%2\n"
341
- "mffprd %1,%0\n" :
342
- /* temp */ "=d"(d),
343
- /* out */ "=r"(r):
344
- /* in */ "f"(f));
345
- return r;
346
- }
347
-
348
- #else
349
-
350
- // FP16 <-> FP32
351
- // ref: https://github.com/Maratyszcza/FP16
352
-
353
- static inline float fp32_from_bits(uint32_t w) {
354
- union {
355
- uint32_t as_bits;
356
- float as_value;
357
- } fp32;
358
- fp32.as_bits = w;
359
- return fp32.as_value;
360
- }
361
-
362
- static inline uint32_t fp32_to_bits(float f) {
363
- union {
364
- float as_value;
365
- uint32_t as_bits;
366
- } fp32;
367
- fp32.as_value = f;
368
- return fp32.as_bits;
369
- }
370
-
371
- static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
372
- const uint32_t w = (uint32_t) h << 16;
373
- const uint32_t sign = w & UINT32_C(0x80000000);
374
- const uint32_t two_w = w + w;
375
-
376
- const uint32_t exp_offset = UINT32_C(0xE0) << 23;
377
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
378
- const float exp_scale = 0x1.0p-112f;
379
- #else
380
- const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
381
- #endif
382
- const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
383
-
384
- const uint32_t magic_mask = UINT32_C(126) << 23;
385
- const float magic_bias = 0.5f;
386
- const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
387
-
388
- const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
389
- const uint32_t result = sign |
390
- (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
391
- return fp32_from_bits(result);
392
- }
393
-
394
- static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
395
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)
396
- const float scale_to_inf = 0x1.0p+112f;
397
- const float scale_to_zero = 0x1.0p-110f;
398
- #else
399
- const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
400
- const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
401
- #endif
402
- float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
403
-
404
- const uint32_t w = fp32_to_bits(f);
405
- const uint32_t shl1_w = w + w;
406
- const uint32_t sign = w & UINT32_C(0x80000000);
407
- uint32_t bias = shl1_w & UINT32_C(0xFF000000);
408
- if (bias < UINT32_C(0x71000000)) {
409
- bias = UINT32_C(0x71000000);
410
- }
411
-
412
- base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
413
- const uint32_t bits = fp32_to_bits(base);
414
- const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
415
- const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
416
- const uint32_t nonsign = exp_bits + mantissa_bits;
417
- return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
418
- }
419
-
420
- #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
421
- #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
422
-
423
- #endif // __F16C__
424
-
425
- #endif // __ARM_NEON
426
-
427
231
  //
428
232
  // global data
429
233
  //
430
234
 
431
235
  // precomputed gelu table for f16 (128 KB)
432
- static ggml_fp16_t table_gelu_f16[1 << 16];
236
+ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
433
237
 
434
238
  // precomputed quick gelu table for f16 (128 KB)
435
- static ggml_fp16_t table_gelu_quick_f16[1 << 16];
239
+ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
436
240
 
437
241
  // precomputed silu table for f16 (128 KB)
438
- static ggml_fp16_t table_silu_f16[1 << 16];
242
+ static ggml_fp16_t ggml_table_silu_f16[1 << 16];
439
243
 
440
244
  // precomputed exp table for f16 (128 KB)
441
- static ggml_fp16_t table_exp_f16[1 << 16];
442
-
443
- // precomputed f32 table for f16 (256 KB)
444
- static float table_f32_f16[1 << 16];
445
-
446
- #if defined(__ARM_NEON) || defined(__wasm_simd128__)
447
- #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
448
- #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
449
- #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
450
- #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
451
- #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
452
- #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
453
- #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
454
- #define B8(c,s ) B7(c,s, c), B7(c,s, s)
455
-
456
- // precomputed tables for expanding 8bits to 8 bytes:
457
- static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
458
- static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
459
- #endif
460
-
461
- // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
462
- // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
463
- // This is also true for POWER9.
464
- #if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16)
465
-
466
- inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
467
- uint16_t s;
468
- memcpy(&s, &f, sizeof(uint16_t));
469
- return table_f32_f16[s];
470
- }
471
-
472
- #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
473
- #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
245
+ static ggml_fp16_t ggml_table_exp_f16[1 << 16];
474
246
 
475
- #endif
247
+ // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
248
+ float ggml_table_f32_f16[1 << 16];
476
249
 
477
250
  // note: do not use these inside ggml.c
478
251
  // these are meant to be used via the ggml.h API
@@ -587,3071 +360,816 @@ int64_t ggml_cycles_per_ms(void) {
587
360
 
588
361
  static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
589
362
 
590
- //
591
- // quantization
592
- //
363
+ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
364
+ static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
593
365
 
594
- #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
595
-
596
- #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
597
- // multiply int8_t, add results pairwise twice
598
- static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
599
- // Get absolute values of x vectors
600
- const __m128i ax = _mm_sign_epi8(x, x);
601
- // Sign the values of the y vectors
602
- const __m128i sy = _mm_sign_epi8(y, x);
603
- // Perform multiplication and create 16-bit values
604
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
605
- const __m128i ones = _mm_set1_epi16(1);
606
- return _mm_madd_epi16(ones, dot);
607
- }
608
-
609
- #if __AVX__ || __AVX2__ || __AVX512F__
610
- // horizontally add 8 floats
611
- static inline float hsum_float_8(const __m256 x) {
612
- __m128 res = _mm256_extractf128_ps(x, 1);
613
- res = _mm_add_ps(res, _mm256_castps256_ps128(x));
614
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
615
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
616
- return _mm_cvtss_f32(res);
617
- }
618
-
619
- // horizontally add 8 int32_t
620
- static inline int hsum_i32_8(const __m256i a) {
621
- const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
622
- const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
623
- const __m128i sum64 = _mm_add_epi32(hi64, sum128);
624
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
625
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
626
- }
627
-
628
- // horizontally add 4 int32_t
629
- static inline int hsum_i32_4(const __m128i a) {
630
- const __m128i hi64 = _mm_unpackhi_epi64(a, a);
631
- const __m128i sum64 = _mm_add_epi32(hi64, a);
632
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
633
- return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
634
- }
635
-
636
- #if defined(__AVX2__) || defined(__AVX512F__)
637
- // spread 32 bits to 32 bytes { 0x00, 0xFF }
638
- static inline __m256i bytes_from_bits_32(const uint8_t * x) {
639
- uint32_t x32;
640
- memcpy(&x32, x, sizeof(uint32_t));
641
- const __m256i shuf_mask = _mm256_set_epi64x(
642
- 0x0303030303030303, 0x0202020202020202,
643
- 0x0101010101010101, 0x0000000000000000);
644
- __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
645
- const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
646
- bytes = _mm256_or_si256(bytes, bit_mask);
647
- return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
648
- }
649
-
650
- // Unpack 32 4-bit fields into 32 bytes
651
- // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
652
- static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
653
- {
654
- const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
655
- const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
656
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
657
- return _mm256_and_si256(lowMask, bytes);
658
- }
659
-
660
- // add int16_t pairwise and return as float vector
661
- static inline __m256 sum_i16_pairs_float(const __m256i x) {
662
- const __m256i ones = _mm256_set1_epi16(1);
663
- const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
664
- return _mm256_cvtepi32_ps(summed_pairs);
665
- }
666
-
667
- static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
668
- #if __AVXVNNI__
669
- const __m256i zero = _mm256_setzero_si256();
670
- const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
671
- return _mm256_cvtepi32_ps(summed_pairs);
672
- #else
673
- // Perform multiplication and create 16-bit values
674
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
675
- return sum_i16_pairs_float(dot);
676
- #endif
677
- }
366
+ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
367
+ [GGML_TYPE_I8] = {
368
+ .type_name = "i8",
369
+ .blck_size = 1,
370
+ .type_size = sizeof(int8_t),
371
+ .is_quantized = false,
372
+ },
373
+ [GGML_TYPE_I16] = {
374
+ .type_name = "i16",
375
+ .blck_size = 1,
376
+ .type_size = sizeof(int16_t),
377
+ .is_quantized = false,
378
+ },
379
+ [GGML_TYPE_I32] = {
380
+ .type_name = "i32",
381
+ .blck_size = 1,
382
+ .type_size = sizeof(int32_t),
383
+ .is_quantized = false,
384
+ },
385
+ [GGML_TYPE_F32] = {
386
+ .type_name = "f32",
387
+ .blck_size = 1,
388
+ .type_size = sizeof(float),
389
+ .is_quantized = false,
390
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
391
+ .vec_dot_type = GGML_TYPE_F32,
392
+ },
393
+ [GGML_TYPE_F16] = {
394
+ .type_name = "f16",
395
+ .blck_size = 1,
396
+ .type_size = sizeof(ggml_fp16_t),
397
+ .is_quantized = false,
398
+ .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
399
+ .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
400
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
401
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
402
+ .vec_dot_type = GGML_TYPE_F16,
403
+ },
404
+ [GGML_TYPE_Q4_0] = {
405
+ .type_name = "q4_0",
406
+ .blck_size = QK4_0,
407
+ .type_size = sizeof(block_q4_0),
408
+ .is_quantized = true,
409
+ .to_float = (ggml_to_float_t) dequantize_row_q4_0,
410
+ .from_float = quantize_row_q4_0,
411
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
412
+ .vec_dot = ggml_vec_dot_q4_0_q8_0,
413
+ .vec_dot_type = GGML_TYPE_Q8_0,
414
+ },
415
+ [GGML_TYPE_Q4_1] = {
416
+ .type_name = "q4_1",
417
+ .blck_size = QK4_1,
418
+ .type_size = sizeof(block_q4_1),
419
+ .is_quantized = true,
420
+ .to_float = (ggml_to_float_t) dequantize_row_q4_1,
421
+ .from_float = quantize_row_q4_1,
422
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
423
+ .vec_dot = ggml_vec_dot_q4_1_q8_1,
424
+ .vec_dot_type = GGML_TYPE_Q8_1,
425
+ },
426
+ [4] = { // GGML_TYPE_Q4_2
427
+ .type_name = "DEPRECATED",
428
+ .blck_size = 0,
429
+ .type_size = 0,
430
+ .is_quantized = false,
431
+ .to_float = NULL,
432
+ .from_float = NULL,
433
+ .from_float_reference = NULL,
434
+ .vec_dot = NULL,
435
+ .vec_dot_type = GGML_TYPE_COUNT,
436
+ },
437
+ [5] = { // GGML_TYPE_Q4_3
438
+ .type_name = "DEPRECATED",
439
+ .blck_size = 0,
440
+ .type_size = 0,
441
+ .is_quantized = false,
442
+ .to_float = NULL,
443
+ .from_float = NULL,
444
+ .from_float_reference = NULL,
445
+ .vec_dot = NULL,
446
+ .vec_dot_type = GGML_TYPE_COUNT,
447
+ },
448
+ [GGML_TYPE_Q5_0] = {
449
+ .type_name = "q5_0",
450
+ .blck_size = QK5_0,
451
+ .type_size = sizeof(block_q5_0),
452
+ .is_quantized = true,
453
+ .to_float = (ggml_to_float_t) dequantize_row_q5_0,
454
+ .from_float = quantize_row_q5_0,
455
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
456
+ .vec_dot = ggml_vec_dot_q5_0_q8_0,
457
+ .vec_dot_type = GGML_TYPE_Q8_0,
458
+ },
459
+ [GGML_TYPE_Q5_1] = {
460
+ .type_name = "q5_1",
461
+ .blck_size = QK5_1,
462
+ .type_size = sizeof(block_q5_1),
463
+ .is_quantized = true,
464
+ .to_float = (ggml_to_float_t) dequantize_row_q5_1,
465
+ .from_float = quantize_row_q5_1,
466
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
467
+ .vec_dot = ggml_vec_dot_q5_1_q8_1,
468
+ .vec_dot_type = GGML_TYPE_Q8_1,
469
+ },
470
+ [GGML_TYPE_Q8_0] = {
471
+ .type_name = "q8_0",
472
+ .blck_size = QK8_0,
473
+ .type_size = sizeof(block_q8_0),
474
+ .is_quantized = true,
475
+ .to_float = (ggml_to_float_t) dequantize_row_q8_0,
476
+ .from_float = quantize_row_q8_0,
477
+ .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
478
+ .vec_dot = ggml_vec_dot_q8_0_q8_0,
479
+ .vec_dot_type = GGML_TYPE_Q8_0,
480
+ },
481
+ [GGML_TYPE_Q8_1] = {
482
+ .type_name = "q8_1",
483
+ .blck_size = QK8_1,
484
+ .type_size = sizeof(block_q8_1),
485
+ .is_quantized = true,
486
+ .from_float = quantize_row_q8_1,
487
+ .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
488
+ .vec_dot_type = GGML_TYPE_Q8_1,
489
+ },
490
+ [GGML_TYPE_Q2_K] = {
491
+ .type_name = "q2_K",
492
+ .blck_size = QK_K,
493
+ .type_size = sizeof(block_q2_K),
494
+ .is_quantized = true,
495
+ .to_float = (ggml_to_float_t) dequantize_row_q2_K,
496
+ .from_float = quantize_row_q2_K,
497
+ .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
498
+ .vec_dot = ggml_vec_dot_q2_K_q8_K,
499
+ .vec_dot_type = GGML_TYPE_Q8_K,
500
+ },
501
+ [GGML_TYPE_Q3_K] = {
502
+ .type_name = "q3_K",
503
+ .blck_size = QK_K,
504
+ .type_size = sizeof(block_q3_K),
505
+ .is_quantized = true,
506
+ .to_float = (ggml_to_float_t) dequantize_row_q3_K,
507
+ .from_float = quantize_row_q3_K,
508
+ .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
509
+ .vec_dot = ggml_vec_dot_q3_K_q8_K,
510
+ .vec_dot_type = GGML_TYPE_Q8_K,
511
+ },
512
+ [GGML_TYPE_Q4_K] = {
513
+ .type_name = "q4_K",
514
+ .blck_size = QK_K,
515
+ .type_size = sizeof(block_q4_K),
516
+ .is_quantized = true,
517
+ .to_float = (ggml_to_float_t) dequantize_row_q4_K,
518
+ .from_float = quantize_row_q4_K,
519
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
520
+ .vec_dot = ggml_vec_dot_q4_K_q8_K,
521
+ .vec_dot_type = GGML_TYPE_Q8_K,
522
+ },
523
+ [GGML_TYPE_Q5_K] = {
524
+ .type_name = "q5_K",
525
+ .blck_size = QK_K,
526
+ .type_size = sizeof(block_q5_K),
527
+ .is_quantized = true,
528
+ .to_float = (ggml_to_float_t) dequantize_row_q5_K,
529
+ .from_float = quantize_row_q5_K,
530
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
531
+ .vec_dot = ggml_vec_dot_q5_K_q8_K,
532
+ .vec_dot_type = GGML_TYPE_Q8_K,
533
+ },
534
+ [GGML_TYPE_Q6_K] = {
535
+ .type_name = "q6_K",
536
+ .blck_size = QK_K,
537
+ .type_size = sizeof(block_q6_K),
538
+ .is_quantized = true,
539
+ .to_float = (ggml_to_float_t) dequantize_row_q6_K,
540
+ .from_float = quantize_row_q6_K,
541
+ .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
542
+ .vec_dot = ggml_vec_dot_q6_K_q8_K,
543
+ .vec_dot_type = GGML_TYPE_Q8_K,
544
+ },
545
+ [GGML_TYPE_Q8_K] = {
546
+ .type_name = "q8_K",
547
+ .blck_size = QK_K,
548
+ .type_size = sizeof(block_q8_K),
549
+ .is_quantized = true,
550
+ .from_float = quantize_row_q8_K,
551
+ }
552
+ };
678
553
 
679
- // multiply int8_t, add results pairwise twice and return as float vector
680
- static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
681
- #if __AVXVNNIINT8__
682
- const __m256i zero = _mm256_setzero_si256();
683
- const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
684
- return _mm256_cvtepi32_ps(summed_pairs);
685
- #else
686
- // Get absolute values of x vectors
687
- const __m256i ax = _mm256_sign_epi8(x, x);
688
- // Sign the values of the y vectors
689
- const __m256i sy = _mm256_sign_epi8(y, x);
690
- return mul_sum_us8_pairs_float(ax, sy);
691
- #endif
554
+ // For internal test use
555
+ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
556
+ GGML_ASSERT(type < GGML_TYPE_COUNT);
557
+ return type_traits[type];
692
558
  }
693
559
 
694
- static inline __m128i packNibbles( __m256i bytes )
695
- {
696
- // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
697
- #if __AVX512F__
698
- const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
699
- bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
700
- return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
701
- #else
702
- const __m256i lowByte = _mm256_set1_epi16( 0xFF );
703
- __m256i high = _mm256_andnot_si256( lowByte, bytes );
704
- __m256i low = _mm256_and_si256( lowByte, bytes );
705
- high = _mm256_srli_epi16( high, 4 );
706
- bytes = _mm256_or_si256( low, high );
707
-
708
- // Compress uint16_t lanes into bytes
709
- __m128i r0 = _mm256_castsi256_si128( bytes );
710
- __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
711
- return _mm_packus_epi16( r0, r1 );
712
- #endif
713
- }
714
- #elif defined(__AVX__)
715
- // spread 32 bits to 32 bytes { 0x00, 0xFF }
716
- static inline __m256i bytes_from_bits_32(const uint8_t * x) {
717
- uint32_t x32;
718
- memcpy(&x32, x, sizeof(uint32_t));
719
- const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
720
- const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
721
- __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
722
- __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
723
- const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
724
- bytesl = _mm_or_si128(bytesl, bit_mask);
725
- bytesh = _mm_or_si128(bytesh, bit_mask);
726
- bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
727
- bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
728
- return MM256_SET_M128I(bytesh, bytesl);
729
- }
730
-
731
- // Unpack 32 4-bit fields into 32 bytes
732
- // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
733
- static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
734
- {
735
- // Load 16 bytes from memory
736
- __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
737
- __m128i tmph = _mm_srli_epi16(tmpl, 4);
738
- const __m128i lowMask = _mm_set1_epi8(0xF);
739
- tmpl = _mm_and_si128(lowMask, tmpl);
740
- tmph = _mm_and_si128(lowMask, tmph);
741
- return MM256_SET_M128I(tmph, tmpl);
742
- }
743
-
744
- // add int16_t pairwise and return as float vector
745
- static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
746
- const __m128i ones = _mm_set1_epi16(1);
747
- const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
748
- const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
749
- const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
750
- return _mm256_cvtepi32_ps(summed_pairs);
751
- }
752
-
753
- static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
754
- const __m128i axl = _mm256_castsi256_si128(ax);
755
- const __m128i axh = _mm256_extractf128_si256(ax, 1);
756
- const __m128i syl = _mm256_castsi256_si128(sy);
757
- const __m128i syh = _mm256_extractf128_si256(sy, 1);
758
- // Perform multiplication and create 16-bit values
759
- const __m128i dotl = _mm_maddubs_epi16(axl, syl);
760
- const __m128i doth = _mm_maddubs_epi16(axh, syh);
761
- return sum_i16_pairs_float(doth, dotl);
762
- }
763
-
764
- // multiply int8_t, add results pairwise twice and return as float vector
765
- static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
766
- const __m128i xl = _mm256_castsi256_si128(x);
767
- const __m128i xh = _mm256_extractf128_si256(x, 1);
768
- const __m128i yl = _mm256_castsi256_si128(y);
769
- const __m128i yh = _mm256_extractf128_si256(y, 1);
770
- // Get absolute values of x vectors
771
- const __m128i axl = _mm_sign_epi8(xl, xl);
772
- const __m128i axh = _mm_sign_epi8(xh, xh);
773
- // Sign the values of the y vectors
774
- const __m128i syl = _mm_sign_epi8(yl, xl);
775
- const __m128i syh = _mm_sign_epi8(yh, xh);
776
- // Perform multiplication and create 16-bit values
777
- const __m128i dotl = _mm_maddubs_epi16(axl, syl);
778
- const __m128i doth = _mm_maddubs_epi16(axh, syh);
779
- return sum_i16_pairs_float(doth, dotl);
780
- }
781
-
782
- static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
783
- {
784
- // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
785
- const __m128i lowByte = _mm_set1_epi16( 0xFF );
786
- __m128i high = _mm_andnot_si128( lowByte, bytes1 );
787
- __m128i low = _mm_and_si128( lowByte, bytes1 );
788
- high = _mm_srli_epi16( high, 4 );
789
- bytes1 = _mm_or_si128( low, high );
790
- high = _mm_andnot_si128( lowByte, bytes2 );
791
- low = _mm_and_si128( lowByte, bytes2 );
792
- high = _mm_srli_epi16( high, 4 );
793
- bytes2 = _mm_or_si128( low, high );
794
-
795
- return _mm_packus_epi16( bytes1, bytes2);
796
- }
797
- #endif
798
- #elif defined(__SSSE3__)
799
- // horizontally add 4x4 floats
800
- static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
801
- __m128 res_0 =_mm_hadd_ps(a, b);
802
- __m128 res_1 =_mm_hadd_ps(c, d);
803
- __m128 res =_mm_hadd_ps(res_0, res_1);
804
- res =_mm_hadd_ps(res, res);
805
- res =_mm_hadd_ps(res, res);
806
-
807
- return _mm_cvtss_f32(res);
808
- }
809
- #endif // __AVX__ || __AVX2__ || __AVX512F__
810
- #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
811
-
812
- #if defined(__ARM_NEON)
813
-
814
- #if !defined(__aarch64__)
815
-
816
- inline static int32_t vaddvq_s32(int32x4_t v) {
817
- return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
818
- }
819
-
820
- inline static float vaddvq_f32(float32x4_t v) {
821
- return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
822
- }
823
-
824
- inline static float vmaxvq_f32(float32x4_t v) {
825
- return
826
- MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
827
- MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
828
- }
829
-
830
- inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
831
- int32x4_t res;
832
-
833
- res[0] = roundf(vgetq_lane_f32(v, 0));
834
- res[1] = roundf(vgetq_lane_f32(v, 1));
835
- res[2] = roundf(vgetq_lane_f32(v, 2));
836
- res[3] = roundf(vgetq_lane_f32(v, 3));
837
-
838
- return res;
839
- }
840
-
841
- #endif
842
- #endif
843
-
844
- #define QK4_0 32
845
- typedef struct {
846
- ggml_fp16_t d; // delta
847
- uint8_t qs[QK4_0 / 2]; // nibbles / quants
848
- } block_q4_0;
849
- static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
850
-
851
- #define QK4_1 32
852
- typedef struct {
853
- ggml_fp16_t d; // delta
854
- ggml_fp16_t m; // min
855
- uint8_t qs[QK4_1 / 2]; // nibbles / quants
856
- } block_q4_1;
857
- static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
858
-
859
- #define QK5_0 32
860
- typedef struct {
861
- ggml_fp16_t d; // delta
862
- uint8_t qh[4]; // 5-th bit of quants
863
- uint8_t qs[QK5_0 / 2]; // nibbles / quants
864
- } block_q5_0;
865
- static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
866
-
867
- #define QK5_1 32
868
- typedef struct {
869
- ggml_fp16_t d; // delta
870
- ggml_fp16_t m; // min
871
- uint8_t qh[4]; // 5-th bit of quants
872
- uint8_t qs[QK5_1 / 2]; // nibbles / quants
873
- } block_q5_1;
874
- static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
875
-
876
- #define QK8_0 32
877
- typedef struct {
878
- ggml_fp16_t d; // delta
879
- int8_t qs[QK8_0]; // quants
880
- } block_q8_0;
881
- static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
882
-
883
- #define QK8_1 32
884
- typedef struct {
885
- float d; // delta
886
- float s; // d * sum(qs[i])
887
- int8_t qs[QK8_1]; // quants
888
- } block_q8_1;
889
- static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
890
-
891
- // reference implementation for deterministic creation of model files
892
- static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
893
- static const int qk = QK4_0;
894
-
895
- assert(k % qk == 0);
896
-
897
- const int nb = k / qk;
898
-
899
- for (int i = 0; i < nb; i++) {
900
- float amax = 0.0f; // absolute max
901
- float max = 0.0f;
902
-
903
- for (int j = 0; j < qk; j++) {
904
- const float v = x[i*qk + j];
905
- if (amax < fabsf(v)) {
906
- amax = fabsf(v);
907
- max = v;
908
- }
909
- }
910
-
911
- const float d = max / -8;
912
- const float id = d ? 1.0f/d : 0.0f;
913
-
914
- y[i].d = GGML_FP32_TO_FP16(d);
915
-
916
- for (int j = 0; j < qk/2; ++j) {
917
- const float x0 = x[i*qk + 0 + j]*id;
918
- const float x1 = x[i*qk + qk/2 + j]*id;
919
-
920
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
921
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
922
-
923
- y[i].qs[j] = xi0;
924
- y[i].qs[j] |= xi1 << 4;
925
- }
926
- }
927
- }
928
-
929
- static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
930
- quantize_row_q4_0_reference(x, y, k);
931
- }
932
-
933
- static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
934
- const int qk = QK4_1;
935
-
936
- assert(k % qk == 0);
937
-
938
- const int nb = k / qk;
939
-
940
- for (int i = 0; i < nb; i++) {
941
- float min = FLT_MAX;
942
- float max = -FLT_MAX;
943
-
944
- for (int j = 0; j < qk; j++) {
945
- const float v = x[i*qk + j];
946
-
947
- if (v < min) min = v;
948
- if (v > max) max = v;
949
- }
950
-
951
- const float d = (max - min) / ((1 << 4) - 1);
952
- const float id = d ? 1.0f/d : 0.0f;
953
-
954
- y[i].d = GGML_FP32_TO_FP16(d);
955
- y[i].m = GGML_FP32_TO_FP16(min);
956
-
957
- for (int j = 0; j < qk/2; ++j) {
958
- const float x0 = (x[i*qk + 0 + j] - min)*id;
959
- const float x1 = (x[i*qk + qk/2 + j] - min)*id;
960
-
961
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
962
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
963
-
964
- y[i].qs[j] = xi0;
965
- y[i].qs[j] |= xi1 << 4;
966
- }
967
- }
968
- }
969
-
970
- static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
971
- quantize_row_q4_1_reference(x, y, k);
972
- }
973
-
974
- static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
975
- static const int qk = QK5_0;
976
-
977
- assert(k % qk == 0);
978
-
979
- const int nb = k / qk;
980
-
981
- for (int i = 0; i < nb; i++) {
982
- float amax = 0.0f; // absolute max
983
- float max = 0.0f;
984
-
985
- for (int j = 0; j < qk; j++) {
986
- const float v = x[i*qk + j];
987
- if (amax < fabsf(v)) {
988
- amax = fabsf(v);
989
- max = v;
990
- }
991
- }
992
-
993
- const float d = max / -16;
994
- const float id = d ? 1.0f/d : 0.0f;
995
-
996
- y[i].d = GGML_FP32_TO_FP16(d);
997
-
998
- uint32_t qh = 0;
999
-
1000
- for (int j = 0; j < qk/2; ++j) {
1001
- const float x0 = x[i*qk + 0 + j]*id;
1002
- const float x1 = x[i*qk + qk/2 + j]*id;
1003
-
1004
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1005
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
1006
-
1007
- y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1008
-
1009
- // get the 5-th bit and store it in qh at the right position
1010
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1011
- qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1012
- }
1013
-
1014
- memcpy(&y[i].qh, &qh, sizeof(qh));
1015
- }
1016
- }
1017
-
1018
- static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
1019
- quantize_row_q5_0_reference(x, y, k);
1020
- }
1021
-
1022
- static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1023
- const int qk = QK5_1;
1024
-
1025
- assert(k % qk == 0);
1026
-
1027
- const int nb = k / qk;
1028
-
1029
- for (int i = 0; i < nb; i++) {
1030
- float min = FLT_MAX;
1031
- float max = -FLT_MAX;
1032
-
1033
- for (int j = 0; j < qk; j++) {
1034
- const float v = x[i*qk + j];
1035
-
1036
- if (v < min) min = v;
1037
- if (v > max) max = v;
1038
- }
1039
-
1040
- const float d = (max - min) / ((1 << 5) - 1);
1041
- const float id = d ? 1.0f/d : 0.0f;
1042
-
1043
- y[i].d = GGML_FP32_TO_FP16(d);
1044
- y[i].m = GGML_FP32_TO_FP16(min);
1045
-
1046
- uint32_t qh = 0;
1047
-
1048
- for (int j = 0; j < qk/2; ++j) {
1049
- const float x0 = (x[i*qk + 0 + j] - min)*id;
1050
- const float x1 = (x[i*qk + qk/2 + j] - min)*id;
1051
-
1052
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
1053
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
1054
-
1055
- y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
1056
-
1057
- // get the 5-th bit and store it in qh at the right position
1058
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
1059
- qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
1060
- }
1061
-
1062
- memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1063
- }
1064
- }
1065
-
1066
- static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
1067
- quantize_row_q5_1_reference(x, y, k);
1068
- }
1069
-
1070
- // reference implementation for deterministic creation of model files
1071
- static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1072
- assert(k % QK8_0 == 0);
1073
- const int nb = k / QK8_0;
1074
-
1075
- for (int i = 0; i < nb; i++) {
1076
- float amax = 0.0f; // absolute max
1077
-
1078
- for (int j = 0; j < QK8_0; j++) {
1079
- const float v = x[i*QK8_0 + j];
1080
- amax = MAX(amax, fabsf(v));
1081
- }
1082
-
1083
- const float d = amax / ((1 << 7) - 1);
1084
- const float id = d ? 1.0f/d : 0.0f;
1085
-
1086
- y[i].d = GGML_FP32_TO_FP16(d);
1087
-
1088
- for (int j = 0; j < QK8_0; ++j) {
1089
- const float x0 = x[i*QK8_0 + j]*id;
1090
-
1091
- y[i].qs[j] = roundf(x0);
1092
- }
1093
- }
1094
- }
1095
-
1096
- static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1097
- assert(QK8_0 == 32);
1098
- assert(k % QK8_0 == 0);
1099
- const int nb = k / QK8_0;
1100
-
1101
- block_q8_0 * restrict y = vy;
1102
-
1103
- #if defined(__ARM_NEON)
1104
- for (int i = 0; i < nb; i++) {
1105
- float32x4_t srcv [8];
1106
- float32x4_t asrcv[8];
1107
- float32x4_t amaxv[8];
1108
-
1109
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
1110
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
1111
-
1112
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
1113
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
1114
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
1115
-
1116
- const float amax = vmaxvq_f32(amaxv[0]);
1117
-
1118
- const float d = amax / ((1 << 7) - 1);
1119
- const float id = d ? 1.0f/d : 0.0f;
1120
-
1121
- y[i].d = GGML_FP32_TO_FP16(d);
1122
-
1123
- for (int j = 0; j < 8; j++) {
1124
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
1125
- const int32x4_t vi = vcvtnq_s32_f32(v);
1126
-
1127
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
1128
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
1129
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
1130
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
1131
- }
1132
- }
1133
- #elif defined(__wasm_simd128__)
1134
- for (int i = 0; i < nb; i++) {
1135
- v128_t srcv [8];
1136
- v128_t asrcv[8];
1137
- v128_t amaxv[8];
1138
-
1139
- for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
1140
- for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
1141
-
1142
- for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
1143
- for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
1144
- for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
1145
-
1146
- const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
1147
- wasm_f32x4_extract_lane(amaxv[0], 1)),
1148
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
1149
- wasm_f32x4_extract_lane(amaxv[0], 3)));
1150
-
1151
- const float d = amax / ((1 << 7) - 1);
1152
- const float id = d ? 1.0f/d : 0.0f;
1153
-
1154
- y[i].d = GGML_FP32_TO_FP16(d);
1155
-
1156
- for (int j = 0; j < 8; j++) {
1157
- const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
1158
- const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
1159
-
1160
- y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
1161
- y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
1162
- y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
1163
- y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
1164
- }
1165
- }
1166
- #elif defined(__AVX2__) || defined(__AVX__)
1167
- for (int i = 0; i < nb; i++) {
1168
- // Load elements into 4 AVX vectors
1169
- __m256 v0 = _mm256_loadu_ps( x );
1170
- __m256 v1 = _mm256_loadu_ps( x + 8 );
1171
- __m256 v2 = _mm256_loadu_ps( x + 16 );
1172
- __m256 v3 = _mm256_loadu_ps( x + 24 );
1173
- x += 32;
1174
-
1175
- // Compute max(abs(e)) for the block
1176
- const __m256 signBit = _mm256_set1_ps( -0.0f );
1177
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1178
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1179
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1180
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1181
-
1182
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1183
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1184
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1185
- const float maxScalar = _mm_cvtss_f32( max4 );
1186
-
1187
- // Quantize these floats
1188
- const float d = maxScalar / 127.f;
1189
- y[i].d = GGML_FP32_TO_FP16(d);
1190
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1191
- const __m256 mul = _mm256_set1_ps( id );
1192
-
1193
- // Apply the multiplier
1194
- v0 = _mm256_mul_ps( v0, mul );
1195
- v1 = _mm256_mul_ps( v1, mul );
1196
- v2 = _mm256_mul_ps( v2, mul );
1197
- v3 = _mm256_mul_ps( v3, mul );
1198
-
1199
- // Round to nearest integer
1200
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1201
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1202
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1203
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1204
-
1205
- // Convert floats to integers
1206
- __m256i i0 = _mm256_cvtps_epi32( v0 );
1207
- __m256i i1 = _mm256_cvtps_epi32( v1 );
1208
- __m256i i2 = _mm256_cvtps_epi32( v2 );
1209
- __m256i i3 = _mm256_cvtps_epi32( v3 );
1210
-
1211
- #if defined(__AVX2__)
1212
- // Convert int32 to int16
1213
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1214
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1215
- // Convert int16 to int8
1216
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1217
-
1218
- // We got our precious signed bytes, but the order is now wrong
1219
- // These AVX2 pack instructions process 16-byte pieces independently
1220
- // The following instruction is fixing the order
1221
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1222
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
1223
-
1224
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1225
- #else
1226
- // Since we don't have in AVX some necessary functions,
1227
- // we split the registers in half and call AVX2 analogs from SSE
1228
- __m128i ni0 = _mm256_castsi256_si128( i0 );
1229
- __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1230
- __m128i ni2 = _mm256_castsi256_si128( i1 );
1231
- __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1232
- __m128i ni4 = _mm256_castsi256_si128( i2 );
1233
- __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1234
- __m128i ni6 = _mm256_castsi256_si128( i3 );
1235
- __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1236
-
1237
- // Convert int32 to int16
1238
- ni0 = _mm_packs_epi32( ni0, ni1 );
1239
- ni2 = _mm_packs_epi32( ni2, ni3 );
1240
- ni4 = _mm_packs_epi32( ni4, ni5 );
1241
- ni6 = _mm_packs_epi32( ni6, ni7 );
1242
- // Convert int16 to int8
1243
- ni0 = _mm_packs_epi16( ni0, ni2 );
1244
- ni4 = _mm_packs_epi16( ni4, ni6 );
1245
-
1246
- _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1247
- _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1248
- #endif
1249
- }
1250
- #elif defined(__riscv_v_intrinsic)
1251
-
1252
- size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1253
-
1254
- for (int i = 0; i < nb; i++) {
1255
- // load elements
1256
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1257
-
1258
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1259
- vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1260
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1261
- float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1262
-
1263
- const float d = amax / ((1 << 7) - 1);
1264
- const float id = d ? 1.0f/d : 0.0f;
1265
-
1266
- y[i].d = GGML_FP32_TO_FP16(d);
1267
-
1268
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1269
-
1270
- // convert to integer
1271
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1272
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1273
-
1274
- // store result
1275
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1276
- }
1277
- #else
1278
- // scalar
1279
- quantize_row_q8_0_reference(x, y, k);
1280
- #endif
1281
- }
1282
-
1283
- // reference implementation for deterministic creation of model files
1284
- static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1285
- assert(QK8_1 == 32);
1286
- assert(k % QK8_1 == 0);
1287
- const int nb = k / QK8_1;
1288
-
1289
- for (int i = 0; i < nb; i++) {
1290
- float amax = 0.0f; // absolute max
1291
-
1292
- for (int j = 0; j < QK8_1; j++) {
1293
- const float v = x[i*QK8_1 + j];
1294
- amax = MAX(amax, fabsf(v));
1295
- }
1296
-
1297
- const float d = amax / ((1 << 7) - 1);
1298
- const float id = d ? 1.0f/d : 0.0f;
1299
-
1300
- y[i].d = d;
1301
-
1302
- int sum = 0;
1303
-
1304
- for (int j = 0; j < QK8_1/2; ++j) {
1305
- const float v0 = x[i*QK8_1 + j]*id;
1306
- const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
1307
-
1308
- y[i].qs[ j] = roundf(v0);
1309
- y[i].qs[QK8_1/2 + j] = roundf(v1);
1310
-
1311
- sum += y[i].qs[ j];
1312
- sum += y[i].qs[QK8_1/2 + j];
1313
- }
1314
-
1315
- y[i].s = sum*d;
1316
- }
1317
- }
1318
-
1319
- static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1320
- assert(k % QK8_1 == 0);
1321
- const int nb = k / QK8_1;
1322
-
1323
- block_q8_1 * restrict y = vy;
1324
-
1325
- #if defined(__ARM_NEON)
1326
- for (int i = 0; i < nb; i++) {
1327
- float32x4_t srcv [8];
1328
- float32x4_t asrcv[8];
1329
- float32x4_t amaxv[8];
1330
-
1331
- for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
1332
- for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
1333
-
1334
- for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
1335
- for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
1336
- for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
1337
-
1338
- const float amax = vmaxvq_f32(amaxv[0]);
1339
-
1340
- const float d = amax / ((1 << 7) - 1);
1341
- const float id = d ? 1.0f/d : 0.0f;
1342
-
1343
- y[i].d = d;
1344
-
1345
- int32x4_t accv = vdupq_n_s32(0);
1346
-
1347
- for (int j = 0; j < 8; j++) {
1348
- const float32x4_t v = vmulq_n_f32(srcv[j], id);
1349
- const int32x4_t vi = vcvtnq_s32_f32(v);
1350
-
1351
- y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
1352
- y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
1353
- y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
1354
- y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
1355
-
1356
- accv = vaddq_s32(accv, vi);
1357
- }
1358
-
1359
- y[i].s = d * vaddvq_s32(accv);
1360
- }
1361
- #elif defined(__wasm_simd128__)
1362
- for (int i = 0; i < nb; i++) {
1363
- v128_t srcv [8];
1364
- v128_t asrcv[8];
1365
- v128_t amaxv[8];
1366
-
1367
- for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
1368
- for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
1369
-
1370
- for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
1371
- for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
1372
- for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
1373
-
1374
- const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
1375
- wasm_f32x4_extract_lane(amaxv[0], 1)),
1376
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
1377
- wasm_f32x4_extract_lane(amaxv[0], 3)));
1378
-
1379
- const float d = amax / ((1 << 7) - 1);
1380
- const float id = d ? 1.0f/d : 0.0f;
1381
-
1382
- y[i].d = d;
1383
-
1384
- v128_t accv = wasm_i32x4_splat(0);
1385
-
1386
- for (int j = 0; j < 8; j++) {
1387
- const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
1388
- const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
1389
-
1390
- y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
1391
- y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
1392
- y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
1393
- y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
1394
-
1395
- accv = wasm_i32x4_add(accv, vi);
1396
- }
1397
-
1398
- y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
1399
- wasm_i32x4_extract_lane(accv, 1) +
1400
- wasm_i32x4_extract_lane(accv, 2) +
1401
- wasm_i32x4_extract_lane(accv, 3));
1402
- }
1403
- #elif defined(__AVX2__) || defined(__AVX__)
1404
- for (int i = 0; i < nb; i++) {
1405
- // Load elements into 4 AVX vectors
1406
- __m256 v0 = _mm256_loadu_ps( x );
1407
- __m256 v1 = _mm256_loadu_ps( x + 8 );
1408
- __m256 v2 = _mm256_loadu_ps( x + 16 );
1409
- __m256 v3 = _mm256_loadu_ps( x + 24 );
1410
- x += 32;
1411
-
1412
- // Compute max(abs(e)) for the block
1413
- const __m256 signBit = _mm256_set1_ps( -0.0f );
1414
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1415
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1416
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1417
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1418
-
1419
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1420
- max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1421
- max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1422
- const float maxScalar = _mm_cvtss_f32( max4 );
1423
-
1424
- // Quantize these floats
1425
- const float d = maxScalar / 127.f;
1426
- y[i].d = d;
1427
- const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1428
- const __m256 mul = _mm256_set1_ps( id );
1429
-
1430
- // Apply the multiplier
1431
- v0 = _mm256_mul_ps( v0, mul );
1432
- v1 = _mm256_mul_ps( v1, mul );
1433
- v2 = _mm256_mul_ps( v2, mul );
1434
- v3 = _mm256_mul_ps( v3, mul );
1435
-
1436
- // Round to nearest integer
1437
- v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1438
- v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1439
- v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1440
- v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1441
-
1442
- // Convert floats to integers
1443
- __m256i i0 = _mm256_cvtps_epi32( v0 );
1444
- __m256i i1 = _mm256_cvtps_epi32( v1 );
1445
- __m256i i2 = _mm256_cvtps_epi32( v2 );
1446
- __m256i i3 = _mm256_cvtps_epi32( v3 );
1447
-
1448
- #if defined(__AVX2__)
1449
- // Compute the sum of the quants and set y[i].s
1450
- y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1451
-
1452
- // Convert int32 to int16
1453
- i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1454
- i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1455
- // Convert int16 to int8
1456
- i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1457
-
1458
- // We got our precious signed bytes, but the order is now wrong
1459
- // These AVX2 pack instructions process 16-byte pieces independently
1460
- // The following instruction is fixing the order
1461
- const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1462
- i0 = _mm256_permutevar8x32_epi32( i0, perm );
1463
-
1464
- _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1465
- #else
1466
- // Since we don't have in AVX some necessary functions,
1467
- // we split the registers in half and call AVX2 analogs from SSE
1468
- __m128i ni0 = _mm256_castsi256_si128( i0 );
1469
- __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1470
- __m128i ni2 = _mm256_castsi256_si128( i1 );
1471
- __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1472
- __m128i ni4 = _mm256_castsi256_si128( i2 );
1473
- __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1474
- __m128i ni6 = _mm256_castsi256_si128( i3 );
1475
- __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1476
-
1477
- // Compute the sum of the quants and set y[i].s
1478
- const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1479
- const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1480
- y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
1481
-
1482
- // Convert int32 to int16
1483
- ni0 = _mm_packs_epi32( ni0, ni1 );
1484
- ni2 = _mm_packs_epi32( ni2, ni3 );
1485
- ni4 = _mm_packs_epi32( ni4, ni5 );
1486
- ni6 = _mm_packs_epi32( ni6, ni7 );
1487
- // Convert int16 to int8
1488
- ni0 = _mm_packs_epi16( ni0, ni2 );
1489
- ni4 = _mm_packs_epi16( ni4, ni6 );
1490
-
1491
- _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1492
- _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1493
- #endif
1494
- }
1495
- #elif defined(__riscv_v_intrinsic)
1496
-
1497
- size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1498
-
1499
- for (int i = 0; i < nb; i++) {
1500
- // load elements
1501
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1502
-
1503
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1504
- vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1505
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1506
- float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1507
-
1508
- const float d = amax / ((1 << 7) - 1);
1509
- const float id = d ? 1.0f/d : 0.0f;
1510
-
1511
- y[i].d = d;
1512
-
1513
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1514
-
1515
- // convert to integer
1516
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1517
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1518
-
1519
- // store result
1520
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1521
-
1522
- // compute sum for y[i].s
1523
- vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1524
- vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1525
-
1526
- // set y[i].s
1527
- int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1528
- y[i].s = sum*d;
1529
- }
1530
- #else
1531
- // scalar
1532
- quantize_row_q8_1_reference(x, y, k);
1533
- #endif
1534
- }
1535
-
1536
- static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
1537
- static const int qk = QK4_0;
1538
-
1539
- assert(k % qk == 0);
1540
-
1541
- const int nb = k / qk;
1542
-
1543
- for (int i = 0; i < nb; i++) {
1544
- const float d = GGML_FP16_TO_FP32(x[i].d);
1545
-
1546
- for (int j = 0; j < qk/2; ++j) {
1547
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
1548
- const int x1 = (x[i].qs[j] >> 4) - 8;
1549
-
1550
- y[i*qk + j + 0 ] = x0*d;
1551
- y[i*qk + j + qk/2] = x1*d;
1552
- }
1553
- }
1554
- }
1555
-
1556
- static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
1557
- static const int qk = QK4_1;
1558
-
1559
- assert(k % qk == 0);
1560
-
1561
- const int nb = k / qk;
1562
-
1563
- for (int i = 0; i < nb; i++) {
1564
- const float d = GGML_FP16_TO_FP32(x[i].d);
1565
- const float m = GGML_FP16_TO_FP32(x[i].m);
1566
-
1567
- for (int j = 0; j < qk/2; ++j) {
1568
- const int x0 = (x[i].qs[j] & 0x0F);
1569
- const int x1 = (x[i].qs[j] >> 4);
1570
-
1571
- y[i*qk + j + 0 ] = x0*d + m;
1572
- y[i*qk + j + qk/2] = x1*d + m;
1573
- }
1574
- }
1575
- }
1576
-
1577
- static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
1578
- static const int qk = QK5_0;
1579
-
1580
- assert(k % qk == 0);
1581
-
1582
- const int nb = k / qk;
1583
-
1584
- for (int i = 0; i < nb; i++) {
1585
- const float d = GGML_FP16_TO_FP32(x[i].d);
1586
-
1587
- uint32_t qh;
1588
- memcpy(&qh, x[i].qh, sizeof(qh));
1589
-
1590
- for (int j = 0; j < qk/2; ++j) {
1591
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1592
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1593
-
1594
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
1595
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
1596
-
1597
- y[i*qk + j + 0 ] = x0*d;
1598
- y[i*qk + j + qk/2] = x1*d;
1599
- }
1600
- }
1601
- }
1602
-
1603
- static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
1604
- static const int qk = QK5_1;
1605
-
1606
- assert(k % qk == 0);
1607
-
1608
- const int nb = k / qk;
1609
-
1610
- for (int i = 0; i < nb; i++) {
1611
- const float d = GGML_FP16_TO_FP32(x[i].d);
1612
- const float m = GGML_FP16_TO_FP32(x[i].m);
1613
-
1614
- uint32_t qh;
1615
- memcpy(&qh, x[i].qh, sizeof(qh));
1616
-
1617
- for (int j = 0; j < qk/2; ++j) {
1618
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
1619
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
1620
-
1621
- const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
1622
- const int x1 = (x[i].qs[j] >> 4) | xh_1;
1623
-
1624
- y[i*qk + j + 0 ] = x0*d + m;
1625
- y[i*qk + j + qk/2] = x1*d + m;
1626
- }
1627
- }
1628
- }
1629
-
1630
- static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
1631
- static const int qk = QK8_0;
1632
-
1633
- assert(k % qk == 0);
1634
-
1635
- const int nb = k / qk;
1636
-
1637
- const block_q8_0 * restrict x = vx;
1638
-
1639
- for (int i = 0; i < nb; i++) {
1640
- const float d = GGML_FP16_TO_FP32(x[i].d);
1641
-
1642
- for (int j = 0; j < qk; ++j) {
1643
- y[i*qk + j] = x[i].qs[j]*d;
1644
- }
1645
- }
1646
- }
1647
-
1648
- static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
1649
- static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
1650
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1651
- static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1652
- static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1653
- static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1654
- static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1655
-
1656
- static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
1657
- [GGML_TYPE_I8] = {
1658
- .type_name = "i8",
1659
- .blck_size = 1,
1660
- .type_size = sizeof(int8_t),
1661
- .is_quantized = false,
1662
- },
1663
- [GGML_TYPE_I16] = {
1664
- .type_name = "i16",
1665
- .blck_size = 1,
1666
- .type_size = sizeof(int16_t),
1667
- .is_quantized = false,
1668
- },
1669
- [GGML_TYPE_I32] = {
1670
- .type_name = "i32",
1671
- .blck_size = 1,
1672
- .type_size = sizeof(int32_t),
1673
- .is_quantized = false,
1674
- },
1675
- [GGML_TYPE_F32] = {
1676
- .type_name = "f32",
1677
- .blck_size = 1,
1678
- .type_size = sizeof(float),
1679
- .is_quantized = false,
1680
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
1681
- .vec_dot_type = GGML_TYPE_F32,
1682
- },
1683
- [GGML_TYPE_F16] = {
1684
- .type_name = "f16",
1685
- .blck_size = 1,
1686
- .type_size = sizeof(ggml_fp16_t),
1687
- .is_quantized = false,
1688
- .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
1689
- .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
1690
- .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
1691
- .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
1692
- .vec_dot_type = GGML_TYPE_F16,
1693
- },
1694
- [GGML_TYPE_Q4_0] = {
1695
- .type_name = "q4_0",
1696
- .blck_size = QK4_0,
1697
- .type_size = sizeof(block_q4_0),
1698
- .is_quantized = true,
1699
- .to_float = (ggml_to_float_t) dequantize_row_q4_0,
1700
- .from_float = quantize_row_q4_0,
1701
- .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
1702
- .vec_dot = ggml_vec_dot_q4_0_q8_0,
1703
- .vec_dot_type = GGML_TYPE_Q8_0,
1704
- },
1705
- [GGML_TYPE_Q4_1] = {
1706
- .type_name = "q4_1",
1707
- .blck_size = QK4_1,
1708
- .type_size = sizeof(block_q4_1),
1709
- .is_quantized = true,
1710
- .to_float = (ggml_to_float_t) dequantize_row_q4_1,
1711
- .from_float = quantize_row_q4_1,
1712
- .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
1713
- .vec_dot = ggml_vec_dot_q4_1_q8_1,
1714
- .vec_dot_type = GGML_TYPE_Q8_1,
1715
- },
1716
- [GGML_TYPE_Q5_0] = {
1717
- .type_name = "q5_0",
1718
- .blck_size = QK5_0,
1719
- .type_size = sizeof(block_q5_0),
1720
- .is_quantized = true,
1721
- .to_float = (ggml_to_float_t) dequantize_row_q5_0,
1722
- .from_float = quantize_row_q5_0,
1723
- .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
1724
- .vec_dot = ggml_vec_dot_q5_0_q8_0,
1725
- .vec_dot_type = GGML_TYPE_Q8_0,
1726
- },
1727
- [GGML_TYPE_Q5_1] = {
1728
- .type_name = "q5_1",
1729
- .blck_size = QK5_1,
1730
- .type_size = sizeof(block_q5_1),
1731
- .is_quantized = true,
1732
- .to_float = (ggml_to_float_t) dequantize_row_q5_1,
1733
- .from_float = quantize_row_q5_1,
1734
- .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
1735
- .vec_dot = ggml_vec_dot_q5_1_q8_1,
1736
- .vec_dot_type = GGML_TYPE_Q8_1,
1737
- },
1738
- [GGML_TYPE_Q8_0] = {
1739
- .type_name = "q8_0",
1740
- .blck_size = QK8_0,
1741
- .type_size = sizeof(block_q8_0),
1742
- .is_quantized = true,
1743
- .to_float = dequantize_row_q8_0,
1744
- .from_float = quantize_row_q8_0,
1745
- .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
1746
- .vec_dot = ggml_vec_dot_q8_0_q8_0,
1747
- .vec_dot_type = GGML_TYPE_Q8_0,
1748
- },
1749
- [GGML_TYPE_Q8_1] = {
1750
- .type_name = "q8_1",
1751
- .blck_size = QK8_1,
1752
- .type_size = sizeof(block_q8_1),
1753
- .is_quantized = true,
1754
- .from_float = quantize_row_q8_1,
1755
- .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
1756
- .vec_dot_type = GGML_TYPE_Q8_1,
1757
- },
1758
- #ifdef GGML_USE_K_QUANTS
1759
- [GGML_TYPE_Q2_K] = {
1760
- .type_name = "q2_K",
1761
- .blck_size = QK_K,
1762
- .type_size = sizeof(block_q2_K),
1763
- .is_quantized = true,
1764
- .to_float = (ggml_to_float_t) dequantize_row_q2_K,
1765
- .from_float = quantize_row_q2_K,
1766
- .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
1767
- .vec_dot = ggml_vec_dot_q2_K_q8_K,
1768
- .vec_dot_type = GGML_TYPE_Q8_K,
1769
- },
1770
- [GGML_TYPE_Q3_K] = {
1771
- .type_name = "q3_K",
1772
- .blck_size = QK_K,
1773
- .type_size = sizeof(block_q3_K),
1774
- .is_quantized = true,
1775
- .to_float = (ggml_to_float_t) dequantize_row_q3_K,
1776
- .from_float = quantize_row_q3_K,
1777
- .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
1778
- .vec_dot = ggml_vec_dot_q3_K_q8_K,
1779
- .vec_dot_type = GGML_TYPE_Q8_K,
1780
- },
1781
- [GGML_TYPE_Q4_K] = {
1782
- .type_name = "q4_K",
1783
- .blck_size = QK_K,
1784
- .type_size = sizeof(block_q4_K),
1785
- .is_quantized = true,
1786
- .to_float = (ggml_to_float_t) dequantize_row_q4_K,
1787
- .from_float = quantize_row_q4_K,
1788
- .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
1789
- .vec_dot = ggml_vec_dot_q4_K_q8_K,
1790
- .vec_dot_type = GGML_TYPE_Q8_K,
1791
- },
1792
- [GGML_TYPE_Q5_K] = {
1793
- .type_name = "q5_K",
1794
- .blck_size = QK_K,
1795
- .type_size = sizeof(block_q5_K),
1796
- .is_quantized = true,
1797
- .to_float = (ggml_to_float_t) dequantize_row_q5_K,
1798
- .from_float = quantize_row_q5_K,
1799
- .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
1800
- .vec_dot = ggml_vec_dot_q5_K_q8_K,
1801
- .vec_dot_type = GGML_TYPE_Q8_K,
1802
- },
1803
- [GGML_TYPE_Q6_K] = {
1804
- .type_name = "q6_K",
1805
- .blck_size = QK_K,
1806
- .type_size = sizeof(block_q6_K),
1807
- .is_quantized = true,
1808
- .to_float = (ggml_to_float_t) dequantize_row_q6_K,
1809
- .from_float = quantize_row_q6_K,
1810
- .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
1811
- .vec_dot = ggml_vec_dot_q6_K_q8_K,
1812
- .vec_dot_type = GGML_TYPE_Q8_K,
1813
- },
1814
- [GGML_TYPE_Q8_K] = {
1815
- .type_name = "q8_K",
1816
- .blck_size = QK_K,
1817
- .type_size = sizeof(block_q8_K),
1818
- .is_quantized = true,
1819
- .from_float = quantize_row_q8_K,
1820
- }
1821
- #endif
1822
- };
1823
-
1824
- // For internal test use
1825
- ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1826
- GGML_ASSERT(type < GGML_TYPE_COUNT);
1827
- return type_traits[type];
1828
- }
1829
-
1830
- //
1831
- // simd mappings
1832
- //
1833
-
1834
- // we define a common set of C macros which map to specific intrinsics based on the current architecture
1835
- // we then implement the fundamental computation operations below using only these macros
1836
- // adding support for new architectures requires to define the corresponding SIMD macros
1837
- //
1838
- // GGML_F32_STEP / GGML_F16_STEP
1839
- // number of elements to process in a single step
1840
- //
1841
- // GGML_F32_EPR / GGML_F16_EPR
1842
- // number of elements to fit in a single register
1843
- //
1844
-
1845
- #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
1846
-
1847
- #define GGML_SIMD
1848
-
1849
- // F32 NEON
1850
-
1851
- #define GGML_F32_STEP 16
1852
- #define GGML_F32_EPR 4
1853
-
1854
- #define GGML_F32x4 float32x4_t
1855
- #define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
1856
- #define GGML_F32x4_SET1(x) vdupq_n_f32(x)
1857
- #define GGML_F32x4_LOAD vld1q_f32
1858
- #define GGML_F32x4_STORE vst1q_f32
1859
- #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1860
- #define GGML_F32x4_ADD vaddq_f32
1861
- #define GGML_F32x4_MUL vmulq_f32
1862
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1863
- #define GGML_F32x4_REDUCE(res, x) \
1864
- { \
1865
- int offset = GGML_F32_ARR >> 1; \
1866
- for (int i = 0; i < offset; ++i) { \
1867
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1868
- } \
1869
- offset >>= 1; \
1870
- for (int i = 0; i < offset; ++i) { \
1871
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1872
- } \
1873
- offset >>= 1; \
1874
- for (int i = 0; i < offset; ++i) { \
1875
- x[i] = vaddq_f32(x[i], x[offset+i]); \
1876
- } \
1877
- res = GGML_F32x4_REDUCE_ONE(x[0]); \
1878
- }
1879
-
1880
- #define GGML_F32_VEC GGML_F32x4
1881
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1882
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1883
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1884
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1885
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1886
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1887
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1888
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1889
-
1890
- // F16 NEON
1891
-
1892
- #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
1893
- #define GGML_F16_STEP 32
1894
- #define GGML_F16_EPR 8
1895
-
1896
- #define GGML_F16x8 float16x8_t
1897
- #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
1898
- #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
1899
- #define GGML_F16x8_LOAD vld1q_f16
1900
- #define GGML_F16x8_STORE vst1q_f16
1901
- #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
1902
- #define GGML_F16x8_ADD vaddq_f16
1903
- #define GGML_F16x8_MUL vmulq_f16
1904
- #define GGML_F16x8_REDUCE(res, x) \
1905
- do { \
1906
- int offset = GGML_F16_ARR >> 1; \
1907
- for (int i = 0; i < offset; ++i) { \
1908
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1909
- } \
1910
- offset >>= 1; \
1911
- for (int i = 0; i < offset; ++i) { \
1912
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1913
- } \
1914
- offset >>= 1; \
1915
- for (int i = 0; i < offset; ++i) { \
1916
- x[i] = vaddq_f16(x[i], x[offset+i]); \
1917
- } \
1918
- const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
1919
- const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
1920
- res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
1921
- } while (0)
1922
-
1923
- #define GGML_F16_VEC GGML_F16x8
1924
- #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
1925
- #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
1926
- #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
1927
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
1928
- #define GGML_F16_VEC_FMA GGML_F16x8_FMA
1929
- #define GGML_F16_VEC_ADD GGML_F16x8_ADD
1930
- #define GGML_F16_VEC_MUL GGML_F16x8_MUL
1931
- #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
1932
- #else
1933
- // if FP16 vector arithmetic is not supported, we use FP32 instead
1934
- // and take advantage of the vcvt_ functions to convert to/from FP16
1935
-
1936
- #define GGML_F16_STEP 16
1937
- #define GGML_F16_EPR 4
1938
-
1939
- #define GGML_F32Cx4 float32x4_t
1940
- #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
1941
- #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
1942
- #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
1943
- #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
1944
- #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
1945
- #define GGML_F32Cx4_ADD vaddq_f32
1946
- #define GGML_F32Cx4_MUL vmulq_f32
1947
- #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1948
-
1949
- #define GGML_F16_VEC GGML_F32Cx4
1950
- #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1951
- #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1952
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1953
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1954
- #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1955
- #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1956
- #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1957
- #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1958
- #endif
1959
-
1960
- #elif defined(__AVX__)
1961
-
1962
- #define GGML_SIMD
1963
-
1964
- // F32 AVX
1965
-
1966
- #define GGML_F32_STEP 32
1967
- #define GGML_F32_EPR 8
1968
-
1969
- #define GGML_F32x8 __m256
1970
- #define GGML_F32x8_ZERO _mm256_setzero_ps()
1971
- #define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
1972
- #define GGML_F32x8_LOAD _mm256_loadu_ps
1973
- #define GGML_F32x8_STORE _mm256_storeu_ps
1974
- #if defined(__FMA__)
1975
- #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
1976
- #else
1977
- #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
1978
- #endif
1979
- #define GGML_F32x8_ADD _mm256_add_ps
1980
- #define GGML_F32x8_MUL _mm256_mul_ps
1981
- #define GGML_F32x8_REDUCE(res, x) \
1982
- do { \
1983
- int offset = GGML_F32_ARR >> 1; \
1984
- for (int i = 0; i < offset; ++i) { \
1985
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1986
- } \
1987
- offset >>= 1; \
1988
- for (int i = 0; i < offset; ++i) { \
1989
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1990
- } \
1991
- offset >>= 1; \
1992
- for (int i = 0; i < offset; ++i) { \
1993
- x[i] = _mm256_add_ps(x[i], x[offset+i]); \
1994
- } \
1995
- const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
1996
- _mm256_extractf128_ps(x[0], 1)); \
1997
- const __m128 t1 = _mm_hadd_ps(t0, t0); \
1998
- res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
1999
- } while (0)
2000
- // TODO: is this optimal ?
2001
-
2002
- #define GGML_F32_VEC GGML_F32x8
2003
- #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
2004
- #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
2005
- #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
2006
- #define GGML_F32_VEC_STORE GGML_F32x8_STORE
2007
- #define GGML_F32_VEC_FMA GGML_F32x8_FMA
2008
- #define GGML_F32_VEC_ADD GGML_F32x8_ADD
2009
- #define GGML_F32_VEC_MUL GGML_F32x8_MUL
2010
- #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
2011
-
2012
- // F16 AVX
2013
-
2014
- #define GGML_F16_STEP 32
2015
- #define GGML_F16_EPR 8
2016
-
2017
- // F16 arithmetic is not supported by AVX, so we use F32 instead
2018
-
2019
- #define GGML_F32Cx8 __m256
2020
- #define GGML_F32Cx8_ZERO _mm256_setzero_ps()
2021
- #define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
2022
-
2023
- #if defined(__F16C__)
2024
- // the _mm256_cvt intrinsics require F16C
2025
- #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
2026
- #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
2027
- #else
2028
- static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
2029
- float tmp[8];
2030
-
2031
- for (int i = 0; i < 8; i++) {
2032
- tmp[i] = GGML_FP16_TO_FP32(x[i]);
2033
- }
2034
-
2035
- return _mm256_loadu_ps(tmp);
2036
- }
2037
- static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
2038
- float arr[8];
2039
-
2040
- _mm256_storeu_ps(arr, y);
2041
-
2042
- for (int i = 0; i < 8; i++)
2043
- x[i] = GGML_FP32_TO_FP16(arr[i]);
2044
- }
2045
- #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
2046
- #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
2047
- #endif
2048
-
2049
- #define GGML_F32Cx8_FMA GGML_F32x8_FMA
2050
- #define GGML_F32Cx8_ADD _mm256_add_ps
2051
- #define GGML_F32Cx8_MUL _mm256_mul_ps
2052
- #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
2053
-
2054
- #define GGML_F16_VEC GGML_F32Cx8
2055
- #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
2056
- #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
2057
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
2058
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
2059
- #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
2060
- #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
2061
- #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
2062
- #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
2063
-
2064
- #elif defined(__POWER9_VECTOR__)
2065
-
2066
- #define GGML_SIMD
2067
-
2068
- // F32 POWER9
2069
-
2070
- #define GGML_F32_STEP 32
2071
- #define GGML_F32_EPR 4
2072
-
2073
- #define GGML_F32x4 vector float
2074
- #define GGML_F32x4_ZERO 0.0f
2075
- #define GGML_F32x4_SET1 vec_splats
2076
- #define GGML_F32x4_LOAD(p) vec_xl(0, p)
2077
- #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
2078
- #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
2079
- #define GGML_F32x4_ADD vec_add
2080
- #define GGML_F32x4_MUL vec_mul
2081
- #define GGML_F32x4_REDUCE(res, x) \
2082
- { \
2083
- int offset = GGML_F32_ARR >> 1; \
2084
- for (int i = 0; i < offset; ++i) { \
2085
- x[i] = vec_add(x[i], x[offset+i]); \
2086
- } \
2087
- offset >>= 1; \
2088
- for (int i = 0; i < offset; ++i) { \
2089
- x[i] = vec_add(x[i], x[offset+i]); \
2090
- } \
2091
- offset >>= 1; \
2092
- for (int i = 0; i < offset; ++i) { \
2093
- x[i] = vec_add(x[i], x[offset+i]); \
2094
- } \
2095
- res = vec_extract(x[0], 0) + \
2096
- vec_extract(x[0], 1) + \
2097
- vec_extract(x[0], 2) + \
2098
- vec_extract(x[0], 3); \
2099
- }
2100
-
2101
- #define GGML_F32_VEC GGML_F32x4
2102
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
2103
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
2104
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
2105
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
2106
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
2107
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
2108
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
2109
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2110
-
2111
- // F16 POWER9
2112
- #define GGML_F16_STEP GGML_F32_STEP
2113
- #define GGML_F16_EPR GGML_F32_EPR
2114
- #define GGML_F16_VEC GGML_F32x4
2115
- #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
2116
- #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
2117
- #define GGML_F16_VEC_FMA GGML_F32x4_FMA
2118
- #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
2119
- // Use vec_xl, not vec_ld, in case the load address is not aligned.
2120
- #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
2121
- vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
2122
- vec_extract_fp32_from_shortl(vec_xl(0, p))
2123
- #define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
2124
- #define GGML_F16_VEC_STORE(p, r, i) \
2125
- if (i & 0x1) \
2126
- vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
2127
- r[i - GGML_ENDIAN_BYTE(0)]), \
2128
- 0, p - GGML_F16_EPR)
2129
-
2130
- #elif defined(__wasm_simd128__)
2131
-
2132
- #define GGML_SIMD
2133
-
2134
- // F32 WASM
2135
-
2136
- #define GGML_F32_STEP 16
2137
- #define GGML_F32_EPR 4
2138
-
2139
- #define GGML_F32x4 v128_t
2140
- #define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
2141
- #define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
2142
- #define GGML_F32x4_LOAD wasm_v128_load
2143
- #define GGML_F32x4_STORE wasm_v128_store
2144
- #define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
2145
- #define GGML_F32x4_ADD wasm_f32x4_add
2146
- #define GGML_F32x4_MUL wasm_f32x4_mul
2147
- #define GGML_F32x4_REDUCE(res, x) \
2148
- { \
2149
- int offset = GGML_F32_ARR >> 1; \
2150
- for (int i = 0; i < offset; ++i) { \
2151
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2152
- } \
2153
- offset >>= 1; \
2154
- for (int i = 0; i < offset; ++i) { \
2155
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2156
- } \
2157
- offset >>= 1; \
2158
- for (int i = 0; i < offset; ++i) { \
2159
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2160
- } \
2161
- res = wasm_f32x4_extract_lane(x[0], 0) + \
2162
- wasm_f32x4_extract_lane(x[0], 1) + \
2163
- wasm_f32x4_extract_lane(x[0], 2) + \
2164
- wasm_f32x4_extract_lane(x[0], 3); \
2165
- }
2166
-
2167
- #define GGML_F32_VEC GGML_F32x4
2168
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
2169
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
2170
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
2171
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
2172
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
2173
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
2174
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
2175
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2176
-
2177
- // F16 WASM
2178
-
2179
- #define GGML_F16_STEP 16
2180
- #define GGML_F16_EPR 4
2181
-
2182
- inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
2183
- float tmp[4];
2184
-
2185
- tmp[0] = GGML_FP16_TO_FP32(p[0]);
2186
- tmp[1] = GGML_FP16_TO_FP32(p[1]);
2187
- tmp[2] = GGML_FP16_TO_FP32(p[2]);
2188
- tmp[3] = GGML_FP16_TO_FP32(p[3]);
2189
-
2190
- return wasm_v128_load(tmp);
2191
- }
2192
-
2193
- inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
2194
- float tmp[4];
2195
-
2196
- wasm_v128_store(tmp, x);
2197
-
2198
- p[0] = GGML_FP32_TO_FP16(tmp[0]);
2199
- p[1] = GGML_FP32_TO_FP16(tmp[1]);
2200
- p[2] = GGML_FP32_TO_FP16(tmp[2]);
2201
- p[3] = GGML_FP32_TO_FP16(tmp[3]);
2202
- }
2203
-
2204
- #define GGML_F16x4 v128_t
2205
- #define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
2206
- #define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
2207
- #define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
2208
- #define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
2209
- #define GGML_F16x4_FMA GGML_F32x4_FMA
2210
- #define GGML_F16x4_ADD wasm_f32x4_add
2211
- #define GGML_F16x4_MUL wasm_f32x4_mul
2212
- #define GGML_F16x4_REDUCE(res, x) \
2213
- { \
2214
- int offset = GGML_F16_ARR >> 1; \
2215
- for (int i = 0; i < offset; ++i) { \
2216
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2217
- } \
2218
- offset >>= 1; \
2219
- for (int i = 0; i < offset; ++i) { \
2220
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2221
- } \
2222
- offset >>= 1; \
2223
- for (int i = 0; i < offset; ++i) { \
2224
- x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
2225
- } \
2226
- res = wasm_f32x4_extract_lane(x[0], 0) + \
2227
- wasm_f32x4_extract_lane(x[0], 1) + \
2228
- wasm_f32x4_extract_lane(x[0], 2) + \
2229
- wasm_f32x4_extract_lane(x[0], 3); \
2230
- }
2231
-
2232
- #define GGML_F16_VEC GGML_F16x4
2233
- #define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
2234
- #define GGML_F16_VEC_SET1 GGML_F16x4_SET1
2235
- #define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
2236
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
2237
- #define GGML_F16_VEC_FMA GGML_F16x4_FMA
2238
- #define GGML_F16_VEC_ADD GGML_F16x4_ADD
2239
- #define GGML_F16_VEC_MUL GGML_F16x4_MUL
2240
- #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
2241
-
2242
- #elif defined(__SSE3__)
2243
-
2244
- #define GGML_SIMD
2245
-
2246
- // F32 SSE
2247
-
2248
- #define GGML_F32_STEP 32
2249
- #define GGML_F32_EPR 4
2250
-
2251
- #define GGML_F32x4 __m128
2252
- #define GGML_F32x4_ZERO _mm_setzero_ps()
2253
- #define GGML_F32x4_SET1(x) _mm_set1_ps(x)
2254
- #define GGML_F32x4_LOAD _mm_loadu_ps
2255
- #define GGML_F32x4_STORE _mm_storeu_ps
2256
- #if defined(__FMA__)
2257
- // TODO: Does this work?
2258
- #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
2259
- #else
2260
- #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
2261
- #endif
2262
- #define GGML_F32x4_ADD _mm_add_ps
2263
- #define GGML_F32x4_MUL _mm_mul_ps
2264
- #define GGML_F32x4_REDUCE(res, x) \
2265
- { \
2266
- int offset = GGML_F32_ARR >> 1; \
2267
- for (int i = 0; i < offset; ++i) { \
2268
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
2269
- } \
2270
- offset >>= 1; \
2271
- for (int i = 0; i < offset; ++i) { \
2272
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
2273
- } \
2274
- offset >>= 1; \
2275
- for (int i = 0; i < offset; ++i) { \
2276
- x[i] = _mm_add_ps(x[i], x[offset+i]); \
2277
- } \
2278
- const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
2279
- res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
2280
- }
2281
- // TODO: is this optimal ?
2282
-
2283
- #define GGML_F32_VEC GGML_F32x4
2284
- #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
2285
- #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
2286
- #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
2287
- #define GGML_F32_VEC_STORE GGML_F32x4_STORE
2288
- #define GGML_F32_VEC_FMA GGML_F32x4_FMA
2289
- #define GGML_F32_VEC_ADD GGML_F32x4_ADD
2290
- #define GGML_F32_VEC_MUL GGML_F32x4_MUL
2291
- #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2292
-
2293
- // F16 SSE
2294
-
2295
- #define GGML_F16_STEP 32
2296
- #define GGML_F16_EPR 4
2297
-
2298
- static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
2299
- float tmp[4];
2300
-
2301
- tmp[0] = GGML_FP16_TO_FP32(x[0]);
2302
- tmp[1] = GGML_FP16_TO_FP32(x[1]);
2303
- tmp[2] = GGML_FP16_TO_FP32(x[2]);
2304
- tmp[3] = GGML_FP16_TO_FP32(x[3]);
2305
-
2306
- return _mm_loadu_ps(tmp);
2307
- }
2308
-
2309
- static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
2310
- float arr[4];
2311
-
2312
- _mm_storeu_ps(arr, y);
2313
-
2314
- x[0] = GGML_FP32_TO_FP16(arr[0]);
2315
- x[1] = GGML_FP32_TO_FP16(arr[1]);
2316
- x[2] = GGML_FP32_TO_FP16(arr[2]);
2317
- x[3] = GGML_FP32_TO_FP16(arr[3]);
2318
- }
2319
-
2320
- #define GGML_F32Cx4 __m128
2321
- #define GGML_F32Cx4_ZERO _mm_setzero_ps()
2322
- #define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
2323
- #define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
2324
- #define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
2325
- #define GGML_F32Cx4_FMA GGML_F32x4_FMA
2326
- #define GGML_F32Cx4_ADD _mm_add_ps
2327
- #define GGML_F32Cx4_MUL _mm_mul_ps
2328
- #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
2329
-
2330
- #define GGML_F16_VEC GGML_F32Cx4
2331
- #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
2332
- #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
2333
- #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
2334
- #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
2335
- #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
2336
- #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
2337
- #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
2338
- #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
2339
-
2340
- #endif
2341
-
2342
- // GGML_F32_ARR / GGML_F16_ARR
2343
- // number of registers to use per step
2344
- #ifdef GGML_SIMD
2345
- #define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
2346
- #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
2347
- #endif
2348
-
2349
- //
2350
- // fundamental operations
2351
- //
2352
-
2353
- inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2354
-
2355
- inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2356
-
2357
- inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2358
-
2359
- inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
2360
-
2361
- inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
2362
- inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
2363
- inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
2364
- inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
2365
- inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
2366
- inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
2367
- inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
2368
- inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
2369
- inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
2370
- inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
2371
-
2372
- static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
2373
- #ifdef GGML_SIMD
2374
- float sumf = 0.0f;
2375
- const int np = (n & ~(GGML_F32_STEP - 1));
2376
-
2377
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
2378
-
2379
- GGML_F32_VEC ax[GGML_F32_ARR];
2380
- GGML_F32_VEC ay[GGML_F32_ARR];
2381
-
2382
- for (int i = 0; i < np; i += GGML_F32_STEP) {
2383
- for (int j = 0; j < GGML_F32_ARR; j++) {
2384
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
2385
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
2386
-
2387
- sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
2388
- }
2389
- }
2390
-
2391
- // reduce sum0..sum3 to sum0
2392
- GGML_F32_VEC_REDUCE(sumf, sum);
2393
-
2394
- // leftovers
2395
- for (int i = np; i < n; ++i) {
2396
- sumf += x[i]*y[i];
2397
- }
2398
- #else
2399
- // scalar
2400
- ggml_float sumf = 0.0;
2401
- for (int i = 0; i < n; ++i) {
2402
- sumf += (ggml_float)(x[i]*y[i]);
2403
- }
2404
- #endif
2405
-
2406
- *s = sumf;
2407
- }
2408
-
2409
- static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
2410
- ggml_float sumf = 0.0;
2411
-
2412
- #if defined(GGML_SIMD)
2413
- const int np = (n & ~(GGML_F16_STEP - 1));
2414
-
2415
- GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
2416
-
2417
- GGML_F16_VEC ax[GGML_F16_ARR];
2418
- GGML_F16_VEC ay[GGML_F16_ARR];
2419
-
2420
- for (int i = 0; i < np; i += GGML_F16_STEP) {
2421
- for (int j = 0; j < GGML_F16_ARR; j++) {
2422
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
2423
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2424
-
2425
- sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
2426
- }
2427
- }
2428
-
2429
- // reduce sum0..sum3 to sum0
2430
- GGML_F16_VEC_REDUCE(sumf, sum);
2431
-
2432
- // leftovers
2433
- for (int i = np; i < n; ++i) {
2434
- sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
2435
- }
2436
- #else
2437
- for (int i = 0; i < n; ++i) {
2438
- sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
2439
- }
2440
- #endif
2441
-
2442
- *s = sumf;
2443
- }
2444
-
2445
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2446
- const int qk = QK8_0;
2447
- const int nb = n / qk;
2448
-
2449
- assert(n % qk == 0);
2450
-
2451
- const block_q4_0 * restrict x = vx;
2452
- const block_q8_0 * restrict y = vy;
2453
-
2454
- #if defined(__ARM_NEON)
2455
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2456
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
2457
-
2458
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2459
- for (int i = 0; i < nb; i += 2) {
2460
- const block_q4_0 * restrict x0 = &x[i + 0];
2461
- const block_q4_0 * restrict x1 = &x[i + 1];
2462
- const block_q8_0 * restrict y0 = &y[i + 0];
2463
- const block_q8_0 * restrict y1 = &y[i + 1];
2464
-
2465
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2466
- const int8x16_t s8b = vdupq_n_s8(0x8);
2467
-
2468
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2469
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2470
-
2471
- // 4-bit -> 8-bit
2472
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2473
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2474
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2475
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2476
-
2477
- // sub 8
2478
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2479
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2480
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2481
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2482
-
2483
- // load y
2484
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2485
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2486
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2487
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2488
-
2489
- #if defined(__ARM_FEATURE_DOTPROD)
2490
- // dot product into int32x4_t
2491
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
2492
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
2493
-
2494
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2495
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2496
- #else
2497
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2498
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2499
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2500
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2501
-
2502
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2503
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2504
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2505
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
2506
-
2507
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2508
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2509
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2510
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2511
-
2512
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2513
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
2514
- #endif
2515
- }
2516
-
2517
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2518
- #elif defined(__AVX2__)
2519
- // Initialize accumulator with zeros
2520
- __m256 acc = _mm256_setzero_ps();
2521
-
2522
- // Main loop
2523
- for (int i = 0; i < nb; ++i) {
2524
- /* Compute combined scale for the block */
2525
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2526
-
2527
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
2528
-
2529
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2530
- const __m256i off = _mm256_set1_epi8( 8 );
2531
- bx = _mm256_sub_epi8( bx, off );
2532
-
2533
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2534
-
2535
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
2536
-
2537
- /* Multiply q with scale and accumulate */
2538
- acc = _mm256_fmadd_ps( d, q, acc );
2539
- }
2540
-
2541
- *s = hsum_float_8(acc);
2542
- #elif defined(__AVX__)
2543
- // Initialize accumulator with zeros
2544
- __m256 acc = _mm256_setzero_ps();
2545
-
2546
- // Main loop
2547
- for (int i = 0; i < nb; ++i) {
2548
- // Compute combined scale for the block
2549
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2550
-
2551
- const __m128i lowMask = _mm_set1_epi8(0xF);
2552
- const __m128i off = _mm_set1_epi8(8);
2553
-
2554
- const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
2555
-
2556
- __m128i bx = _mm_and_si128(lowMask, tmp);
2557
- __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
2558
- bx = _mm_sub_epi8(bx, off);
2559
- const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
2560
-
2561
- bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
2562
- by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2563
- bx = _mm_sub_epi8(bx, off);
2564
- const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2565
-
2566
- // Convert int32_t to float
2567
- __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
2568
-
2569
- // Apply the scale, and accumulate
2570
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2571
- }
2572
-
2573
- *s = hsum_float_8(acc);
2574
- #elif defined(__SSSE3__)
2575
- // set constants
2576
- const __m128i lowMask = _mm_set1_epi8(0xF);
2577
- const __m128i off = _mm_set1_epi8(8);
2578
-
2579
- // Initialize accumulator with zeros
2580
- __m128 acc_0 = _mm_setzero_ps();
2581
- __m128 acc_1 = _mm_setzero_ps();
2582
- __m128 acc_2 = _mm_setzero_ps();
2583
- __m128 acc_3 = _mm_setzero_ps();
2584
-
2585
- // First round without accumulation
2586
- {
2587
- _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
2588
- _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
2589
-
2590
- // Compute combined scale for the block 0 and 1
2591
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
2592
-
2593
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
2594
-
2595
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2596
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
2597
- bx_0 = _mm_sub_epi8(bx_0, off);
2598
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2599
-
2600
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2601
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
2602
- bx_1 = _mm_sub_epi8(bx_1, off);
2603
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2604
-
2605
- _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
2606
- _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
2607
-
2608
- // Compute combined scale for the block 2 and 3
2609
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
2610
-
2611
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
2612
-
2613
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2614
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
2615
- bx_2 = _mm_sub_epi8(bx_2, off);
2616
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2617
-
2618
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2619
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
2620
- bx_3 = _mm_sub_epi8(bx_3, off);
2621
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2622
-
2623
- // Convert int32_t to float
2624
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
2625
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
2626
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
2627
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
2628
-
2629
- // Apply the scale
2630
- acc_0 = _mm_mul_ps( d_0_1, p0 );
2631
- acc_1 = _mm_mul_ps( d_0_1, p1 );
2632
- acc_2 = _mm_mul_ps( d_2_3, p2 );
2633
- acc_3 = _mm_mul_ps( d_2_3, p3 );
2634
- }
2635
-
2636
- // Main loop
2637
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2638
- for (int i = 2; i < nb; i+=2) {
2639
- _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
2640
- _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
2641
-
2642
- // Compute combined scale for the block 0 and 1
2643
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
2644
-
2645
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
2646
-
2647
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
2648
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
2649
- bx_0 = _mm_sub_epi8(bx_0, off);
2650
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
2651
-
2652
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
2653
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2654
- bx_1 = _mm_sub_epi8(bx_1, off);
2655
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
2656
-
2657
- _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
2658
- _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
2659
-
2660
- // Compute combined scale for the block 2 and 3
2661
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
2662
-
2663
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
2664
-
2665
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
2666
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
2667
- bx_2 = _mm_sub_epi8(bx_2, off);
2668
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
2669
-
2670
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
2671
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
2672
- bx_3 = _mm_sub_epi8(bx_3, off);
2673
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
2674
-
2675
- // Convert int32_t to float
2676
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
2677
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
2678
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
2679
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
2680
-
2681
- // Apply the scale
2682
- __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
2683
- __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
2684
- __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
2685
- __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
2686
-
2687
- // Acummulate
2688
- acc_0 = _mm_add_ps(p0_d, acc_0);
2689
- acc_1 = _mm_add_ps(p1_d, acc_1);
2690
- acc_2 = _mm_add_ps(p2_d, acc_2);
2691
- acc_3 = _mm_add_ps(p3_d, acc_3);
2692
- }
2693
-
2694
- *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2695
- #elif defined(__riscv_v_intrinsic)
2696
- float sumf = 0.0;
2697
-
2698
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2699
-
2700
- for (int i = 0; i < nb; i++) {
2701
- // load elements
2702
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2703
-
2704
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2705
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2706
-
2707
- // mask and store lower part of x, and then upper part
2708
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2709
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2710
-
2711
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2712
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2713
-
2714
- // subtract offset
2715
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2716
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2717
-
2718
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2719
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2720
-
2721
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2722
-
2723
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2724
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2725
-
2726
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2727
-
2728
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2729
- }
560
+ //
561
+ // simd mappings
562
+ //
2730
563
 
2731
- *s = sumf;
2732
- #else
2733
- // scalar
2734
- float sumf = 0.0;
564
+ // we define a common set of C macros which map to specific intrinsics based on the current architecture
565
+ // we then implement the fundamental computation operations below using only these macros
566
+ // adding support for new architectures requires to define the corresponding SIMD macros
567
+ //
568
+ // GGML_F32_STEP / GGML_F16_STEP
569
+ // number of elements to process in a single step
570
+ //
571
+ // GGML_F32_EPR / GGML_F16_EPR
572
+ // number of elements to fit in a single register
573
+ //
2735
574
 
2736
- for (int i = 0; i < nb; i++) {
2737
- int sumi = 0;
575
+ #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
2738
576
 
2739
- for (int j = 0; j < qk/2; ++j) {
2740
- const int v0 = (x[i].qs[j] & 0x0F) - 8;
2741
- const int v1 = (x[i].qs[j] >> 4) - 8;
577
+ #define GGML_SIMD
2742
578
 
2743
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2744
- }
579
+ // F32 NEON
2745
580
 
2746
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
2747
- }
581
+ #define GGML_F32_STEP 16
582
+ #define GGML_F32_EPR 4
2748
583
 
2749
- *s = sumf;
2750
- #endif
584
+ #define GGML_F32x4 float32x4_t
585
+ #define GGML_F32x4_ZERO vdupq_n_f32(0.0f)
586
+ #define GGML_F32x4_SET1(x) vdupq_n_f32(x)
587
+ #define GGML_F32x4_LOAD vld1q_f32
588
+ #define GGML_F32x4_STORE vst1q_f32
589
+ #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
590
+ #define GGML_F32x4_ADD vaddq_f32
591
+ #define GGML_F32x4_MUL vmulq_f32
592
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
593
+ #define GGML_F32x4_REDUCE(res, x) \
594
+ { \
595
+ int offset = GGML_F32_ARR >> 1; \
596
+ for (int i = 0; i < offset; ++i) { \
597
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
598
+ } \
599
+ offset >>= 1; \
600
+ for (int i = 0; i < offset; ++i) { \
601
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
602
+ } \
603
+ offset >>= 1; \
604
+ for (int i = 0; i < offset; ++i) { \
605
+ x[i] = vaddq_f32(x[i], x[offset+i]); \
606
+ } \
607
+ res = GGML_F32x4_REDUCE_ONE(x[0]); \
2751
608
  }
2752
609
 
2753
- static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2754
- const int qk = QK8_1;
2755
- const int nb = n / qk;
2756
-
2757
- assert(n % qk == 0);
2758
-
2759
- const block_q4_1 * restrict x = vx;
2760
- const block_q8_1 * restrict y = vy;
2761
-
2762
- // TODO: add WASM SIMD
2763
- #if defined(__ARM_NEON)
2764
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2765
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
2766
-
2767
- float summs = 0;
2768
-
2769
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2770
- for (int i = 0; i < nb; i += 2) {
2771
- const block_q4_1 * restrict x0 = &x[i + 0];
2772
- const block_q4_1 * restrict x1 = &x[i + 1];
2773
- const block_q8_1 * restrict y0 = &y[i + 0];
2774
- const block_q8_1 * restrict y1 = &y[i + 1];
2775
-
2776
- summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
2777
-
2778
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2779
-
2780
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2781
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
610
+ #define GGML_F32_VEC GGML_F32x4
611
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
612
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
613
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
614
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
615
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
616
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
617
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
618
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
2782
619
 
2783
- // 4-bit -> 8-bit
2784
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2785
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2786
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2787
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
620
+ // F16 NEON
2788
621
 
2789
- // load y
2790
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2791
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2792
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2793
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
622
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
623
+ #define GGML_F16_STEP 32
624
+ #define GGML_F16_EPR 8
2794
625
 
2795
- #if defined(__ARM_FEATURE_DOTPROD)
2796
- // dot product into int32x4_t
2797
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
2798
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
626
+ #define GGML_F16x8 float16x8_t
627
+ #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
628
+ #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
629
+ #define GGML_F16x8_LOAD vld1q_f16
630
+ #define GGML_F16x8_STORE vst1q_f16
631
+ #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
632
+ #define GGML_F16x8_ADD vaddq_f16
633
+ #define GGML_F16x8_MUL vmulq_f16
634
+ #define GGML_F16x8_REDUCE(res, x) \
635
+ do { \
636
+ int offset = GGML_F16_ARR >> 1; \
637
+ for (int i = 0; i < offset; ++i) { \
638
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
639
+ } \
640
+ offset >>= 1; \
641
+ for (int i = 0; i < offset; ++i) { \
642
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
643
+ } \
644
+ offset >>= 1; \
645
+ for (int i = 0; i < offset; ++i) { \
646
+ x[i] = vaddq_f16(x[i], x[offset+i]); \
647
+ } \
648
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
649
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
650
+ res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
651
+ } while (0)
2799
652
 
2800
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
2801
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
653
+ #define GGML_F16_VEC GGML_F16x8
654
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
655
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
656
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
657
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i])
658
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
659
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
660
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
661
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
2802
662
  #else
2803
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
2804
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
2805
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
2806
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
2807
-
2808
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
2809
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
2810
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
2811
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
2812
-
2813
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2814
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2815
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2816
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2817
-
2818
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
2819
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
2820
- #endif
2821
- }
2822
-
2823
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2824
- #elif defined(__AVX2__) || defined(__AVX__)
2825
- // Initialize accumulator with zeros
2826
- __m256 acc = _mm256_setzero_ps();
2827
-
2828
- float summs = 0;
2829
-
2830
- // Main loop
2831
- for (int i = 0; i < nb; ++i) {
2832
- const float d0 = GGML_FP16_TO_FP32(x[i].d);
2833
- const float d1 = y[i].d;
2834
-
2835
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2836
-
2837
- const __m256 d0v = _mm256_set1_ps( d0 );
2838
- const __m256 d1v = _mm256_set1_ps( d1 );
2839
-
2840
- // Compute combined scales
2841
- const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
663
+ // if FP16 vector arithmetic is not supported, we use FP32 instead
664
+ // and take advantage of the vcvt_ functions to convert to/from FP16
2842
665
 
2843
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2844
- const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2845
- const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
666
+ #define GGML_F16_STEP 16
667
+ #define GGML_F16_EPR 4
2846
668
 
2847
- const __m256 xy = mul_sum_us8_pairs_float(bx, by);
669
+ #define GGML_F32Cx4 float32x4_t
670
+ #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
671
+ #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
672
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x))
673
+ #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
674
+ #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
675
+ #define GGML_F32Cx4_ADD vaddq_f32
676
+ #define GGML_F32Cx4_MUL vmulq_f32
677
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
2848
678
 
2849
- // Accumulate d0*d1*x*y
2850
- #if defined(__AVX2__)
2851
- acc = _mm256_fmadd_ps( d0d1, xy, acc );
2852
- #else
2853
- acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
679
+ #define GGML_F16_VEC GGML_F32Cx4
680
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
681
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
682
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
683
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
684
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
685
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
686
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
687
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
2854
688
  #endif
2855
- }
2856
-
2857
- *s = hsum_float_8(acc) + summs;
2858
- #elif defined(__riscv_v_intrinsic)
2859
- float sumf = 0.0;
2860
-
2861
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2862
-
2863
- for (int i = 0; i < nb; i++) {
2864
- // load elements
2865
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
2866
-
2867
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2868
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
2869
-
2870
- // mask and store lower part of x, and then upper part
2871
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2872
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2873
-
2874
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2875
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2876
-
2877
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2878
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2879
-
2880
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2881
-
2882
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2883
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2884
-
2885
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2886
-
2887
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2888
- }
2889
-
2890
- *s = sumf;
2891
- #else
2892
- // scalar
2893
- float sumf = 0.0;
2894
-
2895
- for (int i = 0; i < nb; i++) {
2896
- int sumi = 0;
2897
-
2898
- for (int j = 0; j < qk/2; ++j) {
2899
- const int v0 = (x[i].qs[j] & 0x0F);
2900
- const int v1 = (x[i].qs[j] >> 4);
2901
-
2902
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
2903
- }
2904
-
2905
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
2906
- }
2907
689
 
2908
- *s = sumf;
2909
- #endif
2910
- }
690
+ #elif defined(__AVX__)
2911
691
 
2912
- static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2913
- const int qk = QK8_0;
2914
- const int nb = n / qk;
692
+ #define GGML_SIMD
2915
693
 
2916
- assert(n % qk == 0);
2917
- assert(qk == QK5_0);
694
+ // F32 AVX
2918
695
 
2919
- const block_q5_0 * restrict x = vx;
2920
- const block_q8_0 * restrict y = vy;
696
+ #define GGML_F32_STEP 32
697
+ #define GGML_F32_EPR 8
2921
698
 
2922
- #if defined(__ARM_NEON)
2923
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2924
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
2925
-
2926
- uint32_t qh0;
2927
- uint32_t qh1;
2928
-
2929
- uint64_t tmp0[4];
2930
- uint64_t tmp1[4];
2931
-
2932
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
2933
- for (int i = 0; i < nb; i += 2) {
2934
- const block_q5_0 * restrict x0 = &x[i];
2935
- const block_q5_0 * restrict x1 = &x[i + 1];
2936
- const block_q8_0 * restrict y0 = &y[i];
2937
- const block_q8_0 * restrict y1 = &y[i + 1];
2938
-
2939
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
2940
-
2941
- // extract the 5th bit via lookup table ((!b) << 4)
2942
- memcpy(&qh0, x0->qh, sizeof(qh0));
2943
- memcpy(&qh1, x1->qh, sizeof(qh1));
2944
-
2945
- tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
2946
- tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
2947
- tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
2948
- tmp0[3] = table_b2b_1[(qh0 >> 24) ];
2949
-
2950
- tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
2951
- tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
2952
- tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
2953
- tmp1[3] = table_b2b_1[(qh1 >> 24) ];
2954
-
2955
- const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
2956
- const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
2957
- const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
2958
- const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
2959
-
2960
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2961
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2962
-
2963
- // 4-bit -> 8-bit
2964
- int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2965
- int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2966
- int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2967
- int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2968
-
2969
- // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2970
- const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
2971
- const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
2972
- const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
2973
- const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
2974
-
2975
- // load y
2976
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2977
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2978
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2979
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2980
-
2981
- #if defined(__ARM_FEATURE_DOTPROD)
2982
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2983
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2984
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
2985
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2986
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2987
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
699
+ #define GGML_F32x8 __m256
700
+ #define GGML_F32x8_ZERO _mm256_setzero_ps()
701
+ #define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
702
+ #define GGML_F32x8_LOAD _mm256_loadu_ps
703
+ #define GGML_F32x8_STORE _mm256_storeu_ps
704
+ #if defined(__FMA__)
705
+ #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
2988
706
  #else
2989
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2990
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2991
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2992
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2993
-
2994
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2995
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2996
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2997
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
2998
-
2999
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3000
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3001
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3002
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3003
-
3004
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3005
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
707
+ #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
3006
708
  #endif
3007
- }
3008
-
3009
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3010
- #elif defined(__wasm_simd128__)
3011
- v128_t sumv = wasm_f32x4_splat(0.0f);
3012
-
3013
- uint32_t qh;
3014
- uint64_t tmp[4];
3015
-
3016
- // TODO: check if unrolling this is better
3017
- for (int i = 0; i < nb; ++i) {
3018
- const block_q5_0 * restrict x0 = &x[i];
3019
- const block_q8_0 * restrict y0 = &y[i];
3020
-
3021
- const v128_t m4b = wasm_i8x16_splat(0x0F);
3022
-
3023
- // extract the 5th bit
3024
- memcpy(&qh, x0->qh, sizeof(qh));
3025
-
3026
- tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
3027
- tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
3028
- tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
3029
- tmp[3] = table_b2b_1[(qh >> 24) ];
3030
-
3031
- const v128_t qhl = wasm_v128_load(tmp + 0);
3032
- const v128_t qhh = wasm_v128_load(tmp + 2);
3033
-
3034
- const v128_t v0 = wasm_v128_load(x0->qs);
3035
-
3036
- // 4-bit -> 8-bit
3037
- const v128_t v0l = wasm_v128_and (v0, m4b);
3038
- const v128_t v0h = wasm_u8x16_shr(v0, 4);
3039
-
3040
- // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
3041
- const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
3042
- const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
3043
-
3044
- // load y
3045
- const v128_t v1l = wasm_v128_load(y0->qs);
3046
- const v128_t v1h = wasm_v128_load(y0->qs + 16);
3047
-
3048
- // int8x16 -> int16x8
3049
- const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3050
- const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3051
- const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3052
- const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
3053
-
3054
- const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3055
- const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3056
- const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3057
- const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3058
-
3059
- // dot product
3060
- sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3061
- wasm_i32x4_add(
3062
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3063
- wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3064
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3065
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3066
- wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
3067
- }
3068
-
3069
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3070
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3071
- #elif defined(__AVX2__)
3072
- // Initialize accumulator with zeros
3073
- __m256 acc = _mm256_setzero_ps();
3074
-
3075
- // Main loop
3076
- for (int i = 0; i < nb; i++) {
3077
- /* Compute combined scale for the block */
3078
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3079
-
3080
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3081
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
3082
- bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3083
- bx = _mm256_or_si256(bx, bxhi);
3084
-
3085
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3086
-
3087
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
709
+ #define GGML_F32x8_ADD _mm256_add_ps
710
+ #define GGML_F32x8_MUL _mm256_mul_ps
711
+ #define GGML_F32x8_REDUCE(res, x) \
712
+ do { \
713
+ int offset = GGML_F32_ARR >> 1; \
714
+ for (int i = 0; i < offset; ++i) { \
715
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
716
+ } \
717
+ offset >>= 1; \
718
+ for (int i = 0; i < offset; ++i) { \
719
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
720
+ } \
721
+ offset >>= 1; \
722
+ for (int i = 0; i < offset; ++i) { \
723
+ x[i] = _mm256_add_ps(x[i], x[offset+i]); \
724
+ } \
725
+ const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \
726
+ _mm256_extractf128_ps(x[0], 1)); \
727
+ const __m128 t1 = _mm_hadd_ps(t0, t0); \
728
+ res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \
729
+ } while (0)
730
+ // TODO: is this optimal ?
3088
731
 
3089
- /* Multiply q with scale and accumulate */
3090
- acc = _mm256_fmadd_ps(d, q, acc);
3091
- }
732
+ #define GGML_F32_VEC GGML_F32x8
733
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
734
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
735
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
736
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
737
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
738
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
739
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
740
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
3092
741
 
3093
- *s = hsum_float_8(acc);
3094
- #elif defined(__AVX__)
3095
- // Initialize accumulator with zeros
3096
- __m256 acc = _mm256_setzero_ps();
3097
- __m128i mask = _mm_set1_epi8((char)0xF0);
742
+ // F16 AVX
3098
743
 
3099
- // Main loop
3100
- for (int i = 0; i < nb; i++) {
3101
- /* Compute combined scale for the block */
3102
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
744
+ #define GGML_F16_STEP 32
745
+ #define GGML_F16_EPR 8
3103
746
 
3104
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3105
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
3106
- __m128i bxhil = _mm256_castsi256_si128(bxhi);
3107
- __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
3108
- bxhil = _mm_andnot_si128(bxhil, mask);
3109
- bxhih = _mm_andnot_si128(bxhih, mask);
3110
- __m128i bxl = _mm256_castsi256_si128(bx);
3111
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
3112
- bxl = _mm_or_si128(bxl, bxhil);
3113
- bxh = _mm_or_si128(bxh, bxhih);
3114
- bx = MM256_SET_M128I(bxh, bxl);
747
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
3115
748
 
3116
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
749
+ #define GGML_F32Cx8 __m256
750
+ #define GGML_F32Cx8_ZERO _mm256_setzero_ps()
751
+ #define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x)
3117
752
 
3118
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
753
+ #if defined(__F16C__)
754
+ // the _mm256_cvt intrinsics require F16C
755
+ #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
756
+ #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
757
+ #else
758
+ static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
759
+ float tmp[8];
3119
760
 
3120
- /* Multiply q with scale and accumulate */
3121
- acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
761
+ for (int i = 0; i < 8; i++) {
762
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
3122
763
  }
3123
764
 
3124
- *s = hsum_float_8(acc);
3125
- #elif defined(__riscv_v_intrinsic)
3126
- float sumf = 0.0;
765
+ return _mm256_loadu_ps(tmp);
766
+ }
767
+ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
768
+ float arr[8];
3127
769
 
3128
- uint32_t qh;
770
+ _mm256_storeu_ps(arr, y);
3129
771
 
3130
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
772
+ for (int i = 0; i < 8; i++)
773
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
774
+ }
775
+ #define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x)
776
+ #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
777
+ #endif
3131
778
 
3132
- // These tempory registers are for masking and shift operations
3133
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3134
- vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
779
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
780
+ #define GGML_F32Cx8_ADD _mm256_add_ps
781
+ #define GGML_F32Cx8_MUL _mm256_mul_ps
782
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
3135
783
 
3136
- vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3137
- vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
784
+ #define GGML_F16_VEC GGML_F32Cx8
785
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
786
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
787
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
788
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
789
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
790
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
791
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
792
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
3138
793
 
3139
- for (int i = 0; i < nb; i++) {
3140
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
794
+ #elif defined(__POWER9_VECTOR__)
3141
795
 
3142
- // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3143
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3144
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3145
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
796
+ #define GGML_SIMD
3146
797
 
3147
- // ((qh & (1u << (j + 16))) >> (j + 12));
3148
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3149
- vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
798
+ // F32 POWER9
3150
799
 
3151
- // narrowing
3152
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3153
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
800
+ #define GGML_F32_STEP 32
801
+ #define GGML_F32_EPR 4
3154
802
 
3155
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3156
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
803
+ #define GGML_F32x4 vector float
804
+ #define GGML_F32x4_ZERO 0.0f
805
+ #define GGML_F32x4_SET1 vec_splats
806
+ #define GGML_F32x4_LOAD(p) vec_xl(0, p)
807
+ #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
808
+ #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
809
+ #define GGML_F32x4_ADD vec_add
810
+ #define GGML_F32x4_MUL vec_mul
811
+ #define GGML_F32x4_REDUCE(res, x) \
812
+ { \
813
+ int offset = GGML_F32_ARR >> 1; \
814
+ for (int i = 0; i < offset; ++i) { \
815
+ x[i] = vec_add(x[i], x[offset+i]); \
816
+ } \
817
+ offset >>= 1; \
818
+ for (int i = 0; i < offset; ++i) { \
819
+ x[i] = vec_add(x[i], x[offset+i]); \
820
+ } \
821
+ offset >>= 1; \
822
+ for (int i = 0; i < offset; ++i) { \
823
+ x[i] = vec_add(x[i], x[offset+i]); \
824
+ } \
825
+ res = vec_extract(x[0], 0) + \
826
+ vec_extract(x[0], 1) + \
827
+ vec_extract(x[0], 2) + \
828
+ vec_extract(x[0], 3); \
829
+ }
3157
830
 
3158
- // load
3159
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
831
+ #define GGML_F32_VEC GGML_F32x4
832
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
833
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
834
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
835
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
836
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
837
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
838
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
839
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
3160
840
 
3161
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3162
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
841
+ // F16 POWER9
842
+ #define GGML_F16_STEP GGML_F32_STEP
843
+ #define GGML_F16_EPR GGML_F32_EPR
844
+ #define GGML_F16_VEC GGML_F32x4
845
+ #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
846
+ #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
847
+ #define GGML_F16_VEC_FMA GGML_F32x4_FMA
848
+ #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
849
+ // Use vec_xl, not vec_ld, in case the load address is not aligned.
850
+ #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
851
+ vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
852
+ vec_extract_fp32_from_shortl(vec_xl(0, p))
853
+ #define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
854
+ #define GGML_F16_VEC_STORE(p, r, i) \
855
+ if (i & 0x1) \
856
+ vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \
857
+ r[i - GGML_ENDIAN_BYTE(0)]), \
858
+ 0, p - GGML_F16_EPR)
3163
859
 
3164
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3165
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
860
+ #elif defined(__wasm_simd128__)
3166
861
 
3167
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3168
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
862
+ #define GGML_SIMD
3169
863
 
3170
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3171
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
864
+ // F32 WASM
3172
865
 
3173
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3174
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
866
+ #define GGML_F32_STEP 16
867
+ #define GGML_F32_EPR 4
3175
868
 
3176
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3177
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
869
+ #define GGML_F32x4 v128_t
870
+ #define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f)
871
+ #define GGML_F32x4_SET1(x) wasm_f32x4_splat(x)
872
+ #define GGML_F32x4_LOAD wasm_v128_load
873
+ #define GGML_F32x4_STORE wasm_v128_store
874
+ #define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
875
+ #define GGML_F32x4_ADD wasm_f32x4_add
876
+ #define GGML_F32x4_MUL wasm_f32x4_mul
877
+ #define GGML_F32x4_REDUCE(res, x) \
878
+ { \
879
+ int offset = GGML_F32_ARR >> 1; \
880
+ for (int i = 0; i < offset; ++i) { \
881
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
882
+ } \
883
+ offset >>= 1; \
884
+ for (int i = 0; i < offset; ++i) { \
885
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
886
+ } \
887
+ offset >>= 1; \
888
+ for (int i = 0; i < offset; ++i) { \
889
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
890
+ } \
891
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
892
+ wasm_f32x4_extract_lane(x[0], 1) + \
893
+ wasm_f32x4_extract_lane(x[0], 2) + \
894
+ wasm_f32x4_extract_lane(x[0], 3); \
895
+ }
3178
896
 
3179
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
897
+ #define GGML_F32_VEC GGML_F32x4
898
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
899
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
900
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
901
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
902
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
903
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
904
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
905
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
3180
906
 
3181
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3182
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
907
+ // F16 WASM
3183
908
 
3184
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
909
+ #define GGML_F16_STEP 16
910
+ #define GGML_F16_EPR 4
3185
911
 
3186
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3187
- }
912
+ inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
913
+ float tmp[4];
3188
914
 
3189
- *s = sumf;
3190
- #else
3191
- // scalar
3192
- float sumf = 0.0;
915
+ tmp[0] = GGML_FP16_TO_FP32(p[0]);
916
+ tmp[1] = GGML_FP16_TO_FP32(p[1]);
917
+ tmp[2] = GGML_FP16_TO_FP32(p[2]);
918
+ tmp[3] = GGML_FP16_TO_FP32(p[3]);
3193
919
 
3194
- for (int i = 0; i < nb; i++) {
3195
- uint32_t qh;
3196
- memcpy(&qh, x[i].qh, sizeof(qh));
920
+ return wasm_v128_load(tmp);
921
+ }
3197
922
 
3198
- int sumi = 0;
923
+ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
924
+ float tmp[4];
3199
925
 
3200
- for (int j = 0; j < qk/2; ++j) {
3201
- const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3202
- const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
926
+ wasm_v128_store(tmp, x);
3203
927
 
3204
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
3205
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
928
+ p[0] = GGML_FP32_TO_FP16(tmp[0]);
929
+ p[1] = GGML_FP32_TO_FP16(tmp[1]);
930
+ p[2] = GGML_FP32_TO_FP16(tmp[2]);
931
+ p[3] = GGML_FP32_TO_FP16(tmp[3]);
932
+ }
3206
933
 
3207
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
3208
- }
934
+ #define GGML_F16x4 v128_t
935
+ #define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f)
936
+ #define GGML_F16x4_SET1(x) wasm_f32x4_splat(x)
937
+ #define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x)
938
+ #define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
939
+ #define GGML_F16x4_FMA GGML_F32x4_FMA
940
+ #define GGML_F16x4_ADD wasm_f32x4_add
941
+ #define GGML_F16x4_MUL wasm_f32x4_mul
942
+ #define GGML_F16x4_REDUCE(res, x) \
943
+ { \
944
+ int offset = GGML_F16_ARR >> 1; \
945
+ for (int i = 0; i < offset; ++i) { \
946
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
947
+ } \
948
+ offset >>= 1; \
949
+ for (int i = 0; i < offset; ++i) { \
950
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
951
+ } \
952
+ offset >>= 1; \
953
+ for (int i = 0; i < offset; ++i) { \
954
+ x[i] = wasm_f32x4_add(x[i], x[offset+i]); \
955
+ } \
956
+ res = wasm_f32x4_extract_lane(x[0], 0) + \
957
+ wasm_f32x4_extract_lane(x[0], 1) + \
958
+ wasm_f32x4_extract_lane(x[0], 2) + \
959
+ wasm_f32x4_extract_lane(x[0], 3); \
960
+ }
3209
961
 
3210
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
3211
- }
962
+ #define GGML_F16_VEC GGML_F16x4
963
+ #define GGML_F16_VEC_ZERO GGML_F16x4_ZERO
964
+ #define GGML_F16_VEC_SET1 GGML_F16x4_SET1
965
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p)
966
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
967
+ #define GGML_F16_VEC_FMA GGML_F16x4_FMA
968
+ #define GGML_F16_VEC_ADD GGML_F16x4_ADD
969
+ #define GGML_F16_VEC_MUL GGML_F16x4_MUL
970
+ #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
3212
971
 
3213
- *s = sumf;
3214
- #endif
3215
- }
972
+ #elif defined(__SSE3__)
3216
973
 
3217
- static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3218
- const int qk = QK8_1;
3219
- const int nb = n / qk;
974
+ #define GGML_SIMD
3220
975
 
3221
- assert(n % qk == 0);
3222
- assert(qk == QK5_1);
976
+ // F32 SSE
3223
977
 
3224
- const block_q5_1 * restrict x = vx;
3225
- const block_q8_1 * restrict y = vy;
978
+ #define GGML_F32_STEP 32
979
+ #define GGML_F32_EPR 4
3226
980
 
3227
- #if defined(__ARM_NEON)
3228
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
3229
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
3230
-
3231
- float summs0 = 0.0f;
3232
- float summs1 = 0.0f;
3233
-
3234
- uint32_t qh0;
3235
- uint32_t qh1;
3236
-
3237
- uint64_t tmp0[4];
3238
- uint64_t tmp1[4];
3239
-
3240
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
3241
- for (int i = 0; i < nb; i += 2) {
3242
- const block_q5_1 * restrict x0 = &x[i];
3243
- const block_q5_1 * restrict x1 = &x[i + 1];
3244
- const block_q8_1 * restrict y0 = &y[i];
3245
- const block_q8_1 * restrict y1 = &y[i + 1];
3246
-
3247
- const uint8x16_t m4b = vdupq_n_u8(0x0F);
3248
-
3249
- summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
3250
- summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
3251
-
3252
- // extract the 5th bit via lookup table ((b) << 4)
3253
- memcpy(&qh0, x0->qh, sizeof(qh0));
3254
- memcpy(&qh1, x1->qh, sizeof(qh1));
3255
-
3256
- tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
3257
- tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
3258
- tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
3259
- tmp0[3] = table_b2b_0[(qh0 >> 24) ];
3260
-
3261
- tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
3262
- tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
3263
- tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
3264
- tmp1[3] = table_b2b_0[(qh1 >> 24) ];
3265
-
3266
- const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
3267
- const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
3268
- const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
3269
- const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
3270
-
3271
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
3272
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
3273
-
3274
- // 4-bit -> 8-bit
3275
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3276
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3277
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3278
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3279
-
3280
- // add high bit
3281
- const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
3282
- const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
3283
- const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
3284
- const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
3285
-
3286
- // load y
3287
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
3288
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3289
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
3290
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3291
-
3292
- #if defined(__ARM_FEATURE_DOTPROD)
3293
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3294
- vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
3295
- vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
3296
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3297
- vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
3298
- vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
981
+ #define GGML_F32x4 __m128
982
+ #define GGML_F32x4_ZERO _mm_setzero_ps()
983
+ #define GGML_F32x4_SET1(x) _mm_set1_ps(x)
984
+ #define GGML_F32x4_LOAD _mm_loadu_ps
985
+ #define GGML_F32x4_STORE _mm_storeu_ps
986
+ #if defined(__FMA__)
987
+ // TODO: Does this work?
988
+ #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
3299
989
  #else
3300
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
3301
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
3302
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
3303
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
3304
-
3305
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
3306
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
3307
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
3308
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
3309
-
3310
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3311
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3312
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3313
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3314
-
3315
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
3316
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
990
+ #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
3317
991
  #endif
3318
- }
3319
-
3320
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
3321
- #elif defined(__wasm_simd128__)
3322
- v128_t sumv = wasm_f32x4_splat(0.0f);
3323
-
3324
- float summs = 0.0f;
3325
-
3326
- uint32_t qh;
3327
- uint64_t tmp[4];
3328
-
3329
- // TODO: check if unrolling this is better
3330
- for (int i = 0; i < nb; ++i) {
3331
- const block_q5_1 * restrict x0 = &x[i];
3332
- const block_q8_1 * restrict y0 = &y[i];
3333
-
3334
- summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
3335
-
3336
- const v128_t m4b = wasm_i8x16_splat(0x0F);
3337
-
3338
- // extract the 5th bit
3339
- memcpy(&qh, x0->qh, sizeof(qh));
3340
-
3341
- tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
3342
- tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
3343
- tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
3344
- tmp[3] = table_b2b_0[(qh >> 24) ];
992
+ #define GGML_F32x4_ADD _mm_add_ps
993
+ #define GGML_F32x4_MUL _mm_mul_ps
994
+ #define GGML_F32x4_REDUCE(res, x) \
995
+ { \
996
+ int offset = GGML_F32_ARR >> 1; \
997
+ for (int i = 0; i < offset; ++i) { \
998
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
999
+ } \
1000
+ offset >>= 1; \
1001
+ for (int i = 0; i < offset; ++i) { \
1002
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
1003
+ } \
1004
+ offset >>= 1; \
1005
+ for (int i = 0; i < offset; ++i) { \
1006
+ x[i] = _mm_add_ps(x[i], x[offset+i]); \
1007
+ } \
1008
+ const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \
1009
+ res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \
1010
+ }
1011
+ // TODO: is this optimal ?
3345
1012
 
3346
- const v128_t qhl = wasm_v128_load(tmp + 0);
3347
- const v128_t qhh = wasm_v128_load(tmp + 2);
1013
+ #define GGML_F32_VEC GGML_F32x4
1014
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1015
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1016
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1017
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1018
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1019
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1020
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1021
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
3348
1022
 
3349
- const v128_t v0 = wasm_v128_load(x0->qs);
1023
+ // F16 SSE
3350
1024
 
3351
- // 4-bit -> 8-bit
3352
- const v128_t v0l = wasm_v128_and (v0, m4b);
3353
- const v128_t v0h = wasm_u8x16_shr(v0, 4);
1025
+ #define GGML_F16_STEP 32
1026
+ #define GGML_F16_EPR 4
3354
1027
 
3355
- // add high bit
3356
- const v128_t v0lf = wasm_v128_or(v0l, qhl);
3357
- const v128_t v0hf = wasm_v128_or(v0h, qhh);
1028
+ static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) {
1029
+ float tmp[4];
3358
1030
 
3359
- // load y
3360
- const v128_t v1l = wasm_v128_load(y0->qs);
3361
- const v128_t v1h = wasm_v128_load(y0->qs + 16);
1031
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1032
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1033
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1034
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
3362
1035
 
3363
- // int8x16 -> int16x8
3364
- const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3365
- const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3366
- const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3367
- const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
1036
+ return _mm_loadu_ps(tmp);
1037
+ }
3368
1038
 
3369
- const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3370
- const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3371
- const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3372
- const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
1039
+ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1040
+ float arr[4];
3373
1041
 
3374
- // dot product
3375
- sumv = wasm_f32x4_add(sumv,
3376
- wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
3377
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3378
- wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3379
- wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3380
- wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
3381
- wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
3382
- }
1042
+ _mm_storeu_ps(arr, y);
3383
1043
 
3384
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3385
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
3386
- #elif defined(__AVX2__)
3387
- // Initialize accumulator with zeros
3388
- __m256 acc = _mm256_setzero_ps();
1044
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1045
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1046
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1047
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1048
+ }
3389
1049
 
3390
- float summs = 0.0f;
1050
+ #define GGML_F32Cx4 __m128
1051
+ #define GGML_F32Cx4_ZERO _mm_setzero_ps()
1052
+ #define GGML_F32Cx4_SET1(x) _mm_set1_ps(x)
1053
+ #define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x)
1054
+ #define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
1055
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1056
+ #define GGML_F32Cx4_ADD _mm_add_ps
1057
+ #define GGML_F32Cx4_MUL _mm_mul_ps
1058
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
3391
1059
 
3392
- // Main loop
3393
- for (int i = 0; i < nb; i++) {
3394
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
1060
+ #define GGML_F16_VEC GGML_F32Cx4
1061
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1062
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1063
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1064
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1065
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1066
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1067
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1068
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
3395
1069
 
3396
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
1070
+ #endif
3397
1071
 
3398
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3399
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
3400
- bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3401
- bx = _mm256_or_si256(bx, bxhi);
1072
+ // GGML_F32_ARR / GGML_F16_ARR
1073
+ // number of registers to use per step
1074
+ #ifdef GGML_SIMD
1075
+ #define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
1076
+ #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1077
+ #endif
3402
1078
 
3403
- const __m256 dy = _mm256_set1_ps(y[i].d);
3404
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1079
+ //
1080
+ // fundamental operations
1081
+ //
3405
1082
 
3406
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
1083
+ inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3407
1084
 
3408
- acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
3409
- }
1085
+ inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3410
1086
 
3411
- *s = hsum_float_8(acc) + summs;
3412
- #elif defined(__AVX__)
3413
- // Initialize accumulator with zeros
3414
- __m256 acc = _mm256_setzero_ps();
3415
- __m128i mask = _mm_set1_epi8(0x10);
1087
+ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3416
1088
 
3417
- float summs = 0.0f;
1089
+ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
3418
1090
 
3419
- // Main loop
3420
- for (int i = 0; i < nb; i++) {
3421
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
1091
+ inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1092
+ inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1093
+ inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
1094
+ inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
1095
+ inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; }
1096
+ inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; }
1097
+ inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1098
+ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; }
1099
+ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
1100
+ inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
3422
1101
 
3423
- summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
1102
+ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
1103
+ #ifdef GGML_SIMD
1104
+ float sumf = 0.0f;
1105
+ const int np = (n & ~(GGML_F32_STEP - 1));
3424
1106
 
3425
- __m256i bx = bytes_from_nibbles_32(x[i].qs);
3426
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
3427
- __m128i bxhil = _mm256_castsi256_si128(bxhi);
3428
- __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
3429
- bxhil = _mm_and_si128(bxhil, mask);
3430
- bxhih = _mm_and_si128(bxhih, mask);
3431
- __m128i bxl = _mm256_castsi256_si128(bx);
3432
- __m128i bxh = _mm256_extractf128_si256(bx, 1);
3433
- bxl = _mm_or_si128(bxl, bxhil);
3434
- bxh = _mm_or_si128(bxh, bxhih);
3435
- bx = MM256_SET_M128I(bxh, bxl);
1107
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
3436
1108
 
3437
- const __m256 dy = _mm256_set1_ps(y[i].d);
3438
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1109
+ GGML_F32_VEC ax[GGML_F32_ARR];
1110
+ GGML_F32_VEC ay[GGML_F32_ARR];
3439
1111
 
3440
- const __m256 q = mul_sum_us8_pairs_float(bx, by);
1112
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
1113
+ for (int j = 0; j < GGML_F32_ARR; j++) {
1114
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
1115
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3441
1116
 
3442
- acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
1117
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
1118
+ }
3443
1119
  }
3444
1120
 
3445
- *s = hsum_float_8(acc) + summs;
3446
- #elif defined(__riscv_v_intrinsic)
3447
- float sumf = 0.0;
3448
-
3449
- uint32_t qh;
3450
-
3451
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
3452
-
3453
- // temporary registers for shift operations
3454
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3455
- vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3456
-
3457
- for (int i = 0; i < nb; i++) {
3458
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
3459
-
3460
- // load qh
3461
- vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3462
-
3463
- // ((qh >> (j + 0)) << 4) & 0x10;
3464
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3465
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3466
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3467
-
3468
- // ((qh >> (j + 12)) ) & 0x10;
3469
- vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3470
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3471
-
3472
- // narrowing
3473
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3474
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3475
-
3476
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3477
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3478
-
3479
- // load
3480
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
3481
-
3482
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3483
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
3484
-
3485
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3486
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3487
-
3488
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3489
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3490
-
3491
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3492
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3493
-
3494
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3495
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3496
-
3497
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3498
-
3499
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3500
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3501
-
3502
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
1121
+ // reduce sum0..sum3 to sum0
1122
+ GGML_F32_VEC_REDUCE(sumf, sum);
3503
1123
 
3504
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1124
+ // leftovers
1125
+ for (int i = np; i < n; ++i) {
1126
+ sumf += x[i]*y[i];
3505
1127
  }
3506
-
3507
- *s = sumf;
3508
1128
  #else
3509
1129
  // scalar
3510
- float sumf = 0.0;
3511
-
3512
- for (int i = 0; i < nb; i++) {
3513
- uint32_t qh;
3514
- memcpy(&qh, x[i].qh, sizeof(qh));
3515
-
3516
- int sumi = 0;
3517
-
3518
- for (int j = 0; j < qk/2; ++j) {
3519
- const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
3520
- const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
3521
-
3522
- const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
3523
- const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
3524
-
3525
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
3526
- }
3527
-
3528
- sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
1130
+ ggml_float sumf = 0.0;
1131
+ for (int i = 0; i < n; ++i) {
1132
+ sumf += (ggml_float)(x[i]*y[i]);
3529
1133
  }
1134
+ #endif
3530
1135
 
3531
1136
  *s = sumf;
3532
- #endif
3533
1137
  }
3534
1138
 
3535
- static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3536
- const int qk = QK8_0;
3537
- const int nb = n / qk;
3538
-
3539
- assert(n % qk == 0);
3540
-
3541
- const block_q8_0 * restrict x = vx;
3542
- const block_q8_0 * restrict y = vy;
3543
-
3544
- #if defined(__ARM_NEON)
3545
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
3546
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
3547
-
3548
- GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb
3549
- for (int i = 0; i < nb; i += 2) {
3550
- const block_q8_0 * restrict x0 = &x[i + 0];
3551
- const block_q8_0 * restrict x1 = &x[i + 1];
3552
- const block_q8_0 * restrict y0 = &y[i + 0];
3553
- const block_q8_0 * restrict y1 = &y[i + 1];
3554
-
3555
- const int8x16_t x0_0 = vld1q_s8(x0->qs);
3556
- const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3557
- const int8x16_t x1_0 = vld1q_s8(x1->qs);
3558
- const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
3559
-
3560
- // load y
3561
- const int8x16_t y0_0 = vld1q_s8(y0->qs);
3562
- const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3563
- const int8x16_t y1_0 = vld1q_s8(y1->qs);
3564
- const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
3565
-
3566
- #if defined(__ARM_FEATURE_DOTPROD)
3567
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3568
- vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3569
- vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3570
-
3571
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3572
- vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3573
- vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
1139
+ static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
1140
+ ggml_float sumf = 0.0;
3574
1141
 
3575
- #else
3576
- const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3577
- const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3578
- const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3579
- const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3580
-
3581
- const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3582
- const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3583
- const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3584
- const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3585
-
3586
- const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3587
- const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3588
- const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3589
- const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3590
-
3591
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3592
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3593
- #endif
3594
- }
1142
+ #if defined(GGML_SIMD)
1143
+ const int np = (n & ~(GGML_F16_STEP - 1));
3595
1144
 
3596
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3597
- #elif defined(__AVX2__) || defined(__AVX__)
3598
- // Initialize accumulator with zeros
3599
- __m256 acc = _mm256_setzero_ps();
1145
+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
3600
1146
 
3601
- // Main loop
3602
- for (int i = 0; i < nb; ++i) {
3603
- // Compute combined scale for the block
3604
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
3605
- __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3606
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1147
+ GGML_F16_VEC ax[GGML_F16_ARR];
1148
+ GGML_F16_VEC ay[GGML_F16_ARR];
3607
1149
 
3608
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
1150
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
1151
+ for (int j = 0; j < GGML_F16_ARR; j++) {
1152
+ ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
1153
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
3609
1154
 
3610
- // Multiply q with scale and accumulate
3611
- #if defined(__AVX2__)
3612
- acc = _mm256_fmadd_ps( d, q, acc );
3613
- #else
3614
- acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
3615
- #endif
1155
+ sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
1156
+ }
3616
1157
  }
3617
1158
 
3618
- *s = hsum_float_8(acc);
3619
- #elif defined(__riscv_v_intrinsic)
3620
- float sumf = 0.0;
3621
- size_t vl = __riscv_vsetvl_e8m1(qk);
3622
-
3623
- for (int i = 0; i < nb; i++) {
3624
- // load elements
3625
- vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
3626
- vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
3627
-
3628
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
3629
-
3630
- vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3631
- vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3632
-
3633
- int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
1159
+ // reduce sum0..sum3 to sum0
1160
+ GGML_F16_VEC_REDUCE(sumf, sum);
3634
1161
 
3635
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
1162
+ // leftovers
1163
+ for (int i = np; i < n; ++i) {
1164
+ sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
3636
1165
  }
3637
-
3638
- *s = sumf;
3639
1166
  #else
3640
- // scalar
3641
- float sumf = 0.0;
3642
-
3643
- for (int i = 0; i < nb; i++) {
3644
- int sumi = 0;
3645
-
3646
- for (int j = 0; j < qk; j++) {
3647
- sumi += x[i].qs[j]*y[i].qs[j];
3648
- }
3649
-
3650
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
1167
+ for (int i = 0; i < n; ++i) {
1168
+ sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
3651
1169
  }
1170
+ #endif
3652
1171
 
3653
1172
  *s = sumf;
3654
- #endif
3655
1173
  }
3656
1174
 
3657
1175
  // compute GGML_VEC_DOT_UNROLL dot products at once
@@ -3846,7 +1364,7 @@ inline static float ggml_gelu_f32(float x) {
3846
1364
  inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3847
1365
  const uint16_t * i16 = (const uint16_t *) x;
3848
1366
  for (int i = 0; i < n; ++i) {
3849
- y[i] = table_gelu_f16[i16[i]];
1367
+ y[i] = ggml_table_gelu_f16[i16[i]];
3850
1368
  }
3851
1369
  }
3852
1370
 
@@ -3856,7 +1374,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
3856
1374
  for (int i = 0; i < n; ++i) {
3857
1375
  ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3858
1376
  memcpy(&t, &fp16, sizeof(uint16_t));
3859
- y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]);
1377
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
3860
1378
  }
3861
1379
  }
3862
1380
  #else
@@ -3874,7 +1392,7 @@ inline static float ggml_gelu_quick_f32(float x) {
3874
1392
  //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3875
1393
  // const uint16_t * i16 = (const uint16_t *) x;
3876
1394
  // for (int i = 0; i < n; ++i) {
3877
- // y[i] = table_gelu_quick_f16[i16[i]];
1395
+ // y[i] = ggml_table_gelu_quick_f16[i16[i]];
3878
1396
  // }
3879
1397
  //}
3880
1398
 
@@ -3884,7 +1402,7 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float *
3884
1402
  for (int i = 0; i < n; ++i) {
3885
1403
  ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3886
1404
  memcpy(&t, &fp16, sizeof(uint16_t));
3887
- y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
1405
+ y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
3888
1406
  }
3889
1407
  }
3890
1408
  #else
@@ -3903,7 +1421,7 @@ inline static float ggml_silu_f32(float x) {
3903
1421
  //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3904
1422
  // const uint16_t * i16 = (const uint16_t *) x;
3905
1423
  // for (int i = 0; i < n; ++i) {
3906
- // y[i] = table_silu_f16[i16[i]];
1424
+ // y[i] = ggml_table_silu_f16[i16[i]];
3907
1425
  // }
3908
1426
  //}
3909
1427
 
@@ -3913,7 +1431,7 @@ inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
3913
1431
  for (int i = 0; i < n; ++i) {
3914
1432
  ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3915
1433
  memcpy(&t, &fp16, sizeof(uint16_t));
3916
- y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]);
1434
+ y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
3917
1435
  }
3918
1436
  }
3919
1437
  #else
@@ -4629,11 +2147,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
4629
2147
  for (int i = 0; i < (1 << 16); ++i) {
4630
2148
  uint16_t ui = i;
4631
2149
  memcpy(&ii, &ui, sizeof(ii));
4632
- const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
4633
- table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
4634
- table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
4635
- table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
4636
- table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2150
+ const float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
2151
+ ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2152
+ ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2153
+ ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2154
+ ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
4637
2155
  }
4638
2156
 
4639
2157
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -5636,7 +3154,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
5636
3154
  // TODO: support less-strict constraint
5637
3155
  // GGML_ASSERT(ggml_can_repeat(b, a));
5638
3156
  GGML_ASSERT(ggml_can_repeat_rows(b, a));
5639
- GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
3157
+ GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
5640
3158
 
5641
3159
  bool is_node = false;
5642
3160
 
@@ -7328,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
7328
4846
  int n_dims,
7329
4847
  int mode,
7330
4848
  int n_ctx,
4849
+ int n_orig_ctx,
7331
4850
  float freq_base,
7332
4851
  float freq_scale,
4852
+ float ext_factor,
4853
+ float attn_factor,
4854
+ float beta_fast,
4855
+ float beta_slow,
7333
4856
  float xpos_base,
7334
4857
  bool xpos_down,
7335
4858
  bool inplace) {
@@ -7345,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
7345
4868
 
7346
4869
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
7347
4870
 
7348
- int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
7349
- memcpy(params + 4, &freq_base, sizeof(float));
7350
- memcpy(params + 5, &freq_scale, sizeof(float));
7351
- memcpy(params + 6, &xpos_base, sizeof(float));
7352
- memcpy(params + 7, &xpos_down, sizeof(bool));
4871
+ int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
4872
+ memcpy(params + 5, &freq_base, sizeof(float));
4873
+ memcpy(params + 6, &freq_scale, sizeof(float));
4874
+ memcpy(params + 7, &ext_factor, sizeof(float));
4875
+ memcpy(params + 8, &attn_factor, sizeof(float));
4876
+ memcpy(params + 9, &beta_fast, sizeof(float));
4877
+ memcpy(params + 10, &beta_slow, sizeof(float));
4878
+ memcpy(params + 11, &xpos_base, sizeof(float));
4879
+ memcpy(params + 12, &xpos_down, sizeof(bool));
7353
4880
  ggml_set_op_params(result, params, sizeof(params));
7354
4881
 
7355
4882
  result->op = GGML_OP_ROPE;
@@ -7367,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
7367
4894
  int n_dims,
7368
4895
  int mode,
7369
4896
  int n_ctx) {
7370
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
4897
+ return ggml_rope_impl(
4898
+ ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
4899
+ );
7371
4900
  }
7372
4901
 
7373
4902
  struct ggml_tensor * ggml_rope_inplace(
@@ -7377,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
7377
4906
  int n_dims,
7378
4907
  int mode,
7379
4908
  int n_ctx) {
7380
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
4909
+ return ggml_rope_impl(
4910
+ ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
4911
+ );
7381
4912
  }
7382
4913
 
7383
4914
  struct ggml_tensor * ggml_rope_custom(
@@ -7387,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
7387
4918
  int n_dims,
7388
4919
  int mode,
7389
4920
  int n_ctx,
4921
+ int n_orig_ctx,
7390
4922
  float freq_base,
7391
- float freq_scale) {
7392
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
4923
+ float freq_scale,
4924
+ float ext_factor,
4925
+ float attn_factor,
4926
+ float beta_fast,
4927
+ float beta_slow) {
4928
+ return ggml_rope_impl(
4929
+ ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
4930
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
4931
+ );
7393
4932
  }
7394
4933
 
7395
4934
  struct ggml_tensor * ggml_rope_custom_inplace(
@@ -7399,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
7399
4938
  int n_dims,
7400
4939
  int mode,
7401
4940
  int n_ctx,
4941
+ int n_orig_ctx,
7402
4942
  float freq_base,
7403
- float freq_scale) {
7404
- return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
4943
+ float freq_scale,
4944
+ float ext_factor,
4945
+ float attn_factor,
4946
+ float beta_fast,
4947
+ float beta_slow) {
4948
+ return ggml_rope_impl(
4949
+ ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
4950
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
4951
+ );
7405
4952
  }
7406
4953
 
7407
4954
  struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -7411,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
7411
4958
  int n_dims,
7412
4959
  float base,
7413
4960
  bool down) {
7414
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
4961
+ return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
7415
4962
  }
7416
4963
 
7417
4964
  // ggml_rope_back
@@ -9410,9 +6957,15 @@ static void ggml_compute_forward_add_f16_f32(
9410
6957
 
9411
6958
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
9412
6959
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
9413
- GGML_ASSERT(dst->type == GGML_TYPE_F16);
9414
6960
 
9415
- GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6961
+ if (dst->type == GGML_TYPE_F32) {
6962
+ GGML_ASSERT( nb0 == sizeof(float));
6963
+ }
6964
+ else {
6965
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6966
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6967
+ }
6968
+
9416
6969
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
9417
6970
 
9418
6971
  // rows per thread
@@ -9423,18 +6976,35 @@ static void ggml_compute_forward_add_f16_f32(
9423
6976
  const int ir1 = MIN(ir0 + dr, nr);
9424
6977
 
9425
6978
  if (nb10 == sizeof(float)) {
9426
- for (int ir = ir0; ir < ir1; ++ir) {
9427
- // src0, src1 and dst are same shape => same indices
9428
- const int i3 = ir/(ne2*ne1);
9429
- const int i2 = (ir - i3*ne2*ne1)/ne1;
9430
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9431
-
9432
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
9433
- ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
9434
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
9435
-
9436
- for (int i = 0; i < ne0; i++) {
9437
- dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
6979
+ if (dst->type == GGML_TYPE_F16) {
6980
+ for (int ir = ir0; ir < ir1; ++ir) {
6981
+ // src0, src1 and dst are same shape => same indices
6982
+ const int i3 = ir/(ne2*ne1);
6983
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
6984
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
6985
+
6986
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
6987
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
6988
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
6989
+
6990
+ for (int i = 0; i < ne0; i++) {
6991
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
6992
+ }
6993
+ }
6994
+ } else {
6995
+ for (int ir = ir0; ir < ir1; ++ir) {
6996
+ // src0, src1 and dst are same shape => same indices
6997
+ const int i3 = ir/(ne2*ne1);
6998
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
6999
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
7000
+
7001
+ float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
7002
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
7003
+ float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
7004
+
7005
+ for (int i = 0; i < ne0; i++) {
7006
+ dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
7007
+ }
9438
7008
  }
9439
7009
  }
9440
7010
  }
@@ -12996,7 +10566,7 @@ static void ggml_compute_forward_soft_max_f32(
12996
10566
  // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
12997
10567
  ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
12998
10568
  memcpy(&scvt, &s, sizeof(scvt));
12999
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
10569
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
13000
10570
  sum += (ggml_float)val;
13001
10571
  dp[i] = val;
13002
10572
  }
@@ -13361,6 +10931,45 @@ static void ggml_compute_forward_clamp(
13361
10931
 
13362
10932
  // ggml_compute_forward_rope
13363
10933
 
10934
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
10935
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
10936
+ return 1 - MIN(1, MAX(0, y));
10937
+ }
10938
+
10939
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
10940
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
10941
+ static void rope_yarn(
10942
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
10943
+ float * cos_theta, float * sin_theta
10944
+ ) {
10945
+ // Get n-d rotational scaling corrected for extrapolation
10946
+ float theta_interp = freq_scale * theta_extrap;
10947
+ float theta = theta_interp;
10948
+ if (ext_factor != 0.0f) {
10949
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
10950
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
10951
+
10952
+ // Get n-d magnitude scaling corrected for interpolation
10953
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
10954
+ }
10955
+ *cos_theta = cosf(theta) * mscale;
10956
+ *sin_theta = sinf(theta) * mscale;
10957
+ }
10958
+
10959
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
10960
+ // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
10961
+ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
10962
+ return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
10963
+ }
10964
+
10965
+ void ggml_rope_yarn_corr_dims(
10966
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
10967
+ ) {
10968
+ // start and end correction dims
10969
+ dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
10970
+ dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
10971
+ }
10972
+
13364
10973
  static void ggml_compute_forward_rope_f32(
13365
10974
  const struct ggml_compute_params * params,
13366
10975
  const struct ggml_tensor * src0,
@@ -13370,21 +10979,26 @@ static void ggml_compute_forward_rope_f32(
13370
10979
  return;
13371
10980
  }
13372
10981
 
13373
- float freq_base;
13374
- float freq_scale;
10982
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
13375
10983
 
13376
10984
  // these two only relevant for xPos RoPE:
13377
10985
  float xpos_base;
13378
10986
  bool xpos_down;
13379
10987
 
13380
- //const int n_past = ((int32_t *) dst->op_params)[0];
13381
- const int n_dims = ((int32_t *) dst->op_params)[1];
13382
- const int mode = ((int32_t *) dst->op_params)[2];
13383
- const int n_ctx = ((int32_t *) dst->op_params)[3];
13384
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
13385
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
13386
- memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
13387
- memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
10988
+ //const int n_past = ((int32_t *) dst->op_params)[0];
10989
+ const int n_dims = ((int32_t *) dst->op_params)[1];
10990
+ const int mode = ((int32_t *) dst->op_params)[2];
10991
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
10992
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
10993
+
10994
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
10995
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
10996
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
10997
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
10998
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
10999
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
11000
+ memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float));
11001
+ memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool));
13388
11002
 
13389
11003
  GGML_TENSOR_UNARY_OP_LOCALS
13390
11004
 
@@ -13412,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
13412
11026
  int ir = 0;
13413
11027
 
13414
11028
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
11029
+ const float inv_ndims = -1.f/n_dims;
11030
+ float corr_dims[2];
11031
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
13415
11032
 
13416
11033
  const bool is_neox = mode & 2;
13417
11034
  const bool is_glm = mode & 4;
@@ -13425,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
13425
11042
  if (ir++ < ir0) continue;
13426
11043
  if (ir > ir1) break;
13427
11044
 
13428
- float theta = freq_scale * (float)p;
11045
+ float theta_base = (float)p;
13429
11046
 
13430
11047
  if (is_glm) {
13431
- theta = MIN(p, n_ctx - 2);
11048
+ theta_base = MIN(p, n_ctx - 2);
13432
11049
  float block_theta = MAX(p - (n_ctx - 2), 0);
13433
11050
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
13434
- const float cos_theta = cosf(theta);
13435
- const float sin_theta = sinf(theta);
11051
+ const float cos_theta = cosf(theta_base);
11052
+ const float sin_theta = sinf(theta_base);
13436
11053
  const float cos_block_theta = cosf(block_theta);
13437
11054
  const float sin_block_theta = sinf(block_theta);
13438
11055
 
13439
- theta *= theta_scale;
11056
+ theta_base *= theta_scale;
13440
11057
  block_theta *= theta_scale;
13441
11058
 
13442
11059
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -13454,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
13454
11071
  }
13455
11072
  } else if (!is_neox) {
13456
11073
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13457
- const float cos_theta = cosf(theta);
13458
- const float sin_theta = sinf(theta);
11074
+ float cos_theta, sin_theta;
11075
+ rope_yarn(
11076
+ theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11077
+ );
11078
+
13459
11079
  // zeta scaling for xPos only:
13460
11080
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
13461
11081
  if (xpos_down) zeta = 1.0f / zeta;
13462
11082
 
13463
- theta *= theta_scale;
11083
+ theta_base *= theta_scale;
13464
11084
 
13465
11085
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13466
11086
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13474,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
13474
11094
  } else {
13475
11095
  // TODO: this might be wrong for ne0 != n_dims - need double check
13476
11096
  // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11097
+ theta_base *= freq_scale;
13477
11098
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13478
11099
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13479
- const float cos_theta = cosf(theta);
13480
- const float sin_theta = sinf(theta);
11100
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
11101
+ float cur_rot = inv_ndims * ic - ib;
11102
+
11103
+ float cos_theta, sin_theta;
11104
+ rope_yarn(
11105
+ theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11106
+ &cos_theta, &sin_theta
11107
+ );
13481
11108
 
13482
- theta *= theta_scale;
11109
+ theta_base *= theta_scale;
13483
11110
 
13484
11111
  const int64_t i0 = ib*n_dims + ic/2;
13485
11112
 
@@ -13508,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
13508
11135
  return;
13509
11136
  }
13510
11137
 
13511
- float freq_base;
13512
- float freq_scale;
11138
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
13513
11139
 
13514
- //const int n_past = ((int32_t *) dst->op_params)[0];
13515
- const int n_dims = ((int32_t *) dst->op_params)[1];
13516
- const int mode = ((int32_t *) dst->op_params)[2];
13517
- const int n_ctx = ((int32_t *) dst->op_params)[3];
13518
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
13519
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
11140
+ //const int n_past = ((int32_t *) dst->op_params)[0];
11141
+ const int n_dims = ((int32_t *) dst->op_params)[1];
11142
+ const int mode = ((int32_t *) dst->op_params)[2];
11143
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
11144
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
11145
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
11146
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
11147
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
11148
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
11149
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
11150
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
13520
11151
 
13521
11152
  GGML_TENSOR_UNARY_OP_LOCALS
13522
11153
 
@@ -13544,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
13544
11175
  int ir = 0;
13545
11176
 
13546
11177
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
11178
+ const float inv_ndims = -1.f/n_dims;
11179
+ float corr_dims[2];
11180
+ ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
13547
11181
 
13548
11182
  const bool is_neox = mode & 2;
13549
11183
  const bool is_glm = mode & 4;
@@ -13557,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
13557
11191
  if (ir++ < ir0) continue;
13558
11192
  if (ir > ir1) break;
13559
11193
 
13560
- float theta = freq_scale * (float)p;
11194
+ float theta_base = (float)p;
13561
11195
 
13562
11196
  if (is_glm) {
13563
- theta = MIN(p, n_ctx - 2);
11197
+ theta_base = MIN(p, n_ctx - 2);
13564
11198
  float block_theta = MAX(p - (n_ctx - 2), 0);
13565
11199
  for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
13566
- const float cos_theta = cosf(theta);
13567
- const float sin_theta = sinf(theta);
11200
+ const float cos_theta = cosf(theta_base);
11201
+ const float sin_theta = sinf(theta_base);
13568
11202
  const float cos_block_theta = cosf(block_theta);
13569
11203
  const float sin_block_theta = sinf(block_theta);
13570
11204
 
13571
- theta *= theta_scale;
11205
+ theta_base *= theta_scale;
13572
11206
  block_theta *= theta_scale;
13573
11207
 
13574
11208
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -13586,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
13586
11220
  }
13587
11221
  } else if (!is_neox) {
13588
11222
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13589
- const float cos_theta = cosf(theta);
13590
- const float sin_theta = sinf(theta);
11223
+ float cos_theta, sin_theta;
11224
+ rope_yarn(
11225
+ theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
11226
+ );
13591
11227
 
13592
- theta *= theta_scale;
11228
+ theta_base *= theta_scale;
13593
11229
 
13594
11230
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13595
11231
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13603,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
13603
11239
  } else {
13604
11240
  // TODO: this might be wrong for ne0 != n_dims - need double check
13605
11241
  // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
11242
+ theta_base *= freq_scale;
13606
11243
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13607
11244
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13608
- const float cos_theta = cosf(theta);
13609
- const float sin_theta = sinf(theta);
11245
+ // simplified from `(ib * n_dims + ic) * inv_ndims`
11246
+ float cur_rot = inv_ndims * ic - ib;
11247
+
11248
+ float cos_theta, sin_theta;
11249
+ rope_yarn(
11250
+ theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
11251
+ &cos_theta, &sin_theta
11252
+ );
13610
11253
 
13611
- theta *= theta_scale;
11254
+ theta_base *= theta_scale;
13612
11255
 
13613
11256
  const int64_t i0 = ib*n_dims + ic/2;
13614
11257
 
@@ -13716,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
13716
11359
  if (ir++ < ir0) continue;
13717
11360
  if (ir > ir1) break;
13718
11361
 
13719
- float theta = freq_scale * (float)p;
11362
+ float theta_base = freq_scale * (float)p;
13720
11363
 
13721
11364
  if (!is_neox) {
13722
11365
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13723
- const float cos_theta = cosf(theta);
13724
- const float sin_theta = sinf(theta);
11366
+ const float cos_theta = cosf(theta_base);
11367
+ const float sin_theta = sinf(theta_base);
11368
+
13725
11369
  // zeta scaling for xPos only:
13726
11370
  float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
13727
11371
  if (xpos_down) zeta = 1.0f / zeta;
13728
11372
 
13729
- theta *= theta_scale;
11373
+ theta_base *= theta_scale;
13730
11374
 
13731
11375
  const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13732
11376
  float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13740,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
13740
11384
  } else {
13741
11385
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13742
11386
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13743
- const float cos_theta = cosf(theta);
13744
- const float sin_theta = sinf(theta);
11387
+ const float cos_theta = cosf(theta_base);
11388
+ const float sin_theta = sinf(theta_base);
13745
11389
 
13746
- theta *= theta_scale;
11390
+ theta_base *= theta_scale;
13747
11391
 
13748
11392
  const int64_t i0 = ib*n_dims + ic/2;
13749
11393
 
@@ -13816,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
13816
11460
  if (ir++ < ir0) continue;
13817
11461
  if (ir > ir1) break;
13818
11462
 
13819
- float theta = (float)p;
11463
+ float theta_base = (float)p;
13820
11464
 
13821
11465
  if (!is_neox) {
13822
11466
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
13823
- const float cos_theta = cosf(theta);
13824
- const float sin_theta = sinf(theta);
11467
+ const float cos_theta = cosf(theta_base);
11468
+ const float sin_theta = sinf(theta_base);
13825
11469
 
13826
- theta *= theta_scale;
11470
+ theta_base *= theta_scale;
13827
11471
 
13828
11472
  const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
13829
11473
  ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -13837,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
13837
11481
  } else {
13838
11482
  for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
13839
11483
  for (int64_t ic = 0; ic < n_dims; ic += 2) {
13840
- const float cos_theta = cosf(theta);
13841
- const float sin_theta = sinf(theta);
11484
+ const float cos_theta = cosf(theta_base);
11485
+ const float sin_theta = sinf(theta_base);
13842
11486
 
13843
- theta *= theta_scale;
11487
+ theta_base *= theta_scale;
13844
11488
 
13845
11489
  const int64_t i0 = ib*n_dims + ic/2;
13846
11490
 
@@ -15285,7 +12929,7 @@ static void ggml_compute_forward_flash_attn_f32(
15285
12929
  #else
15286
12930
  ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15287
12931
  memcpy(&scvt[j], &s, sizeof(uint16_t));
15288
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
12932
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15289
12933
  #endif
15290
12934
  sump[j] += (ggml_float)val;
15291
12935
  SS[j] = val;
@@ -15487,7 +13131,7 @@ static void ggml_compute_forward_flash_attn_f16(
15487
13131
  } else {
15488
13132
  ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15489
13133
  memcpy(&scvt[j], &s, sizeof(uint16_t));
15490
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13134
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15491
13135
  sump[j] += (ggml_float)val;
15492
13136
  SS[j] = val;
15493
13137
  }
@@ -15938,7 +13582,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
15938
13582
  #else
15939
13583
  ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
15940
13584
  memcpy(&scvt[j], &s, sizeof(uint16_t));
15941
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
13585
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15942
13586
  #endif
15943
13587
  sump[j] += (ggml_float)val;
15944
13588
  SW[j] = val;
@@ -16688,7 +14332,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
16688
14332
  #else
16689
14333
  ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16690
14334
  memcpy(&scvt, &s, sizeof(scvt));
16691
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14335
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16692
14336
  #endif
16693
14337
  sum += (ggml_float)val;
16694
14338
  st[i] = val;
@@ -16802,7 +14446,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
16802
14446
  #else
16803
14447
  ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
16804
14448
  memcpy(&scvt, &s, sizeof(scvt));
16805
- const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
14449
+ const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
16806
14450
  #endif
16807
14451
  sum += (ggml_float)val;
16808
14452
  ds0[i] = val;
@@ -17965,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17965
15609
  src1,
17966
15610
  n_dims,
17967
15611
  mode,
15612
+ 0,
17968
15613
  n_ctx,
17969
15614
  freq_base,
17970
15615
  freq_scale,
15616
+ 0.0f,
15617
+ 1.0f,
15618
+ 0.0f,
15619
+ 0.0f,
17971
15620
  xpos_base,
17972
15621
  xpos_down,
17973
15622
  false),
@@ -21001,7 +18650,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
21001
18650
  block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
21002
18651
  result = ggml_quantize_q8_0(src + start, block, n, n, hist);
21003
18652
  } break;
21004
- #ifdef GGML_USE_K_QUANTS
21005
18653
  case GGML_TYPE_Q2_K:
21006
18654
  {
21007
18655
  GGML_ASSERT(start % QK_K == 0);
@@ -21032,7 +18680,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
21032
18680
  block_q6_K * block = (block_q6_K*)dst + start / QK_K;
21033
18681
  result = ggml_quantize_q6_K(src + start, block, n, n, hist);
21034
18682
  } break;
21035
- #endif
21036
18683
  case GGML_TYPE_F16:
21037
18684
  {
21038
18685
  int elemsize = sizeof(ggml_fp16_t);
@@ -21164,8 +18811,7 @@ static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset)
21164
18811
  return n == size;
21165
18812
  }
21166
18813
 
21167
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21168
- static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) {
18814
+ static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
21169
18815
  p->n = 0;
21170
18816
  p->data = NULL;
21171
18817
 
@@ -21177,19 +18823,6 @@ static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset
21177
18823
  return ok;
21178
18824
  }
21179
18825
 
21180
- static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) {
21181
- p->n = 0;
21182
- p->data = NULL;
21183
-
21184
- bool ok = true;
21185
-
21186
- uint32_t n = 0;
21187
- ok = ok && gguf_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n;
21188
- ok = ok && gguf_fread_el(file, p->data, p->n, offset);
21189
-
21190
- return ok;
21191
- }
21192
-
21193
18826
  struct gguf_context * gguf_init_empty(void) {
21194
18827
  struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
21195
18828
 
@@ -21248,20 +18881,14 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21248
18881
  ctx->data = NULL;
21249
18882
 
21250
18883
  ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset);
18884
+ ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
18885
+ ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
21251
18886
 
21252
18887
  if (ctx->header.version == 1) {
21253
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21254
- uint32_t n_tensors = 0;
21255
- uint32_t n_kv = 0;
21256
-
21257
- ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset);
21258
- ok = ok && gguf_fread_el(file, &n_kv, sizeof(n_kv), &offset);
21259
-
21260
- ctx->header.n_tensors = n_tensors;
21261
- ctx->header.n_kv = n_kv;
21262
- } else {
21263
- ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
21264
- ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset);
18888
+ fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
18889
+ fclose(file);
18890
+ gguf_free(ctx);
18891
+ return NULL;
21265
18892
  }
21266
18893
 
21267
18894
  if (!ok) {
@@ -21272,12 +18899,6 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21272
18899
  }
21273
18900
  }
21274
18901
 
21275
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21276
- bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur;
21277
- if (ctx->header.version == 1) {
21278
- gguf_fread_str = gguf_fread_str_v1;
21279
- }
21280
-
21281
18902
  // read the kv pairs
21282
18903
  {
21283
18904
  ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv));
@@ -21308,15 +18929,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21308
18929
  case GGUF_TYPE_ARRAY:
21309
18930
  {
21310
18931
  ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
21311
-
21312
- if (ctx->header.version == 1) {
21313
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21314
- uint32_t n = 0;
21315
- ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset);
21316
- kv->value.arr.n = n;
21317
- } else {
21318
- ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
21319
- }
18932
+ ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset);
21320
18933
 
21321
18934
  switch (kv->value.arr.type) {
21322
18935
  case GGUF_TYPE_UINT8:
@@ -21375,14 +18988,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21375
18988
  ok = ok && gguf_fread_str(file, &info->name, &offset);
21376
18989
  ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset);
21377
18990
  for (uint32_t j = 0; j < info->n_dims; ++j) {
21378
- if (ctx->header.version == 1) {
21379
- // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023
21380
- uint32_t t = 0;
21381
- ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset);
21382
- info->ne[j] = t;
21383
- } else {
21384
- ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
21385
- }
18991
+ ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
21386
18992
  }
21387
18993
  ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset);
21388
18994
  ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset);