llama_cpp 0.0.5 → 0.0.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -19,6 +19,7 @@
19
19
  #include <inttypes.h>
20
20
  #include <stdio.h>
21
21
  #include <float.h>
22
+ #include <limits.h>
22
23
 
23
24
  // if C99 - static_assert is noop
24
25
  // ref: https://stackoverflow.com/a/53923785/4039976
@@ -142,10 +143,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
142
143
  } \
143
144
  } while (0)
144
145
 
145
- #ifdef GGML_USE_ACCELERATE
146
+ #if defined(GGML_USE_ACCELERATE)
146
147
  #include <Accelerate/Accelerate.h>
147
- #elif GGML_USE_OPENBLAS
148
+ #elif defined(GGML_USE_OPENBLAS)
148
149
  #include <cblas.h>
150
+ #elif defined(GGML_USE_CUBLAS)
151
+ #include "ggml-cuda.h"
152
+ #elif defined(GGML_USE_CLBLAST)
153
+ #include "ggml-opencl.h"
149
154
  #endif
150
155
 
151
156
  #undef MIN
@@ -325,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
325
330
  // precomputed f32 table for f16 (256 KB)
326
331
  static float table_f32_f16[1 << 16];
327
332
 
333
+ #if defined(__ARM_NEON)
334
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
335
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
336
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
337
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
338
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
339
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
340
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
341
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
342
+
343
+ // precomputed tables for expanding 8bits to 8 bytes (shl 4)
344
+ static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
345
+ #endif
346
+
328
347
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
329
348
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
330
349
  // This is also true for POWER9.
@@ -427,12 +446,69 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427
446
  // quantization
428
447
  //
429
448
 
430
- // AVX routines provided by GH user Const-me
431
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
449
+ #if __AVX__ || __AVX2__ || __AVX512F__
450
+ // Unpack 16 4-bit fields into 16 bytes
451
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
452
+ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
453
+ {
454
+ // Load 8 bytes from memory
455
+ __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
456
+
457
+ // Expand bytes into uint16_t values
458
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
459
+
460
+ // Unpack values into individual bytes
461
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
462
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
463
+ __m128i low = _mm_and_si128( lowMask, bytes );
464
+ high = _mm_slli_epi16( high, 4 );
465
+ bytes = _mm_or_si128( low, high );
466
+ return bytes;
467
+ }
468
+
469
+ // horizontally add 8 floats
470
+ static inline float hsum_float_8(const __m256 x) {
471
+ __m128 res = _mm256_extractf128_ps(x, 1);
472
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
473
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
474
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
475
+ return _mm_cvtss_f32(res);
476
+ }
477
+
478
+ // horizontally add 8 int32_t
479
+ static inline int hsum_i32_8(const __m256i a) {
480
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
481
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
482
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
483
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
484
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
485
+ }
486
+
487
+ // horizontally add 4 int32_t
488
+ static inline int hsum_i32_4(const __m128i a) {
489
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
490
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
491
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
492
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
493
+ }
494
+
432
495
  #if __AVX2__ || __AVX512F__
496
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
497
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
498
+ uint32_t x32;
499
+ memcpy(&x32, x, sizeof(uint32_t));
500
+ const __m256i shuf_mask = _mm256_set_epi64x(
501
+ 0x0303030303030303, 0x0202020202020202,
502
+ 0x0101010101010101, 0x0000000000000000);
503
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
504
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
505
+ bytes = _mm256_or_si256(bytes, bit_mask);
506
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
507
+ }
508
+
433
509
  // Unpack 32 4-bit fields into 32 bytes
434
510
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435
- static inline __m256i bytesFromNibbles( const uint8_t* rsi )
511
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
436
512
  {
437
513
  // Load 16 bytes from memory
438
514
  __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -449,9 +525,38 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
449
525
  return bytes;
450
526
  }
451
527
 
528
+ // add int16_t pairwise and return as float vector
529
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
530
+ const __m256i ones = _mm256_set1_epi16(1);
531
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
532
+ return _mm256_cvtepi32_ps(summed_pairs);
533
+ }
534
+
535
+ // multiply int8_t, add results pairwise twice and return as float vector
536
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
537
+ // Get absolute values of x vectors
538
+ const __m256i ax = _mm256_sign_epi8(x, x);
539
+ // Sign the values of the y vectors
540
+ const __m256i sy = _mm256_sign_epi8(y, x);
541
+ #if __AVXVNNI__
542
+ const __m256i zero = _mm256_setzero_si256();
543
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
544
+ return _mm256_cvtepi32_ps(summed_pairs);
545
+ #else
546
+ // Perform multiplication and create 16-bit values
547
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
548
+ return sum_i16_pairs_float(dot);
549
+ #endif
550
+ }
551
+
452
552
  static inline __m128i packNibbles( __m256i bytes )
453
553
  {
454
554
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
555
+ #if __AVX512F__
556
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
557
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
558
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
559
+ #else
455
560
  const __m256i lowByte = _mm256_set1_epi16( 0xFF );
456
561
  __m256i high = _mm256_andnot_si256( lowByte, bytes );
457
562
  __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -462,25 +567,9 @@ static inline __m128i packNibbles( __m256i bytes )
462
567
  __m128i r0 = _mm256_castsi256_si128( bytes );
463
568
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
464
569
  return _mm_packus_epi16( r0, r1 );
570
+ #endif
465
571
  }
466
- #elif __AVX__
467
- static inline __m128i bytesFromNibbles( const uint8_t* rsi )
468
- {
469
- // Load 8 bytes from memory
470
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
471
-
472
- // Expand bytes into uint16_t values
473
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
474
-
475
- // Unpack values into individual bytes
476
- const __m128i lowMask = _mm_set1_epi8( 0xF );
477
- __m128i high = _mm_andnot_si128( lowMask, bytes );
478
- __m128i low = _mm_and_si128( lowMask, bytes );
479
- high = _mm_slli_epi16( high, 4 );
480
- bytes = _mm_or_si128( low, high );
481
- return bytes;
482
- }
483
-
572
+ #else
484
573
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
485
574
  {
486
575
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -497,6 +586,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
586
  return _mm_packus_epi16( bytes1, bytes2);
498
587
  }
499
588
  #endif
589
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
500
590
 
501
591
  #if __ARM_NEON
502
592
 
@@ -514,6 +604,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
514
604
  (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
515
605
  }
516
606
 
