llama_cpp 0.0.6 → 0.0.7

Sign up to get free protection for your applications and to get access to all the features.
@@ -148,44 +148,9 @@ inline static void* ggml_aligned_malloc(size_t size) {
148
148
  #elif defined(GGML_USE_OPENBLAS)
149
149
  #include <cblas.h>
150
150
  #elif defined(GGML_USE_CUBLAS)
151
- #include <cublas_v2.h>
152
- #include <cuda_runtime.h>
153
151
  #include "ggml-cuda.h"
154
-
155
- #define CUDA_CHECK(err) \
156
- do { \
157
- cudaError_t err_ = (err); \
158
- if (err_ != cudaSuccess) { \
159
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
- cudaGetErrorString(err_)); \
161
- exit(1); \
162
- } \
163
- } while (0)
164
-
165
- #define CUBLAS_CHECK(err) \
166
- do { \
167
- cublasStatus_t err_ = (err); \
168
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
- exit(1); \
171
- } \
172
- } while (0)
173
-
174
- static cublasHandle_t cublasH = NULL;
175
- static cudaStream_t cudaStream = NULL;
176
- static void init_cublas(void) {
177
- if (cublasH == NULL) {
178
- // create cublas handle, bind a stream
179
- CUBLAS_CHECK(cublasCreate(&cublasH));
180
-
181
- CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
-
183
- CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
-
185
- // configure logging to stdout
186
- // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
- }
188
- }
152
+ #elif defined(GGML_USE_CLBLAST)
153
+ #include "ggml-opencl.h"
189
154
  #endif
190
155
 
191
156
  #undef MIN
@@ -365,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
365
330
  // precomputed f32 table for f16 (256 KB)
366
331
  static float table_f32_f16[1 << 16];
367
332
 
333
+ #if defined(__ARM_NEON)
334
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
335
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
336
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
337
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
338
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
339
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
340
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
341
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
342
+
343
+ // precomputed tables for expanding 8bits to 8 bytes (shl 4)
344
+ static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
345
+ #endif
346
+
368
347
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
369
348
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
370
349
  // This is also true for POWER9.
@@ -473,7 +452,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
473
452
  static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
453
  {
475
454
  // Load 8 bytes from memory
476
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
455
+ __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
477
456
 
478
457
  // Expand bytes into uint16_t values
479
458
  __m128i bytes = _mm_cvtepu8_epi16( tmp );
@@ -487,7 +466,46 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487
466
  return bytes;
488
467
  }
489
468
 
469
+ // horizontally add 8 floats
470
+ static inline float hsum_float_8(const __m256 x) {
471
+ __m128 res = _mm256_extractf128_ps(x, 1);
472
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
473
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
474
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
475
+ return _mm_cvtss_f32(res);
476
+ }
477
+
478
+ // horizontally add 8 int32_t
479
+ static inline int hsum_i32_8(const __m256i a) {
480
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
481
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
482
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
483
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
484
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
485
+ }
486
+
487
+ // horizontally add 4 int32_t
488
+ static inline int hsum_i32_4(const __m128i a) {
489
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
490
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
491
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
492
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
493
+ }
494
+
490
495
  #if __AVX2__ || __AVX512F__
496
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
497
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
498
+ uint32_t x32;
499
+ memcpy(&x32, x, sizeof(uint32_t));
500
+ const __m256i shuf_mask = _mm256_set_epi64x(
501
+ 0x0303030303030303, 0x0202020202020202,
502
+ 0x0101010101010101, 0x0000000000000000);
503
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
504
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
505
+ bytes = _mm256_or_si256(bytes, bit_mask);
506
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
507
+ }
508
+
491
509
  // Unpack 32 4-bit fields into 32 bytes
492
510
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
493
511
  static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
@@ -507,9 +525,38 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507
525
  return bytes;
508
526
  }
509
527
 
528
+ // add int16_t pairwise and return as float vector
529
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
530
+ const __m256i ones = _mm256_set1_epi16(1);
531
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
532
+ return _mm256_cvtepi32_ps(summed_pairs);
533
+ }
534
+
535
+ // multiply int8_t, add results pairwise twice and return as float vector
536
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
537
+ // Get absolute values of x vectors
538
+ const __m256i ax = _mm256_sign_epi8(x, x);
539
+ // Sign the values of the y vectors
540
+ const __m256i sy = _mm256_sign_epi8(y, x);
541
+ #if __AVXVNNI__
542
+ const __m256i zero = _mm256_setzero_si256();
543
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
544
+ return _mm256_cvtepi32_ps(summed_pairs);
545
+ #else
546
+ // Perform multiplication and create 16-bit values
547
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
548
+ return sum_i16_pairs_float(dot);
549
+ #endif
550
+ }
551
+
510
552
  static inline __m128i packNibbles( __m256i bytes )
511
553
  {
512
554
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
555
+ #if __AVX512F__
556
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
557
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
558
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
559
+ #else
513
560
  const __m256i lowByte = _mm256_set1_epi16( 0xFF );
514
561
  __m256i high = _mm256_andnot_si256( lowByte, bytes );
515
562
  __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -520,6 +567,7 @@ static inline __m128i packNibbles( __m256i bytes )
520
567
  __m128i r0 = _mm256_castsi256_si128( bytes );
521
568
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
522
569
  return _mm_packus_epi16( r0, r1 );
570
+ #endif
523
571
  }
524
572
  #else
525
573
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
@@ -654,6 +702,23 @@ typedef struct {
654
702
  } block_q4_3;
655
703
  static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
656
704
 
705
+ #define QK5_0 32
706
+ typedef struct {
707
+ ggml_fp16_t d; // delta
708
+ uint8_t qh[4]; // 5-th bit of quants
709
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
710
+ } block_q5_0;
711
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
712
+
713
+ #define QK5_1 32
714
+ typedef struct {
715
+ ggml_fp16_t d; // delta
716
+ ggml_fp16_t m; // min
717
+ uint8_t qh[4]; // 5-th bit of quants
718
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
719
+ } block_q5_1;
720
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
721
+
657
722
  #define QK8_0 32
658
723
  typedef struct {
659
724
  float d; // delta
@@ -661,6 +726,14 @@ typedef struct {
661
726
  } block_q8_0;
662
727
  static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663
728
 
729
+ #define QK8_1 32
730
+ typedef struct {
731
+ float d; // delta
732
+ float s0; // d * sum(qs[i]) low
733
+ float s1; // d * sum(qs[i]) high
734
+ int8_t qs[QK8_1]; // quants
735
+ } block_q8_1;
736
+ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
664
737
 
665
738
  // reference implementation for deterministic creation of model files
666
739
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -671,13 +744,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
671
744
 
672
745
  for (int i = 0; i < nb; i++) {
673
746
  float amax = 0.0f; // absolute max
747
+ float max = 0.0f;
674
748
 
675
749
  for (int l = 0; l < QK4_0; l++) {
676
750
  const float v = x[i*QK4_0 + l];
677
- amax = MAX(amax, fabsf(v));
751
+ if (amax < fabsf(v)) {
752
+ amax = fabsf(v);
753
+ max = v;
754
+ }
678
755
  }
679
756
 
680
- const float d = amax / ((1 << 3) - 1);
757
+ const float d = max / -8;
681
758
  const float id = d ? 1.0f/d : 0.0f;
682
759
 
683
760
  y[i].d = d;
@@ -686,8 +763,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
686
763
  const float v0 = x[i*QK4_0 + l + 0]*id;
687
764
  const float v1 = x[i*QK4_0 + l + 1]*id;
688
765
 
689
- const uint8_t vi0 = (int8_t)roundf(v0) + 8;
690
- const uint8_t vi1 = (int8_t)roundf(v1) + 8;
766
+ const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
767
+ const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
691
768
 
692
769
  assert(vi0 < 16);
693
770
  assert(vi1 < 16);
@@ -707,28 +784,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
707
784
 
708
785
  #if defined(__POWER9_VECTOR__)
709
786
  const vector float v85 = vec_splats(8.5f);
787
+ const vector signed int v15 = vec_splats(15);
710
788
  for (int i = 0; i < nb; i++) {
711
- float amax = 0.0f; // absolute max
789
+ float max = 0.0f;
790
+ float min = 0.0f;
712
791
 
713
792
  vector float srcv [8];
714
- vector float asrcv[8];
715
- vector float amaxv[8];
793
+ vector float maxv[8];
794
+ vector float minv[8];
716
795
 
717
796
  for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
718
- for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
719
-
720
- for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
721
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
722
- amaxv[0] = vec_max(amaxv[0], amaxv[2]);
723
- amaxv[4] = vec_max(amaxv[4], amaxv[6]);
724
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
725
- amaxv[0] = vec_max(amaxv[0], amaxv[4]);
726
-
727
- amax = MAX(
728
- MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
729
- MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
730
-
731
- const float d = amax / ((1 << 3) - 1);
797
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
798
+
799
+ for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
800
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
801
+ maxv[0] = vec_max(maxv[0], maxv[2]);
802
+ maxv[4] = vec_max(maxv[4], maxv[6]);
803
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
804
+ maxv[0] = vec_max(maxv[0], maxv[4]);
805
+
806
+ for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
807
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
808
+ minv[0] = vec_min(minv[0], minv[2]);
809
+ minv[4] = vec_min(minv[4], minv[6]);
810
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
811
+ minv[0] = vec_min(minv[0], minv[4]);
812
+
813
+
814
+ max = MAX(
815
+ MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
816
+ MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
817
+ min = MIN(
818
+ MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
819
+ MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
820
+
821
+ const float magnitude = max >= fabsf(min) ? max : min;
822
+ const float d = magnitude / -8;
732
823
  const float id = d ? 1.0/d : 0.0;
733
824
 
734
825
  y[i].d = d;
@@ -738,27 +829,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
738
829
  for (int l = 0; l < 8; l++) {
739
830
  const vector float vf = vec_madd(srcv[l], vid, v85);
740
831
  const vector signed int vi = vec_signed(vf);
832
+ const vector signed int vc = vec_min(vi, v15);
741
833
 
742
- pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
743
- pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
834
+ pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
835
+ pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
744
836
  }
745
837
  }
746
838
  #elif __ARM_NEON
747
839
  for (int i = 0; i < nb; i++) {
748
840
  float32x4_t srcv [8];
749
- float32x4_t asrcv[8];
750
- float32x4_t amaxv[8];
841
+ float32x4_t maxv[8];
842
+ float32x4_t minv[8];
751
843
 
752
844
  for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
753
- for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
754
845
 
755
- for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
756
- for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
757
- for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
846
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
847
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
848
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
758
849
 
759
- const float amax = vmaxvq_f32(amaxv[0]);
850
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
851
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
852
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
853
+
854
+ const float max = vmaxvq_f32(maxv[0]);
855
+ const float min = vminvq_f32(minv[0]);
760
856
 
761
- const float d = amax / ((1 << 3) - 1);
857
+ const float magnitude = max >= fabsf(min) ? max : min;
858
+ const float d = magnitude / -8;
762
859
  const float id = d ? 1.0f/d : 0.0f;
763
860
 
764
861
  y[i].d = d;
@@ -767,9 +864,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
767
864
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
768
865
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
769
866
  const int32x4_t vi = vcvtq_s32_f32(vf);
867
+ const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
770
868
 
771
- y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
772
- y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
869
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
870
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
773
871
  }
774
872
  }
775
873
  #elif defined(__AVX2__)
@@ -781,22 +879,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
781
879
  __m256 v3 = _mm256_loadu_ps( x + 24 );
782
880
  x += 32;
783
881
 
784
- // Compute max(abs(e)) for the block
785
- const __m256 signBit = _mm256_set1_ps( -0.0f );
786
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
787
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
788
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
789
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
882
+ // Compute max for the block
883
+ __m256 max = _mm256_max_ps( v0, v1 );
884
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
885
+ max = _mm256_max_ps( max, maxTmp );
790
886
 
791
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
887
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
792
888
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
793
889
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
794
890
  const float maxScalar = _mm_cvtss_f32( max4 );
795
891
 
892
+ // Compute min for the block
893
+ __m256 min = _mm256_min_ps( v0, v1 );
894
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
895
+ min = _mm256_min_ps( min, minTmp );
896
+
897
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
898
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
899
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
900
+ const float minScalar = _mm_cvtss_f32( min4 );
901
+
796
902
  // Quantize these floats
797
- const float d = maxScalar / 7.0f;
903
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
904
+ const float d = magnitude / -8.0f;
798
905
  y[i].d = d;
799
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
906
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
800
907
  const __m256 mul = _mm256_set1_ps( id );
801
908
 
802
909
  // Apply the multiplier
@@ -829,9 +936,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
829
936
  const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
830
937
  i0 = _mm256_permutevar8x32_epi32( i0, perm );
831
938
 
832
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
939
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
833
940
  const __m256i off = _mm256_set1_epi8( 8 );
834
941
  i0 = _mm256_add_epi8( i0, off );
942
+ const __m256i maxNibble = _mm256_set1_epi8( 15 );
943
+ i0 = _mm256_min_epi8( i0, maxNibble );
835
944
 
836
945
  // Compress the vector into 4 bit/value, and store
837
946
  __m128i res = packNibbles( i0 );
@@ -846,22 +955,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
846
955
  __m256 v3 = _mm256_loadu_ps( x + 24 );
847
956
  x += 32;
848
957
 
849
- // Compute max(abs(e)) for the block
850
- const __m256 signBit = _mm256_set1_ps( -0.0f );
851
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
852
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
853
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
854
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
958
+ // Compute max for the block
959
+ __m256 max = _mm256_max_ps( v0, v1 );
960
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
961
+ max = _mm256_max_ps( max, maxTmp );
855
962
 
856
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
963
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
857
964
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
858
965
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
859
966
  const float maxScalar = _mm_cvtss_f32( max4 );
860
967
 
968
+ // Compute min for the block
969
+ __m256 min = _mm256_min_ps( v0, v1 );
970
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
971
+ min = _mm256_min_ps( min, minTmp );
972
+
973
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
974
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
975
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
976
+ const float minScalar = _mm_cvtss_f32( min4 );
977
+
861
978
  // Quantize these floats
862
- const float d = maxScalar / 7.0f;
979
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
980
+ const float d = magnitude / -8.0f;
863
981
  y[i].d = d;
864
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
982
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
865
983
  const __m256 mul = _mm256_set1_ps( id );
866
984
 
867
985
  // Apply the multiplier
@@ -902,10 +1020,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
902
1020
  ni0 = _mm_packs_epi16( ni0, ni2 );
903
1021
  ni4 = _mm_packs_epi16( ni4, ni6 );
904
1022
 
905
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
906
- const __m128i off = _mm_set1_epi8( 8);
1023
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
1024
+ const __m128i off = _mm_set1_epi8( 8 );
907
1025
  ni0 = _mm_add_epi8( ni0, off );
908
1026
  ni4 = _mm_add_epi8( ni4, off );
1027
+ const __m128i maxNibble = _mm_set1_epi8( 15 );
1028
+ ni0 = _mm_min_epi8( ni0, maxNibble );
1029
+ ni4 = _mm_min_epi8( ni4, maxNibble );
909
1030
 
910
1031
  // Compress the vector into 4 bit/value, and store
911
1032
  __m128i res = packNibbles( ni0, ni4 );
@@ -913,24 +1034,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
913
1034
  }
914
1035
  #elif defined(__wasm_simd128__)
915
1036
  for (int i = 0; i < nb; i++) {
916
- float amax = 0.0f; // absolute max
1037
+ float max = 0.0f;
1038
+ float min = 0.0f;
917
1039
 
918
1040
  v128_t srcv [8];
919
- v128_t asrcv[8];
920
- v128_t amaxv[8];
1041
+ v128_t maxv[8];
1042
+ v128_t minv[8];
921
1043
 
922
1044
  for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
923
- for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
924
1045
 
925
- for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
926
- for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
927
- for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
1046
+ for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
1047
+ for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
1048
+ for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
1049
+
1050
+ for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
1051
+ for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
1052
+ for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
928
1053
 
929
- amax = MAX(
930
- MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
931
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
1054
+ max = MAX(
1055
+ MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
1056
+ MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
1057
+ min = MIN(
1058
+ MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
1059
+ MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
932
1060
 
933
- const float d = amax / ((1 << 3) - 1);
1061
+ const float magnitude = max >= fabsf(min) ? max : min;
1062
+ const float d = magnitude / -8;
934
1063
  const float id = d ? 1.0/d : 0.0;
935
1064
 
936
1065
  y[i].d = d;
@@ -939,9 +1068,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
939
1068
  const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
940
1069
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
941
1070
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1071
+ const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
942
1072
 
943
- y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
944
- y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
1073
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
1074
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
945
1075
  }
946
1076
  }
947
1077
  #else
@@ -1122,13 +1252,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1122
1252
 
1123
1253
  for (int i = 0; i < nb; i++) {
1124
1254
  float amax = 0.0f; // absolute max
1255
+ float max = 0.0f;
1125
1256
 
1126
1257
  for (int l = 0; l < QK4_2; l++) {
1127
1258
  const float v = x[i*QK4_2 + l];
1128
- amax = MAX(amax, fabsf(v));
1259
+ if (amax < fabsf(v)) {
1260
+ amax = fabsf(v);
1261
+ max = v;
1262
+ }
1129
1263
  }
1130
1264
 
1131
- const float d = amax / ((1 << 3) - 1);
1265
+ const float d = max / -8;
1132
1266
 
1133
1267
  const float id = d ? 1.0f/d : 0.0f;
1134
1268
 
@@ -1138,93 +1272,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1138
1272
  const float v0 = x[i*QK4_2 + l + 0]*id;
1139
1273
  const float v1 = x[i*QK4_2 + l + 1]*id;
1140
1274
 
1141
- const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
- const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1143
-
1144
- assert(vi0 < 16);
1145
- assert(vi1 < 16);
1146
-
1147
- y[i].qs[l/2] = vi0 | (vi1 << 4);
1148
- }
1149
- }
1150
- }
1151
-
1152
- static inline int nearest_int(float fval) {
1153
- assert(fval <= 4194303.f);
1154
- float val = fval + 12582912.f;
1155
- int i; memcpy(&i, &val, sizeof(int));
1156
- return (i & 0x007fffff) - 0x00400000;
1157
- }
1158
-
1159
- static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
- const float * restrict candidates, int8_t * restrict L) {
1161
- assert (nmin >= INT8_MIN);
1162
- assert (nmax <= INT8_MAX);
1163
- float amax = 0;
1164
- for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
- if (!amax) { // all zero
1166
- for (int i=0; i<n; ++i) L[i] = 0;
1167
- return 1.f;
1168
- }
1169
- float best = 0, bestScale = 0;
1170
- for (int si=0; si<nCandidates; ++si) {
1171
- float iscale = candidates[si]/amax;
1172
- float sumlxP = 0; int suml2P = 0;
1173
- float sumlxM = 0; int suml2M = 0;
1174
- for (int i=0; i<n; ++i) {
1175
- int l = nearest_int(iscale*X[i]);
1176
- int lp = MAX(nmin, MIN(nmax, +l));
1177
- int lm = MAX(nmin, MIN(nmax, -l));
1178
- sumlxP += X[i]*lp; suml2P += lp*lp;
1179
- sumlxM += X[i]*lm; suml2M += lm*lm;
1180
- }
1181
- float sumlxP2 = sumlxP*sumlxP;
1182
- float sumlxM2 = sumlxM*sumlxM;
1183
- if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
- if (sumlxP2 > best*suml2P) {
1185
- best = sumlxP2/suml2P; bestScale = iscale;
1186
- }
1187
- } else {
1188
- if (sumlxM2 > best*suml2M) {
1189
- best = sumlxM2/suml2M; bestScale = -iscale;
1190
- }
1191
- }
1192
- }
1193
- float sumlx = 0; int suml2 = 0;
1194
- for (int i=0; i<n; ++i) {
1195
- int l = nearest_int(bestScale*X[i]);
1196
- l = MAX(nmin, MIN(nmax, l));
1197
- sumlx += X[i]*l; suml2 += l*l;
1198
- L[i] = l;
1199
- }
1200
- float scale = sumlx/suml2;
1201
- return scale;
1202
- }
1203
-
1204
- static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
- #define CANDIDATE_COUNT 8
1206
- static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
- assert(k % QK4_2 == 0);
1208
-
1209
- int8_t L[QK4_2];
1210
-
1211
- const int nb = k / QK4_2;
1212
-
1213
- for (int i = 0; i < nb; i++) {
1214
- float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
- y[i].d = GGML_FP32_TO_FP16(scale);
1216
-
1217
- for (int l = 0; l < QK4_2; l += 2) {
1218
- const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
- const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1275
+ const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
1276
+ const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
1220
1277
 
1221
1278
  assert(vi0 < 16);
1222
1279
  assert(vi1 < 16);
1223
1280
 
1224
1281
  y[i].qs[l/2] = vi0 | (vi1 << 4);
1225
1282
  }
1226
-
1227
- x += QK4_2;
1228
1283
  }
1229
1284
  }
1230
1285
 
@@ -1233,9 +1288,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
1233
1288
 
1234
1289
  block_q4_2 * restrict y = vy;
1235
1290
 
1236
- //quantize_row_q4_2_reference(x, y, k);
1237
- // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
- quantize_row_q4_2_rmse(x, y, k);
1291
+ quantize_row_q4_2_reference(x, y, k);
1239
1292
  }
1240
1293
 
1241
1294
  static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
@@ -1281,6 +1334,103 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int
1281
1334
  quantize_row_q4_3_reference(x, y, k);
1282
1335
  }
1283
1336
 
1337
+ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
1338
+ assert(k % QK5_0 == 0);
1339
+ const int nb = k / QK5_0;
1340
+
1341
+ for (int i = 0; i < nb; i++) {
1342
+ float amax = 0.0f; // absolute max
1343
+ float max = 0.0f;
1344
+
1345
+ for (int l = 0; l < QK5_0; l++) {
1346
+ const float v = x[i*QK5_0 + l];
1347
+ if (amax < fabsf(v)) {
1348
+ amax = fabsf(v);
1349
+ max = v;
1350
+ }
1351
+ }
1352
+
1353
+ const float d = max / -16;
1354
+ const float id = d ? 1.0f/d : 0.0f;
1355
+
1356
+ y[i].d = GGML_FP32_TO_FP16(d);
1357
+
1358
+ uint32_t qh = 0;
1359
+
1360
+ for (int l = 0; l < QK5_0; l += 2) {
1361
+ const float v0 = x[i*QK5_0 + l + 0]*id;
1362
+ const float v1 = x[i*QK5_0 + l + 1]*id;
1363
+
1364
+ const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1365
+ const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1366
+
1367
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1368
+
1369
+ // get the 5-th bit and store it in qh at the right position
1370
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1371
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1372
+ }
1373
+
1374
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1375
+ }
1376
+ }
1377
+
1378
+ static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1379
+ assert(k % QK5_0 == 0);
1380
+
1381
+ block_q5_0 * restrict y = vy;
1382
+
1383
+ quantize_row_q5_0_reference(x, y, k);
1384
+ }
1385
+
1386
+ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1387
+ assert(k % QK5_1 == 0);
1388
+ const int nb = k / QK5_1;
1389
+
1390
+ for (int i = 0; i < nb; i++) {
1391
+ float min = FLT_MAX;
1392
+ float max = -FLT_MAX;
1393
+
1394
+ for (int l = 0; l < QK5_1; l++) {
1395
+ const float v = x[i*QK5_1 + l];
1396
+ if (v < min) min = v;
1397
+ if (v > max) max = v;
1398
+ }
1399
+
1400
+ const float d = (max - min) / ((1 << 5) - 1);
1401
+ const float id = d ? 1.0f/d : 0.0f;
1402
+
1403
+ y[i].d = GGML_FP32_TO_FP16(d);
1404
+ y[i].m = GGML_FP32_TO_FP16(min);
1405
+
1406
+ uint32_t qh = 0;
1407
+
1408
+ for (int l = 0; l < QK5_1; l += 2) {
1409
+ const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1410
+ const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1411
+
1412
+ const uint32_t vi0 = (int) (v0 + 0.5f);
1413
+ const uint32_t vi1 = (int) (v1 + 0.5f);
1414
+
1415
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1416
+
1417
+ // get the 5-th bit and store it in qh at the right position
1418
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1419
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1420
+ }
1421
+
1422
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1423
+ }
1424
+ }
1425
+
1426
+ static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1427
+ assert(k % QK5_1 == 0);
1428
+
1429
+ block_q5_1 * restrict y = vy;
1430
+
1431
+ quantize_row_q5_1_reference(x, y, k);
1432
+ }
1433
+
1284
1434
  // reference implementation for deterministic creation of model files