607
+ inline static int16_t vaddvq_s8(int8x16_t v) {
608
+ return
609
+ (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
610
+ (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
611
+ (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
612
+ (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
613
+ (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
614
+ (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
615
+ (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
616
+ (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
617
+ }
618
+
517
619
  inline static int32_t vaddvq_s16(int16x8_t v) {
518
620
  return
519
621
  (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -583,7 +685,39 @@ typedef struct {
583
685
  float m; // min
584
686
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
585
687
  } block_q4_1;
586
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
688
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
689
+
690
+ #define QK4_2 16
691
+ typedef struct {
692
+ ggml_fp16_t d; // delta
693
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
694
+ } block_q4_2;
695
+ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
696
+
697
+ #define QK4_3 16
698
+ typedef struct {
699
+ ggml_fp16_t d; // delta
700
+ ggml_fp16_t m; // min
701
+ uint8_t qs[QK4_3 / 2]; // nibbles / quants
702
+ } block_q4_3;
703
+ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
704
+
705
+ #define QK5_0 32
706
+ typedef struct {
707
+ ggml_fp16_t d; // delta
708
+ uint8_t qh[4]; // 5-th bit of quants
709
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
710
+ } block_q5_0;
711
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
712
+
713
+ #define QK5_1 32
714
+ typedef struct {
715
+ ggml_fp16_t d; // delta
716
+ ggml_fp16_t m; // min
717
+ uint8_t qh[4]; // 5-th bit of quants
718
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
719
+ } block_q5_1;
720
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
587
721
 
588
722
  #define QK8_0 32
589
723
  typedef struct {
@@ -592,6 +726,14 @@ typedef struct {
592
726
  } block_q8_0;
593
727
  static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
594
728
 
729
+ #define QK8_1 32
730
+ typedef struct {
731
+ float d; // delta
732
+ float s0; // d * sum(qs[i]) low
733
+ float s1; // d * sum(qs[i]) high
734
+ int8_t qs[QK8_1]; // quants
735
+ } block_q8_1;
736
+ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
595
737
 
596
738
  // reference implementation for deterministic creation of model files
597
739
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -602,13 +744,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
602
744
 
603
745
  for (int i = 0; i < nb; i++) {
604
746
  float amax = 0.0f; // absolute max
747
+ float max = 0.0f;
605
748
 
606
749
  for (int l = 0; l < QK4_0; l++) {
607
750
  const float v = x[i*QK4_0 + l];
608
- amax = MAX(amax, fabsf(v));
751
+ if (amax < fabsf(v)) {
752
+ amax = fabsf(v);
753
+ max = v;
754
+ }
609
755
  }
610
756
 
611
- const float d = amax / ((1 << 3) - 1);
757
+ const float d = max / -8;
612
758
  const float id = d ? 1.0f/d : 0.0f;
613
759
 
614
760
  y[i].d = d;
@@ -617,8 +763,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
617
763
  const float v0 = x[i*QK4_0 + l + 0]*id;
618
764
  const float v1 = x[i*QK4_0 + l + 1]*id;
619
765
 
620
- const uint8_t vi0 = (int8_t)roundf(v0) + 8;
621
- const uint8_t vi1 = (int8_t)roundf(v1) + 8;
766
+ const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
767
+ const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
622
768
 
623
769
  assert(vi0 < 16);
624
770
  assert(vi1 < 16);
@@ -638,28 +784,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
638
784
 
639
785
  #if defined(__POWER9_VECTOR__)
640
786
  const vector float v85 = vec_splats(8.5f);
787
+ const vector signed int v15 = vec_splats(15);
641
788
  for (int i = 0; i < nb; i++) {
642
- float amax = 0.0f; // absolute max
789
+ float max = 0.0f;
790
+ float min = 0.0f;
643
791
 
644
792
  vector float srcv [8];
645
- vector float asrcv[8];
646
- vector float amaxv[8];
793
+ vector float maxv[8];
794
+ vector float minv[8];
647
795
 
648
796
  for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
649
- for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
650
-
651
- for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
652
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
653
- amaxv[0] = vec_max(amaxv[0], amaxv[2]);
654
- amaxv[4] = vec_max(amaxv[4], amaxv[6]);
655
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
656
- amaxv[0] = vec_max(amaxv[0], amaxv[4]);
657
-
658
- amax = MAX(
659
- MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
660
- MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
661
-
662
- const float d = amax / ((1 << 3) - 1);
797
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
798
+
799
+ for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
800
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
801
+ maxv[0] = vec_max(maxv[0], maxv[2]);
802
+ maxv[4] = vec_max(maxv[4], maxv[6]);
803
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
804
+ maxv[0] = vec_max(maxv[0], maxv[4]);
805
+
806
+ for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
807
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
808
+ minv[0] = vec_min(minv[0], minv[2]);
809
+ minv[4] = vec_min(minv[4], minv[6]);
810
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
811
+ minv[0] = vec_min(minv[0], minv[4]);
812
+
813
+
814
+ max = MAX(
815
+ MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
816
+ MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
817
+ min = MIN(
818
+ MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
819
+ MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
820
+
821
+ const float magnitude = max >= fabsf(min) ? max : min;
822
+ const float d = magnitude / -8;
663
823
  const float id = d ? 1.0/d : 0.0;
664
824
 
665
825
  y[i].d = d;
@@ -669,27 +829,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
669
829
  for (int l = 0; l < 8; l++) {
670
830
  const vector float vf = vec_madd(srcv[l], vid, v85);
671
831
  const vector signed int vi = vec_signed(vf);
832
+ const vector signed int vc = vec_min(vi, v15);
672
833
 
673
- pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
674
- pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
834
+ pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
835
+ pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
675
836
  }
676
837
  }
677
838
  #elif __ARM_NEON
678
839
  for (int i = 0; i < nb; i++) {
679
840
  float32x4_t srcv [8];
680
- float32x4_t asrcv[8];
681
- float32x4_t amaxv[8];
841
+ float32x4_t maxv[8];
842
+ float32x4_t minv[8];
682
843
 
683
844
  for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
684
- for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
685
845
 
686
- for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
687
- for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
688
- for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
846
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
847
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
848
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
689
849
 
690
- const float amax = vmaxvq_f32(amaxv[0]);
850
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
851
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
852
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
853
+
854
+ const float max = vmaxvq_f32(maxv[0]);
855
+ const float min = vminvq_f32(minv[0]);
691
856
 
692
- const float d = amax / ((1 << 3) - 1);
857
+ const float magnitude = max >= fabsf(min) ? max : min;
858
+ const float d = magnitude / -8;
693
859
  const float id = d ? 1.0f/d : 0.0f;
694
860
 
695
861
  y[i].d = d;
@@ -698,9 +864,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
698
864
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
699
865
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
700
866
  const int32x4_t vi = vcvtq_s32_f32(vf);
867
+ const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
701
868
 
702
- y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
703
- y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
869
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
870
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
704
871
  }
705
872
  }
706
873
  #elif defined(__AVX2__)
@@ -712,22 +879,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
712
879
  __m256 v3 = _mm256_loadu_ps( x + 24 );
713
880
  x += 32;
714
881
 
715
- // Compute max(abs(e)) for the block
716
- const __m256 signBit = _mm256_set1_ps( -0.0f );
717
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
718
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
719
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
720
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
882
+ // Compute max for the block
883
+ __m256 max = _mm256_max_ps( v0, v1 );
884
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
885
+ max = _mm256_max_ps( max, maxTmp );
721
886
 
722
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
887
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
723
888
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
724
889
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
725
890
  const float maxScalar = _mm_cvtss_f32( max4 );
726
891
 
892
+ // Compute min for the block
893
+ __m256 min = _mm256_min_ps( v0, v1 );
894
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
895
+ min = _mm256_min_ps( min, minTmp );
896
+
897
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
898
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
899
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
900
+ const float minScalar = _mm_cvtss_f32( min4 );
901
+
727
902
  // Quantize these floats
728
- const float d = maxScalar / 7.0f;
903
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
904
+ const float d = magnitude / -8.0f;
729
905
  y[i].d = d;
730
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
906
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
731
907
  const __m256 mul = _mm256_set1_ps( id );
732
908
 
733
909
  // Apply the multiplier
@@ -760,9 +936,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
760
936
  const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
761
937
  i0 = _mm256_permutevar8x32_epi32( i0, perm );
762
938
 
763
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
939
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
764
940
  const __m256i off = _mm256_set1_epi8( 8 );
765
941
  i0 = _mm256_add_epi8( i0, off );
942
+ const __m256i maxNibble = _mm256_set1_epi8( 15 );
943
+ i0 = _mm256_min_epi8( i0, maxNibble );
766
944
 
767
945
  // Compress the vector into 4 bit/value, and store
768
946
  __m128i res = packNibbles( i0 );
@@ -777,22 +955,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
777
955
  __m256 v3 = _mm256_loadu_ps( x + 24 );
778
956
  x += 32;
779
957
 
780
- // Compute max(abs(e)) for the block
781
- const __m256 signBit = _mm256_set1_ps( -0.0f );
782
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
783
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
784
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
785
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
958
+ // Compute max for the block
959
+ __m256 max = _mm256_max_ps( v0, v1 );
960
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
961
+ max = _mm256_max_ps( max, maxTmp );
786
962
 
787
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
963
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
788
964
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
789
965
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
790
966
  const float maxScalar = _mm_cvtss_f32( max4 );
791
967
 
968
+ // Compute min for the block
969
+ __m256 min = _mm256_min_ps( v0, v1 );
970
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
971
+ min = _mm256_min_ps( min, minTmp );
972
+
973
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
974
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
975
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
976
+ const float minScalar = _mm_cvtss_f32( min4 );
977
+
792
978
  // Quantize these floats
793
- const float d = maxScalar / 7.0f;
979
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
980
+ const float d = magnitude / -8.0f;
794
981
  y[i].d = d;
795
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
982
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
796
983
  const __m256 mul = _mm256_set1_ps( id );
797
984
 
798
985
  // Apply the multiplier
@@ -833,10 +1020,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
833
1020
  ni0 = _mm_packs_epi16( ni0, ni2 );
834
1021
  ni4 = _mm_packs_epi16( ni4, ni6 );
835
1022
 
836
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
837
- const __m128i off = _mm_set1_epi8( 8);
1023
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
1024
+ const __m128i off = _mm_set1_epi8( 8 );
838
1025
  ni0 = _mm_add_epi8( ni0, off );
839
1026
  ni4 = _mm_add_epi8( ni4, off );
1027
+ const __m128i maxNibble = _mm_set1_epi8( 15 );
1028
+ ni0 = _mm_min_epi8( ni0, maxNibble );
1029
+ ni4 = _mm_min_epi8( ni4, maxNibble );
840
1030
 
841
1031
  // Compress the vector into 4 bit/value, and store
842
1032
  __m128i res = packNibbles( ni0, ni4 );
@@ -844,24 +1034,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
844
1034
  }
845
1035
  #elif defined(__wasm_simd128__)
846
1036
  for (int i = 0; i < nb; i++) {
847
- float amax = 0.0f; // absolute max
1037
+ float max = 0.0f;
1038
+ float min = 0.0f;
848
1039
 
849
1040
  v128_t srcv [8];
850
- v128_t asrcv[8];
851
- v128_t amaxv[8];
1041
+ v128_t maxv[8];
1042
+ v128_t minv[8];
852
1043
 
853
1044
  for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
854
- for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
855
1045
 
856
- for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
857
- for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
858
- for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
1046
+ for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
1047
+ for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
1048
+ for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
1049
+
1050
+ for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
1051
+ for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
1052
+ for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
859
1053
 
860
- amax = MAX(
861
- MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
862
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
1054
+ max = MAX(
1055
+ MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
1056
+ MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
1057
+ min = MIN(
1058
+ MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
1059
+ MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
863
1060
 
864
- const float d = amax / ((1 << 3) - 1);
1061
+ const float magnitude = max >= fabsf(min) ? max : min;
1062
+ const float d = magnitude / -8;
865
1063
  const float id = d ? 1.0/d : 0.0;
866
1064
 
867
1065
  y[i].d = d;
@@ -870,9 +1068,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
870
1068
  const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
871
1069
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
872
1070
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1071
+ const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
873
1072
 
874
- y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
875
- y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
1073
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
1074
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
876
1075
  }
877
1076
  }
878
1077
  #else
@@ -1045,6 +1244,193 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1045
1244
  #endif
1046
1245
  }
1047
1246
 
1247
+ // reference implementation for deterministic creation of model files
1248
+ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
1249
+ assert(k % QK4_2 == 0);
1250
+
1251
+ const int nb = k / QK4_2;
1252
+
1253
+ for (int i = 0; i < nb; i++) {
1254
+ float amax = 0.0f; // absolute max
1255
+ float max = 0.0f;
1256
+
1257
+ for (int l = 0; l < QK4_2; l++) {
1258
+ const float v = x[i*QK4_2 + l];
1259
+ if (amax < fabsf(v)) {
1260
+ amax = fabsf(v);
1261
+ max = v;
1262
+ }
1263
+ }
1264
+
1265
+ const float d = max / -8;
1266
+
1267
+ const float id = d ? 1.0f/d : 0.0f;
1268
+
1269
+ y[i].d = GGML_FP32_TO_FP16(d);
1270
+
1271
+ for (int l = 0; l < QK4_2; l += 2) {
1272
+ const float v0 = x[i*QK4_2 + l + 0]*id;
1273
+ const float v1 = x[i*QK4_2 + l + 1]*id;
1274
+
1275
+ const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
1276
+ const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
1277
+
1278
+ assert(vi0 < 16);
1279
+ assert(vi1 < 16);
1280
+
1281
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1282
+ }
1283
+ }
1284
+ }
1285
+
1286
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1287
+ assert(k % QK4_2 == 0);
1288
+
1289
+ block_q4_2 * restrict y = vy;
1290
+
1291
+ quantize_row_q4_2_reference(x, y, k);
1292
+ }
1293
+
1294
+ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1295
+ assert(k % QK4_3 == 0);
1296
+ const int nb = k / QK4_3;
1297
+
1298
+ for (int i = 0; i < nb; i++) {
1299
+ float min = FLT_MAX;
1300
+ float max = -FLT_MAX;
1301
+
1302
+ for (int l = 0; l < QK4_3; l++) {
1303
+ const float v = x[i*QK4_3 + l];
1304
+ if (v < min) min = v;
1305
+ if (v > max) max = v;
1306
+ }
1307
+
1308
+ const float d = (max - min) / ((1 << 4) - 1);
1309
+ const float id = d ? 1.0f/d : 0.0f;
1310
+
1311
+ y[i].d = GGML_FP32_TO_FP16(d);
1312
+ y[i].m = GGML_FP32_TO_FP16(min);
1313
+
1314
+ for (int l = 0; l < QK4_3; l += 2) {
1315
+ const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1316
+ const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1317
+
1318
+ const uint8_t vi0 = (int) (v0 + 0.5f);
1319
+ const uint8_t vi1 = (int) (v1 + 0.5f);
1320
+
1321
+ assert(vi0 < 16);
1322
+ assert(vi1 < 16);
1323
+
1324
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1325
+ }
1326
+ }
1327
+ }
1328
+
1329
+ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1330
+ assert(k % QK4_3 == 0);
1331
+
1332
+ block_q4_3 * restrict y = vy;
1333
+
1334
+ quantize_row_q4_3_reference(x, y, k);
1335
+ }
1336
+
1337
+ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
1338
+ assert(k % QK5_0 == 0);
1339
+ const int nb = k / QK5_0;
1340
+
1341
+ for (int i = 0; i < nb; i++) {
1342
+ float amax = 0.0f; // absolute max
1343
+ float max = 0.0f;
1344
+
1345
+ for (int l = 0; l < QK5_0; l++) {
1346
+ const float v = x[i*QK5_0 + l];
1347
+ if (amax < fabsf(v)) {
1348
+ amax = fabsf(v);
1349
+ max = v;
1350
+ }
1351
+ }
1352
+
1353
+ const float d = max / -16;
1354
+ const float id = d ? 1.0f/d : 0.0f;
1355
+
1356
+ y[i].d = GGML_FP32_TO_FP16(d);
1357
+
1358
+ uint32_t qh = 0;
1359
+
1360
+ for (int l = 0; l < QK5_0; l += 2) {
1361
+ const float v0 = x[i*QK5_0 + l + 0]*id;
1362
+ const float v1 = x[i*QK5_0 + l + 1]*id;
1363
+
1364
+ const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1365
+ const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1366
+
1367
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1368
+
1369
+ // get the 5-th bit and store it in qh at the right position
1370
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1371
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1372
+ }
1373
+
1374
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1375
+ }
1376
+ }
1377
+
1378
+ static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1379
+ assert(k % QK5_0 == 0);
1380
+
1381
+ block_q5_0 * restrict y = vy;
1382
+
1383
+ quantize_row_q5_0_reference(x, y, k);
1384
+ }
1385
+
1386
+ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1387
+ assert(k % QK5_1 == 0);
1388
+ const int nb = k / QK5_1;
1389
+
1390
+ for (int i = 0; i < nb; i++) {
1391
+ float min = FLT_MAX;
1392
+ float max = -FLT_MAX;
1393
+
1394
+ for (int l = 0; l < QK5_1; l++) {
1395
+ const float v = x[i*QK5_1 + l];
1396
+ if (v < min) min = v;
1397
+ if (v > max) max = v;
1398
+ }
1399
+
1400
+ const float d = (max - min) / ((1 << 5) - 1);
1401
+ const float id = d ? 1.0f/d : 0.0f;
1402
+
1403
+ y[i].d = GGML_FP32_TO_FP16(d);
1404
+ y[i].m = GGML_FP32_TO_FP16(min);
1405
+
1406
+ uint32_t qh = 0;
1407
+
1408
+ for (int l = 0; l < QK5_1; l += 2) {
1409
+ const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1410
+ const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1411
+
1412
+ const uint32_t vi0 = (int) (v0 + 0.5f);
1413
+ const uint32_t vi1 = (int) (v1 + 0.5f);
1414
+
1415
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1416
+
1417
+ // get the 5-th bit and store it in qh at the right position
1418
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1419
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1420
+ }
1421
+
1422
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1423
+ }
1424
+ }
1425
+
1426
+ static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1427
+ assert(k % QK5_1 == 0);
1428
+
1429
+ block_q5_1 * restrict y = vy;
1430
+
1431
+ quantize_row_q5_1_reference(x, y, k);
1432
+ }
1433
+
1048
1434
  // reference implementation for deterministic creation of model files
1049
1435
  static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
1436
  assert(k % QK8_0 == 0);
@@ -1064,18 +1450,64 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1064
1450
  y[i].d = d;
1065
1451
 
1066
1452
  for (int l = 0; l < QK8_0; ++l) {
1067
- const float v = x[i*QK8_0 + l]*id;
1068
- y[i].qs[l] = roundf(v);
1453
+ const float v0 = x[i*QK8_0 + l]*id;
1454
+
1455
+ y[i].qs[l] = roundf(v0);
1069
1456
  }
1070
1457
  }
1071
1458
  }
1072
1459
 
1073
1460
  static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1074
1461
  assert(k % QK8_0 == 0);
1075
- const int nb = k / QK8_0;
1076
1462
 
1077
1463
  block_q8_0 * restrict y = vy;
1078
1464
 
1465
+ quantize_row_q8_0_reference(x, y, k);
1466
+ }
1467
+
1468
+ // reference implementation for deterministic creation of model files
1469
+ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1470
+ assert(k % QK8_1 == 0);
1471
+ const int nb = k / QK8_1;
1472
+
1473
+ for (int i = 0; i < nb; i++) {
1474
+ float amax = 0.0f; // absolute max
1475
+
1476
+ for (int l = 0; l < QK8_1; l++) {
1477
+ const float v = x[i*QK8_1 + l];
1478
+ amax = MAX(amax, fabsf(v));
1479
+ }
1480
+
1481
+ const float d = amax / ((1 << 7) - 1);
1482
+ const float id = d ? 1.0f/d : 0.0f;
1483
+
1484
+ y[i].d = d;
1485
+
1486
+ int sum0 = 0;
1487
+ int sum1 = 0;
1488
+
1489
+ for (int l = 0; l < QK8_1/2; ++l) {
1490
+ const float v0 = x[i*QK8_1 + l]*id;
1491
+ const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
1492
+
1493
+ y[i].qs[ l] = roundf(v0);
1494
+ y[i].qs[QK8_1/2 + l] = roundf(v1);
1495
+
1496
+ sum0 += y[i].qs[ l];
1497
+ sum1 += y[i].qs[QK8_1/2 + l];
1498
+ }
1499
+
1500
+ y[i].s0 = d * sum0;
1501
+ y[i].s1 = d * sum1;
1502
+ }
1503
+ }
1504
+
1505
+ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1506
+ assert(k % QK8_1 == 0);
1507
+ const int nb = k / QK8_1;
1508
+
1509
+ block_q8_1 * restrict y = vy;
1510
+
1079
1511
  #if defined(__ARM_NEON)
1080
1512
  for (int i = 0; i < nb; i++) {
1081
1513
  float32x4_t srcv [8];
@@ -1096,7 +1528,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1096
1528
 
1097
1529
  y[i].d = d;
1098
1530
 
1099
- for (int l = 0; l < 8; l++) {
1531
+ int32x4_t accv0 = vdupq_n_s32(0);
1532
+ int32x4_t accv1 = vdupq_n_s32(0);
1533
+
1534
+ // low half
1535
+ for (int l = 0; l < 4; l++) {
1100
1536
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
1101
1537
  const int32x4_t vi = vcvtnq_s32_f32(v);
1102
1538
 
@@ -1104,19 +1540,40 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1104
1540
  y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1105
1541
  y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1106
1542
  y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1543
+
1544
+ accv0 = vaddq_s32(accv0, vi);
1107
1545
  }
1108
- }
1109
- #elif defined(__AVX2__) || defined(__AVX__)
1110
- for (int i = 0; i < nb; i++) {
1111
- // Load elements into 4 AVX vectors
1112
- __m256 v0 = _mm256_loadu_ps( x );
1113
- __m256 v1 = _mm256_loadu_ps( x + 8 );
1114
- __m256 v2 = _mm256_loadu_ps( x + 16 );
1115
- __m256 v3 = _mm256_loadu_ps( x + 24 );
1116
- x += 32;
1117
1546
 
1118
- // Compute max(abs(e)) for the block
1119
- const __m256 signBit = _mm256_set1_ps( -0.0f );
1547
+ // high half
1548
+ for (int l = 4; l < 8; l++) {
1549
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1550
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1551
+
1552
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1553
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1554
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1555
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1556
+
1557
+ accv1 = vaddq_s32(accv1, vi);
1558
+ }
1559
+
1560
+ const int32_t sum0 = vaddvq_s32(accv0);
1561
+ const int32_t sum1 = vaddvq_s32(accv1);
1562
+
1563
+ y[i].s0 = d * sum0;
1564
+ y[i].s1 = d * sum1;
1565
+ }
1566
+ #elif defined(__AVX2__) || defined(__AVX__)
1567
+ for (int i = 0; i < nb; i++) {
1568
+ // Load elements into 4 AVX vectors
1569
+ __m256 v0 = _mm256_loadu_ps( x );
1570
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1571
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1572
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1573
+ x += 32;
1574
+
1575
+ // Compute max(abs(e)) for the block
1576
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1120
1577
  __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1121
1578
  maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1122
1579
  maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
@@ -1152,6 +1609,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1152
1609
  __m256i i3 = _mm256_cvtps_epi32( v3 );
1153
1610
 
1154
1611
  #if defined(__AVX2__)
1612
+ // Compute the sum of the quants and set y[i].s
1613
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1614
+ y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1615
+ y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1616
+
1155
1617
  // Convert int32 to int16
1156
1618
  i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1157
1619
  i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1177,6 +1639,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1177
1639
  __m128i ni6 = _mm256_castsi256_si128( i3 );
1178
1640
  __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1179
1641
 
1642
+ // Compute the sum of the quants and set y[i].s
1643
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1644
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1645
+ y[i].s0 = d * hsum_i32_4(s0);
1646
+ y[i].s1 = d * hsum_i32_4(s1);
1647
+
1180
1648
  // Convert int32 to int16
1181
1649
  ni0 = _mm_packs_epi32( ni0, ni1 );
1182
1650
  ni2 = _mm_packs_epi32( ni2, ni3 );
@@ -1192,7 +1660,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1192
1660
  }
1193
1661
  #else
1194
1662
  // scalar
1195
- quantize_row_q8_0_reference(x, y, k);
1663
+ quantize_row_q8_1_reference(x, y, k);
1196
1664
  #endif
1197
1665
  }
1198
1666
 
@@ -1211,7 +1679,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1211
1679
 
1212
1680
  for (int l = 0; l < QK4_0; l += 32) {
1213
1681
  // Load 32x4-bit integers into 32x8-bit integers
1214
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1682
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1215
1683
 
1216
1684
  // Subtract 8 from the integers
1217
1685
  vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1246,7 +1714,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1246
1714
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1247
1715
 
1248
1716
  // Expand 4-bit qs to 8-bit bytes
1249
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1717
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1250
1718
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1251
1719
 
1252
1720
  // Convert to signed 8-bit integers
@@ -1296,7 +1764,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1296
1764
  for (int l = 0; l < QK4_0; l += 2) {
1297
1765
  const uint8_t vi = pp[l/2];
1298
1766
 
1299
- const int8_t vi0 = vi & 0xf;
1767
+ const int8_t vi0 = vi & 0x0F;
1300
1768
  const int8_t vi1 = vi >> 4;
1301
1769
 
1302
1770
  const float v0 = (vi0 - 8)*d;
@@ -1329,7 +1797,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1329
1797
 
1330
1798
  for (int l = 0; l < QK4_1; l += 32) {
1331
1799
  // Load 32x4-bit integers into 32x8-bit integers
1332
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1800
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1333
1801
 
1334
1802
  // Convert to 16-bit int
1335
1803
  const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1362,7 +1830,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1362
1830
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1363
1831
 
1364
1832
  // Expand 4-bit qs to 8-bit bytes
1365
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1833
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1366
1834
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1367
1835
 
1368
1836
  // Interleave and combine
@@ -1404,7 +1872,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1404
1872
  for (int l = 0; l < QK4_1; l += 2) {
1405
1873
  const uint8_t vi = pp[l/2];
1406
1874
 
1407
- const int8_t vi0 = vi & 0xf;
1875
+ const int8_t vi0 = vi & 0x0F;
1408
1876
  const int8_t vi1 = vi >> 4;
1409
1877
 
1410
1878
  const float v0 = vi0*d + m;
@@ -1420,8 +1888,162 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1420
1888
  #endif
1421
1889
  }
1422
1890
 
1423
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1891
+ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1892
+ assert(k % QK4_2 == 0);
1893
+ const int nb = k / QK4_2;
1894
+
1895
+ const block_q4_2 * restrict x = vx;
1896
+
1897
+ for (int i = 0; i < nb; i++) {
1898
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1899
+
1900
+ const uint8_t * restrict pp = x[i].qs;
1901
+
1902
+ for (int l = 0; l < QK4_2; l += 2) {
1903
+ const uint8_t vi = pp[l/2];
1904
+
1905
+ const int8_t vi0 = vi & 0x0F;
1906
+ const int8_t vi1 = vi >> 4;
1907
+
1908
+ const float v0 = (vi0 - 8)*d;
1909
+ const float v1 = (vi1 - 8)*d;
1910
+
1911
+ y[i*QK4_2 + l + 0] = v0;
1912
+ y[i*QK4_2 + l + 1] = v1;
1913
+
1914
+ assert(!isnan(y[i*QK4_2 + l + 0]));
1915
+ assert(!isnan(y[i*QK4_2 + l + 1]));
1916
+ }
1917
+ }
1918
+ }
1919
+
1920
+ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1921
+ assert(k % QK4_3 == 0);
1922
+ const int nb = k / QK4_3;
1923
+
1924
+ const block_q4_3 * restrict x = vx;
1925
+
1926
+ for (int i = 0; i < nb; i++) {
1927
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1928
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1929
+
1930
+ const uint8_t * restrict pp = x[i].qs;
1931
+
1932
+ for (int l = 0; l < QK4_3; l += 2) {
1933
+ const uint8_t vi = pp[l/2];
1934
+
1935
+ const int8_t vi0 = vi & 0x0F;
1936
+ const int8_t vi1 = vi >> 4;
1937
+
1938
+ const float v0 = vi0*d + m;
1939
+ const float v1 = vi1*d + m;
1940
+
1941
+ y[i*QK4_3 + l + 0] = v0;
1942
+ y[i*QK4_3 + l + 1] = v1;
1943
+
1944
+ assert(!isnan(y[i*QK4_3 + l + 0]));
1945
+ assert(!isnan(y[i*QK4_3 + l + 1]));
1946
+ }
1947
+ }
1948
+ }
1949
+
1950
+ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
1951
+ assert(k % QK5_0 == 0);
1952
+ const int nb = k / QK5_0;
1953
+
1954
+ const block_q5_0 * restrict x = vx;
1955
+
1956
+ for (int i = 0; i < nb; i++) {
1957
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1958
+
1959
+ const uint8_t * restrict pp = x[i].qs;
1960
+
1961
+ uint32_t qh;
1962
+ memcpy(&qh, x[i].qh, sizeof(qh));
1963
+
1964
+ for (int l = 0; l < QK5_0; l += 2) {
1965
+ const uint8_t vi = pp[l/2];
1966
+
1967
+ // extract the 5-th bit from qh
1968
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1969
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1970
+
1971
+ const int8_t vi0 = (vi & 0x0F) | vh0;
1972
+ const int8_t vi1 = (vi >> 4) | vh1;
1973
+
1974
+ const float v0 = (vi0 - 16)*d;
1975
+ const float v1 = (vi1 - 16)*d;
1976
+
1977
+ y[i*QK5_0 + l + 0] = v0;
1978
+ y[i*QK5_0 + l + 1] = v1;
1979
+
1980
+ assert(!isnan(y[i*QK5_0 + l + 0]));
1981
+ assert(!isnan(y[i*QK5_0 + l + 1]));
1982
+ }
1983
+ }
1984
+ }
1985
+
1986
+ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
1987
+ assert(k % QK5_1 == 0);
1988
+ const int nb = k / QK5_1;
1989
+
1990
+ const block_q5_1 * restrict x = vx;
1991
+
1992
+ for (int i = 0; i < nb; i++) {
1993
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1994
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1995
+
1996
+ const uint8_t * restrict pp = x[i].qs;
1997
+
1998
+ uint32_t qh;
1999
+ memcpy(&qh, x[i].qh, sizeof(qh));
2000
+
2001
+ for (int l = 0; l < QK5_1; l += 2) {
2002
+ const uint8_t vi = pp[l/2];
2003
+
2004
+ // extract the 5-th bit from qh
2005
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
2006
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
2007
+
2008
+ const uint8_t vi0 = (vi & 0x0F) | vh0;
2009
+ const uint8_t vi1 = (vi >> 4) | vh1;
2010
+
2011
+ const float v0 = vi0*d + m;
2012
+ const float v1 = vi1*d + m;
2013
+
2014
+ y[i*QK5_1 + l + 0] = v0;
2015
+ y[i*QK5_1 + l + 1] = v1;
2016
+
2017
+ assert(!isnan(y[i*QK5_1 + l + 0]));
2018
+ assert(!isnan(y[i*QK5_1 + l + 1]));
2019
+ }
2020
+ }
2021
+ }
2022
+
2023
+ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
2024
+ assert(k % QK8_0 == 0);
2025
+ const int nb = k / QK8_0;
2026
+
2027
+ const block_q8_0 * restrict x = vx;
2028
+
2029
+ for (int i = 0; i < nb; i++) {
2030
+ const float d = x[i].d;
2031
+
2032
+ const int8_t * restrict pp = x[i].qs;
2033
+
2034
+ for (int l = 0; l < QK8_0; ++l) {
2035
+ y[i*QK8_0 + l] = pp[l]*d;
2036
+ }
2037
+ }
2038
+ }
2039
+
1424
2040
  static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2041
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2042
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2043
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2044
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2045
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2046
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
2047
 
1426
2048
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
2049
  [GGML_TYPE_Q4_0] = {
@@ -1430,15 +2052,64 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1430
2052
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1431
2053
  .quantize_row_q_dot = quantize_row_q8_0,
1432
2054
  .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
2055
+ .vec_dot_type = GGML_TYPE_Q8_0,
1433
2056
  },
1434
2057
  [GGML_TYPE_Q4_1] = {
1435
2058
  .dequantize_row_q = dequantize_row_q4_1,
1436
2059
  .quantize_row_q = quantize_row_q4_1,
1437
2060
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
- .quantize_row_q_dot = quantize_row_q4_1,
1439
- .vec_dot_q = ggml_vec_dot_q4_1,
2061
+ .quantize_row_q_dot = quantize_row_q8_1,
2062
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
2063
+ .vec_dot_type = GGML_TYPE_Q8_1,
2064
+ },
2065
+ [GGML_TYPE_Q4_2] = {
2066
+ .dequantize_row_q = dequantize_row_q4_2,
2067
+ .quantize_row_q = quantize_row_q4_2,
2068
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
2069
+ .quantize_row_q_dot = quantize_row_q8_0,
2070
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
2071
+ .vec_dot_type = GGML_TYPE_Q8_0,
2072
+ },
2073
+ [GGML_TYPE_Q4_3] = {
2074
+ .dequantize_row_q = dequantize_row_q4_3,
2075
+ .quantize_row_q = quantize_row_q4_3,
2076
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
2077
+ .quantize_row_q_dot = quantize_row_q8_1,
2078
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_1,
2079
+ .vec_dot_type = GGML_TYPE_Q8_1,
2080
+ },
2081
+ [GGML_TYPE_Q5_0] = {
2082
+ .dequantize_row_q = dequantize_row_q5_0,
2083
+ .quantize_row_q = quantize_row_q5_0,
2084
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
2085
+ .quantize_row_q_dot = quantize_row_q8_0,
2086
+ .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
2087
+ .vec_dot_type = GGML_TYPE_Q8_0,
2088
+ },
2089
+ [GGML_TYPE_Q5_1] = {
2090
+ .dequantize_row_q = dequantize_row_q5_1,
2091
+ .quantize_row_q = quantize_row_q5_1,
2092
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
2093
+ .quantize_row_q_dot = quantize_row_q8_1,
2094
+ .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
2095
+ .vec_dot_type = GGML_TYPE_Q8_1,
2096
+ },
2097
+ [GGML_TYPE_Q8_0] = {
2098
+ .dequantize_row_q = dequantize_row_q8_0,
2099
+ .quantize_row_q = quantize_row_q8_0,
2100
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
2101
+ .quantize_row_q_dot = quantize_row_q8_0,
2102
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
2103
+ .vec_dot_type = GGML_TYPE_Q8_0,
2104
+ },
2105
+ [GGML_TYPE_Q8_1] = {
2106
+ .dequantize_row_q = NULL, // TODO
2107
+ .quantize_row_q = quantize_row_q8_1,
2108
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
2109
+ .quantize_row_q_dot = quantize_row_q8_1,
2110
+ .vec_dot_q = NULL, // TODO
2111
+ .vec_dot_type = GGML_TYPE_Q8_1,
1440
2112
  },
1441
- // TODO: GGML_TYPE_Q8_0
1442
2113
  };
1443
2114
 
1444
2115
  // For internal test use
@@ -2004,191 +2675,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
2004
2675
  *s = sumf;
2005
2676
  }
2006
2677
 
2007
- #if __AVX512F__ && QK4_0 == 32
2008
- static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
- // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
- // | :. =_ () [] <> () Zz Yy|
2014
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
- // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
- //
2020
- // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
- // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
- // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
- // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
- #ifdef __AVX512VBMI__
2025
- const __m512i byte_perm = _mm512_set_epi8(
2026
- 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
- 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
- 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
- 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
- );
2031
- const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
- // After applying VPERMB, `permuted` looks like this:
2033
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
- // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
- // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
- #else
2043
- const __m512i word_perm = _mm512_set_epi16(
2044
- 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
- 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
- );
2047
- const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
- // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
- // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
- // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
- #endif
2052
-
2053
- // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
- const __mmask32 shift_mask = 0xaaaaaaaa;
2055
- const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
- // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
- // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
- // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
-
2067
- // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
- const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
- return _mm512_and_si512( low_nibble_mask, shifted );
2070
- // The final result looks like this:
2071
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
- // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
- // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
- }
2081
-
2082
- static inline __m512 dot_q4_0_twoblocks_avx512(
2083
- __m512 acc,
2084
- const block_q4_0 * restrict x,
2085
- const block_q4_0 * restrict y,
2086
- int i
2087
- ) {
2088
- // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
- // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
- // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
- // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
- const __mmask8 load_mask = 0x1f;
2093
- const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
- const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
-
2096
- // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
- // blocks_0_float
2101
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
- // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
- // blocks_1_float
2105
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
- // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
- const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
- const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
- // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
- // random data, which might very well underflow. At least on Intel, this leads
2112
- // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
- // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
- // (and ggml can't assume that you do)...
2115
- const __mmask16 scale_mul_mask = 0x21;
2116
- #ifdef __clang__
2117
- // ...however, clang decides to optimize the multiplication mask away:
2118
- // https://godbolt.org/z/P8PqdsfvW
2119
- // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
- __m512i scales;
2121
- __asm__(
2122
- "vmulps %1, %2, %0%{%3%}"
2123
- : "=v" ( scales )
2124
- : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
- );
2126
- #else
2127
- const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
- #endif
2129
- const __m512i scale_perm = _mm512_set_epi32(
2130
- 5, 5, 5, 5, 5, 5, 5, 5,
2131
- 0, 0, 0, 0, 0, 0, 0, 0
2132
- );
2133
- const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
- // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
- // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
-
2141
- const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
- const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
-
2144
- // Now we want to compute dot products of 4-element byte vectors and store them in
2145
- // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
- // +----+----+----+----+
2147
- // ... | 03 | 02 | 01 | 00 |
2148
- // +----+----+----+----+
2149
- // bytes_0
2150
- // +----+----+----+----+
2151
- // ... | D | C | B | A |
2152
- // +----+----+----+----+
2153
- // bytes_1
2154
- // +----+----+----+----+
2155
- // ... | H | G | F | E |
2156
- // +----+----+----+----+
2157
- // final_res_int
2158
- // +----+----+----+----+
2159
- // ... | A*E+B*F+C*G+D*H |
2160
- // +----+----+----+----+
2161
- const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
- const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
-
2164
- #ifdef __AVX512VNNI__
2165
- // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
- // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
- // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
- // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
- // which means we only need 2 instructions.
2170
- const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
- const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
- const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
- const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
- #else
2175
- // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
- // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
- // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
- // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
- const __m512i one = _mm512_set1_epi16( 1 );
2180
- const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
- const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
- const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
- const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
- #endif
2185
-
2186
- // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
- const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
- return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
2189
- }
2190
- #endif
2191
-
2192
2678
  inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
2193
2679
  ggml_float sumf = 0.0;
2194
2680
 
@@ -2225,67 +2711,62 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2225
2711
  *s = sumf;
2226
2712
  }
2227
2713
 
2228
- static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK4_0;
2714
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2715
+ const int nb = n / QK8_0;
2230
2716
 
2231
- assert(n % QK4_0 == 0);
2717
+ assert(n % QK8_0 == 0);
2232
2718
  assert(nb % 2 == 0);
2233
2719
 
2234
2720
  const block_q4_0 * restrict x = vx;
2235
- const block_q4_0 * restrict y = vy;
2236
-
2237
- float sumf = 0.0;
2721
+ const block_q8_0 * restrict y = vy;
2238
2722
 
2239
2723
  #if defined(__ARM_NEON)
2240
- float sum0 = 0.0f;
2241
- float sum1 = 0.0f;
2724
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2725
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2242
2726
 
2243
2727
  for (int i = 0; i < nb; i += 2) {
2244
2728
  const block_q4_0 * restrict x0 = &x[i + 0];
2245
- const block_q4_0 * restrict y0 = &y[i + 0];
2246
2729
  const block_q4_0 * restrict x1 = &x[i + 1];
2247
- const block_q4_0 * restrict y1 = &y[i + 1];
2730
+ const block_q8_0 * restrict y0 = &y[i + 0];
2731
+ const block_q8_0 * restrict y1 = &y[i + 1];
2248
2732
 
2249
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2250
- const int8x16_t s8b = vdupq_n_s8(0x8);
2733
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2734
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2251
2735
 
2252
2736
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2253
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2254
2737
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2255
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2256
2738
 
2257
2739
  // 4-bit -> 8-bit
2258
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
2259
- const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
2740
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2260
2741
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2261
- const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
2262
-
2263
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
2264
- const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
2742
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2265
2743
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2266
- const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
2267
2744
 
2268
2745
  // sub 8
2269
2746
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2270
- const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
2271
2747
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2272
- const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
2273
-
2274
2748
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2275
- const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
2276
2749
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2277
- const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
2750
+
2751
+ // load y
2752
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2753
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2754
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2755
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2756
+
2757
+ // interleave
2758
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2759
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2760
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2761
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2278
2762
 
2279
2763
  #if defined(__ARM_FEATURE_DOTPROD)
2280
2764
  // dot product into int32x4_t
2281
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2282
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2765
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2766
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2283
2767
 
2284
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2285
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2286
-
2287
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2288
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2768
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2769
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2289
2770
  #else
2290
2771
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2291
2772
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2297,125 +2778,41 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2297
2778
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2298
2779
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2299
2780
 
2300
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2301
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2302
-
2303
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2304
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2305
-
2306
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2307
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2781
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2782
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2783
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2784
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2308
2785
 
2309
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2310
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2786
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2787
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2311
2788
  #endif
2312
2789
  }
2313
2790
 
2314
- sumf = sum0 + sum1;
2315
- #elif defined(__AVX512F__)
2791
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2792
+ #elif defined(__AVX2__)
2316
2793
  // Initialize accumulator with zeros
2317
- __m512 acc0 = _mm512_setzero_ps();
2318
- __m512 acc1 = _mm512_setzero_ps();
2319
-
2320
- const int superblock_size = 16;
2321
-
2322
- const int superblock_count = nb / superblock_size;
2794
+ __m256 acc = _mm256_setzero_ps();
2323
2795
 
2324
- for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
2325
- int i = superblock_ix * superblock_size;
2796
+ // Main loop
2797
+ for (int i = 0; i < nb; ++i) {
2798
+ /* Compute combined scale for the block */
2799
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2326
2800
 
2327
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
2335
- }
2801
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2336
2802
 
2337
- // Remainders
2338
- for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
2340
- }
2803
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2804
+ const __m256i off = _mm256_set1_epi8( 8 );
2805
+ bx = _mm256_sub_epi8( bx, off );
2341
2806
 
2342
- // Horizontal sum of all lanes of the accumulator
2343
- sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
2344
- #elif defined(__AVX2__)
2345
- // Initialize accumulator with zeros
2346
- __m256 acc = _mm256_setzero_ps();
2807
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2347
2808
 
2348
- /* Prepare the constants we will need during execution */
2349
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
2350
- const __m256i offset_8 = _mm256_set1_epi16( 8 );
2809
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2351
2810
 
2352
- #define UNROLL_COUNT 8
2353
- // make sure we only unroll multiples of the block count
2354
- assert(nb % UNROLL_COUNT == 0);
2811
+ /* Multiply q with scale and accumulate */
2812
+ acc = _mm256_fmadd_ps( d, q, acc );
2813
+ }
2355
2814
 
2356
- // Main loop
2357
- for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2358
- // This loop will be unrolled by the compiler
2359
- for (int u=0;u<UNROLL_COUNT;u++) {
2360
- /* Compute combined scale for the block */
2361
- const __m256 scale = _mm256_mul_ps(
2362
- _mm256_broadcast_ss( &x[i+u].d ),
2363
- _mm256_broadcast_ss( &y[i+u].d ) );
2364
-
2365
- /* get input from x
2366
- Input: 32 Nibbles (16 bytes) at *x[i+u]
2367
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2368
-
2369
- /* Load 16 bytes from memory */
2370
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2371
- /* Expand bytes into uint16_t values */
2372
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2373
- /* Unpack values into individual bytes */
2374
- __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2375
- const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2376
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2377
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2378
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2379
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2380
-
2381
- /* get input from y
2382
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2383
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2384
-
2385
- /* Load 16 bytes from memory */
2386
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2387
- /* Expand bytes into uint16_t values */
2388
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2389
- /* Unpack values into individual bytes */
2390
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2391
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2392
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2393
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2394
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2395
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2396
-
2397
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2398
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2399
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2400
-
2401
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2402
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2403
-
2404
- /* Convert to vectore of 8 int32_t to 8 floats */
2405
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2406
-
2407
- /* Multiply q with scale and accumulate */
2408
- acc = _mm256_fmadd_ps( scale, q, acc );
2409
- }
2410
- }
2411
-
2412
- // Return horizontal sum of the acc vector
2413
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2414
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2415
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2416
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2417
-
2418
- sumf = _mm_cvtss_f32( res );
2815
+ *s = hsum_float_8(acc);
2419
2816
  #elif defined(__AVX__)
2420
2817
  // Initialize accumulator with zeros
2421
2818
  __m256 acc = _mm256_setzero_ps();
@@ -2428,13 +2825,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2428
2825
  __m128i i32[2];