1285
1435
  static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1286
1436
  assert(k % QK8_0 == 0);
@@ -1300,18 +1450,64 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1300
1450
  y[i].d = d;
1301
1451
 
1302
1452
  for (int l = 0; l < QK8_0; ++l) {
1303
- const float v = x[i*QK8_0 + l]*id;
1304
- y[i].qs[l] = roundf(v);
1453
+ const float v0 = x[i*QK8_0 + l]*id;
1454
+
1455
+ y[i].qs[l] = roundf(v0);
1305
1456
  }
1306
1457
  }
1307
1458
  }
1308
1459
 
1309
1460
  static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1310
1461
  assert(k % QK8_0 == 0);
1311
- const int nb = k / QK8_0;
1312
1462
 
1313
1463
  block_q8_0 * restrict y = vy;
1314
1464
 
1465
+ quantize_row_q8_0_reference(x, y, k);
1466
+ }
1467
+
1468
+ // reference implementation for deterministic creation of model files
1469
+ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1470
+ assert(k % QK8_1 == 0);
1471
+ const int nb = k / QK8_1;
1472
+
1473
+ for (int i = 0; i < nb; i++) {
1474
+ float amax = 0.0f; // absolute max
1475
+
1476
+ for (int l = 0; l < QK8_1; l++) {
1477
+ const float v = x[i*QK8_1 + l];
1478
+ amax = MAX(amax, fabsf(v));
1479
+ }
1480
+
1481
+ const float d = amax / ((1 << 7) - 1);
1482
+ const float id = d ? 1.0f/d : 0.0f;
1483
+
1484
+ y[i].d = d;
1485
+
1486
+ int sum0 = 0;
1487
+ int sum1 = 0;
1488
+
1489
+ for (int l = 0; l < QK8_1/2; ++l) {
1490
+ const float v0 = x[i*QK8_1 + l]*id;
1491
+ const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
1492
+
1493
+ y[i].qs[ l] = roundf(v0);
1494
+ y[i].qs[QK8_1/2 + l] = roundf(v1);
1495
+
1496
+ sum0 += y[i].qs[ l];
1497
+ sum1 += y[i].qs[QK8_1/2 + l];
1498
+ }
1499
+
1500
+ y[i].s0 = d * sum0;
1501
+ y[i].s1 = d * sum1;
1502
+ }
1503
+ }
1504
+
1505
+ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1506
+ assert(k % QK8_1 == 0);
1507
+ const int nb = k / QK8_1;
1508
+
1509
+ block_q8_1 * restrict y = vy;
1510
+
1315
1511
  #if defined(__ARM_NEON)
1316
1512
  for (int i = 0; i < nb; i++) {
1317
1513
  float32x4_t srcv [8];
@@ -1332,7 +1528,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1332
1528
 
1333
1529
  y[i].d = d;
1334
1530
 
1335
- for (int l = 0; l < 8; l++) {
1531
+ int32x4_t accv0 = vdupq_n_s32(0);
1532
+ int32x4_t accv1 = vdupq_n_s32(0);
1533
+
1534
+ // low half
1535
+ for (int l = 0; l < 4; l++) {
1536
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1537
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1538
+
1539
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1540
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1541
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1542
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1543
+
1544
+ accv0 = vaddq_s32(accv0, vi);
1545
+ }
1546
+
1547
+ // high half
1548
+ for (int l = 4; l < 8; l++) {
1336
1549
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
1337
1550
  const int32x4_t vi = vcvtnq_s32_f32(v);
1338
1551
 
@@ -1340,7 +1553,15 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1340
1553
  y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1341
1554
  y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1342
1555
  y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1556
+
1557
+ accv1 = vaddq_s32(accv1, vi);
1343
1558
  }
1559
+
1560
+ const int32_t sum0 = vaddvq_s32(accv0);
1561
+ const int32_t sum1 = vaddvq_s32(accv1);
1562
+
1563
+ y[i].s0 = d * sum0;
1564
+ y[i].s1 = d * sum1;
1344
1565
  }
1345
1566
  #elif defined(__AVX2__) || defined(__AVX__)
1346
1567
  for (int i = 0; i < nb; i++) {
@@ -1388,6 +1609,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1388
1609
  __m256i i3 = _mm256_cvtps_epi32( v3 );
1389
1610
 
1390
1611
  #if defined(__AVX2__)
1612
+ // Compute the sum of the quants and set y[i].s
1613
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1614
+ y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1615
+ y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1616
+
1391
1617
  // Convert int32 to int16
1392
1618
  i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1393
1619
  i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1413,6 +1639,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1413
1639
  __m128i ni6 = _mm256_castsi256_si128( i3 );
1414
1640
  __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1415
1641
 
1642
+ // Compute the sum of the quants and set y[i].s
1643
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1644
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1645
+ y[i].s0 = d * hsum_i32_4(s0);
1646
+ y[i].s1 = d * hsum_i32_4(s1);
1647
+
1416
1648
  // Convert int32 to int16
1417
1649
  ni0 = _mm_packs_epi32( ni0, ni1 );
1418
1650
  ni2 = _mm_packs_epi32( ni2, ni3 );
@@ -1428,7 +1660,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1428
1660
  }
1429
1661
  #else
1430
1662
  // scalar
1431
- quantize_row_q8_0_reference(x, y, k);
1663
+ quantize_row_q8_1_reference(x, y, k);
1432
1664
  #endif
1433
1665
  }
1434
1666
 
@@ -1482,7 +1714,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1482
1714
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1483
1715
 
1484
1716
  // Expand 4-bit qs to 8-bit bytes
1485
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1717
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1486
1718
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1487
1719
 
1488
1720
  // Convert to signed 8-bit integers
@@ -1532,7 +1764,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1532
1764
  for (int l = 0; l < QK4_0; l += 2) {
1533
1765
  const uint8_t vi = pp[l/2];
1534
1766
 
1535
- const int8_t vi0 = vi & 0xf;
1767
+ const int8_t vi0 = vi & 0x0F;
1536
1768
  const int8_t vi1 = vi >> 4;
1537
1769
 
1538
1770
  const float v0 = (vi0 - 8)*d;
@@ -1598,7 +1830,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1598
1830
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1599
1831
 
1600
1832
  // Expand 4-bit qs to 8-bit bytes
1601
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1833
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1602
1834
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1603
1835
 
1604
1836
  // Interleave and combine
@@ -1640,7 +1872,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1640
1872
  for (int l = 0; l < QK4_1; l += 2) {
1641
1873
  const uint8_t vi = pp[l/2];
1642
1874
 
1643
- const int8_t vi0 = vi & 0xf;
1875
+ const int8_t vi0 = vi & 0x0F;
1644
1876
  const int8_t vi1 = vi >> 4;
1645
1877
 
1646
1878
  const float v0 = vi0*d + m;
@@ -1670,7 +1902,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1670
1902
  for (int l = 0; l < QK4_2; l += 2) {
1671
1903
  const uint8_t vi = pp[l/2];
1672
1904
 
1673
- const int8_t vi0 = vi & 0xf;
1905
+ const int8_t vi0 = vi & 0x0F;
1674
1906
  const int8_t vi1 = vi >> 4;
1675
1907
 
1676
1908
  const float v0 = (vi0 - 8)*d;
@@ -1700,7 +1932,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
1700
1932
  for (int l = 0; l < QK4_3; l += 2) {
1701
1933
  const uint8_t vi = pp[l/2];
1702
1934
 
1703
- const int8_t vi0 = vi & 0xf;
1935
+ const int8_t vi0 = vi & 0x0F;
1704
1936
  const int8_t vi1 = vi >> 4;
1705
1937
 
1706
1938
  const float v0 = vi0*d + m;
@@ -1715,54 +1947,176 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
1715
1947
  }
1716
1948
  }
1717
1949
 
1718
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
- static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1722
-
1723
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1724
- [GGML_TYPE_Q4_0] = {
1725
- .dequantize_row_q = dequantize_row_q4_0,
1726
- .quantize_row_q = quantize_row_q4_0,
1727
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1728
- .quantize_row_q_dot = quantize_row_q8_0,
1729
- .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1730
- },
1731
- [GGML_TYPE_Q4_1] = {
1732
- .dequantize_row_q = dequantize_row_q4_1,
1733
- .quantize_row_q = quantize_row_q4_1,
1734
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1735
- .quantize_row_q_dot = quantize_row_q8_0,
1736
- .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
1737
- },
1738
- [GGML_TYPE_Q4_2] = {
1739
- .dequantize_row_q = dequantize_row_q4_2,
1740
- .quantize_row_q = quantize_row_q4_2,
1741
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
1742
- .quantize_row_q_dot = quantize_row_q8_0,
1743
- .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
1744
- },
1745
- [GGML_TYPE_Q4_3] = {
1746
- .dequantize_row_q = dequantize_row_q4_3,
1747
- .quantize_row_q = quantize_row_q4_3,
1748
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
1749
- .quantize_row_q_dot = quantize_row_q8_0,
1750
- .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
1751
- },
1752
- [GGML_TYPE_Q8_0] = {
1753
- .dequantize_row_q = NULL, // TODO
1754
- .quantize_row_q = quantize_row_q8_0,
1755
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
- .quantize_row_q_dot = quantize_row_q8_0,
1757
- .vec_dot_q = NULL, // TODO
1758
- },
1759
- };
1950
+ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
1951
+ assert(k % QK5_0 == 0);
1952
+ const int nb = k / QK5_0;
1760
1953
 
1761
- // For internal test use
1762
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1763
- GGML_ASSERT(i < GGML_TYPE_COUNT);
1764
- return quantize_fns[i];
1765
- }
1954
+ const block_q5_0 * restrict x = vx;
1955
+
1956
+ for (int i = 0; i < nb; i++) {
1957
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1958
+
1959
+ const uint8_t * restrict pp = x[i].qs;
1960
+
1961
+ uint32_t qh;
1962
+ memcpy(&qh, x[i].qh, sizeof(qh));
1963
+
1964
+ for (int l = 0; l < QK5_0; l += 2) {
1965
+ const uint8_t vi = pp[l/2];
1966
+
1967
+ // extract the 5-th bit from qh
1968
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1969
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1970
+
1971
+ const int8_t vi0 = (vi & 0x0F) | vh0;
1972
+ const int8_t vi1 = (vi >> 4) | vh1;
1973
+
1974
+ const float v0 = (vi0 - 16)*d;
1975
+ const float v1 = (vi1 - 16)*d;
1976
+
1977
+ y[i*QK5_0 + l + 0] = v0;
1978
+ y[i*QK5_0 + l + 1] = v1;
1979
+
1980
+ assert(!isnan(y[i*QK5_0 + l + 0]));
1981
+ assert(!isnan(y[i*QK5_0 + l + 1]));
1982
+ }
1983
+ }
1984
+ }
1985
+
1986
+ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
1987
+ assert(k % QK5_1 == 0);
1988
+ const int nb = k / QK5_1;
1989
+
1990
+ const block_q5_1 * restrict x = vx;
1991
+
1992
+ for (int i = 0; i < nb; i++) {
1993
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1994
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1995
+
1996
+ const uint8_t * restrict pp = x[i].qs;
1997
+
1998
+ uint32_t qh;
1999
+ memcpy(&qh, x[i].qh, sizeof(qh));
2000
+
2001
+ for (int l = 0; l < QK5_1; l += 2) {
2002
+ const uint8_t vi = pp[l/2];
2003
+
2004
+ // extract the 5-th bit from qh
2005
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
2006
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
2007
+
2008
+ const uint8_t vi0 = (vi & 0x0F) | vh0;
2009
+ const uint8_t vi1 = (vi >> 4) | vh1;
2010
+
2011
+ const float v0 = vi0*d + m;
2012
+ const float v1 = vi1*d + m;
2013
+
2014
+ y[i*QK5_1 + l + 0] = v0;
2015
+ y[i*QK5_1 + l + 1] = v1;
2016
+
2017
+ assert(!isnan(y[i*QK5_1 + l + 0]));
2018
+ assert(!isnan(y[i*QK5_1 + l + 1]));
2019
+ }
2020
+ }
2021
+ }
2022
+
2023
+ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
2024
+ assert(k % QK8_0 == 0);
2025
+ const int nb = k / QK8_0;
2026
+
2027
+ const block_q8_0 * restrict x = vx;
2028
+
2029
+ for (int i = 0; i < nb; i++) {
2030
+ const float d = x[i].d;
2031
+
2032
+ const int8_t * restrict pp = x[i].qs;
2033
+
2034
+ for (int l = 0; l < QK8_0; ++l) {
2035
+ y[i*QK8_0 + l] = pp[l]*d;
2036
+ }
2037
+ }
2038
+ }
2039
+
2040
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2041
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2042
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2043
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2044
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2045
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2046
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2047
+
2048
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
2049
+ [GGML_TYPE_Q4_0] = {
2050
+ .dequantize_row_q = dequantize_row_q4_0,
2051
+ .quantize_row_q = quantize_row_q4_0,
2052
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
2053
+ .quantize_row_q_dot = quantize_row_q8_0,
2054
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
2055
+ .vec_dot_type = GGML_TYPE_Q8_0,
2056
+ },
2057
+ [GGML_TYPE_Q4_1] = {
2058
+ .dequantize_row_q = dequantize_row_q4_1,
2059
+ .quantize_row_q = quantize_row_q4_1,
2060
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
2061
+ .quantize_row_q_dot = quantize_row_q8_1,
2062
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
2063
+ .vec_dot_type = GGML_TYPE_Q8_1,
2064
+ },
2065
+ [GGML_TYPE_Q4_2] = {
2066
+ .dequantize_row_q = dequantize_row_q4_2,
2067
+ .quantize_row_q = quantize_row_q4_2,
2068
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
2069
+ .quantize_row_q_dot = quantize_row_q8_0,
2070
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
2071
+ .vec_dot_type = GGML_TYPE_Q8_0,
2072
+ },
2073
+ [GGML_TYPE_Q4_3] = {
2074
+ .dequantize_row_q = dequantize_row_q4_3,
2075
+ .quantize_row_q = quantize_row_q4_3,
2076
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
2077
+ .quantize_row_q_dot = quantize_row_q8_1,
2078
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_1,
2079
+ .vec_dot_type = GGML_TYPE_Q8_1,
2080
+ },
2081
+ [GGML_TYPE_Q5_0] = {
2082
+ .dequantize_row_q = dequantize_row_q5_0,
2083
+ .quantize_row_q = quantize_row_q5_0,
2084
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
2085
+ .quantize_row_q_dot = quantize_row_q8_0,
2086
+ .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
2087
+ .vec_dot_type = GGML_TYPE_Q8_0,
2088
+ },
2089
+ [GGML_TYPE_Q5_1] = {
2090
+ .dequantize_row_q = dequantize_row_q5_1,
2091
+ .quantize_row_q = quantize_row_q5_1,
2092
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
2093
+ .quantize_row_q_dot = quantize_row_q8_1,
2094
+ .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
2095
+ .vec_dot_type = GGML_TYPE_Q8_1,
2096
+ },
2097
+ [GGML_TYPE_Q8_0] = {
2098
+ .dequantize_row_q = dequantize_row_q8_0,
2099
+ .quantize_row_q = quantize_row_q8_0,
2100
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
2101
+ .quantize_row_q_dot = quantize_row_q8_0,
2102
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
2103
+ .vec_dot_type = GGML_TYPE_Q8_0,
2104
+ },
2105
+ [GGML_TYPE_Q8_1] = {
2106
+ .dequantize_row_q = NULL, // TODO
2107
+ .quantize_row_q = quantize_row_q8_1,
2108
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
2109
+ .quantize_row_q_dot = quantize_row_q8_1,
2110
+ .vec_dot_q = NULL, // TODO
2111
+ .vec_dot_type = GGML_TYPE_Q8_1,
2112
+ },
2113
+ };
2114
+
2115
+ // For internal test use
2116
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
2117
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
2118
+ return quantize_fns[i];
2119
+ }
1766
2120
 