2429
2826
  for (int j = 0; j < 2; ++j) {
2430
2827
  // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2431
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2432
- __m128i by = bytesFromNibbles( y[i].qs + 8*j );
2828
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2829
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2433
2830
 
2434
2831
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2435
2832
  const __m128i off = _mm_set1_epi8( 8 );
2436
2833
  bx = _mm_sub_epi8( bx, off );
2437
- by = _mm_sub_epi8( by, off );
2438
2834
 
2439
2835
  // Get absolute values of x vectors
2440
2836
  const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2445,516 +2841,833 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2445
2841
  // Perform multiplication and create 16-bit values
2446
2842
  const __m128i dot = _mm_maddubs_epi16(ax, sy);
2447
2843
 
2448
- const __m128i ones = _mm_set1_epi16(1);
2449
- i32[j] = _mm_madd_epi16(ones, dot);
2450
- }
2844
+ const __m128i ones = _mm_set1_epi16(1);
2845
+ i32[j] = _mm_madd_epi16(ones, dot);
2846
+ }
2847
+
2848
+ // Convert int32_t to float
2849
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2850
+ // Apply the scale, and accumulate
2851
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2852
+ }
2853
+
2854
+ *s = hsum_float_8(acc);
2855
+ #else
2856
+ // scalar
2857
+ float sumf = 0.0;
2858
+ for (int i = 0; i < nb; i++) {
2859
+ const float d0 = x[i].d;
2860
+ const float d1 = y[i].d;
2861
+
2862
+ const uint8_t * restrict p0 = x[i].qs;
2863
+ const int8_t * restrict p1 = y[i].qs;
2864
+
2865
+ int sumi = 0;
2866
+ for (int j = 0; j < QK8_0/2; j++) {
2867
+ const uint8_t v0 = p0[j];
2868
+
2869
+ const int i0 = (int8_t) (v0 & 0x0F) - 8;
2870
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2871
+
2872
+ const int i2 = p1[2*j + 0];
2873
+ const int i3 = p1[2*j + 1];
2874
+
2875
+ sumi += i0*i2 + i1*i3;
2876
+ }
2877
+ sumf += d0*d1*sumi;
2878
+ }
2879
+ *s = sumf;
2880
+ #endif
2881
+ }
2882
+
2883
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2884
+ const int nb = n / QK8_1;
2885
+
2886
+ assert(n % QK8_1 == 0);
2887
+ assert(nb % 2 == 0);
2888
+
2889
+ const block_q4_1 * restrict x = vx;
2890
+ const block_q8_1 * restrict y = vy;
2891
+
2892
+ // TODO: add AVX / WASM SIMD / etc
2893
+ #if defined(__ARM_NEON)
2894
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2895
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2896
+
2897
+ float summs = 0;
2898
+
2899
+ for (int i = 0; i < nb; i += 2) {
2900
+ const block_q4_1 * restrict x0 = &x[i + 0];
2901
+ const block_q4_1 * restrict x1 = &x[i + 1];
2902
+ const block_q8_1 * restrict y0 = &y[i + 0];
2903
+ const block_q8_1 * restrict y1 = &y[i + 1];
2904
+
2905
+ summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2906
+
2907
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2908
+
2909
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2910
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2911
+
2912
+ // 4-bit -> 8-bit
2913
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2914
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2915
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2916
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2917
+
2918
+ // interleave
2919
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2920
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2921
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2922
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2923
+
2924
+ // load y
2925
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2926
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2927
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2928
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2929
+
2930
+ #if defined(__ARM_FEATURE_DOTPROD)
2931
+ // dot product into int32x4_t
2932
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2933
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2934
+
2935
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2936
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2937
+ #else
2938
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2939
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2940
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2941
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2942
+
2943
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2944
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2945
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2946
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2947
+
2948
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2949
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2950
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2951
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2952
+
2953
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2954
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2955
+ #endif
2956
+ }
2957
+
2958
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2959
+ #elif defined(__AVX2__)
2960
+ // Initialize accumulator with zeros
2961
+ __m256 acc = _mm256_setzero_ps();
2962
+
2963
+ float summs = 0;
2964
+
2965
+ // Main loop
2966
+ for (int i = 0; i < nb; ++i) {
2967
+ const float * d0 = &x[i].d;
2968
+ const float * d1 = &y[i].d;
2969
+
2970
+ summs += x[i].m * (y[i].s0 + y[i].s1);
2971
+
2972
+ const __m256 d0v = _mm256_broadcast_ss( d0 );
2973
+ const __m256 d1v = _mm256_broadcast_ss( d1 );
2974
+
2975
+ // Compute combined scales
2976
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2977
+
2978
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2979
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2980
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2981
+
2982
+ const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2983
+
2984
+ // Accumulate d0*d1*x*y
2985
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
2986
+ }
2987
+
2988
+ *s = hsum_float_8(acc) + summs;
2989
+ #else
2990
+ // scalar
2991
+ float sumf = 0.0;
2992
+ for (int i = 0; i < nb; i++) {
2993
+ const float d0 = x[i].d;
2994
+ const float m0 = x[i].m;
2995
+ const float d1 = y[i].d;
2996
+
2997
+ const uint8_t * restrict p0 = x[i].qs;
2998
+ const int8_t * restrict p1 = y[i].qs;
2999
+
3000
+ // TODO: this is very slow ..
3001
+ for (int j = 0; j < QK8_1/2; j++) {
3002
+ const uint8_t v0 = p0[j];
3003
+
3004
+ const float f0 = d0*(v0 & 0x0F) + m0;
3005
+ const float f1 = d0*(v0 >> 4) + m0;
3006
+
3007
+ const float f2 = d1*p1[2*j + 0];
3008
+ const float f3 = d1*p1[2*j + 1];
3009
+
3010
+ sumf += f0*f2 + f1*f3;
3011
+ }
3012
+ }
3013
+ *s = sumf;
3014
+ #endif
3015
+ }
3016
+
3017
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3018
+ const int nb = n / QK8_0;
3019
+
3020
+ assert(n % QK8_0 == 0);
3021
+ assert(nb % 2 == 0);
3022
+ assert(QK8_0 == 2*QK4_2);
3023
+
3024
+ const block_q4_2 * restrict x = vx;
3025
+ const block_q8_0 * restrict y = vy;
3026
+
3027
+ #if defined(__ARM_NEON)
3028
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3029
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3030
+
3031
+ for (int i = 0; i < nb; i += 2) {
3032
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
3033
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
3034
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
3035
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
3036
+
3037
+ const block_q8_0 * restrict y0 = &y[i + 0];
3038
+ const block_q8_0 * restrict y1 = &y[i + 1];
3039
+
3040
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3041
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3042
+
3043
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3044
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
3045
+
3046
+ // 4-bit -> 8-bit
3047
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3048
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3049
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3050
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3051
+
3052
+ // sub 8
3053
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
3054
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
3055
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
3056
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3057
+
3058
+ // interleave
3059
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
3060
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
3061
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
3062
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
3063
+
3064
+ // load y
3065
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3066
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3067
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
3068
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3069
+
3070
+ #if defined(__ARM_FEATURE_DOTPROD)
3071
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3072
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
3073
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3074
+
3075
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3076
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3077
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3078
+ #else
3079
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3080
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3081
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3082
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3083
+
3084
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3085
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3086
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3087
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3088
+
3089
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3090
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3091
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3092
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3093
+
3094
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3095
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3096
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3097
+
3098
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3099
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3100
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3101
+ #endif
3102
+ }
3103
+
3104
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3105
+ #elif defined(__AVX2__)
3106
+ // Initialize accumulator with zeros
3107
+ __m256 acc = _mm256_setzero_ps();
3108
+
3109
+ // Main loop
3110
+ for (int i = 0; i < nb; i++) {
3111
+ /* Compute combined scale for the block */
3112
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3113
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3114
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3115
+
3116
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3117
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3118
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
3119
+
3120
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3121
+ const __m256i off = _mm256_set1_epi8(8);
3122
+ bx = _mm256_sub_epi8(bx, off);
3123
+
3124
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3125
+
3126
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3127
+
3128
+ /* Multiply q with scale and accumulate */
3129
+ acc = _mm256_fmadd_ps(d, q, acc);
3130
+ }
3131
+
3132
+ *s = hsum_float_8(acc);
3133
+ #else
3134
+ // scalar
3135
+ float sumf = 0.0;
3136
+ for (int i = 0; i < nb; i++) {
3137
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3138
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3139
+ const int8_t * restrict y0 = y[i].qs;
3140
+
3141
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3142
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3143
+
3144
+ int sumi_0 = 0;
3145
+ int sumi_1 = 0;
3146
+
3147
+ for (int j = 0; j < QK8_0/4; j++) {
3148
+ const uint8_t v0 = x0[j];
3149
+ const uint8_t v1 = x1[j];
3150
+
3151
+ const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3152
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
3153
+
3154
+ const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3155
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
3156
+
3157
+ const int i2_0 = y0[2*j + 0];
3158
+ const int i3_0 = y0[2*j + 1];
3159
+
3160
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
3161
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3162
+
3163
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
3164
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3165
+ }
3166
+
3167
+ sumf += (d0 * y[i].d) * sumi_0;
3168
+ sumf += (d1 * y[i].d) * sumi_1;
3169
+ }
3170
+ *s = sumf;
3171
+ #endif
3172
+ }
3173
+
3174
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3175
+ const int nb = n / QK8_1;
3176
+
3177
+ assert(n % QK8_1 == 0);
3178
+ assert(nb % 2 == 0);
3179
+ assert(QK8_1 == 2*QK4_3);
3180
+
3181
+ const block_q4_3 * restrict x = vx;
3182
+ const block_q8_1 * restrict y = vy;
3183
+
3184
+ #if defined(__ARM_NEON)
3185
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3186
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3187
+
3188
+ float summs0 = 0.0f;
3189
+ float summs1 = 0.0f;
3190
+
3191
+ for (int i = 0; i < nb; ++i) {
3192
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
3193
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
3194
+
3195
+ const block_q8_1 * restrict y0 = &y[i + 0];
3196
+
3197
+ summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
3198
+ summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
3199
+
3200
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3201
+
3202
+ // 4-bit -> 8-bit
3203
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
3204
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3205
+
3206
+ // interleave
3207
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3208
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3209
+
3210
+ // load y
3211
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3212
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3213
+
3214
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
3215
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
3216
+
3217
+ #if defined(__ARM_FEATURE_DOTPROD)
3218
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
3219
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
3220
+ #else
3221
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3222
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3223
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3224
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3225
+
3226
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3227
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3228
+
3229
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
3230
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
3231
+ #endif
3232
+ }
3233
+
3234
+ *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
3235
+ #elif defined(__AVX2__)
3236
+ // Initialize accumulator with zeros
3237
+ __m256 acc = _mm256_setzero_ps();
3238
+ float summs = 0.0f;
3239
+
3240
+ // Main loop
3241
+ for (int i = 0; i < nb; i++) {
3242
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3243
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3244
+ const __m256 dx = _mm256_set_m128(d1, d0);
3245
+
3246
+ summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
3247
+ + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
3248
+
3249
+ const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3250
+ const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3251
+ const __m256i bx = _mm256_set_m128i(bx1, bx0);
3252
+
3253
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3254
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3255
+
3256
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3257
+
3258
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
3259
+ }
3260
+
3261
+ *s = hsum_float_8(acc) + summs;
3262
+ #else
3263
+ // scalar
3264
+ float sumf = 0.0;
3265
+ for (int i = 0; i < nb; i++) {
3266
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3267
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3268
+ const int8_t * restrict y0 = y[i].qs;
3269
+
3270
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3271
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3272
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3273
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3274
+
3275
+ int sxy_0 = 0;
3276
+ int sxy_1 = 0;
3277
+
3278
+ for (int j = 0; j < QK8_1/4; j++) {
3279
+ const uint8_t v0 = x0[j];
3280
+ const uint8_t v1 = x1[j];
3281
+
3282
+ const int x0_0 = v0 & 0x0F;
3283
+ const int x1_0 = v0 >> 4;
3284
+
3285
+ const int x0_1 = v1 & 0x0F;
3286
+ const int x1_1 = v1 >> 4;
3287
+
3288
+ const int y0_0 = y0[2*j + 0];
3289
+ const int y1_0 = y0[2*j + 1];
3290
+
3291
+ const int y0_1 = y0[2*(j + QK8_1/4) + 0];
3292
+ const int y1_1 = y0[2*(j + QK8_1/4) + 1];
3293
+
3294
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3295
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3296
+ }
3297
+
3298
+ sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
3299
+ }
3300
+ *s = sumf;
3301
+ #endif
3302
+ }
3303
+
3304
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3305
+ const int nb = n / QK8_0;
3306
+
3307
+ assert(n % QK8_0 == 0);
3308
+ assert(nb % 2 == 0);
3309
+ assert(QK8_0 == QK5_0);
3310
+
3311
+ const block_q5_0 * restrict x = vx;
3312
+ const block_q8_0 * restrict y = vy;
2451
3313
 
2452
- // Convert int32_t to float
2453
- __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2454
- // Apply the scale, and accumulate
2455
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2456
- }
3314
+ #if defined(__ARM_NEON)
3315
+ float32x4_t sumv = vdupq_n_f32(0.0f);
2457
3316
 
2458
- // Return horizontal sum of the acc vector
2459
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2460
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2461
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2462
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
3317
+ uint64_t tmp[4];
2463
3318
 