1767
2121
 
1768
2122
  //
@@ -2366,8 +2720,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2366
2720
  const block_q4_0 * restrict x = vx;
2367
2721
  const block_q8_0 * restrict y = vy;
2368
2722
 
2369
- float sumf = 0.0;
2370
-
2371
2723
  #if defined(__ARM_NEON)
2372
2724
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
2725
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2378,7 +2730,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2378
2730
  const block_q8_0 * restrict y0 = &y[i + 0];
2379
2731
  const block_q8_0 * restrict y1 = &y[i + 1];
2380
2732
 
2381
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2733
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2382
2734
  const int8x16_t s8b = vdupq_n_s8(0x8);
2383
2735
 
2384
2736
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2436,7 +2788,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2788
  #endif
2437
2789
  }
2438
2790
 
2439
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2791
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
2792
  #elif defined(__AVX2__)
2441
2793
  // Initialize accumulator with zeros
2442
2794
  __m256 acc = _mm256_setzero_ps();
@@ -2454,32 +2806,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2454
2806
 
2455
2807
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2456
2808
 
2457
- // Get absolute values of x vectors
2458
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2459
-
2460
- // Sign the values of the y vectors
2461
- const __m256i sy = _mm256_sign_epi8(by, bx);
2462
-
2463
- // Perform multiplication and create 16-bit values
2464
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2465
-
2466
- const __m256i ones = _mm256_set1_epi16(1);
2467
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
2468
-
2469
- /* Convert to vectore of 8 int32_t to 8 floats */
2470
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2809
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2471
2810
 
2472
2811
  /* Multiply q with scale and accumulate */
2473
2812
  acc = _mm256_fmadd_ps( d, q, acc );
2474
2813
  }
2475
2814
 
2476
- // Return horizontal sum of the acc vector
2477
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2478
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2479
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2480
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2481
-
2482
- sumf = _mm_cvtss_f32( res );
2815
+ *s = hsum_float_8(acc);
2483
2816
  #elif defined(__AVX__)
2484
2817
  // Initialize accumulator with zeros
2485
2818
  __m256 acc = _mm256_setzero_ps();
@@ -2518,15 +2851,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2518
2851
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2519
2852
  }
2520
2853
 
2521
- // Return horizontal sum of the acc vector
2522
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2523
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2524
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2525
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2526
-
2527
- sumf = _mm_cvtss_f32( res );
2854
+ *s = hsum_float_8(acc);
2528
2855
  #else
2529
2856
  // scalar
2857
+ float sumf = 0.0;
2530
2858
  for (int i = 0; i < nb; i++) {
2531
2859
  const float d0 = x[i].d;
2532
2860
  const float d1 = y[i].d;
@@ -2538,8 +2866,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2538
2866
  for (int j = 0; j < QK8_0/2; j++) {
2539
2867
  const uint8_t v0 = p0[j];
2540
2868
 
2541
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
- const int i1 = (int8_t) (v0 >> 4) - 8;
2869
+ const int i0 = (int8_t) (v0 & 0x0F) - 8;
2870
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2543
2871
 
2544
2872
  const int i2 = p1[2*j + 0];
2545
2873
  const int i3 = p1[2*j + 1];
@@ -2548,34 +2876,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2548
2876
  }
2549
2877
  sumf += d0*d1*sumi;
2550
2878
  }
2551
- #endif
2552
-
2553
2879
  *s = sumf;
2880
+ #endif
2554
2881
  }
2555
2882
 
2556
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
- const int nb = n / QK8_0;
2883
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2884
+ const int nb = n / QK8_1;
2558
2885
 
2559
- assert(n % QK8_0 == 0);
2886
+ assert(n % QK8_1 == 0);
2560
2887
  assert(nb % 2 == 0);
2561
2888
 
2562
2889
  const block_q4_1 * restrict x = vx;
2563
- const block_q8_0 * restrict y = vy;
2564
-
2565
- float sumf = 0.0;
2890
+ const block_q8_1 * restrict y = vy;
2566
2891
 
2567
2892
  // TODO: add AVX / WASM SIMD / etc
2568
2893
  #if defined(__ARM_NEON)
2569
2894
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
2895
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2571
2896
 