2464
- sumf = _mm_cvtss_f32( res );
2465
- #elif defined(__wasm_simd128__)
2466
- // wasm simd
2467
- float sum0 = 0.0f;
2468
- float sum1 = 0.0f;
3319
+ for (int i = 0; i < nb; ++i) {
3320
+ const block_q5_0 * restrict x0 = &x[i];
3321
+ const block_q8_0 * restrict y0 = &y[i];
2469
3322
 
2470
- for (int i = 0; i < nb; i += 2) {
2471
- const block_q4_0 * restrict x0 = &x[i + 0];
2472
- const block_q4_0 * restrict y0 = &y[i + 0];
2473
- const block_q4_0 * restrict x1 = &x[i + 1];
2474
- const block_q4_0 * restrict y1 = &y[i + 1];
3323
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3324
+ const int8x16_t s16b = vdupq_n_s8(0x10);
2475
3325
 
2476
- const v128_t m4b = wasm_u8x16_splat(0xf);
2477
- const v128_t s8b = wasm_i8x16_splat(0x8);
3326
+ // extract the 5th bit
3327
+ uint32_t qh;
3328
+ memcpy(&qh, x0->qh, sizeof(qh));
2478
3329
 
2479
- const v128_t v0_0 = wasm_v128_load(x0->qs);
2480
- const v128_t v0_1 = wasm_v128_load(y0->qs);
2481
- const v128_t v1_0 = wasm_v128_load(x1->qs);
2482
- const v128_t v1_1 = wasm_v128_load(y1->qs);
3330
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3331
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3332
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3333
+ tmp[3] = table_b2b_u[(qh >> 24) ];
2483
3334
 
2484
- // 4-bit -> 8-bit
2485
- const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2486
- const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
3335
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3336
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
2487
3337
 
2488
- const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2489
- const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
3338
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
2490
3339
 
2491
- const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2492
- const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
3340
+ // 4-bit -> 8-bit
3341
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
3342
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
2493
3343
 
2494
- const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2495
- const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
3344
+ // interleave
3345
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3346
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2496
3347
 
2497
- // sub 8
2498
- const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2499
- const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
3348
+ // add high bit and sub 16
3349
+ const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
3350
+ const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
2500
3351
 
2501
- const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2502
- const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
3352
+ // load y
3353
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3354
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
2503
3355
 
2504
- const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2505
- const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
3356
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2506
3357
 
2507
- const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2508
- const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
3358
+ #if defined(__ARM_FEATURE_DOTPROD)
3359
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3360
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3361
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3362
+ #else
3363
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3364
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3365
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3366
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2509
3367
 
2510
- // dot product into int16x8_t
2511
- const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
2512
- const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
3368
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3369
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2513
3370
 
2514
- const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
2515
- const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
3371
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3372
+ #endif
3373
+ }
2516
3374
 
2517
- const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
2518
- const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
3375
+ *s = vaddvq_f32(sumv);
3376
+ #elif defined(__AVX2__)
3377
+ // Initialize accumulator with zeros
3378
+ __m256 acc = _mm256_setzero_ps();
2519
3379
 
2520
- const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
2521
- const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
3380
+ // Main loop
3381
+ for (int i = 0; i < nb; i++) {
3382
+ /* Compute combined scale for the block */
3383
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
2522
3384
 
2523
- const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
2524
- const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
3385
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3386
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3387
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3388
+ bx = _mm256_or_si256(bx, bxhi);
2525
3389
 
2526
- const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
2527
- const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
3390
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2528
3391
 
2529
- const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
2530
- const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
3392
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2531
3393
 
2532
- sum0 += x0->d * y0->d * (
2533
- wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
2534
- wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
2535
- wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
2536
- wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
2537
- sum1 += x1->d * y1->d * (
2538
- wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
2539
- wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
2540
- wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
2541
- wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
3394
+ /* Multiply q with scale and accumulate */
3395
+ acc = _mm256_fmadd_ps(d, q, acc);
2542
3396
  }
2543
3397
 
2544
- sumf = sum0 + sum1;
3398
+ *s = hsum_float_8(acc);
2545
3399
  #else
2546
3400
  // scalar
3401
+ float sumf = 0.0;
2547
3402
  for (int i = 0; i < nb; i++) {
2548
- const float d0 = x[i].d;
2549
- const float d1 = y[i].d;
3403
+ const uint8_t * restrict x0 = x[i].qs;
3404
+ const int8_t * restrict y0 = y[i].qs;
2550
3405
 
2551
- const uint8_t * restrict p0 = x[i].qs;
2552
- const uint8_t * restrict p1 = y[i].qs;
3406
+ uint32_t qh;
3407
+ memcpy(&qh, x[i].qh, sizeof(qh));
2553
3408
 
2554
- int sumi = 0;
2555
- for (int j = 0; j < QK4_0/2; j++) {
2556
- const uint8_t v0 = p0[j];
2557
- const uint8_t v1 = p1[j];
3409
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3410
+
3411
+ int sxy = 0;
2558
3412
 
2559
- const int i0 = (v0 & 0xf) - 8;
2560
- const int i1 = (v0 >> 4) - 8;
3413
+ for (int j = 0; j < QK8_0/2; j++) {
3414
+ const uint8_t v0 = x0[j];
2561
3415
 
2562
- const int i2 = (v1 & 0xf) - 8;
2563
- const int i3 = (v1 >> 4) - 8;
3416
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3417
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
2564
3418
 
2565
- sumi += i0*i2 + i1*i3;
3419
+ const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
3420
+ const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
3421
+
3422
+ const int y0_0 = y0[2*j + 0];
3423
+ const int y1_0 = y0[2*j + 1];
3424
+
3425
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2566
3426
  }
2567
- sumf += d0 * d1 * sumi;
2568
- }
2569
- #endif
2570
3427
 
3428
+ sumf += (d*sxy)*y[i].d;
3429
+ }
2571
3430
  *s = sumf;
3431
+ #endif
2572
3432
  }
2573
3433
 
2574
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2575
- const int nb = n / QK4_1;
3434
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3435
+ const int nb = n / QK8_1;
2576
3436
 
2577
- const block_q4_1 * restrict x = vx;
2578
- const block_q4_1 * restrict y = vy;
2579
-
2580
- float sumf = 0.0;
3437
+ assert(n % QK8_1 == 0);
3438
+ assert(nb % 2 == 0);
3439
+ assert(QK8_1 == QK5_1);
2581
3440
 
2582
- #if defined(__AVX2__)
2583
- // Initialize accumulator with zeros
2584
- __m256 acc = _mm256_setzero_ps();
2585
- // Accumulator for constant offsets
2586
- float acc_offset = 0.0f;
3441
+ const block_q5_1 * restrict x = vx;
3442
+ const block_q8_1 * restrict y = vy;
2587
3443
 
2588
- // Main loop
2589
- for (int i = 0; i < nb; ++i) {
2590
- const float * d0 = &x[i].d;
2591
- const float * d1 = &y[i].d;
3444
+ #if defined(__ARM_NEON)
3445
+ float32x4_t sumv = vdupq_n_f32(0.0f);
2592
3446
 
2593
- const float * m0 = &x[i].m;
2594
- const float * m1 = &y[i].m;
3447
+ float summs = 0.0f;
2595
3448
 
2596
- const __m256 d0v = _mm256_broadcast_ss( d0 );
2597
- const __m256 d1v = _mm256_broadcast_ss( d1 );
2598
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2599
- const __m256 m1v = _mm256_broadcast_ss( m1 );
3449
+ uint64_t tmp[4];
2600
3450
 
2601
- // Compute combined scale for the block
2602
- const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
3451
+ for (int i = 0; i < nb; ++i) {
3452
+ const block_q5_1 * restrict x0 = &x[i];
3453
+ const block_q8_1 * restrict y0 = &y[i];
2603
3454
 
2604
- // Compute cross scales for the block
2605
- const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2606
- const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2607
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
3455
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2608
3456
 
2609
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2610
- __m256i bx = bytesFromNibbles( x[i].qs );
2611
- __m256i by = bytesFromNibbles( y[i].qs );
2612
-
2613
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2614
-
2615
- // Sign-extend first 16 signed bytes into int16_t
2616
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2617
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2618
- // Compute products of int16_t integers, add pairwise
2619
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2620
-
2621
- // Sign-extend last 16 signed bytes into int16_t vectors
2622
- __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2623
- __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2624
- // Accumulate products of int16_t integers
2625
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2626
-
2627
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2628
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2629
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2630
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2631
- __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2632
- __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2633
- __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2634
- __m256 sums = _mm256_cvtepi32_ps( sumsi );
3457
+ // extract the 5th bit
3458
+ uint32_t qh;
3459
+ memcpy(&qh, x0->qh, sizeof(qh));
2635
3460
 
2636
- // Convert int32_t to float
2637
- __m256 p = _mm256_cvtepi32_ps( i32 );
2638
- // Apply the scale, and accumulate
2639
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2640
- acc = _mm256_fmadd_ps( scale_01, p, acc );
2641
- acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2642
- // acc_offset += m0*m1 (for each entry in the block)
2643
- acc_offset += (*m0)*(*m1);
2644
- }
3461
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3462
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3463
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3464
+ tmp[3] = table_b2b_u[(qh >> 24) ];
2645
3465
 
2646
- // Return horizontal sum of the acc vector
2647
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2648
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2649
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2650
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
3466
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3467
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
2651
3468
 
2652
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2653
- #elif defined(__ARM_NEON)
2654
- float sum00 = 0.0f;
2655
- float sum01 = 0.0f;
2656
- float sum10 = 0.0f;
2657
- float sum11 = 0.0f;
3469
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
2658
3470
 
2659
- for (int i = 0; i < nb; i += 2) {
2660
- const block_q4_1 * restrict x0 = &x[i + 0];
2661
- const block_q4_1 * restrict y0 = &y[i + 0];
2662
- const block_q4_1 * restrict x1 = &x[i + 1];
2663
- const block_q4_1 * restrict y1 = &y[i + 1];
3471
+ // 4-bit -> 8-bit
3472
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3473
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
2664
3474
 
2665
- const uint8x16_t m4b = vdupq_n_u8(0xf);
3475
+ // interleave
3476
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3477
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2666
3478
 
2667
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2668
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2669
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2670
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
3479
+ // add
3480
+ const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3481
+ const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
2671
3482
 
2672
- // 4-bit -> 8-bit
2673
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2674
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2675
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2676
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
3483
+ // load y
3484
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3485
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
2677
3486
 
2678
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2679
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2680
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2681
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
3487
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2682
3488
 
2683
- sum00 += x0->m*y0->m;
2684
- sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
- sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
3489
+ #if defined(__ARM_FEATURE_DOTPROD)
3490
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3491
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3492
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3493
+ #else
3494
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3495
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3496
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3497
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2686
3498
 
2687
- sum00 += x1->m*y1->m;
2688
- sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
- sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
3499
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3500
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2690
3501
 
2691
- #if defined(__ARM_FEATURE_DOTPROD)
2692
- // dot product into int32x4_t
2693
- uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2694
- uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
3502
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3503
+ #endif
3504
+ }
2695
3505
 
2696
- p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2697
- p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
3506
+ *s = vaddvq_f32(sumv) + summs;
3507
+ #elif defined(__AVX2__)
3508
+ // Initialize accumulator with zeros
3509
+ __m256 acc = _mm256_setzero_ps();
3510
+ float summs = 0.0f;
2698
3511
 
2699
- sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2700
- sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2701
- #else
2702
- const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2703
- const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2704
- const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2705
- const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
3512
+ // Main loop
3513
+ for (int i = 0; i < nb; i++) {
3514
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2706
3515
 
2707
- const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2708
- const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2709
- const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2710
- const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
3516
+ summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2711
3517
 
2712
- const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2713
- const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
3518
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3519
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3520
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3521
+ bx = _mm256_or_si256(bx, bxhi);
2714
3522
 
2715
- const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2716
- const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
3523
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3524
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2717
3525
 
2718
- const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2719
- const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
3526
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2720
3527
 
2721
- sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2722
- sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2723
- #endif
3528
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2724
3529
  }
2725
3530
 
2726
- sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
3531
+ *s = hsum_float_8(acc) + summs;
2727
3532
  #else
2728
- // scalar
3533
+ float sumf = 0.0;
3534
+
2729
3535
  for (int i = 0; i < nb; i++) {
2730
- const float d0 = x[i].d;
2731
- const float d1 = y[i].d;
3536
+ const uint8_t * restrict x0 = x[i].qs;
3537
+ const int8_t * restrict y0 = y[i].qs;
2732
3538
 
2733
- const float m0 = x[i].m;
2734
- const float m1 = y[i].m;
3539
+ uint32_t qh;
3540
+ memcpy(&qh, x[i].qh, sizeof(qh));
2735
3541
 
2736
- const uint8_t * restrict p0 = x[i].qs;
2737
- const uint8_t * restrict p1 = y[i].qs;
3542
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3543
+ const float m = GGML_FP16_TO_FP32(x[i].m);
2738
3544
 
2739
- for (int j = 0; j < QK4_1/2; j++) {
2740
- const uint8_t v0 = p0[j];
2741
- const uint8_t v1 = p1[j];
3545
+ int sxy = 0;
2742
3546
 
2743
- const float f0 = d0*(v0 & 0xf) + m0;
2744
- const float f1 = d0*(v0 >> 4) + m0;
3547
+ for (int j = 0; j < QK8_1/2; j++) {
3548
+ const uint8_t v0 = x0[j];
2745
3549
 
2746
- const float f2 = d1*(v1 & 0xf) + m1;
2747
- const float f3 = d1*(v1 >> 4) + m1;
3550
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3551
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
2748
3552
 
2749
- sumf += f0*f2 + f1*f3;
3553
+ const int x0_0 = (v0 & 0x0F) | x0_0h;
3554
+ const int x1_0 = (v0 >> 4) | x1_0h;
3555
+
3556
+ const int y0_0 = y0[2*j + 0];
3557
+ const int y1_0 = y0[2*j + 1];
3558
+
3559
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2750
3560
  }
3561
+
3562
+ sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2751
3563
  }
2752
- #endif
2753
3564
 
2754
3565
  *s = sumf;
3566
+ #endif
2755
3567
  }
2756
3568
 
2757
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3569
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
3570
  const int nb = n / QK8_0;
2759
3571
 
2760
3572
  assert(n % QK8_0 == 0);
2761
3573
  assert(nb % 2 == 0);
3574
+ assert(QK8_0 == QK8_0);
2762
3575
 
2763
- const block_q4_0 * restrict x = vx;
3576
+ const block_q8_0 * restrict x = vx;
2764
3577
  const block_q8_0 * restrict y = vy;
2765
3578
 
2766
- float sumf = 0.0;
2767
-
2768
3579
  #if defined(__ARM_NEON)
2769
- float sum0 = 0.0f;
2770
- float sum1 = 0.0f;
3580
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3581
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2771
3582
 
2772
3583
  for (int i = 0; i < nb; i += 2) {
2773
- const block_q4_0 * restrict x0 = &x[i + 0];
2774
- const block_q4_0 * restrict x1 = &x[i + 1];
3584
+ const block_q8_0 * restrict x0 = &x[i + 0];
3585
+ const block_q8_0 * restrict x1 = &x[i + 1];
2775
3586
  const block_q8_0 * restrict y0 = &y[i + 0];
2776
3587
  const block_q8_0 * restrict y1 = &y[i + 1];
2777
3588
 
2778
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
- const int8x16_t s8b = vdupq_n_s8(0x8);
2780
-
2781
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2783
-
2784
- // 4-bit -> 8-bit
2785
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2786
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2787
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2788
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2789
-
2790
- // sub 8
2791
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2792
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2793
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3589
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3590
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3591
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3592
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
2795
3593
 
2796
3594
  // load y
2797
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2801
-
2802
- // interleave
2803
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
3595
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3596
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3597
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3598
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
2807
3599
 
2808
3600
  #if defined(__ARM_FEATURE_DOTPROD)
2809
- // dot product into int32x4_t
2810
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
3601
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3602
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3603
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
2812
3604
 
2813
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
3605
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3606
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3607
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
2815
3608
 
2816
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2818
3609
  #else
2819
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2823
-
2824
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
-
2829
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
-
2832
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
-
2835
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
-
2838
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
3610
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3611
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3612
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3613
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3614
+
3615
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3616
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3617
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3618
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3619
+
3620
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3621
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3622
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3623
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3624
+
3625
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3626
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
2840
3627
  #endif
2841
3628
  }
2842
3629
 
2843
- sumf = sum0 + sum1;
3630
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2844
3631
  #elif defined(__AVX2__)
2845
3632
  // Initialize accumulator with zeros
2846
3633
  __m256 acc = _mm256_setzero_ps();
2847
3634
 
2848
3635
  // Main loop
2849
3636
  for (int i = 0; i < nb; ++i) {
2850
- /* Compute combined scale for the block */
3637
+ // Compute combined scale for the block
2851
3638
  const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2852
-
2853
- __m256i bx = bytesFromNibbles(x[i].qs);
2854
-
2855
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
- const __m256i off = _mm256_set1_epi8( 8 );
2857
- bx = _mm256_sub_epi8( bx, off );
2858
-
3639
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
2859
3640
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
3641
 
2861
- // Get absolute values of x vectors
2862
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
-
2864
- // Sign the values of the y vectors
2865
- const __m256i sy = _mm256_sign_epi8(by, bx);
3642
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2866
3643
 
2867
- // Perform multiplication and create 16-bit values
2868
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
-
2870
- const __m256i ones = _mm256_set1_epi16(1);
2871
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
-
2873
- /* Convert to vectore of 8 int32_t to 8 floats */
2874
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2875
-
2876
- /* Multiply q with scale and accumulate */
3644
+ // Multiply q with scale and accumulate
2877
3645
  acc = _mm256_fmadd_ps( d, q, acc );
2878
3646
  }
2879
3647
 
2880
- // Return horizontal sum of the acc vector
2881
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2885
-
2886
- sumf = _mm_cvtss_f32( res );
2887
- #elif defined(__AVX__)
2888
- // Initialize accumulator with zeros
2889
- __m256 acc = _mm256_setzero_ps();
2890
-
2891
- // Main loop
2892
- for (int i = 0; i < nb; ++i) {
2893
- // Compute combined scale for the block
2894
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2895
-
2896
- __m128i i32[2];
2897
- for (int j = 0; j < 2; ++j) {
2898
- // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
- __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2901
-
2902
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
- const __m128i off = _mm_set1_epi8( 8 );
2904
- bx = _mm_sub_epi8( bx, off );
2905
-
2906
- // Get absolute values of x vectors
2907
- const __m128i ax = _mm_sign_epi8(bx, bx);
2908
-
2909
- // Sign the values of the y vectors
2910
- const __m128i sy = _mm_sign_epi8(by, bx);
2911
-
2912
- // Perform multiplication and create 16-bit values
2913
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
2914
-
2915
- const __m128i ones = _mm_set1_epi16(1);
2916
- i32[j] = _mm_madd_epi16(ones, dot);
2917
- }
2918
-
2919
- // Convert int32_t to float
2920
- __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
- // Apply the scale, and accumulate
2922
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2923
- }
2924
-
2925
- // Return horizontal sum of the acc vector
2926
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
-
2931
- sumf = _mm_cvtss_f32( res );
3648
+ *s = hsum_float_8(acc);
2932
3649
  #else
2933
3650
  // scalar
2934
- for (int i = 0; i < nb; i++) {
2935
- const float d0 = x[i].d;
2936
- const float d1 = y[i].d;
3651
+ float sumf = 0.0;
2937
3652
 
2938
- const uint8_t * restrict p0 = x[i].qs;
2939
- const int8_t * restrict p1 = y[i].qs;
3653
+ for (int i = 0; i < nb; i++) {
3654
+ const int8_t * restrict x0 = x[i].qs;
3655
+ const int8_t * restrict y0 = y[i].qs;
2940
3656
 
2941
3657
  int sumi = 0;
2942
- for (int j = 0; j < QK8_0/2; j++) {
2943
- const uint8_t v0 = p0[j];
2944
3658
 
2945
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
- const int i1 = (int8_t) (v0 >> 4) - 8;
3659
+ for (int j = 0; j < QK8_0; j++) {
3660
+ const int v0 = x0[j];
3661
+ const int v1 = y0[j];
2947
3662
 
2948
- const int i2 = p1[2*j + 0];
2949
- const int i3 = p1[2*j + 1];
2950
-
2951
- sumi += i0*i2 + i1*i3;
3663
+ sumi += v0*v1;
2952
3664
  }
2953
- sumf += d0*d1*sumi;
3665
+
3666
+ sumf += (x[i].d*y[i].d)*sumi;
2954
3667
  }
2955
- #endif
2956
3668
 
2957
3669
  *s = sumf;
3670
+ #endif
2958
3671
  }
2959
3672
 
2960
3673
  // compute GGML_VEC_DOT_UNROLL dot products at once
@@ -3153,6 +3866,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
3153
3866
  #endif
3154
3867
  }
3155
3868
 
3869
+ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
3870
+ ggml_float sum = 0.0;
3871
+ for (int i = 0; i < n; ++i) {
3872
+ sum += (ggml_float)x[i];
3873
+ }
3874
+ *s = sum;
3875
+ }
3876
+
3156
3877
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
3157
3878
  #ifndef GGML_USE_ACCELERATE
3158
3879
  float max = -INFINITY;
@@ -3203,24 +3924,34 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3203
3924
  [GGML_TYPE_F16] = 1,
3204
3925
  [GGML_TYPE_Q4_0] = QK4_0,
3205
3926
  [GGML_TYPE_Q4_1] = QK4_1,
3927
+ [GGML_TYPE_Q4_2] = QK4_2,
3928
+ [GGML_TYPE_Q4_3] = QK4_3,
3929
+ [GGML_TYPE_Q5_0] = QK5_0,
3930
+ [GGML_TYPE_Q5_1] = QK5_1,
3206
3931
  [GGML_TYPE_Q8_0] = QK8_0,
3932
+ [GGML_TYPE_Q8_1] = QK8_1,
3207
3933
  [GGML_TYPE_I8] = 1,
3208
3934
  [GGML_TYPE_I16] = 1,
3209
3935
  [GGML_TYPE_I32] = 1,
3210
3936
  };
3211
- static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
3937
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3212
3938
 
3213
3939
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3214
3940
  [GGML_TYPE_F32] = sizeof(float),
3215
3941
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
3216
3942
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3217
3943
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3944
+ [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3945
+ [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3946
+ [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
3947
+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3218
3948
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3949
+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3219
3950
  [GGML_TYPE_I8] = sizeof(int8_t),
3220
3951
  [GGML_TYPE_I16] = sizeof(int16_t),
3221
3952
  [GGML_TYPE_I32] = sizeof(int32_t),
3222
3953
  };
3223
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
3954
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3224
3955
 
3225
3956
 
3226
3957
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3228,12 +3959,34 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3228
3959
  [GGML_TYPE_F16] = "f16",
3229
3960
  [GGML_TYPE_Q4_0] = "q4_0",
3230
3961
  [GGML_TYPE_Q4_1] = "q4_1",
3962
+ [GGML_TYPE_Q4_2] = "q4_2",
3963
+ [GGML_TYPE_Q4_3] = "q4_3",
3964
+ [GGML_TYPE_Q5_0] = "q5_0",
3965
+ [GGML_TYPE_Q5_1] = "q5_1",
3231
3966
  [GGML_TYPE_Q8_0] = "q8_0",
3967
+ [GGML_TYPE_Q8_1] = "q8_1",
3232
3968
  [GGML_TYPE_I8] = "i8",
3233
3969
  [GGML_TYPE_I16] = "i16",
3234
3970
  [GGML_TYPE_I32] = "i32",
3235
3971
  };
3236
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
3972
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3973
+
3974
+ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3975
+ [GGML_TYPE_F32] = false,
3976
+ [GGML_TYPE_F16] = false,
3977
+ [GGML_TYPE_Q4_0] = true,
3978
+ [GGML_TYPE_Q4_1] = true,
3979
+ [GGML_TYPE_Q4_2] = true,
3980
+ [GGML_TYPE_Q4_3] = true,
3981
+ [GGML_TYPE_Q5_0] = true,
3982
+ [GGML_TYPE_Q5_1] = true,
3983
+ [GGML_TYPE_Q8_0] = true,
3984
+ [GGML_TYPE_Q8_1] = true,
3985
+ [GGML_TYPE_I8] = false,
3986
+ [GGML_TYPE_I16] = false,
3987
+ [GGML_TYPE_I32] = false,
3988
+ };
3989
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3237
3990
 
3238
3991
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3239
3992
  "NONE",
@@ -3495,6 +4248,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3495
4248
  (t0->ne[3] == t1->ne[3]);
3496
4249
  }