2897
+ float summs = 0;
2898
+
2572
2899
  for (int i = 0; i < nb; i += 2) {
2573
2900
  const block_q4_1 * restrict x0 = &x[i + 0];
2574
2901
  const block_q4_1 * restrict x1 = &x[i + 1];
2575
- const block_q8_0 * restrict y0 = &y[i + 0];
2576
- const block_q8_0 * restrict y1 = &y[i + 1];
2902
+ const block_q8_1 * restrict y0 = &y[i + 0];
2903
+ const block_q8_1 * restrict y1 = &y[i + 1];
2577
2904
 
2578
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2905
+ summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2906
+
2907
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2579
2908
 
2580
2909
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2581
2910
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2586,46 +2915,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2586
2915
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
2916
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
2917
 
2918
+ // interleave
2919
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2920
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2921
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2922
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2923
+
2589
2924
  // load y
2590
2925
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
2926
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
2927
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
2928
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2594
2929
 
2595
- // interleave
2596
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2600
-
2601
- const int16x8_t s0i = vaddq_s16(
2602
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
-
2605
- const int16x8_t s1i = vaddq_s16(
2606
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
-
2609
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
-
2612
2930
  #if defined(__ARM_FEATURE_DOTPROD)
2613
2931
  // dot product into int32x4_t
2614
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2932
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2933
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2616
2934
 
2617
2935
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
2936
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2619
2937
  #else
2620
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2938
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2939
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2940
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2941
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2624
2942
 
2625
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2943
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2944
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2945
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2946
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2629
2947
 
2630
2948
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
2949
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2637,65 +2955,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
2955
  #endif
2638
2956
  }
2639
2957
 
2640
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2958
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2641
2959
  #elif defined(__AVX2__)
2642
2960
  // Initialize accumulator with zeros
2643
2961
  __m256 acc = _mm256_setzero_ps();
2644
2962
 
2963
+ float summs = 0;
2964
+
2645
2965
  // Main loop
2646
2966
  for (int i = 0; i < nb; ++i) {
2647
2967
  const float * d0 = &x[i].d;
2648
2968
  const float * d1 = &y[i].d;
2649
- const float * m0 = &x[i].m;
2969
+
2970
+ summs += x[i].m * (y[i].s0 + y[i].s1);
2650
2971
 
2651
2972
  const __m256 d0v = _mm256_broadcast_ss( d0 );
2652
2973
  const __m256 d1v = _mm256_broadcast_ss( d1 );
2653
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2654
2974
 
2655
2975
  // Compute combined scales
2656
2976
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
- const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2658
2977
 
2659
2978
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
2979
  const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
2980
  const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2662
2981
 
2663
- // Get absolute values of x vectors
2664
- const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
-
2666
- // Sign the values of the y vectors
2667
- const __m256i sy = _mm256_sign_epi8( by, bx );
2668
-
2669
- // Perform multiplication and create 16-bit values
2670
- const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
- const __m256i ones = _mm256_set1_epi16( 1 );
2672
- const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
-
2674
- // Convert to vector of 8 int32_t to 8 floats
2675
- const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2982
+ const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2676
2983
 
2677
2984
  // Accumulate d0*d1*x*y
2678
2985
  acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
-
2680
- // Compute sum of y values
2681
- const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
- const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
-
2686
- // Accumulate d1*m0*y
2687
- acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2688
2986
  }
2689
2987
 
2690
- // Return horizontal sum of the acc vector
2691
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2692
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695
-
2696
- sumf = _mm_cvtss_f32( res );
2988
+ *s = hsum_float_8(acc) + summs;
2697
2989
  #else
2698
2990
  // scalar
2991
+ float sumf = 0.0;
2699
2992
  for (int i = 0; i < nb; i++) {
2700
2993
  const float d0 = x[i].d;
2701
2994
  const float m0 = x[i].m;
@@ -2705,11 +2998,11 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2705
2998
  const int8_t * restrict p1 = y[i].qs;
2706
2999
 
2707
3000
  // TODO: this is very slow ..
2708
- for (int j = 0; j < QK8_0/2; j++) {
3001
+ for (int j = 0; j < QK8_1/2; j++) {
2709
3002
  const uint8_t v0 = p0[j];
2710
3003
 
2711
- const float f0 = d0*(v0 & 0xf) + m0;
2712
- const float f1 = d0*(v0 >> 4) + m0;
3004
+ const float f0 = d0*(v0 & 0x0F) + m0;
3005
+ const float f1 = d0*(v0 >> 4) + m0;
2713
3006
 
2714
3007
  const float f2 = d1*p1[2*j + 0];
2715
3008
  const float f3 = d1*p1[2*j + 1];
@@ -2717,9 +3010,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2717
3010
  sumf += f0*f2 + f1*f3;
2718
3011
  }
2719
3012
  }
2720
- #endif
2721
-
2722
3013
  *s = sumf;
3014
+ #endif
2723
3015
  }
2724
3016
 
2725
3017
  static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2732,8 +3024,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2732
3024
  const block_q4_2 * restrict x = vx;
2733
3025
  const block_q8_0 * restrict y = vy;
2734
3026
 
2735
- float sumf = 0.0;
2736
-
2737
3027
  #if defined(__ARM_NEON)
2738
3028
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
3029
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2747,7 +3037,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2747
3037
  const block_q8_0 * restrict y0 = &y[i + 0];
2748
3038
  const block_q8_0 * restrict y1 = &y[i + 1];
2749
3039
 
2750
- const uint8x16_t m4b = vdupq_n_u8(0xf);
3040
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2751
3041
  const int8x16_t s8b = vdupq_n_s8(0x8);
2752
3042
 
2753
3043
  const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
@@ -2782,270 +3072,604 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2782
3072
  vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
3073
  vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2784
3074
 
2785
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2788
- #else
2789
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3075
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3076
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3077
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3078
+ #else
3079
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3080
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3081
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3082
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3083
+
3084
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3085
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3086
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3087
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3088
+
3089
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3090
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3091
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3092
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3093
+
3094
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3095
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3096
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3097
+
3098
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3099
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3100
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3101
+ #endif
3102
+ }
3103
+
3104
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3105
+ #elif defined(__AVX2__)
3106
+ // Initialize accumulator with zeros
3107
+ __m256 acc = _mm256_setzero_ps();
3108
+
3109
+ // Main loop
3110
+ for (int i = 0; i < nb; i++) {
3111
+ /* Compute combined scale for the block */
3112
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3113
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3114
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3115
+
3116
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3117
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3118
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
3119
+
3120
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3121
+ const __m256i off = _mm256_set1_epi8(8);
3122
+ bx = _mm256_sub_epi8(bx, off);
3123
+
3124
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3125
+
3126
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3127
+
3128
+ /* Multiply q with scale and accumulate */
3129
+ acc = _mm256_fmadd_ps(d, q, acc);
3130
+ }
3131
+
3132
+ *s = hsum_float_8(acc);
3133
+ #else
3134
+ // scalar
3135
+ float sumf = 0.0;
3136
+ for (int i = 0; i < nb; i++) {
3137
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3138
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3139
+ const int8_t * restrict y0 = y[i].qs;
3140
+
3141
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3142
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3143
+
3144
+ int sumi_0 = 0;
3145
+ int sumi_1 = 0;
3146
+
3147
+ for (int j = 0; j < QK8_0/4; j++) {
3148
+ const uint8_t v0 = x0[j];
3149
+ const uint8_t v1 = x1[j];
3150
+
3151
+ const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3152
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
3153
+
3154
+ const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3155
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
3156
+
3157
+ const int i2_0 = y0[2*j + 0];
3158
+ const int i3_0 = y0[2*j + 1];
3159
+
3160
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
3161
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3162
+
3163
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
3164
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3165
+ }
3166
+
3167
+ sumf += (d0 * y[i].d) * sumi_0;
3168
+ sumf += (d1 * y[i].d) * sumi_1;
3169
+ }
3170
+ *s = sumf;
3171
+ #endif
3172
+ }
3173
+
3174
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3175
+ const int nb = n / QK8_1;
3176
+
3177
+ assert(n % QK8_1 == 0);
3178
+ assert(nb % 2 == 0);
3179
+ assert(QK8_1 == 2*QK4_3);
3180
+
3181
+ const block_q4_3 * restrict x = vx;
3182
+ const block_q8_1 * restrict y = vy;
3183
+
3184
+ #if defined(__ARM_NEON)
3185
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3186
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3187
+
3188
+ float summs0 = 0.0f;
3189
+ float summs1 = 0.0f;
3190
+
3191
+ for (int i = 0; i < nb; ++i) {
3192
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
3193
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
3194
+
3195
+ const block_q8_1 * restrict y0 = &y[i + 0];
3196
+
3197
+ summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
3198
+ summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
3199
+
3200
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3201
+
3202
+ // 4-bit -> 8-bit
3203
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
3204
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3205
+
3206
+ // interleave
3207
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3208
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3209
+
3210
+ // load y
3211
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3212
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3213
+
3214
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
3215
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
3216
+
3217
+ #if defined(__ARM_FEATURE_DOTPROD)
3218
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
3219
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
3220
+ #else
3221
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3222
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3223
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3224
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3225
+
3226
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3227
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3228
+
3229
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
3230
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
3231
+ #endif
3232
+ }
3233
+
3234
+ *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
3235
+ #elif defined(__AVX2__)
3236
+ // Initialize accumulator with zeros
3237
+ __m256 acc = _mm256_setzero_ps();
3238
+ float summs = 0.0f;
3239
+
3240
+ // Main loop
3241
+ for (int i = 0; i < nb; i++) {
3242
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3243
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3244
+ const __m256 dx = _mm256_set_m128(d1, d0);
3245
+
3246
+ summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
3247
+ + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
3248
+
3249
+ const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3250
+ const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3251
+ const __m256i bx = _mm256_set_m128i(bx1, bx0);
3252
+
3253
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3254
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3255
+
3256
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3257
+
3258
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
3259
+ }
3260
+
3261
+ *s = hsum_float_8(acc) + summs;
3262
+ #else
3263
+ // scalar
3264
+ float sumf = 0.0;
3265
+ for (int i = 0; i < nb; i++) {
3266
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3267
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3268
+ const int8_t * restrict y0 = y[i].qs;
3269
+
3270
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3271
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3272
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3273
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3274
+
3275
+ int sxy_0 = 0;
3276
+ int sxy_1 = 0;
3277
+
3278
+ for (int j = 0; j < QK8_1/4; j++) {
3279
+ const uint8_t v0 = x0[j];
3280
+ const uint8_t v1 = x1[j];
3281
+
3282
+ const int x0_0 = v0 & 0x0F;
3283
+ const int x1_0 = v0 >> 4;
3284
+
3285
+ const int x0_1 = v1 & 0x0F;
3286
+ const int x1_1 = v1 >> 4;
3287
+
3288
+ const int y0_0 = y0[2*j + 0];
3289
+ const int y1_0 = y0[2*j + 1];
3290
+
3291
+ const int y0_1 = y0[2*(j + QK8_1/4) + 0];
3292
+ const int y1_1 = y0[2*(j + QK8_1/4) + 1];
3293
+
3294
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3295
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3296
+ }
3297
+
3298
+ sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
3299
+ }
3300
+ *s = sumf;
3301
+ #endif
3302
+ }
3303
+
3304
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3305
+ const int nb = n / QK8_0;
3306
+
3307
+ assert(n % QK8_0 == 0);
3308
+ assert(nb % 2 == 0);
3309
+ assert(QK8_0 == QK5_0);
3310
+
3311
+ const block_q5_0 * restrict x = vx;
3312
+ const block_q8_0 * restrict y = vy;
3313
+
3314
+ #if defined(__ARM_NEON)
3315
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3316
+
3317
+ uint64_t tmp[4];
3318
+
3319
+ for (int i = 0; i < nb; ++i) {
3320
+ const block_q5_0 * restrict x0 = &x[i];
3321
+ const block_q8_0 * restrict y0 = &y[i];
3322
+
3323
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3324
+ const int8x16_t s16b = vdupq_n_s8(0x10);
3325
+
3326
+ // extract the 5th bit
3327
+ uint32_t qh;
3328
+ memcpy(&qh, x0->qh, sizeof(qh));
3329
+
3330
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3331
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3332
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3333
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3334
+
3335
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3336
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3337
+
3338
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3339
+
3340
+ // 4-bit -> 8-bit
3341
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
3342
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3343
+
3344
+ // interleave
3345
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3346
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3347
+
3348
+ // add high bit and sub 16
3349
+ const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
3350
+ const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
3351
+
3352
+ // load y
3353
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3354
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3355
+
3356
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
3357
+
3358
+ #if defined(__ARM_FEATURE_DOTPROD)
3359
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3360
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3361
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3362
+ #else
3363
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3364
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3365
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3366
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
3367
+
3368
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3369
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3370
+
3371
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3372
+ #endif
3373
+ }
3374
+
3375
+ *s = vaddvq_f32(sumv);
3376
+ #elif defined(__AVX2__)
3377
+ // Initialize accumulator with zeros
3378
+ __m256 acc = _mm256_setzero_ps();
3379
+
3380
+ // Main loop
3381
+ for (int i = 0; i < nb; i++) {
3382
+ /* Compute combined scale for the block */
3383
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
3384
+
3385
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3386
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3387
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3388
+ bx = _mm256_or_si256(bx, bxhi);
3389
+
3390
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3391
+
3392
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3393
+
3394
+ /* Multiply q with scale and accumulate */
3395
+ acc = _mm256_fmadd_ps(d, q, acc);
3396
+ }
3397
+
3398
+ *s = hsum_float_8(acc);
3399
+ #else
3400
+ // scalar
3401
+ float sumf = 0.0;
3402
+ for (int i = 0; i < nb; i++) {
3403
+ const uint8_t * restrict x0 = x[i].qs;
3404
+ const int8_t * restrict y0 = y[i].qs;
3405
+
3406
+ uint32_t qh;
3407
+ memcpy(&qh, x[i].qh, sizeof(qh));
3408
+
3409
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3410
+
3411
+ int sxy = 0;
3412
+
3413
+ for (int j = 0; j < QK8_0/2; j++) {
3414
+ const uint8_t v0 = x0[j];
3415
+
3416
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3417
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
3418
+
3419
+ const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
3420
+ const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
3421
+
3422
+ const int y0_0 = y0[2*j + 0];
3423
+ const int y1_0 = y0[2*j + 1];
3424
+
3425
+ sxy += x0_0*y0_0 + x1_0*y1_0;
3426
+ }
3427
+
3428
+ sumf += (d*sxy)*y[i].d;
3429
+ }
3430
+ *s = sumf;
3431
+ #endif
3432
+ }
3433
+
3434
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3435
+ const int nb = n / QK8_1;
3436
+
3437
+ assert(n % QK8_1 == 0);
3438
+ assert(nb % 2 == 0);
3439
+ assert(QK8_1 == QK5_1);
3440
+
3441
+ const block_q5_1 * restrict x = vx;
3442
+ const block_q8_1 * restrict y = vy;
3443
+
3444
+ #if defined(__ARM_NEON)
3445
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3446
+
3447
+ float summs = 0.0f;
3448
+
3449
+ uint64_t tmp[4];
3450
+
3451
+ for (int i = 0; i < nb; ++i) {
3452
+ const block_q5_1 * restrict x0 = &x[i];
3453
+ const block_q8_1 * restrict y0 = &y[i];
3454
+
3455
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
3456
+
3457
+ // extract the 5th bit
3458
+ uint32_t qh;
3459
+ memcpy(&qh, x0->qh, sizeof(qh));
3460
+
3461
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3462
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3463
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3464
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3465
+
3466
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3467
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3468
+
3469
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3470
+
3471
+ // 4-bit -> 8-bit
3472
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3473
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3474
+
3475
+ // interleave
3476
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3477
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3478
+
3479
+ // add
3480
+ const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3481
+ const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
3482
+
3483
+ // load y
3484
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3485
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3486
+
3487
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2793
3488
 
2794
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3489
+ #if defined(__ARM_FEATURE_DOTPROD)
3490
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3491
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3492
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3493
+ #else
3494
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3495
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3496
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3497
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2798
3498
 
2799
3499
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
3500
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2803
3501
 
2804
- sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
- vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
- vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2807
-
2808
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
- vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
- vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3502
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2811
3503
  #endif
2812
3504
  }
2813
3505
 
2814
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3506
+ *s = vaddvq_f32(sumv) + summs;
2815
3507
  #elif defined(__AVX2__)
2816
3508
  // Initialize accumulator with zeros
2817
3509
  __m256 acc = _mm256_setzero_ps();
3510
+ float summs = 0.0f;
2818
3511
 
2819
3512
  // Main loop
2820
3513
  for (int i = 0; i < nb; i++) {
2821
- /* Compute combined scale for the block */
2822
- const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
- const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
- const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2825
-
2826
- __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
- __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
- __m256i bx = _mm256_set_m128i(bx1, bx0);
2829
-
2830
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2831
- const __m256i off = _mm256_set1_epi8(8);
2832
- bx = _mm256_sub_epi8(bx, off);
3514
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2833
3515
 
2834
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3516
+ summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2835
3517
 
2836
- // Get absolute values of x vectors
2837
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2838
- // Sign the values of the y vectors
2839
- const __m256i sy = _mm256_sign_epi8(by, bx);
2840
- // Perform multiplication and create 16-bit values
2841
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
3518
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3519
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3520
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3521
+ bx = _mm256_or_si256(bx, bxhi);
2842
3522
 
2843
- const __m256i ones = _mm256_set1_epi16(1);
2844
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
3523
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3524
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2845
3525
 
2846
- /* Convert to vectore of 8 int32_t to 8 floats */
2847
- __m256 q = _mm256_cvtepi32_ps(xy_q);
3526
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2848
3527
 
2849
- /* Multiply q with scale and accumulate */
2850
- acc = _mm256_fmadd_ps(d, q, acc);
3528
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2851
3529
  }
2852
3530
 
2853
- // Return horizontal sum of the acc vector
2854
- __m128 res = _mm256_extractf128_ps(acc, 1);
2855
- res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858
-
2859
- sumf = _mm_cvtss_f32(res);
3531
+ *s = hsum_float_8(acc) + summs;
2860
3532
  #else
2861
- // scalar
3533
+ float sumf = 0.0;
3534
+
2862
3535
  for (int i = 0; i < nb; i++) {
2863
- const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3536
+ const uint8_t * restrict x0 = x[i].qs;
2865
3537
  const int8_t * restrict y0 = y[i].qs;
2866
3538
 
2867
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3539
+ uint32_t qh;
3540
+ memcpy(&qh, x[i].qh, sizeof(qh));
2869
3541
 
2870
- int sumi_0 = 0;
2871
- int sumi_1 = 0;
3542
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3543
+ const float m = GGML_FP16_TO_FP32(x[i].m);
2872
3544
 
2873
- for (int j = 0; j < QK8_0/4; j++) {
2874
- const uint8_t v0 = x0[j];
2875
- const uint8_t v1 = x1[j];
3545
+ int sxy = 0;
2876
3546
 
2877
- const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
- const int i1_0 = (int8_t) (v0 >> 4) - 8;
3547
+ for (int j = 0; j < QK8_1/2; j++) {
3548
+ const uint8_t v0 = x0[j];
2879
3549
 
2880
- const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
- const int i1_1 = (int8_t) (v1 >> 4) - 8;
3550
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3551
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
2882
3552
 
2883
- const int i2_0 = y0[2*j + 0];
2884
- const int i3_0 = y0[2*j + 1];
3553
+ const int x0_0 = (v0 & 0x0F) | x0_0h;
3554
+ const int x1_0 = (v0 >> 4) | x1_0h;
2885
3555
 
2886
- const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
- const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3556
+ const int y0_0 = y0[2*j + 0];
3557
+ const int y1_0 = y0[2*j + 1];
2888
3558
 
2889
- sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
- sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3559
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2891
3560
  }
2892
3561
 
2893
- sumf += (d0 * y[i].d) * sumi_0;
2894
- sumf += (d1 * y[i].d) * sumi_1;
3562
+ sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2895
3563
  }
2896
- #endif
2897
3564
 
2898
3565
  *s = sumf;
3566
+ #endif
2899
3567
  }
2900
3568
 
2901
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3569
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
3570
  const int nb = n / QK8_0;
2903
3571
 
2904
3572
  assert(n % QK8_0 == 0);
2905
3573
  assert(nb % 2 == 0);
2906
- assert(QK8_0 == 2*QK4_2);
3574
+ assert(QK8_0 == QK8_0);
2907
3575
 
2908
- const block_q4_3 * restrict x = vx;
3576
+ const block_q8_0 * restrict x = vx;
2909
3577
  const block_q8_0 * restrict y = vy;
2910
3578
 
2911
- float sumf = 0.0;
2912
-
2913
3579
  #if defined(__ARM_NEON)
2914
3580
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
3581
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2916
3582
 
2917
3583
  for (int i = 0; i < nb; i += 2) {
2918
- const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
- const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
- const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
- const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2922
-
3584
+ const block_q8_0 * restrict x0 = &x[i + 0];
3585
+ const block_q8_0 * restrict x1 = &x[i + 1];
2923
3586
  const block_q8_0 * restrict y0 = &y[i + 0];
2924
3587
  const block_q8_0 * restrict y1 = &y[i + 1];
2925
3588
 
2926
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
-
2928
- const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
- const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
- const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
- const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
-
2933
- const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
- const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
- const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
- const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
-
2938
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
- const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
-
2941
- // 4-bit -> 8-bit
2942
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2946
-
2947
- // interleave
2948
- const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
- const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
- const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
- const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
3589
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3590
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3591
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3592
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
2952
3593
 
2953
3594
  // load y
2954
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2958
-
2959
- const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
- const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2961
-
2962
- const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
- const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2964
-
2965
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
3595
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3596
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3597
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3598
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
2969
3599
 
2970
3600
  #if defined(__ARM_FEATURE_DOTPROD)
2971
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
- #else
2976
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
-
2981
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3601
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3602
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3603
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
2985
3604
 
2986
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3605
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3606
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3607
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
2990
3608
 
2991
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
3609
+ #else
3610
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3611
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3612
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3613
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3614
+
3615
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3616
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3617
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3618
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3619
+
3620
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3621
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3622
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3623
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3624
+
3625
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3626
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
2995
3627
  #endif
2996
3628
  }
2997
3629
 
2998
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2999
- #else
3000
- // scalar
3001
- for (int i = 0; i < nb; i++) {
3002
- const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
- const int8_t * restrict y0 = y[i].qs;
3005
-
3006
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
- const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
- const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3010
-
3011
- int sy_0 = 0;
3012
- int sy_1 = 0;
3630
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3631
+ #elif defined(__AVX2__)
3632
+ // Initialize accumulator with zeros
3633
+ __m256 acc = _mm256_setzero_ps();
3013
3634
 
3014
- int sxy_0 = 0;
3015
- int sxy_1 = 0;
3635
+ // Main loop
3636
+ for (int i = 0; i < nb; ++i) {
3637
+ // Compute combined scale for the block
3638
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
3639
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3640
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3016
3641
 
3017
- for (int j = 0; j < QK8_0/4; j++) {
3018
- const uint8_t v0 = x0[j];
3019
- const uint8_t v1 = x1[j];
3642
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3020
3643
 
3021
- const int x0_0 = v0 & 0xf;
3022
- const int x1_0 = v0 >> 4;
3644
+ // Multiply q with scale and accumulate
3645
+ acc = _mm256_fmadd_ps( d, q, acc );
3646
+ }
3023
3647
 
3024
- const int x0_1 = v1 & 0xf;
3025
- const int x1_1 = v1 >> 4;
3648
+ *s = hsum_float_8(acc);
3649
+ #else
3650
+ // scalar
3651
+ float sumf = 0.0;
3026
3652
 
3027
- const int y0_0 = y0[2*j + 0];
3028
- const int y1_0 = y0[2*j + 1];
3653
+ for (int i = 0; i < nb; i++) {
3654
+ const int8_t * restrict x0 = x[i].qs;
3655
+ const int8_t * restrict y0 = y[i].qs;
3029
3656
 
3030
- const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
- const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3657
+ int sumi = 0;
3032
3658
 
3033
- sy_0 += y0_0 + y1_0;
3034
- sy_1 += y0_1 + y1_1;
3659
+ for (int j = 0; j < QK8_0; j++) {
3660
+ const int v0 = x0[j];
3661
+ const int v1 = y0[j];
3035
3662
 
3036
- sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
- sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3663
+ sumi += v0*v1;
3038
3664
  }
3039
3665
 
3040
- sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
- sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
3666
+ sumf += (x[i].d*y[i].d)*sumi;
3042
3667
  }
3043
- #endif
3044
3668
 
3045
3669
  *s = sumf;
3670
+ #endif
3046
3671
  }
3047
3672
 
3048
-
3049
3673
  // compute GGML_VEC_DOT_UNROLL dot products at once
3050
3674
  // xs - x row stride in bytes
3051
3675
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3242,6 +3866,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
3242
3866
  #endif
3243
3867
  }
3244
3868
 
3869
+ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
3870
+ ggml_float sum = 0.0;
3871
+ for (int i = 0; i < n; ++i) {
3872
+ sum += (ggml_float)x[i];
3873
+ }
3874
+ *s = sum;
3875
+ }
3876
+
3245
3877
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
3246
3878
  #ifndef GGML_USE_ACCELERATE
3247
3879
  float max = -INFINITY;
@@ -3294,12 +3926,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3294
3926
  [GGML_TYPE_Q4_1] = QK4_1,
3295
3927
  [GGML_TYPE_Q4_2] = QK4_2,
3296
3928
  [GGML_TYPE_Q4_3] = QK4_3,
3929
+ [GGML_TYPE_Q5_0] = QK5_0,
3930
+ [GGML_TYPE_Q5_1] = QK5_1,
3297
3931
  [GGML_TYPE_Q8_0] = QK8_0,
3932
+ [GGML_TYPE_Q8_1] = QK8_1,
3298
3933
  [GGML_TYPE_I8] = 1,
3299
3934
  [GGML_TYPE_I16] = 1,
3300
3935
  [GGML_TYPE_I32] = 1,
3301
3936
  };
3302
- static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
3937
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3303
3938
 
3304
3939
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3305
3940
  [GGML_TYPE_F32] = sizeof(float),