3497
4250
 
4251
+ bool ggml_is_quantized(enum ggml_type type) {
4252
+ return GGML_IS_QUANTIZED[type];
4253
+ }
4254
+
3498
4255
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3499
4256
  return tensor->nb[0] > tensor->nb[1];
3500
4257
  }
@@ -3605,6 +4362,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3605
4362
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3606
4363
  }
3607
4364
 
4365
+ // initialize cuBLAS
4366
+ #if defined(GGML_USE_CUBLAS)
4367
+ ggml_init_cublas();
4368
+ #elif defined(GGML_USE_CLBLAST)
4369
+ ggml_cl_init();
4370
+ #endif
4371
+
3608
4372
  is_first_call = false;
3609
4373
  }
3610
4374
 
@@ -5535,7 +6299,6 @@ static void ggml_compute_forward_dup_f16(
5535
6299
  const struct ggml_compute_params * params,
5536
6300
  const struct ggml_tensor * src0,
5537
6301
  struct ggml_tensor * dst) {
5538
- GGML_ASSERT(params->ith == 0);
5539
6302
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5540
6303
 
5541
6304
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5547,6 +6310,11 @@ static void ggml_compute_forward_dup_f16(
5547
6310
  const int64_t ne02 = src0->ne[2];
5548
6311
  const int64_t ne03 = src0->ne[3];
5549
6312
 
6313
+ const int64_t ne0 = dst->ne[0];
6314
+ const int64_t ne1 = dst->ne[1];
6315
+ const int64_t ne2 = dst->ne[2];
6316
+ const int64_t ne3 = dst->ne[3];
6317
+
5550
6318
  const size_t nb00 = src0->nb[0];
5551
6319
  const size_t nb01 = src0->nb[1];
5552
6320
  const size_t nb02 = src0->nb[2];
@@ -5557,19 +6325,40 @@ static void ggml_compute_forward_dup_f16(
5557
6325
  const size_t nb2 = dst->nb[2];
5558
6326
  const size_t nb3 = dst->nb[3];
5559
6327
 
6328
+ const int ith = params->ith; // thread index
6329
+ const int nth = params->nth; // number of threads
6330
+
5560
6331
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5561
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
6332
+ // parallelize by elements
6333
+ const int ne = ggml_nelements(dst);
6334
+ const int dr = (ne + nth - 1) / nth;
6335
+ const int ie0 = dr * ith;
6336
+ const int ie1 = MIN(ie0 + dr, ne);
6337
+
6338
+ memcpy(
6339
+ ((char *) dst->data + ie0*nb0),
6340
+ ((char *) src0->data + ie0*nb00),
6341
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
6342
+
5562
6343
  return;
5563
6344
  }
5564
6345
 
6346
+ // parallelize by rows
6347
+ const int nr = ne01;
6348
+ // number of rows per thread
6349
+ const int dr = (nr + nth - 1) / nth;
6350
+ // row range for this thread
6351
+ const int ir0 = dr * ith;
6352
+ const int ir1 = MIN(ir0 + dr, nr);
6353
+
5565
6354
  if (src0->type == dst->type &&
5566
- src0->ne[0] == dst->ne[0] &&
5567
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6355
+ ne00 == ne0 &&
6356
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5568
6357
  // copy by rows
5569
6358
  const size_t rs = ne00*nb00;
5570
6359
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5571
6360
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5572
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6361
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5573
6362
  memcpy(
5574
6363
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5575
6364
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5583,21 +6372,21 @@ static void ggml_compute_forward_dup_f16(
5583
6372
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
5584
6373
 
5585
6374
  if (ggml_is_contiguous(dst)) {
5586
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
6375
+ if (nb00 == sizeof(ggml_fp16_t)) {
5587
6376
  if (dst->type == GGML_TYPE_F16) {
5588
6377
  size_t id = 0;
5589
- const size_t rs = ne00*nb00;
6378
+ const size_t rs = ne00 * nb00;
6379
+ char * dst_ptr = (char *) dst->data;
5590
6380
 
5591
6381
  for (int i03 = 0; i03 < ne03; i03++) {
5592
6382
  for (int i02 = 0; i02 < ne02; i02++) {
5593
- for (int i01 = 0; i01 < ne01; i01++) {
6383
+ id += rs * ir0;
6384
+ for (int i01 = ir0; i01 < ir1; i01++) {
5594
6385
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5595
- char * dst_ptr = (char *) dst->data + id*rs;
5596
-
5597
- memcpy(dst_ptr, src0_ptr, rs);
5598
-
5599
- id++;
6386
+ memcpy(dst_ptr + id, src0_ptr, rs);
6387
+ id += rs;
5600
6388
  }
6389
+ id += rs * (ne01 - ir1);
5601
6390
  }
5602
6391
  }
5603
6392
  } else if (dst->type == GGML_TYPE_F32) {
@@ -5606,34 +6395,39 @@ static void ggml_compute_forward_dup_f16(
5606
6395
 
5607
6396
  for (int i03 = 0; i03 < ne03; i03++) {
5608
6397
  for (int i02 = 0; i02 < ne02; i02++) {
5609
- for (int i01 = 0; i01 < ne01; i01++) {
6398
+ id += ne00 * ir0;
6399
+ for (int i01 = ir0; i01 < ir1; i01++) {
6400
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5610
6401
  for (int i00 = 0; i00 < ne00; i00++) {
5611
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5612
-
5613
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
6402
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5614
6403
  id++;
5615
6404
  }
5616
6405
  }
6406
+ id += ne00 * (ne01 - ir1);
5617
6407
  }
5618
6408
  }
5619
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
6409
+ } else if (ggml_is_quantized(dst->type)) {
5620
6410
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6411
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6412
+
5621
6413
  size_t id = 0;
5622
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
- float * src0_f32 = (float *) params->wdata;
6414
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6415
+ char * dst_ptr = (char *) dst->data;
5625
6416
 
5626
6417
  for (int i03 = 0; i03 < ne03; i03++) {
5627
6418
  for (int i02 = 0; i02 < ne02; i02++) {
5628
- for (int i01 = 0; i01 < ne01; i01++) {
6419
+ id += rs * ir0;
6420
+ for (int i01 = ir0; i01 < ir1; i01++) {
5629
6421
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
- // convert to f32 and quantize
6422
+
5631
6423
  for (int i00 = 0; i00 < ne00; i00++) {
5632
6424
  src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
6425
  }
6426
+
5634
6427
  quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
- id += dst_row_size;
6428
+ id += rs;
5636
6429
  }
6430
+ id += rs * (ne01 - ir1);
5637
6431
  }
5638
6432
  }
5639
6433
  } else {
@@ -5648,7 +6442,8 @@ static void ggml_compute_forward_dup_f16(
5648
6442
 
5649
6443
  for (int i03 = 0; i03 < ne03; i03++) {
5650
6444
  for (int i02 = 0; i02 < ne02; i02++) {
5651
- for (int i01 = 0; i01 < ne01; i01++) {
6445
+ id += ne00 * ir0;
6446
+ for (int i01 = ir0; i01 < ir1; i01++) {
5652
6447
  for (int i00 = 0; i00 < ne00; i00++) {
5653
6448
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5654
6449
 
@@ -5656,6 +6451,7 @@ static void ggml_compute_forward_dup_f16(
5656
6451
  id++;
5657
6452
  }
5658
6453
  }
6454
+ id += ne00 * (ne01 - ir1);
5659
6455
  }
5660
6456
  }
5661
6457
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5664,7 +6460,8 @@ static void ggml_compute_forward_dup_f16(
5664
6460
 
5665
6461
  for (int i03 = 0; i03 < ne03; i03++) {
5666
6462
  for (int i02 = 0; i02 < ne02; i02++) {
5667
- for (int i01 = 0; i01 < ne01; i01++) {
6463
+ id += ne00 * ir0;
6464
+ for (int i01 = ir0; i01 < ir1; i01++) {
5668
6465
  for (int i00 = 0; i00 < ne00; i00++) {
5669
6466
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5670
6467
 
@@ -5672,6 +6469,7 @@ static void ggml_compute_forward_dup_f16(
5672
6469
  id++;
5673
6470
  }
5674
6471
  }
6472
+ id += ne00 * (ne01 - ir1);
5675
6473
  }
5676
6474
  }
5677
6475
  } else {
@@ -5690,7 +6488,20 @@ static void ggml_compute_forward_dup_f16(
5690
6488
  if (dst->type == GGML_TYPE_F16) {
5691
6489
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5692
6490
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5693
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6491
+ i10 += ne00 * ir0;
6492
+ while (i10 >= ne0) {
6493
+ i10 -= ne0;
6494
+ if (++i11 == ne1) {
6495
+ i11 = 0;
6496
+ if (++i12 == ne2) {
6497
+ i12 = 0;
6498
+ if (++i13 == ne3) {
6499
+ i13 = 0;
6500
+ }
6501
+ }
6502
+ }
6503
+ }
6504
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5694
6505
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5695
6506
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5696
6507
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@@ -5711,25 +6522,51 @@ static void ggml_compute_forward_dup_f16(
5711
6522
  }
5712
6523
  }
5713
6524
  }
6525
+ i10 += ne00 * (ne01 - ir1);
6526
+ while (i10 >= ne0) {
6527
+ i10 -= ne0;
6528
+ if (++i11 == ne1) {
6529
+ i11 = 0;
6530
+ if (++i12 == ne2) {
6531
+ i12 = 0;
6532
+ if (++i13 == ne3) {
6533
+ i13 = 0;
6534
+ }
6535
+ }
6536
+ }
6537
+ }
5714
6538
  }
5715
6539
  }
5716
6540
  } else if (dst->type == GGML_TYPE_F32) {
5717
6541
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5718
6542
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5719
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6543
+ i10 += ne00 * ir0;
6544
+ while (i10 >= ne0) {
6545
+ i10 -= ne0;
6546
+ if (++i11 == ne1) {
6547
+ i11 = 0;
6548
+ if (++i12 == ne2) {
6549
+ i12 = 0;
6550
+ if (++i13 == ne3) {
6551
+ i13 = 0;
6552
+ }
6553
+ }
6554
+ }
6555
+ }
6556
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5720
6557
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5721
6558
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5722
6559
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5723
6560
 
5724
6561
  *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5725
6562
 
5726
- if (++i10 == ne00) {
6563
+ if (++i10 == ne0) {
5727
6564
  i10 = 0;
5728
- if (++i11 == ne01) {
6565
+ if (++i11 == ne1) {
5729
6566
  i11 = 0;
5730
- if (++i12 == ne02) {
6567
+ if (++i12 == ne2) {
5731
6568
  i12 = 0;
5732
- if (++i13 == ne03) {
6569
+ if (++i13 == ne3) {
5733
6570
  i13 = 0;
5734
6571
  }
5735
6572
  }
@@ -5737,6 +6574,19 @@ static void ggml_compute_forward_dup_f16(
5737
6574
  }
5738
6575
  }
5739
6576
  }
6577
+ i10 += ne00 * (ne01 - ir1);
6578
+ while (i10 >= ne0) {
6579
+ i10 -= ne0;
6580
+ if (++i11 == ne1) {
6581
+ i11 = 0;
6582
+ if (++i12 == ne2) {
6583
+ i12 = 0;
6584
+ if (++i13 == ne3) {
6585
+ i13 = 0;
6586
+ }
6587
+ }
6588
+ }
6589
+ }
5740
6590
  }
5741
6591
  }
5742
6592
  } else {
@@ -5748,7 +6598,6 @@ static void ggml_compute_forward_dup_f32(
5748
6598
  const struct ggml_compute_params * params,
5749
6599
  const struct ggml_tensor * src0,
5750
6600
  struct ggml_tensor * dst) {
5751
- GGML_ASSERT(params->ith == 0);
5752
6601
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5753
6602
 
5754
6603
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5760,6 +6609,11 @@ static void ggml_compute_forward_dup_f32(
5760
6609
  const int64_t ne02 = src0->ne[2];
5761
6610
  const int64_t ne03 = src0->ne[3];
5762
6611
 
6612
+ const int64_t ne0 = dst->ne[0];
6613
+ const int64_t ne1 = dst->ne[1];
6614
+ const int64_t ne2 = dst->ne[2];
6615
+ const int64_t ne3 = dst->ne[3];
6616
+
5763
6617
  const size_t nb00 = src0->nb[0];
5764
6618
  const size_t nb01 = src0->nb[1];
5765
6619
  const size_t nb02 = src0->nb[2];
@@ -5770,19 +6624,40 @@ static void ggml_compute_forward_dup_f32(
5770
6624
  const size_t nb2 = dst->nb[2];
5771
6625
  const size_t nb3 = dst->nb[3];
5772
6626
 
6627
+ const int ith = params->ith; // thread index
6628
+ const int nth = params->nth; // number of threads
6629
+
5773
6630
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5774
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
6631
+ // parallelize by elements
6632
+ const int ne = ggml_nelements(dst);
6633
+ const int dr = (ne + nth - 1) / nth;
6634
+ const int ie0 = dr * ith;
6635
+ const int ie1 = MIN(ie0 + dr, ne);
6636
+
6637
+ memcpy(
6638
+ ((char *) dst->data + ie0*nb0),
6639
+ ((char *) src0->data + ie0*nb00),
6640
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
6641
+
5775
6642
  return;
5776
6643
  }
5777
6644
 
6645
+ // parallelize by rows
6646
+ const int nr = ne01;
6647
+ // number of rows per thread
6648
+ const int dr = (nr + nth - 1) / nth;
6649
+ // row range for this thread
6650
+ const int ir0 = dr * ith;
6651
+ const int ir1 = MIN(ir0 + dr, nr);
6652
+
5778
6653
  if (src0->type == dst->type &&
5779
- src0->ne[0] == dst->ne[0] &&
5780
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6654
+ ne00 == ne0 &&
6655
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5781
6656
  // copy by rows
5782
6657
  const size_t rs = ne00*nb00;
5783
6658
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5784
6659
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5785
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6660
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5786
6661
  memcpy(
5787
6662
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5788
6663
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5795,21 +6670,21 @@ static void ggml_compute_forward_dup_f32(
5795
6670
 
5796
6671
  if (ggml_is_contiguous(dst)) {
5797
6672
  // TODO: simplify
5798
- if (src0->nb[0] == sizeof(float)) {
6673
+ if (nb00 == sizeof(float)) {
5799
6674
  if (dst->type == GGML_TYPE_F32) {
5800
6675
  size_t id = 0;
5801
- const size_t rs = ne00*nb00;
6676
+ const size_t rs = ne00 * nb00;
6677
+ char * dst_ptr = (char *) dst->data;
5802
6678
 
5803
6679
  for (int i03 = 0; i03 < ne03; i03++) {
5804
6680
  for (int i02 = 0; i02 < ne02; i02++) {
5805
- for (int i01 = 0; i01 < ne01; i01++) {
6681
+ id += rs * ir0;
6682
+ for (int i01 = ir0; i01 < ir1; i01++) {
5806
6683
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5807
- char * dst_ptr = (char *) dst->data + id*rs;
5808
-
5809
- memcpy(dst_ptr, src0_ptr, rs);
5810
-
5811
- id++;
6684
+ memcpy(dst_ptr + id, src0_ptr, rs);
6685
+ id += rs;
5812
6686
  }
6687
+ id += rs * (ne01 - ir1);
5813
6688
  }
5814
6689
  }
5815
6690
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5818,7 +6693,8 @@ static void ggml_compute_forward_dup_f32(
5818
6693
 
5819
6694
  for (int i03 = 0; i03 < ne03; i03++) {
5820
6695
  for (int i02 = 0; i02 < ne02; i02++) {
5821
- for (int i01 = 0; i01 < ne01; i01++) {
6696
+ id += ne00 * ir0;
6697
+ for (int i01 = ir0; i01 < ir1; i01++) {
5822
6698
  for (int i00 = 0; i00 < ne00; i00++) {
5823
6699
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5824
6700
 
@@ -5826,21 +6702,25 @@ static void ggml_compute_forward_dup_f32(
5826
6702
  id++;
5827
6703
  }
5828
6704
  }
6705
+ id += ne00 * (ne01 - ir1);
5829
6706
  }
5830
6707
  }
5831
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
6708
+ } else if (ggml_is_quantized(dst->type)) {
5832
6709
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6710
+
5833
6711
  size_t id = 0;
5834
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6712
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6713
+ char * dst_ptr = (char *) dst->data;
5836
6714
 
5837
6715
  for (int i03 = 0; i03 < ne03; i03++) {
5838
6716
  for (int i02 = 0; i02 < ne02; i02++) {
5839
- for (int i01 = 0; i01 < ne01; i01++) {
6717
+ id += rs * ir0;
6718
+ for (int i01 = ir0; i01 < ir1; i01++) {
5840
6719
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
6720
  quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
- id += dst_row_size;
6721
+ id += rs;
5843
6722
  }
6723
+ id += rs * (ne01 - ir1);
5844
6724
  }
5845
6725
  }
5846
6726
  } else {
@@ -5855,7 +6735,8 @@ static void ggml_compute_forward_dup_f32(
5855
6735
 
5856
6736
  for (int i03 = 0; i03 < ne03; i03++) {
5857
6737
  for (int i02 = 0; i02 < ne02; i02++) {
5858
- for (int i01 = 0; i01 < ne01; i01++) {
6738
+ id += ne00 * ir0;
6739
+ for (int i01 = ir0; i01 < ir1; i01++) {
5859
6740
  for (int i00 = 0; i00 < ne00; i00++) {
5860
6741
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5861
6742
 
@@ -5863,6 +6744,7 @@ static void ggml_compute_forward_dup_f32(
5863
6744
  id++;
5864
6745
  }
5865
6746
  }
6747
+ id += ne00 * (ne01 - ir1);
5866
6748
  }
5867
6749
  }
5868
6750
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5871,7 +6753,8 @@ static void ggml_compute_forward_dup_f32(
5871
6753
 
5872
6754
  for (int i03 = 0; i03 < ne03; i03++) {
5873
6755
  for (int i02 = 0; i02 < ne02; i02++) {
5874
- for (int i01 = 0; i01 < ne01; i01++) {
6756
+ id += ne00 * ir0;
6757
+ for (int i01 = ir0; i01 < ir1; i01++) {
5875
6758
  for (int i00 = 0; i00 < ne00; i00++) {
5876
6759
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5877
6760
 
@@ -5879,6 +6762,7 @@ static void ggml_compute_forward_dup_f32(
5879
6762
  id++;
5880
6763
  }
5881
6764
  }
6765
+ id += ne00 * (ne01 - ir1);
5882
6766
  }
5883
6767
  }
5884
6768
  } else {
@@ -5890,6 +6774,7 @@ static void ggml_compute_forward_dup_f32(
5890
6774
  }
5891
6775
 
5892
6776
  // dst counters
6777
+
5893
6778
  int64_t i10 = 0;
5894
6779
  int64_t i11 = 0;
5895
6780
  int64_t i12 = 0;
@@ -5898,20 +6783,33 @@ static void ggml_compute_forward_dup_f32(
5898
6783
  if (dst->type == GGML_TYPE_F32) {
5899
6784
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5900
6785
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5901
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6786
+ i10 += ne00 * ir0;
6787
+ while (i10 >= ne0) {
6788
+ i10 -= ne0;
6789
+ if (++i11 == ne1) {
6790
+ i11 = 0;
6791
+ if (++i12 == ne2) {
6792
+ i12 = 0;
6793
+ if (++i13 == ne3) {
6794
+ i13 = 0;
6795
+ }
6796
+ }
6797
+ }
6798
+ }
6799
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5902
6800
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5903
6801
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5904
6802
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5905
6803
 
5906
6804
  memcpy(dst_ptr, src0_ptr, sizeof(float));
5907
6805
 
5908
- if (++i10 == dst->ne[0]) {
6806
+ if (++i10 == ne0) {
5909
6807
  i10 = 0;
5910
- if (++i11 == dst->ne[1]) {
6808
+ if (++i11 == ne1) {
5911
6809
  i11 = 0;
5912
- if (++i12 == dst->ne[2]) {
6810
+ if (++i12 == ne2) {
5913
6811
  i12 = 0;
5914
- if (++i13 == dst->ne[3]) {
6812
+ if (++i13 == ne3) {
5915
6813
  i13 = 0;
5916
6814
  }
5917
6815
  }
@@ -5919,25 +6817,51 @@ static void ggml_compute_forward_dup_f32(
5919
6817
  }
5920
6818
  }
5921
6819
  }
6820
+ i10 += ne00 * (ne01 - ir1);
6821
+ while (i10 >= ne0) {
6822
+ i10 -= ne0;
6823
+ if (++i11 == ne1) {
6824
+ i11 = 0;
6825
+ if (++i12 == ne2) {
6826
+ i12 = 0;
6827
+ if (++i13 == ne3) {
6828
+ i13 = 0;
6829
+ }
6830
+ }
6831
+ }
6832
+ }
5922
6833
  }
5923
6834
  }
5924
6835
  } else if (dst->type == GGML_TYPE_F16) {
5925
6836
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5926
6837
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5927
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6838
+ i10 += ne00 * ir0;
6839
+ while (i10 >= ne0) {
6840
+ i10 -= ne0;
6841
+ if (++i11 == ne1) {
6842
+ i11 = 0;
6843
+ if (++i12 == ne2) {
6844
+ i12 = 0;
6845
+ if (++i13 == ne3) {
6846
+ i13 = 0;
6847
+ }
6848
+ }
6849
+ }
6850
+ }
6851
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5928
6852
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5929
6853
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5930
6854
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5931
6855
 
5932
6856
  *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5933
6857
 
5934
- if (++i10 == dst->ne[0]) {
6858
+ if (++i10 == ne0) {
5935
6859
  i10 = 0;
5936
- if (++i11 == dst->ne[1]) {
6860
+ if (++i11 == ne1) {
5937
6861
  i11 = 0;
5938
- if (++i12 == dst->ne[2]) {
6862
+ if (++i12 == ne2) {
5939
6863
  i12 = 0;
5940
- if (++i13 == dst->ne[3]) {
6864
+ if (++i13 == ne3) {
5941
6865
  i13 = 0;
5942
6866
  }
5943
6867
  }
@@ -5945,6 +6869,19 @@ static void ggml_compute_forward_dup_f32(
5945
6869
  }
5946
6870
  }
5947
6871
  }
6872
+ i10 += ne00 * (ne01 - ir1);
6873
+ while (i10 >= ne0) {
6874
+ i10 -= ne0;
6875
+ if (++i11 == ne1) {
6876
+ i11 = 0;
6877
+ if (++i12 == ne2) {
6878
+ i12 = 0;
6879
+ if (++i13 == ne3) {
6880
+ i13 = 0;
6881
+ }
6882
+ }
6883
+ }
6884
+ }
5948
6885
  }
5949
6886
  }
5950
6887
  } else {
@@ -6191,7 +7128,7 @@ static void ggml_compute_forward_add_q_f32(
6191
7128
  GGML_ASSERT(nb1 <= nb2);
6192
7129
  GGML_ASSERT(nb2 <= nb3);
6193
7130
 
6194
- GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
7131
+ GGML_ASSERT(ggml_is_quantized(src0->type));
6195
7132
  GGML_ASSERT(dst->type == src0->type);
6196
7133
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
7134
 
@@ -6205,7 +7142,7 @@ static void ggml_compute_forward_add_q_f32(
6205
7142
  const int ir0 = dr*ith;
6206
7143
  const int ir1 = MIN(ir0 + dr, nr);
6207
7144
 
6208
- float * wdata = (float*) params->wdata + ne00 * ith;
7145
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6209
7146
 
6210
7147
  for (int ir = ir0; ir < ir1; ++ir) {
6211
7148
  // src0 indices
@@ -6261,6 +7198,11 @@ static void ggml_compute_forward_add(
6261
7198
  } break;
6262
7199
  case GGML_TYPE_Q4_0:
6263
7200
  case GGML_TYPE_Q4_1:
7201
+ case GGML_TYPE_Q4_2:
7202
+ case GGML_TYPE_Q4_3:
7203
+ case GGML_TYPE_Q5_0:
7204
+ case GGML_TYPE_Q5_1:
7205
+ case GGML_TYPE_Q8_0:
6264
7206
  {
6265
7207
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
7208
  } break;
@@ -6518,15 +7460,20 @@ static void ggml_compute_forward_sum_f32(
6518
7460
  const size_t nb02 = src0->nb[2];
6519
7461
  const size_t nb03 = src0->nb[3];
6520
7462
 
7463
+ ggml_float sum = 0;
7464
+ ggml_float row_sum = 0;
7465
+
6521
7466
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6522
7467
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6523
7468
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6524
- ggml_vec_sum_f32(ne00,
6525
- (float *) (dst->data),
7469
+ ggml_vec_sum_ggf(ne00,
7470
+ &row_sum,
6526
7471
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
7472
+ sum += row_sum;
6527
7473
  }
6528
7474
  }
6529
7475
  }
7476
+ ((float *) dst->data)[0] = sum;
6530
7477
  }
6531
7478
 
6532
7479
  static void ggml_compute_forward_sum(
@@ -7161,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
7161
8108
 
7162
8109
  // ggml_compute_forward_mul_mat
7163
8110
 
7164
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8111
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7165
8112
  // helper function to determine if it is better to use BLAS or not
7166
8113
  // for large matrices, BLAS is faster
7167
8114
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7186,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
7186
8133
 
7187
8134
  return false;
7188
8135
  }
8136
+
7189
8137
  #endif
7190
8138
 
7191
8139
  static void ggml_compute_forward_mul_mat_f32(
@@ -7201,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
7201
8149
  const int64_t ne02 = src0->ne[2];
7202
8150
  const int64_t ne03 = src0->ne[3];
7203
8151
 
7204
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8152
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7205
8153
  const int64_t ne10 = src1->ne[0];
7206
8154
  #endif
7207
8155
  const int64_t ne11 = src1->ne[1];
@@ -7258,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
7258
8206
  // nb01 >= nb00 - src0 is not transposed
7259
8207
  // compute by src0 rows
7260
8208
 
7261
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8209
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7262
8210
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7263
8211
  if (params->ith != 0) {
7264
8212
  return;
@@ -7272,6 +8220,19 @@ static void ggml_compute_forward_mul_mat_f32(
7272
8220
  return;
7273
8221
  }
7274
8222
 
8223
+ #if defined(GGML_USE_CUBLAS)
8224
+ const float alpha = 1.0f;
8225
+ const float beta = 0.0f;
8226
+ const int x_ne = ne01 * ne10;
8227
+ const int y_ne = ne11 * ne10;
8228
+ const int d_ne = ne11 * ne01;
8229
+
8230
+ size_t x_size, y_size, d_size;
8231
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8232
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8233
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8234
+ #endif
8235
+
7275
8236
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7276
8237
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7277
8238
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -7279,15 +8240,44 @@ static void ggml_compute_forward_mul_mat_f32(
7279
8240
 
7280
8241
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7281
8242
 
8243
+ #if defined(GGML_USE_CUBLAS)
8244
+ // copy data to device
8245
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8246
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8247
+
8248
+ // compute
8249
+ CUBLAS_CHECK(
8250
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8251
+ ne01, ne11, ne10,
8252
+ &alpha, d_X, ne00,
8253
+ d_Y, ne10,
8254
+ &beta, d_D, ne01));
8255
+
8256
+ // copy data to host
8257
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8258
+ #elif defined(GGML_USE_CLBLAST)
7282
8259
  // zT = y * xT
8260
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8261
+ ne11, ne01, ne10,
8262
+ 1.0f, y, ne10,
8263
+ x, ne10,
8264
+ 0.0f, d, ne01,
8265
+ GGML_TYPE_F32);
8266
+ #else
7283
8267
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7284
8268
  ne11, ne01, ne10,
7285
8269
  1.0f, y, ne10,
7286
8270
  x, ne00,
7287
8271
  0.0f, d, ne01);
8272
+ #endif
7288
8273
  }
7289
8274
  }
7290
-
8275
+ #if defined(GGML_USE_CUBLAS)
8276
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8277
+ ggml_cuda_pool_free(d_X, x_size);
8278
+ ggml_cuda_pool_free(d_Y, y_size);
8279
+ ggml_cuda_pool_free(d_D, d_size);
8280
+ #endif
7291
8281
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7292
8282
 
7293
8283
  return;
@@ -7417,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7417
8407
  // nb01 >= nb00 - src0 is not transposed
7418
8408
  // compute by src0 rows
7419
8409
 
7420
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8410
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7421
8411
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7422
8412
  GGML_ASSERT(nb10 == sizeof(float));
7423
8413
 
@@ -7433,10 +8423,35 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7433
8423
  return;
7434
8424
  }
7435
8425
 
7436
- float * const wdata = params->wdata;
8426
+ #if defined(GGML_USE_CUBLAS)
8427
+ ggml_fp16_t * const wdata = params->wdata;
8428
+
8429
+ const float alpha = 1.0f;
8430
+ const float beta = 0.0f;
8431
+ const int x_ne = ne01 * ne10;
8432
+ const int y_ne = ne11 * ne10;
8433
+ const int d_ne = ne11 * ne01;
7437
8434
 
8435
+ size_t x_size, y_size, d_size;
8436
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8437
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8438
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8439
+ #else
8440
+ float * const wdata = params->wdata;
8441
+ #endif
7438
8442
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7439
8443
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8444
+ #if defined(GGML_USE_CUBLAS)
8445
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
8446
+ {
8447
+ size_t id = 0;
8448
+ for (int64_t i01 = 0; i01 < ne11; ++i01) {
8449
+ for (int64_t i00 = 0; i00 < ne10; ++i00) {
8450
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
8451
+ }
8452
+ }
8453
+ }
8454
+ #else
7440
8455
  {
7441
8456
  size_t id = 0;
7442
8457
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7445,7 +8460,44 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7445
8460
  }
7446
8461
  }
7447
8462
  }
8463
+ #endif
8464
+
8465
+ #if defined(GGML_USE_CUBLAS)
8466
+ const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
8467
+ const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
8468
+
8469
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8470
+
8471
+ // copy data to device
8472
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8473
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8474
+
8475
+ // compute
8476
+ CUBLAS_CHECK(
8477
+ cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8478
+ ne01, ne11, ne10,
8479
+ &alpha, d_X, CUDA_R_16F, ne00,
8480
+ d_Y, CUDA_R_16F, ne10,
8481
+ &beta, d_D, CUDA_R_32F, ne01,
8482
+ CUBLAS_COMPUTE_32F,
8483
+ CUBLAS_GEMM_DEFAULT));
8484
+
8485
+ // copy data to host
8486
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8487
+ #elif defined(GGML_USE_CLBLAST)
8488
+ const float * x = wdata;
8489
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7448
8490
 
8491
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8492
+
8493
+ // zT = y * xT
8494
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8495
+ ne11, ne01, ne10,
8496
+ 1.0f, y, ne10,
8497
+ x, ne10,
8498
+ 0.0f, d, ne01,
8499
+ GGML_TYPE_F32);
8500
+ #else
7449
8501
  const float * x = wdata;
7450
8502
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7451
8503
 
@@ -7457,9 +8509,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7457
8509
  1.0f, y, ne10,
7458
8510
  x, ne00,
7459
8511
  0.0f, d, ne01);
8512
+ #endif
7460
8513
  }
7461
8514
  }
7462
8515
 
8516
+ #if defined(GGML_USE_CUBLAS)
8517
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8518
+ ggml_cuda_pool_free(d_X, x_size);
8519
+ ggml_cuda_pool_free(d_Y, y_size);
8520
+ ggml_cuda_pool_free(d_D, d_size);
8521
+ #endif
7463
8522
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7464
8523
 
7465
8524
  return;
@@ -7592,6 +8651,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7592
8651
  const enum ggml_type type = src0->type;
7593
8652
  quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7594
8653
  vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
8654
+ enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
7595
8655
 
7596
8656
  // we don't support permuted src0 or src1
7597
8657
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7611,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7611
8671
  // nb01 >= nb00 - src0 is not transposed
7612
8672
  // compute by src0 rows
7613
8673
 
7614
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8674
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7615
8675
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7616
8676
  if (params->ith != 0) {
7617
8677
  return;
@@ -7625,11 +8685,66 @@ static void ggml_compute_forward_mul_mat_q_f32(
7625
8685
  return;
7626
8686
  }
7627
8687
 
8688
+ #if defined(GGML_USE_CUBLAS)
8689
+ const float alpha = 1.0f;
8690
+ const float beta = 0.0f;
8691
+ const int x_ne = ne01 * ne10;
8692
+ const int y_ne = ne11 * ne10;
8693
+ const int d_ne = ne11 * ne01;
8694
+
8695
+ size_t x_size, y_size, d_size, q_size;
8696
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8697
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8698
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8699
+ float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8700
+
8701
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8702
+ if (type == GGML_TYPE_Q4_0) {
8703
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8704
+ }
8705
+ else if (type == GGML_TYPE_Q4_1) {
8706
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8707
+ }
8708
+ else if (type == GGML_TYPE_Q4_2) {
8709
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8710
+ }
8711
+ else if (type == GGML_TYPE_Q4_3) {
8712
+ dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8713
+ }
8714
+ else if (type == GGML_TYPE_Q5_0) {
8715
+ dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8716
+ }
8717
+ else if (type == GGML_TYPE_Q5_1) {
8718
+ dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8719
+ }
8720
+ else if (type == GGML_TYPE_Q8_0) {
8721
+ dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8722
+ }
8723
+ else {
8724
+ GGML_ASSERT(false);
8725
+ }
8726
+ #elif !defined(GGML_USE_CLBLAST)
7628
8727
  float * const wdata = params->wdata;
7629
8728
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8729
+ #endif
7630
8730
 
7631
8731
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7632
8732
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8733
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8734
+
8735
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8736
+
8737
+ #if defined(GGML_USE_CUBLAS)
8738
+ // copy and dequantize on device
8739
+ CUDA_CHECK(
8740
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8741
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8742
+
8743
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8744
+ CUDA_CHECK(cudaGetLastError());
8745
+ #elif defined(GGML_USE_CLBLAST)
8746
+ const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
8747
+ #else
7633
8748
  {
7634
8749
  size_t id = 0;
7635
8750
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7637,21 +8752,49 @@ static void ggml_compute_forward_mul_mat_q_f32(
7637
8752
  id += ne00;
7638
8753
  }
7639
8754
  }
7640
-
7641
8755
  const float * x = wdata;
7642
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8756
+ #endif
7643
8757
 
7644
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7645
8758
 
8759
+ #if defined(GGML_USE_CUBLAS)
8760
+ // copy data to device
8761
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8762
+
8763
+ // compute
8764
+ CUBLAS_CHECK(
8765
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8766
+ ne01, ne11, ne10,
8767
+ &alpha, d_X, ne00,
8768
+ d_Y, ne10,
8769
+ &beta, d_D, ne01));
8770
+
8771
+ // copy data to host
8772
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8773
+ #elif defined(GGML_USE_CLBLAST)
7646
8774
  // zT = y * xT
8775
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8776
+ ne11, ne01, ne10,
8777
+ 1.0f, y, ne10,
8778
+ x, ne10,
8779
+ 0.0f, d, ne01,
8780
+ type);
8781
+ #else
7647
8782
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7648
8783
  ne11, ne01, ne10,
7649
8784
  1.0f, y, ne10,
7650
8785
  x, ne00,
7651
8786
  0.0f, d, ne01);
8787
+ #endif
7652
8788
  }
7653
8789
  }
7654
8790
 
8791
+ #if defined(GGML_USE_CUBLAS)
8792
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8793
+ ggml_cuda_pool_free(d_X, x_size);
8794
+ ggml_cuda_pool_free(d_Y, y_size);
8795
+ ggml_cuda_pool_free(d_D, d_size);
8796
+ ggml_cuda_pool_free(d_Q, q_size);
8797
+ #endif
7655
8798
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7656
8799
 
7657
8800
  return;
@@ -7660,7 +8803,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7660
8803
 
7661
8804
  if (params->type == GGML_TASK_INIT) {
7662
8805
  char * wdata = params->wdata;
7663
- const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8806
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
7664
8807
 
7665
8808
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7666
8809
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -7691,7 +8834,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7691
8834
  const int ir1 = MIN(ir0 + dr, nr);
7692
8835
 
7693
8836
  void * wdata = params->wdata;
7694
- const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8837
+ const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
7695
8838
 
7696
8839
  for (int ir = ir0; ir < ir1; ++ir) {
7697
8840
  // src0 indices
@@ -7739,7 +8882,12 @@ static void ggml_compute_forward_mul_mat(
7739
8882
  switch (src0->type) {
7740
8883
  case GGML_TYPE_Q4_0:
7741
8884
  case GGML_TYPE_Q4_1:
8885
+ case GGML_TYPE_Q4_2:
8886
+ case GGML_TYPE_Q4_3:
8887
+ case GGML_TYPE_Q5_0:
8888
+ case GGML_TYPE_Q5_1:
7742
8889
  case GGML_TYPE_Q8_0:
8890
+ case GGML_TYPE_Q8_1:
7743
8891
  {
7744
8892
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
7745
8893
  } break;
@@ -7756,34 +8904,6 @@ static void ggml_compute_forward_mul_mat(
7756
8904
  GGML_ASSERT(false);
7757
8905
  } break;
7758
8906
  }
7759
-
7760
- #if 0
7761
- if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
7762
- static int first = 8;
7763
- printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7764
- printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7765
- printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7766
- if (first) {
7767
- --first;
7768
- } else {
7769
- for (int k = 0; k < dst->ne[1]; ++k) {
7770
- for (int j = 0; j < dst->ne[0]/16; ++j) {
7771
- for (int i = 0; i < 16; ++i) {
7772
- printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
7773
- }
7774
- printf("\n");
7775
- }
7776
- printf("\n");
7777
- }
7778
- printf("\n");
7779
- exit(0);
7780
- }
7781
- } else {
7782
- printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7783
- printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7784
- printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7785
- }
7786
- #endif
7787
8907
  }
7788
8908
 
7789
8909
  // ggml_compute_forward_scale
@@ -7994,7 +9114,12 @@ static void ggml_compute_forward_get_rows(
7994
9114
  switch (src0->type) {
7995
9115
  case GGML_TYPE_Q4_0:
7996
9116
  case GGML_TYPE_Q4_1:
9117
+ case GGML_TYPE_Q4_2:
9118
+ case GGML_TYPE_Q4_3:
9119
+ case GGML_TYPE_Q5_0:
9120
+ case GGML_TYPE_Q5_1:
7997
9121
  case GGML_TYPE_Q8_0:
9122
+ case GGML_TYPE_Q8_1:
7998
9123
  {
7999
9124
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
8000
9125
  } break;
@@ -8132,6 +9257,7 @@ static void ggml_compute_forward_soft_max_f32(
8132
9257
 
8133
9258
  uint16_t scvt;
8134
9259
  for (int i = 0; i < nc; i++) {
9260
+ //printf("p[%3d] = %8.4f\n", i, p[i]);
8135
9261
  if (p[i] == -INFINITY) {
8136
9262
  p[i] = 0.0f;
8137
9263
  } else {
@@ -8224,9 +9350,11 @@ static void ggml_compute_forward_rope_f32(
8224
9350
 
8225
9351
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8226
9352
 
9353
+ const bool is_neox = mode & 2;
9354
+
8227
9355
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8228
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8229
- const int p = (mode == 0 ? n_past + i2 : i2);
9356
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
9357
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8230
9358
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8231
9359
  if (ir++ < ir0) continue;
8232
9360
  if (ir > ir1) break;
@@ -8239,14 +9367,25 @@ static void ggml_compute_forward_rope_f32(
8239
9367
 
8240
9368
  theta *= theta_scale;
8241
9369
 
8242
- const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8243
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9370
+ if (!is_neox) {
9371
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9372
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9373
+
9374
+ const float x0 = src[0];
9375
+ const float x1 = src[1];
9376
+
9377
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9378
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
9379
+ } else {
9380
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
9381
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8244
9382
 
8245
- const float x0 = src[0];
8246
- const float x1 = src[1];
9383
+ const float x0 = src[0];
9384
+ const float x1 = src[n_dims/2];
8247
9385
 
8248
- dst_data[0] = x0*cos_theta - x1*sin_theta;
8249
- dst_data[1] = x0*sin_theta + x1*cos_theta;
9386
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9387
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
9388
+ }
8250
9389
  }
8251
9390
  }
8252
9391
  }
@@ -8301,9 +9440,11 @@ static void ggml_compute_forward_rope_f16(
8301
9440
 
8302
9441
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8303
9442
 
9443
+ const bool is_neox = mode & 2;
9444
+
8304
9445
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8305
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8306
- const int p = (mode == 0 ? n_past + i2 : i2);
9446
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
9447
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8307
9448
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8308
9449
  if (ir++ < ir0) continue;
8309
9450
  if (ir > ir1) break;
@@ -8316,14 +9457,25 @@ static void ggml_compute_forward_rope_f16(
8316
9457
 
8317
9458
  theta *= theta_scale;
8318
9459
 
8319
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8320
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9460
+ if (!is_neox) {
9461
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9462
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9463
+
9464
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
9465
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
9466
+
9467
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9468
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9469
+ } else {
9470
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
9471
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8321
9472
 
8322
- const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
- const float x1 = GGML_FP16_TO_FP32(src[1]);
9473
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
9474
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8324
9475
 
8325
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
- dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9476
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9477
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9478
+ }
8327
9479
  }
8328
9480
  }
8329
9481
  }
@@ -10402,11 +11554,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10402
11554
  case GGML_OP_CPY:
10403
11555
  case GGML_OP_DUP:
10404
11556
  {
10405
- node->n_tasks = 1;
11557
+ node->n_tasks = n_threads;
10406
11558
 
10407
11559
  size_t cur = 0;
10408
- if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
11560
+ if (ggml_is_quantized(node->type)) {
11561
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
10410
11562
  }
10411
11563
 
10412
11564
  work_size = MAX(work_size, cur);
@@ -10417,7 +11569,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10417
11569
 
10418
11570
  size_t cur = 0;
10419
11571
 
10420
- if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
11572
+ if (ggml_is_quantized(node->src0->type)) {
10421
11573
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
11574
  }
10423
11575
 
@@ -10466,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10466
11618
  size_t cur = 0;
10467
11619
 
10468
11620
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10469
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
11621
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10470
11622
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10471
11623
  node->n_tasks = 1; // TODO: this actually is doing nothing
10472
11624
  // the threads are still spinning
@@ -10482,15 +11634,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10482
11634
  #endif
10483
11635
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10484
11636
  cur = 0;
10485
- } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
10486
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
11637
+ } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
11638
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10487
11639
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10488
11640
  node->n_tasks = 1;
10489
11641
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10490
11642
  } else
10491
11643
  #endif
10492
11644
  {
10493
- cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
11645
+ const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
11646
+ cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
10494
11647
  }
10495
11648
  } else {
10496
11649
  GGML_ASSERT(false);
@@ -10818,9 +11971,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
10818
11971
  for (int i = 0; i < cgraph->n_nodes; i++) {
10819
11972
  struct ggml_tensor * node = cgraph->nodes[i];
10820
11973
 
10821
- perf_total_per_op_us[node->op] += node->perf_time_us;
11974
+ perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
10822
11975
 
10823
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11976
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
10824
11977
  i,
10825
11978
  node->ne[0], node->ne[1], node->ne[2],
10826
11979
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -10834,13 +11987,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
10834
11987
  for (int i = 0; i < cgraph->n_leafs; i++) {
10835
11988
  struct ggml_tensor * node = cgraph->leafs[i];
10836
11989
 
10837
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
11990
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
10838
11991
  i,
10839
11992
  node->ne[0], node->ne[1],
10840
11993
  GGML_OP_LABEL[node->op]);
10841
11994
  }
10842
11995
 
10843
11996
  for (int i = 0; i < GGML_OP_COUNT; i++) {
11997
+ if (perf_total_per_op_us[i] == 0) {
11998
+ continue;
11999
+ }
12000
+
10844
12001
  GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
10845
12002
  }
10846
12003
 
@@ -11674,7 +12831,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
11674
12831
 
11675
12832
  for (int i = 0; i < nb; i++) {
11676
12833
  for (int l = 0; l < QK4_0; l += 2) {
11677
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12834
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
11678
12835
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11679
12836
 
11680
12837
  hist[vi0]++;
@@ -11697,7 +12854,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11697
12854
 
11698
12855
  for (int i = 0; i < nb; i++) {
11699
12856
  for (int l = 0; l < QK4_1; l += 2) {
11700
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12857
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
11701
12858
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11702
12859
 
11703
12860
  hist[vi0]++;
@@ -11709,6 +12866,184 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11709
12866
  return (n/QK4_1*sizeof(block_q4_1));
11710
12867
  }
11711
12868
 
12869
+ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
12870
+ assert(k % QK4_2 == 0);
12871
+ const int nb = k / QK4_2;
12872
+
12873
+ for (int j = 0; j < n; j += k) {
12874
+ block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12875
+
12876
+ quantize_row_q4_2_reference(src + j, y, k);
12877
+
12878
+ for (int i = 0; i < nb; i++) {
12879
+ for (int l = 0; l < QK4_2; l += 2) {
12880
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12881
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12882
+
12883
+ hist[vi0]++;
12884
+ hist[vi1]++;
12885
+ }
12886
+ }
12887
+ }
12888
+
12889
+ return (n/QK4_2*sizeof(block_q4_2));
12890
+ }
12891
+
12892
+ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12893
+ assert(k % QK4_3 == 0);
12894
+ const int nb = k / QK4_3;
12895
+
12896
+ for (int j = 0; j < n; j += k) {
12897
+ block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
12898
+
12899
+ quantize_row_q4_3_reference(src + j, y, k);
12900
+
12901
+ for (int i = 0; i < nb; i++) {
12902
+ for (int l = 0; l < QK4_3; l += 2) {
12903
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12904
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12905
+
12906
+ hist[vi0]++;
12907
+ hist[vi1]++;
12908
+ }
12909
+ }
12910
+ }
12911
+
12912
+ return (n/QK4_3*sizeof(block_q4_3));
12913
+ }
12914
+
12915
+ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12916
+ assert(k % QK5_0 == 0);
12917
+ const int nb = k / QK5_0;
12918
+
12919
+ for (int j = 0; j < n; j += k) {
12920
+ block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
12921
+
12922
+ quantize_row_q5_0_reference(src + j, y, k);
12923
+
12924
+ for (int i = 0; i < nb; i++) {
12925
+ uint32_t qh;
12926
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12927
+
12928
+ for (int l = 0; l < QK5_0; l += 2) {
12929
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12930
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12931
+
12932
+ // cast to 16 bins
12933
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12934
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12935
+
12936
+ hist[vi0]++;
12937
+ hist[vi1]++;
12938
+ }
12939
+ }
12940
+ }
12941
+
12942
+ return (n/QK5_0*sizeof(block_q5_0));
12943
+ }
12944
+
12945
+ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
12946
+ assert(k % QK5_1 == 0);
12947
+ const int nb = k / QK5_1;
12948
+
12949
+ for (int j = 0; j < n; j += k) {
12950
+ block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
12951
+
12952
+ quantize_row_q5_1_reference(src + j, y, k);
12953
+
12954
+ for (int i = 0; i < nb; i++) {
12955
+ uint32_t qh;
12956
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12957
+
12958
+ for (int l = 0; l < QK5_1; l += 2) {
12959
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12960
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12961
+
12962
+ // cast to 16 bins
12963
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12964
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12965
+
12966
+ hist[vi0]++;
12967
+ hist[vi1]++;
12968
+ }
12969
+ }
12970
+ }
12971
+
12972
+ return (n/QK5_1*sizeof(block_q5_1));
12973
+ }
12974
+
12975
+ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12976
+ assert(k % QK8_0 == 0);
12977
+ const int nb = k / QK8_0;
12978
+
12979
+ for (int j = 0; j < n; j += k) {
12980
+ block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
12981
+
12982
+ quantize_row_q8_0_reference(src + j, y, k);
12983
+
12984
+ for (int i = 0; i < nb; i++) {
12985
+ for (int l = 0; l < QK8_0; ++l) {
12986
+ const int8_t vi = y[i].qs[l];
12987
+
12988
+ hist[vi/16 + 8]++;
12989
+ }
12990
+ }
12991
+ }
12992
+
12993
+ return (n/QK8_0*sizeof(block_q8_0));
12994
+ }
12995
+
12996
+ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12997
+ size_t result = 0;
12998
+ switch (type) {
12999
+ case GGML_TYPE_Q4_0:
13000
+ {
13001
+ GGML_ASSERT(start % QK4_0 == 0);
13002
+ block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
13003
+ result = ggml_quantize_q4_0(src + start, block, n, n, hist);
13004
+ } break;
13005
+ case GGML_TYPE_Q4_1:
13006
+ {
13007
+ GGML_ASSERT(start % QK4_1 == 0);
13008
+ block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
13009
+ result = ggml_quantize_q4_1(src + start, block, n, n, hist);
13010
+ } break;
13011
+ case GGML_TYPE_Q4_2:
13012
+ {
13013
+ GGML_ASSERT(start % QK4_2 == 0);
13014
+ block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
13015
+ result = ggml_quantize_q4_2(src + start, block, n, n, hist);
13016
+ } break;
13017
+ case GGML_TYPE_Q4_3:
13018
+ {
13019
+ GGML_ASSERT(start % QK4_3 == 0);
13020
+ block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
13021
+ result = ggml_quantize_q4_3(src + start, block, n, n, hist);
13022
+ } break;
13023
+ case GGML_TYPE_Q5_0:
13024
+ {
13025
+ GGML_ASSERT(start % QK5_0 == 0);
13026
+ block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
13027
+ result = ggml_quantize_q5_0(src + start, block, n, n, hist);
13028
+ } break;
13029
+ case GGML_TYPE_Q5_1:
13030
+ {
13031
+ GGML_ASSERT(start % QK5_1 == 0);
13032
+ block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
13033
+ result = ggml_quantize_q5_1(src + start, block, n, n, hist);
13034
+ } break;
13035
+ case GGML_TYPE_Q8_0:
13036
+ {
13037
+ GGML_ASSERT(start % QK8_0 == 0);
13038
+ block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
13039
+ result = ggml_quantize_q8_0(src + start, block, n, n, hist);
13040
+ } break;
13041
+ default:
13042
+ assert(false);
13043
+ }
13044
+ return result;
13045
+ }
13046
+
11712
13047
  ////////////////////////////////////////////////////////////////////////////////
11713
13048
 
11714
13049
  int ggml_cpu_has_avx(void) {
@@ -11800,13 +13135,33 @@ int ggml_cpu_has_wasm_simd(void) {
11800
13135
  }
11801
13136
 
11802
13137
  int ggml_cpu_has_blas(void) {
11803
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
13138
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
13139
+ return 1;
13140
+ #else
13141
+ return 0;
13142
+ #endif
13143
+ }
13144
+
13145
+ int ggml_cpu_has_cublas(void) {
13146
+ #if defined(GGML_USE_CUBLAS)
11804
13147
  return 1;
11805
13148
  #else
11806
13149
  return 0;
11807
13150
  #endif
11808
13151
  }
11809
13152
 
13153
+ int ggml_cpu_has_clblast(void) {
13154
+ #if defined(GGML_USE_CLBLAST)
13155
+ return 1;
13156
+ #else
13157
+ return 0;
13158
+ #endif
13159
+ }
13160
+
13161
+ int ggml_cpu_has_gpublas(void) {
13162
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
13163
+ }
13164
+
11810
13165
  int ggml_cpu_has_sse3(void) {
11811
13166
  #if defined(__SSE3__)
11812
13167
  return 1;