@@ -3308,12 +3943,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3308
3943
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
3944
  [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
3945
  [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3946
+ [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
3947
+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3311
3948
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3949
+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3312
3950
  [GGML_TYPE_I8] = sizeof(int8_t),
3313
3951
  [GGML_TYPE_I16] = sizeof(int16_t),
3314
3952
  [GGML_TYPE_I32] = sizeof(int32_t),
3315
3953
  };
3316
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
3954
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3317
3955
 
3318
3956
 
3319
3957
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3323,12 +3961,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3323
3961
  [GGML_TYPE_Q4_1] = "q4_1",
3324
3962
  [GGML_TYPE_Q4_2] = "q4_2",
3325
3963
  [GGML_TYPE_Q4_3] = "q4_3",
3964
+ [GGML_TYPE_Q5_0] = "q5_0",
3965
+ [GGML_TYPE_Q5_1] = "q5_1",
3326
3966
  [GGML_TYPE_Q8_0] = "q8_0",
3967
+ [GGML_TYPE_Q8_1] = "q8_1",
3327
3968
  [GGML_TYPE_I8] = "i8",
3328
3969
  [GGML_TYPE_I16] = "i16",
3329
3970
  [GGML_TYPE_I32] = "i32",
3330
3971
  };
3331
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
3972
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3332
3973
 
3333
3974
  static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
3975
  [GGML_TYPE_F32] = false,
@@ -3337,12 +3978,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3337
3978
  [GGML_TYPE_Q4_1] = true,
3338
3979
  [GGML_TYPE_Q4_2] = true,
3339
3980
  [GGML_TYPE_Q4_3] = true,
3981
+ [GGML_TYPE_Q5_0] = true,
3982
+ [GGML_TYPE_Q5_1] = true,
3340
3983
  [GGML_TYPE_Q8_0] = true,
3984
+ [GGML_TYPE_Q8_1] = true,
3341
3985
  [GGML_TYPE_I8] = false,
3342
3986
  [GGML_TYPE_I16] = false,
3343
3987
  [GGML_TYPE_I32] = false,
3344
3988
  };
3345
- static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
3989
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3346
3990
 
3347
3991
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3348
3992
  "NONE",
@@ -3720,7 +4364,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3720
4364
 
3721
4365
  // initialize cuBLAS
3722
4366
  #if defined(GGML_USE_CUBLAS)
3723
- init_cublas();
4367
+ ggml_init_cublas();
4368
+ #elif defined(GGML_USE_CLBLAST)
4369
+ ggml_cl_init();
3724
4370
  #endif
3725
4371
 
3726
4372
  is_first_call = false;
@@ -6554,6 +7200,9 @@ static void ggml_compute_forward_add(
6554
7200
  case GGML_TYPE_Q4_1:
6555
7201
  case GGML_TYPE_Q4_2:
6556
7202
  case GGML_TYPE_Q4_3:
7203
+ case GGML_TYPE_Q5_0:
7204
+ case GGML_TYPE_Q5_1:
7205
+ case GGML_TYPE_Q8_0:
6557
7206
  {
6558
7207
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6559
7208
  } break;
@@ -6811,15 +7460,20 @@ static void ggml_compute_forward_sum_f32(
6811
7460
  const size_t nb02 = src0->nb[2];
6812
7461
  const size_t nb03 = src0->nb[3];
6813
7462
 
7463
+ ggml_float sum = 0;
7464
+ ggml_float row_sum = 0;
7465
+
6814
7466
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6815
7467
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6816
7468
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6817
- ggml_vec_sum_f32(ne00,
6818
- (float *) (dst->data),
7469
+ ggml_vec_sum_ggf(ne00,
7470
+ &row_sum,
6819
7471
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
7472
+ sum += row_sum;
6820
7473
  }
6821
7474
  }
6822
7475
  }
7476
+ ((float *) dst->data)[0] = sum;
6823
7477
  }
6824
7478
 
6825
7479
  static void ggml_compute_forward_sum(
@@ -7454,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
7454
8108
 
7455
8109
  // ggml_compute_forward_mul_mat
7456
8110
 
7457
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8111
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7458
8112
  // helper function to determine if it is better to use BLAS or not
7459
8113
  // for large matrices, BLAS is faster
7460
8114
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7479,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
7479
8133
 
7480
8134
  return false;
7481
8135
  }
8136
+
7482
8137
  #endif
7483
8138
 
7484
8139
  static void ggml_compute_forward_mul_mat_f32(
@@ -7494,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
7494
8149
  const int64_t ne02 = src0->ne[2];
7495
8150
  const int64_t ne03 = src0->ne[3];
7496
8151
 
7497
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8152
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7498
8153
  const int64_t ne10 = src1->ne[0];
7499
8154
  #endif
7500
8155
  const int64_t ne11 = src1->ne[1];
@@ -7551,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
7551
8206
  // nb01 >= nb00 - src0 is not transposed
7552
8207
  // compute by src0 rows
7553
8208
 
7554
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8209
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7555
8210
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7556
8211
  if (params->ith != 0) {
7557
8212
  return;
@@ -7566,18 +8221,16 @@ static void ggml_compute_forward_mul_mat_f32(
7566
8221
  }
7567
8222
 
7568
8223
  #if defined(GGML_USE_CUBLAS)
7569
- float *d_X = NULL;
7570
- float *d_Y = NULL;
7571
- float *d_D = NULL;
7572
8224
  const float alpha = 1.0f;
7573
8225
  const float beta = 0.0f;
7574
8226
  const int x_ne = ne01 * ne10;
7575
8227
  const int y_ne = ne11 * ne10;
7576
8228
  const int d_ne = ne11 * ne01;
7577
8229
 
7578
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8230
+ size_t x_size, y_size, d_size;
8231
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8232
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8233
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
7581
8234
  #endif
7582
8235
 
7583
8236
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7589,21 +8242,28 @@ static void ggml_compute_forward_mul_mat_f32(
7589
8242
 
7590
8243
  #if defined(GGML_USE_CUBLAS)
7591
8244
  // copy data to device
7592
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8245
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8246
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
7594
8247
 
7595
8248
  // compute
7596
8249
  CUBLAS_CHECK(
7597
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8250
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
8251
  ne01, ne11, ne10,
7599
8252
  &alpha, d_X, ne00,
7600
8253
  d_Y, ne10,
7601
8254
  &beta, d_D, ne01));
7602
8255
 
7603
8256
  // copy data to host
7604
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
- #else
8257
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8258
+ #elif defined(GGML_USE_CLBLAST)
7606
8259
  // zT = y * xT
8260
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8261
+ ne11, ne01, ne10,
8262
+ 1.0f, y, ne10,
8263
+ x, ne10,
8264
+ 0.0f, d, ne01,
8265
+ GGML_TYPE_F32);
8266
+ #else
7607
8267
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7608
8268
  ne11, ne01, ne10,
7609
8269
  1.0f, y, ne10,
@@ -7613,10 +8273,10 @@ static void ggml_compute_forward_mul_mat_f32(
7613
8273
  }
7614
8274
  }
7615
8275
  #if defined(GGML_USE_CUBLAS)
7616
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
- CUDA_CHECK(cudaFree(d_X));
7618
- CUDA_CHECK(cudaFree(d_Y));
7619
- CUDA_CHECK(cudaFree(d_D));
8276
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8277
+ ggml_cuda_pool_free(d_X, x_size);
8278
+ ggml_cuda_pool_free(d_Y, y_size);
8279
+ ggml_cuda_pool_free(d_D, d_size);
7620
8280
  #endif
7621
8281
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7622
8282
 
@@ -7747,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7747
8407
  // nb01 >= nb00 - src0 is not transposed
7748
8408
  // compute by src0 rows
7749
8409
 
7750
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8410
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7751
8411
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7752
8412
  GGML_ASSERT(nb10 == sizeof(float));
7753
8413
 
@@ -7766,18 +8426,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7766
8426
  #if defined(GGML_USE_CUBLAS)
7767
8427
  ggml_fp16_t * const wdata = params->wdata;
7768
8428
 
7769
- float *d_X = NULL;
7770
- float *d_Y = NULL;
7771
- float *d_D = NULL;
7772
8429
  const float alpha = 1.0f;
7773
8430
  const float beta = 0.0f;
7774
8431
  const int x_ne = ne01 * ne10;
7775
8432
  const int y_ne = ne11 * ne10;
7776
8433
  const int d_ne = ne11 * ne01;
7777
8434
 
7778
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8435
+ size_t x_size, y_size, d_size;
8436
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8437
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8438
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
7781
8439
  #else
7782
8440
  float * const wdata = params->wdata;
7783
8441
  #endif
@@ -7811,12 +8469,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7811
8469
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7812
8470
 
7813
8471
  // copy data to device
7814
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8472
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8473
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
7816
8474
 
7817
8475
  // compute
7818
8476
  CUBLAS_CHECK(
7819
- cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8477
+ cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
8478
  ne01, ne11, ne10,
7821
8479
  &alpha, d_X, CUDA_R_16F, ne00,
7822
8480
  d_Y, CUDA_R_16F, ne10,
@@ -7825,7 +8483,20 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7825
8483
  CUBLAS_GEMM_DEFAULT));
7826
8484
 
7827
8485
  // copy data to host
7828
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8486
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8487
+ #elif defined(GGML_USE_CLBLAST)
8488
+ const float * x = wdata;
8489
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8490
+
8491
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8492
+
8493
+ // zT = y * xT
8494
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8495
+ ne11, ne01, ne10,
8496
+ 1.0f, y, ne10,
8497
+ x, ne10,
8498
+ 0.0f, d, ne01,
8499
+ GGML_TYPE_F32);
7829
8500
  #else
7830
8501
  const float * x = wdata;
7831
8502
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7843,10 +8514,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7843
8514
  }
7844
8515
 
7845
8516
  #if defined(GGML_USE_CUBLAS)
7846
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
- CUDA_CHECK(cudaFree(d_X));
7848
- CUDA_CHECK(cudaFree(d_Y));
7849
- CUDA_CHECK(cudaFree(d_D));
8517
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8518
+ ggml_cuda_pool_free(d_X, x_size);
8519
+ ggml_cuda_pool_free(d_Y, y_size);
8520
+ ggml_cuda_pool_free(d_D, d_size);
7850
8521
  #endif
7851
8522
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7852
8523
 
@@ -7980,6 +8651,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7980
8651
  const enum ggml_type type = src0->type;
7981
8652
  quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7982
8653
  vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
8654
+ enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
7983
8655
 
7984
8656
  // we don't support permuted src0 or src1
7985
8657
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7999,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7999
8671
  // nb01 >= nb00 - src0 is not transposed
8000
8672
  // compute by src0 rows
8001
8673
 
8002
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8674
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
8003
8675
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
8004
8676
  if (params->ith != 0) {
8005
8677
  return;
@@ -8014,20 +8686,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
8014
8686
  }
8015
8687
 
8016
8688
  #if defined(GGML_USE_CUBLAS)
8017
- float *d_X = NULL;
8018
- float *d_Y = NULL;
8019
- float *d_D = NULL;
8020
- float *d_Q = NULL;
8021
8689
  const float alpha = 1.0f;
8022
8690
  const float beta = 0.0f;
8023
8691
  const int x_ne = ne01 * ne10;
8024
8692
  const int y_ne = ne11 * ne10;
8025
8693
  const int d_ne = ne11 * ne01;
8026
8694
 
8027
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
- CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8695
+ size_t x_size, y_size, d_size, q_size;
8696
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8697
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8698
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8699
+ float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8031
8700
 
8032
8701
  void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
8702
  if (type == GGML_TYPE_Q4_0) {
@@ -8039,10 +8708,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
8039
8708
  else if (type == GGML_TYPE_Q4_2) {
8040
8709
  dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
8710
  }
8711
+ else if (type == GGML_TYPE_Q4_3) {
8712
+ dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8713
+ }
8714
+ else if (type == GGML_TYPE_Q5_0) {
8715
+ dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8716
+ }
8717
+ else if (type == GGML_TYPE_Q5_1) {
8718
+ dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8719
+ }
8720
+ else if (type == GGML_TYPE_Q8_0) {
8721
+ dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8722
+ }
8042
8723
  else {
8043
8724
  GGML_ASSERT(false);
8044
8725
  }
8045
- #else
8726
+ #elif !defined(GGML_USE_CLBLAST)
8046
8727
  float * const wdata = params->wdata;
8047
8728
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
8729
  #endif
@@ -8057,10 +8738,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
8057
8738
  // copy and dequantize on device
8058
8739
  CUDA_CHECK(
8059
8740
  cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
- GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8741
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8061
8742
 
8062
- dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8743
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8063
8744
  CUDA_CHECK(cudaGetLastError());
8745
+ #elif defined(GGML_USE_CLBLAST)
8746
+ const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
8064
8747
  #else
8065
8748
  {
8066
8749
  size_t id = 0;
@@ -8075,20 +8758,27 @@ static void ggml_compute_forward_mul_mat_q_f32(
8075
8758
 
8076
8759
  #if defined(GGML_USE_CUBLAS)
8077
8760
  // copy data to device
8078
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8761
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8079
8762
 
8080
8763
  // compute
8081
8764
  CUBLAS_CHECK(
8082
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8765
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
8766
  ne01, ne11, ne10,
8084
8767
  &alpha, d_X, ne00,
8085
8768
  d_Y, ne10,
8086
8769
  &beta, d_D, ne01));
8087
8770
 
8088
8771
  // copy data to host
8089
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
- #else
8772
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8773
+ #elif defined(GGML_USE_CLBLAST)
8091
8774
  // zT = y * xT
8775
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8776
+ ne11, ne01, ne10,
8777
+ 1.0f, y, ne10,
8778
+ x, ne10,
8779
+ 0.0f, d, ne01,
8780
+ type);
8781
+ #else
8092
8782
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
8093
8783
  ne11, ne01, ne10,
8094
8784
  1.0f, y, ne10,
@@ -8099,11 +8789,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
8099
8789
  }
8100
8790
 
8101
8791
  #if defined(GGML_USE_CUBLAS)
8102
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
- CUDA_CHECK(cudaFree(d_X));
8104
- CUDA_CHECK(cudaFree(d_Y));
8105
- CUDA_CHECK(cudaFree(d_D));
8106
- CUDA_CHECK(cudaFree(d_Q));
8792
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8793
+ ggml_cuda_pool_free(d_X, x_size);
8794
+ ggml_cuda_pool_free(d_Y, y_size);
8795
+ ggml_cuda_pool_free(d_D, d_size);
8796
+ ggml_cuda_pool_free(d_Q, q_size);
8107
8797
  #endif
8108
8798
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
8109
8799
 
@@ -8113,7 +8803,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8113
8803
 
8114
8804
  if (params->type == GGML_TASK_INIT) {
8115
8805
  char * wdata = params->wdata;
8116
- const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8806
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8117
8807
 
8118
8808
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
8119
8809
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -8144,7 +8834,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8144
8834
  const int ir1 = MIN(ir0 + dr, nr);
8145
8835
 
8146
8836
  void * wdata = params->wdata;
8147
- const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8837
+ const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8148
8838
 
8149
8839
  for (int ir = ir0; ir < ir1; ++ir) {
8150
8840
  // src0 indices
@@ -8194,7 +8884,10 @@ static void ggml_compute_forward_mul_mat(
8194
8884
  case GGML_TYPE_Q4_1:
8195
8885
  case GGML_TYPE_Q4_2:
8196
8886
  case GGML_TYPE_Q4_3:
8887
+ case GGML_TYPE_Q5_0:
8888
+ case GGML_TYPE_Q5_1:
8197
8889
  case GGML_TYPE_Q8_0:
8890
+ case GGML_TYPE_Q8_1:
8198
8891
  {
8199
8892
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
8200
8893
  } break;
@@ -8423,7 +9116,10 @@ static void ggml_compute_forward_get_rows(
8423
9116
  case GGML_TYPE_Q4_1:
8424
9117
  case GGML_TYPE_Q4_2:
8425
9118
  case GGML_TYPE_Q4_3:
9119
+ case GGML_TYPE_Q5_0:
9120
+ case GGML_TYPE_Q5_1:
8426
9121
  case GGML_TYPE_Q8_0:
9122
+ case GGML_TYPE_Q8_1:
8427
9123
  {
8428
9124
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
8429
9125
  } break;
@@ -8561,6 +9257,7 @@ static void ggml_compute_forward_soft_max_f32(
8561
9257
 
8562
9258
  uint16_t scvt;
8563
9259
  for (int i = 0; i < nc; i++) {
9260
+ //printf("p[%3d] = %8.4f\n", i, p[i]);
8564
9261
  if (p[i] == -INFINITY) {
8565
9262
  p[i] = 0.0f;
8566
9263
  } else {
@@ -10921,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10921
11618
  size_t cur = 0;
10922
11619
 
10923
11620
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10924
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11621
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10925
11622
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10926
11623
  node->n_tasks = 1; // TODO: this actually is doing nothing
10927
11624
  // the threads are still spinning
@@ -10938,14 +11635,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10938
11635
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10939
11636
  cur = 0;
10940
11637
  } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11638
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10942
11639
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10943
11640
  node->n_tasks = 1;
10944
11641
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10945
11642
  } else
10946
11643
  #endif
10947
11644
  {
10948
- cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
11645
+ const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
11646
+ cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
10949
11647
  }
10950
11648
  } else {
10951
11649
  GGML_ASSERT(false);
@@ -11273,9 +11971,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11273
11971
  for (int i = 0; i < cgraph->n_nodes; i++) {
11274
11972
  struct ggml_tensor * node = cgraph->nodes[i];
11275
11973
 
11276
- perf_total_per_op_us[node->op] += node->perf_time_us;
11974
+ perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
11277
11975
 
11278
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11976
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11279
11977
  i,
11280
11978
  node->ne[0], node->ne[1], node->ne[2],
11281
11979
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -11289,13 +11987,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11289
11987
  for (int i = 0; i < cgraph->n_leafs; i++) {
11290
11988
  struct ggml_tensor * node = cgraph->leafs[i];
11291
11989
 
11292
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
11990
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
11293
11991
  i,
11294
11992
  node->ne[0], node->ne[1],
11295
11993
  GGML_OP_LABEL[node->op]);
11296
11994
  }
11297
11995
 
11298
11996
  for (int i = 0; i < GGML_OP_COUNT; i++) {
11997
+ if (perf_total_per_op_us[i] == 0) {
11998
+ continue;
11999
+ }
12000
+
11299
12001
  GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
11300
12002
  }
11301
12003
 
@@ -12129,7 +12831,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
12129
12831
 
12130
12832
  for (int i = 0; i < nb; i++) {
12131
12833
  for (int l = 0; l < QK4_0; l += 2) {
12132
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12834
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12133
12835
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12134
12836
 
12135
12837
  hist[vi0]++;
@@ -12152,7 +12854,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
12152
12854
 
12153
12855
  for (int i = 0; i < nb; i++) {
12154
12856
  for (int l = 0; l < QK4_1; l += 2) {
12155
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12857
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12156
12858
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12157
12859
 
12158
12860
  hist[vi0]++;
@@ -12171,12 +12873,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12171
12873
  for (int j = 0; j < n; j += k) {
12172
12874
  block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
12875
 
12174
- //quantize_row_q4_2_reference(src + j, y, k);
12175
- quantize_row_q4_2_rmse(src + j, y, k);
12876
+ quantize_row_q4_2_reference(src + j, y, k);
12176
12877
 
12177
12878
  for (int i = 0; i < nb; i++) {
12178
12879
  for (int l = 0; l < QK4_2; l += 2) {
12179
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12880
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12180
12881
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
12882
 
12182
12883
  hist[vi0]++;
@@ -12199,7 +12900,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
12199
12900
 
12200
12901
  for (int i = 0; i < nb; i++) {
12201
12902
  for (int l = 0; l < QK4_3; l += 2) {
12202
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12903
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12203
12904
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12204
12905
 
12205
12906
  hist[vi0]++;
@@ -12211,6 +12912,87 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
12211
12912
  return (n/QK4_3*sizeof(block_q4_3));
12212
12913
  }
12213
12914
 
12915
+ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12916
+ assert(k % QK5_0 == 0);
12917
+ const int nb = k / QK5_0;
12918
+
12919
+ for (int j = 0; j < n; j += k) {
12920
+ block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
12921
+
12922
+ quantize_row_q5_0_reference(src + j, y, k);
12923
+
12924
+ for (int i = 0; i < nb; i++) {
12925
+ uint32_t qh;
12926
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12927
+
12928
+ for (int l = 0; l < QK5_0; l += 2) {
12929
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12930
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12931
+
12932
+ // cast to 16 bins
12933
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12934
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12935
+
12936
+ hist[vi0]++;
12937
+ hist[vi1]++;
12938
+ }
12939
+ }
12940
+ }
12941
+
12942
+ return (n/QK5_0*sizeof(block_q5_0));
12943
+ }
12944
+
12945
+ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
12946
+ assert(k % QK5_1 == 0);
12947
+ const int nb = k / QK5_1;
12948
+
12949
+ for (int j = 0; j < n; j += k) {
12950
+ block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
12951
+
12952
+ quantize_row_q5_1_reference(src + j, y, k);
12953
+
12954
+ for (int i = 0; i < nb; i++) {
12955
+ uint32_t qh;
12956
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12957
+
12958
+ for (int l = 0; l < QK5_1; l += 2) {
12959
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12960
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12961
+
12962
+ // cast to 16 bins
12963
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12964
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12965
+
12966
+ hist[vi0]++;
12967
+ hist[vi1]++;
12968
+ }
12969
+ }
12970
+ }
12971
+
12972
+ return (n/QK5_1*sizeof(block_q5_1));
12973
+ }
12974
+
12975
+ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12976
+ assert(k % QK8_0 == 0);
12977
+ const int nb = k / QK8_0;
12978
+
12979
+ for (int j = 0; j < n; j += k) {
12980
+ block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
12981
+
12982
+ quantize_row_q8_0_reference(src + j, y, k);
12983
+
12984
+ for (int i = 0; i < nb; i++) {
12985
+ for (int l = 0; l < QK8_0; ++l) {
12986
+ const int8_t vi = y[i].qs[l];
12987
+
12988
+ hist[vi/16 + 8]++;
12989
+ }
12990
+ }
12991
+ }
12992
+
12993
+ return (n/QK8_0*sizeof(block_q8_0));
12994
+ }
12995
+
12214
12996
  size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12215
12997
  size_t result = 0;
12216
12998
  switch (type) {
@@ -12238,6 +13020,24 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
12238
13020
  block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
13021
  result = ggml_quantize_q4_3(src + start, block, n, n, hist);
12240
13022
  } break;
13023
+ case GGML_TYPE_Q5_0:
13024
+ {
13025
+ GGML_ASSERT(start % QK5_0 == 0);
13026
+ block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
13027
+ result = ggml_quantize_q5_0(src + start, block, n, n, hist);
13028
+ } break;
13029
+ case GGML_TYPE_Q5_1:
13030
+ {
13031
+ GGML_ASSERT(start % QK5_1 == 0);
13032
+ block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
13033
+ result = ggml_quantize_q5_1(src + start, block, n, n, hist);
13034
+ } break;
13035
+ case GGML_TYPE_Q8_0:
13036
+ {
13037
+ GGML_ASSERT(start % QK8_0 == 0);
13038
+ block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
13039
+ result = ggml_quantize_q8_0(src + start, block, n, n, hist);
13040
+ } break;
12241
13041
  default:
12242
13042
  assert(false);
12243
13043
  }
@@ -12335,7 +13135,7 @@ int ggml_cpu_has_wasm_simd(void) {
12335
13135
  }
12336
13136
 
12337
13137
  int ggml_cpu_has_blas(void) {
12338
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
13138
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
12339
13139
  return 1;
12340
13140
  #else
12341
13141
  return 0;
@@ -12350,6 +13150,18 @@ int ggml_cpu_has_cublas(void) {
12350
13150
  #endif
12351
13151
  }
12352
13152
 
13153
+ int ggml_cpu_has_clblast(void) {
13154
+ #if defined(GGML_USE_CLBLAST)
13155
+ return 1;
13156
+ #else
13157
+ return 0;
13158
+ #endif
13159
+ }
13160
+
13161
+ int ggml_cpu_has_gpublas(void) {
13162
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
13163
+ }
13164
+
12353
13165
  int ggml_cpu_has_sse3(void) {
12354
13166
  #if defined(__SSE3__)
12355
13167
  return 1;