llama_cpp 0.0.6 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -148,44 +148,9 @@ inline static void* ggml_aligned_malloc(size_t size) {
148
148
  #elif defined(GGML_USE_OPENBLAS)
149
149
  #include <cblas.h>
150
150
  #elif defined(GGML_USE_CUBLAS)
151
- #include <cublas_v2.h>
152
- #include <cuda_runtime.h>
153
151
  #include "ggml-cuda.h"
154
-
155
- #define CUDA_CHECK(err) \
156
- do { \
157
- cudaError_t err_ = (err); \
158
- if (err_ != cudaSuccess) { \
159
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
- cudaGetErrorString(err_)); \
161
- exit(1); \
162
- } \
163
- } while (0)
164
-
165
- #define CUBLAS_CHECK(err) \
166
- do { \
167
- cublasStatus_t err_ = (err); \
168
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
- exit(1); \
171
- } \
172
- } while (0)
173
-
174
- static cublasHandle_t cublasH = NULL;
175
- static cudaStream_t cudaStream = NULL;
176
- static void init_cublas(void) {
177
- if (cublasH == NULL) {
178
- // create cublas handle, bind a stream
179
- CUBLAS_CHECK(cublasCreate(&cublasH));
180
-
181
- CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
-
183
- CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
-
185
- // configure logging to stdout
186
- // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
- }
188
- }
152
+ #elif defined(GGML_USE_CLBLAST)
153
+ #include "ggml-opencl.h"
189
154
  #endif
190
155
 
191
156
  #undef MIN
@@ -365,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
365
330
  // precomputed f32 table for f16 (256 KB)
366
331
  static float table_f32_f16[1 << 16];
367
332
 
333
+ #if defined(__ARM_NEON)
334
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
335
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
336
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
337
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
338
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
339
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
340
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
341
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
342
+
343
+ // precomputed tables for expanding 8bits to 8 bytes (shl 4)
344
+ static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
345
+ #endif
346
+
368
347
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
369
348
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
370
349
  // This is also true for POWER9.
@@ -473,7 +452,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
473
452
  static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
453
  {
475
454
  // Load 8 bytes from memory
476
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
455
+ __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
477
456
 
478
457
  // Expand bytes into uint16_t values
479
458
  __m128i bytes = _mm_cvtepu8_epi16( tmp );
@@ -487,7 +466,46 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487
466
  return bytes;
488
467
  }
489
468
 
469
+ // horizontally add 8 floats
470
+ static inline float hsum_float_8(const __m256 x) {
471
+ __m128 res = _mm256_extractf128_ps(x, 1);
472
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
473
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
474
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
475
+ return _mm_cvtss_f32(res);
476
+ }
477
+
478
+ // horizontally add 8 int32_t
479
+ static inline int hsum_i32_8(const __m256i a) {
480
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
481
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
482
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
483
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
484
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
485
+ }
486
+
487
+ // horizontally add 4 int32_t
488
+ static inline int hsum_i32_4(const __m128i a) {
489
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
490
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
491
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
492
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
493
+ }
494
+
490
495
  #if __AVX2__ || __AVX512F__
496
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
497
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
498
+ uint32_t x32;
499
+ memcpy(&x32, x, sizeof(uint32_t));
500
+ const __m256i shuf_mask = _mm256_set_epi64x(
501
+ 0x0303030303030303, 0x0202020202020202,
502
+ 0x0101010101010101, 0x0000000000000000);
503
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
504
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
505
+ bytes = _mm256_or_si256(bytes, bit_mask);
506
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
507
+ }
508
+
491
509
  // Unpack 32 4-bit fields into 32 bytes
492
510
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
493
511
  static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
@@ -507,9 +525,38 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507
525
  return bytes;
508
526
  }
509
527
 
528
+ // add int16_t pairwise and return as float vector
529
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
530
+ const __m256i ones = _mm256_set1_epi16(1);
531
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
532
+ return _mm256_cvtepi32_ps(summed_pairs);
533
+ }
534
+
535
+ // multiply int8_t, add results pairwise twice and return as float vector
536
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
537
+ // Get absolute values of x vectors
538
+ const __m256i ax = _mm256_sign_epi8(x, x);
539
+ // Sign the values of the y vectors
540
+ const __m256i sy = _mm256_sign_epi8(y, x);
541
+ #if __AVXVNNI__
542
+ const __m256i zero = _mm256_setzero_si256();
543
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
544
+ return _mm256_cvtepi32_ps(summed_pairs);
545
+ #else
546
+ // Perform multiplication and create 16-bit values
547
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
548
+ return sum_i16_pairs_float(dot);
549
+ #endif
550
+ }
551
+
510
552
  static inline __m128i packNibbles( __m256i bytes )
511
553
  {
512
554
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
555
+ #if __AVX512F__
556
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
557
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
558
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
559
+ #else
513
560
  const __m256i lowByte = _mm256_set1_epi16( 0xFF );
514
561
  __m256i high = _mm256_andnot_si256( lowByte, bytes );
515
562
  __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -520,6 +567,7 @@ static inline __m128i packNibbles( __m256i bytes )
520
567
  __m128i r0 = _mm256_castsi256_si128( bytes );
521
568
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
522
569
  return _mm_packus_epi16( r0, r1 );
570
+ #endif
523
571
  }
524
572
  #else
525
573
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
@@ -654,6 +702,23 @@ typedef struct {
654
702
  } block_q4_3;
655
703
  static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
656
704
 
705
+ #define QK5_0 32
706
+ typedef struct {
707
+ ggml_fp16_t d; // delta
708
+ uint8_t qh[4]; // 5-th bit of quants
709
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
710
+ } block_q5_0;
711
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
712
+
713
+ #define QK5_1 32
714
+ typedef struct {
715
+ ggml_fp16_t d; // delta
716
+ ggml_fp16_t m; // min
717
+ uint8_t qh[4]; // 5-th bit of quants
718
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
719
+ } block_q5_1;
720
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
721
+
657
722
  #define QK8_0 32
658
723
  typedef struct {
659
724
  float d; // delta
@@ -661,6 +726,14 @@ typedef struct {
661
726
  } block_q8_0;
662
727
  static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663
728
 
729
+ #define QK8_1 32
730
+ typedef struct {
731
+ float d; // delta
732
+ float s0; // d * sum(qs[i]) low
733
+ float s1; // d * sum(qs[i]) high
734
+ int8_t qs[QK8_1]; // quants
735
+ } block_q8_1;
736
+ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
664
737
 
665
738
  // reference implementation for deterministic creation of model files
666
739
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -671,13 +744,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
671
744
 
672
745
  for (int i = 0; i < nb; i++) {
673
746
  float amax = 0.0f; // absolute max
747
+ float max = 0.0f;
674
748
 
675
749
  for (int l = 0; l < QK4_0; l++) {
676
750
  const float v = x[i*QK4_0 + l];
677
- amax = MAX(amax, fabsf(v));
751
+ if (amax < fabsf(v)) {
752
+ amax = fabsf(v);
753
+ max = v;
754
+ }
678
755
  }
679
756
 
680
- const float d = amax / ((1 << 3) - 1);
757
+ const float d = max / -8;
681
758
  const float id = d ? 1.0f/d : 0.0f;
682
759
 
683
760
  y[i].d = d;
@@ -686,8 +763,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
686
763
  const float v0 = x[i*QK4_0 + l + 0]*id;
687
764
  const float v1 = x[i*QK4_0 + l + 1]*id;
688
765
 
689
- const uint8_t vi0 = (int8_t)roundf(v0) + 8;
690
- const uint8_t vi1 = (int8_t)roundf(v1) + 8;
766
+ const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
767
+ const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
691
768
 
692
769
  assert(vi0 < 16);
693
770
  assert(vi1 < 16);
@@ -707,28 +784,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
707
784
 
708
785
  #if defined(__POWER9_VECTOR__)
709
786
  const vector float v85 = vec_splats(8.5f);
787
+ const vector signed int v15 = vec_splats(15);
710
788
  for (int i = 0; i < nb; i++) {
711
- float amax = 0.0f; // absolute max
789
+ float max = 0.0f;
790
+ float min = 0.0f;
712
791
 
713
792
  vector float srcv [8];
714
- vector float asrcv[8];
715
- vector float amaxv[8];
793
+ vector float maxv[8];
794
+ vector float minv[8];
716
795
 
717
796
  for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
718
- for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
719
-
720
- for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
721
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
722
- amaxv[0] = vec_max(amaxv[0], amaxv[2]);
723
- amaxv[4] = vec_max(amaxv[4], amaxv[6]);
724
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
725
- amaxv[0] = vec_max(amaxv[0], amaxv[4]);
726
-
727
- amax = MAX(
728
- MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
729
- MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
730
-
731
- const float d = amax / ((1 << 3) - 1);
797
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
798
+
799
+ for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
800
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
801
+ maxv[0] = vec_max(maxv[0], maxv[2]);
802
+ maxv[4] = vec_max(maxv[4], maxv[6]);
803
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
804
+ maxv[0] = vec_max(maxv[0], maxv[4]);
805
+
806
+ for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
807
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
808
+ minv[0] = vec_min(minv[0], minv[2]);
809
+ minv[4] = vec_min(minv[4], minv[6]);
810
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
811
+ minv[0] = vec_min(minv[0], minv[4]);
812
+
813
+
814
+ max = MAX(
815
+ MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
816
+ MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
817
+ min = MIN(
818
+ MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
819
+ MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
820
+
821
+ const float magnitude = max >= fabsf(min) ? max : min;
822
+ const float d = magnitude / -8;
732
823
  const float id = d ? 1.0/d : 0.0;
733
824
 
734
825
  y[i].d = d;
@@ -738,27 +829,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
738
829
  for (int l = 0; l < 8; l++) {
739
830
  const vector float vf = vec_madd(srcv[l], vid, v85);
740
831
  const vector signed int vi = vec_signed(vf);
832
+ const vector signed int vc = vec_min(vi, v15);
741
833
 
742
- pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
743
- pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
834
+ pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
835
+ pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
744
836
  }
745
837
  }
746
838
  #elif __ARM_NEON
747
839
  for (int i = 0; i < nb; i++) {
748
840
  float32x4_t srcv [8];
749
- float32x4_t asrcv[8];
750
- float32x4_t amaxv[8];
841
+ float32x4_t maxv[8];
842
+ float32x4_t minv[8];
751
843
 
752
844
  for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
753
- for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
754
845
 
755
- for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
756
- for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
757
- for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
846
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
847
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
848
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
758
849
 
759
- const float amax = vmaxvq_f32(amaxv[0]);
850
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
851
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
852
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
853
+
854
+ const float max = vmaxvq_f32(maxv[0]);
855
+ const float min = vminvq_f32(minv[0]);
760
856
 
761
- const float d = amax / ((1 << 3) - 1);
857
+ const float magnitude = max >= fabsf(min) ? max : min;
858
+ const float d = magnitude / -8;
762
859
  const float id = d ? 1.0f/d : 0.0f;
763
860
 
764
861
  y[i].d = d;
@@ -767,9 +864,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
767
864
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
768
865
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
769
866
  const int32x4_t vi = vcvtq_s32_f32(vf);
867
+ const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
770
868
 
771
- y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
772
- y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
869
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
870
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
773
871
  }
774
872
  }
775
873
  #elif defined(__AVX2__)
@@ -781,22 +879,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
781
879
  __m256 v3 = _mm256_loadu_ps( x + 24 );
782
880
  x += 32;
783
881
 
784
- // Compute max(abs(e)) for the block
785
- const __m256 signBit = _mm256_set1_ps( -0.0f );
786
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
787
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
788
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
789
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
882
+ // Compute max for the block
883
+ __m256 max = _mm256_max_ps( v0, v1 );
884
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
885
+ max = _mm256_max_ps( max, maxTmp );
790
886
 
791
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
887
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
792
888
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
793
889
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
794
890
  const float maxScalar = _mm_cvtss_f32( max4 );
795
891
 
892
+ // Compute min for the block
893
+ __m256 min = _mm256_min_ps( v0, v1 );
894
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
895
+ min = _mm256_min_ps( min, minTmp );
896
+
897
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
898
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
899
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
900
+ const float minScalar = _mm_cvtss_f32( min4 );
901
+
796
902
  // Quantize these floats
797
- const float d = maxScalar / 7.0f;
903
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
904
+ const float d = magnitude / -8.0f;
798
905
  y[i].d = d;
799
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
906
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
800
907
  const __m256 mul = _mm256_set1_ps( id );
801
908
 
802
909
  // Apply the multiplier
@@ -829,9 +936,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
829
936
  const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
830
937
  i0 = _mm256_permutevar8x32_epi32( i0, perm );
831
938
 
832
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
939
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
833
940
  const __m256i off = _mm256_set1_epi8( 8 );
834
941
  i0 = _mm256_add_epi8( i0, off );
942
+ const __m256i maxNibble = _mm256_set1_epi8( 15 );
943
+ i0 = _mm256_min_epi8( i0, maxNibble );
835
944
 
836
945
  // Compress the vector into 4 bit/value, and store
837
946
  __m128i res = packNibbles( i0 );
@@ -846,22 +955,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
846
955
  __m256 v3 = _mm256_loadu_ps( x + 24 );
847
956
  x += 32;
848
957
 
849
- // Compute max(abs(e)) for the block
850
- const __m256 signBit = _mm256_set1_ps( -0.0f );
851
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
852
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
853
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
854
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
958
+ // Compute max for the block
959
+ __m256 max = _mm256_max_ps( v0, v1 );
960
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
961
+ max = _mm256_max_ps( max, maxTmp );
855
962
 
856
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
963
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
857
964
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
858
965
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
859
966
  const float maxScalar = _mm_cvtss_f32( max4 );
860
967
 
968
+ // Compute min for the block
969
+ __m256 min = _mm256_min_ps( v0, v1 );
970
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
971
+ min = _mm256_min_ps( min, minTmp );
972
+
973
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
974
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
975
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
976
+ const float minScalar = _mm_cvtss_f32( min4 );
977
+
861
978
  // Quantize these floats
862
- const float d = maxScalar / 7.0f;
979
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
980
+ const float d = magnitude / -8.0f;
863
981
  y[i].d = d;
864
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
982
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
865
983
  const __m256 mul = _mm256_set1_ps( id );
866
984
 
867
985
  // Apply the multiplier
@@ -902,10 +1020,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
902
1020
  ni0 = _mm_packs_epi16( ni0, ni2 );
903
1021
  ni4 = _mm_packs_epi16( ni4, ni6 );
904
1022
 
905
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
906
- const __m128i off = _mm_set1_epi8( 8);
1023
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
1024
+ const __m128i off = _mm_set1_epi8( 8 );
907
1025
  ni0 = _mm_add_epi8( ni0, off );
908
1026
  ni4 = _mm_add_epi8( ni4, off );
1027
+ const __m128i maxNibble = _mm_set1_epi8( 15 );
1028
+ ni0 = _mm_min_epi8( ni0, maxNibble );
1029
+ ni4 = _mm_min_epi8( ni4, maxNibble );
909
1030
 
910
1031
  // Compress the vector into 4 bit/value, and store
911
1032
  __m128i res = packNibbles( ni0, ni4 );
@@ -913,24 +1034,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
913
1034
  }
914
1035
  #elif defined(__wasm_simd128__)
915
1036
  for (int i = 0; i < nb; i++) {
916
- float amax = 0.0f; // absolute max
1037
+ float max = 0.0f;
1038
+ float min = 0.0f;
917
1039
 
918
1040
  v128_t srcv [8];
919
- v128_t asrcv[8];
920
- v128_t amaxv[8];
1041
+ v128_t maxv[8];
1042
+ v128_t minv[8];
921
1043
 
922
1044
  for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
923
- for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
924
1045
 
925
- for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
926
- for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
927
- for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
1046
+ for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
1047
+ for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
1048
+ for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
1049
+
1050
+ for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
1051
+ for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
1052
+ for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
928
1053
 
929
- amax = MAX(
930
- MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
931
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
1054
+ max = MAX(
1055
+ MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
1056
+ MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
1057
+ min = MIN(
1058
+ MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
1059
+ MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
932
1060
 
933
- const float d = amax / ((1 << 3) - 1);
1061
+ const float magnitude = max >= fabsf(min) ? max : min;
1062
+ const float d = magnitude / -8;
934
1063
  const float id = d ? 1.0/d : 0.0;
935
1064
 
936
1065
  y[i].d = d;
@@ -939,9 +1068,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
939
1068
  const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
940
1069
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
941
1070
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1071
+ const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
942
1072
 
943
- y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
944
- y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
1073
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
1074
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
945
1075
  }
946
1076
  }
947
1077
  #else
@@ -1122,13 +1252,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1122
1252
 
1123
1253
  for (int i = 0; i < nb; i++) {
1124
1254
  float amax = 0.0f; // absolute max
1255
+ float max = 0.0f;
1125
1256
 
1126
1257
  for (int l = 0; l < QK4_2; l++) {
1127
1258
  const float v = x[i*QK4_2 + l];
1128
- amax = MAX(amax, fabsf(v));
1259
+ if (amax < fabsf(v)) {
1260
+ amax = fabsf(v);
1261
+ max = v;
1262
+ }
1129
1263
  }
1130
1264
 
1131
- const float d = amax / ((1 << 3) - 1);
1265
+ const float d = max / -8;
1132
1266
 
1133
1267
  const float id = d ? 1.0f/d : 0.0f;
1134
1268
 
@@ -1138,93 +1272,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1138
1272
  const float v0 = x[i*QK4_2 + l + 0]*id;
1139
1273
  const float v1 = x[i*QK4_2 + l + 1]*id;
1140
1274
 
1141
- const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
- const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1143
-
1144
- assert(vi0 < 16);
1145
- assert(vi1 < 16);
1146
-
1147
- y[i].qs[l/2] = vi0 | (vi1 << 4);
1148
- }
1149
- }
1150
- }
1151
-
1152
- static inline int nearest_int(float fval) {
1153
- assert(fval <= 4194303.f);
1154
- float val = fval + 12582912.f;
1155
- int i; memcpy(&i, &val, sizeof(int));
1156
- return (i & 0x007fffff) - 0x00400000;
1157
- }
1158
-
1159
- static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
- const float * restrict candidates, int8_t * restrict L) {
1161
- assert (nmin >= INT8_MIN);
1162
- assert (nmax <= INT8_MAX);
1163
- float amax = 0;
1164
- for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
- if (!amax) { // all zero
1166
- for (int i=0; i<n; ++i) L[i] = 0;
1167
- return 1.f;
1168
- }
1169
- float best = 0, bestScale = 0;
1170
- for (int si=0; si<nCandidates; ++si) {
1171
- float iscale = candidates[si]/amax;
1172
- float sumlxP = 0; int suml2P = 0;
1173
- float sumlxM = 0; int suml2M = 0;
1174
- for (int i=0; i<n; ++i) {
1175
- int l = nearest_int(iscale*X[i]);
1176
- int lp = MAX(nmin, MIN(nmax, +l));
1177
- int lm = MAX(nmin, MIN(nmax, -l));
1178
- sumlxP += X[i]*lp; suml2P += lp*lp;
1179
- sumlxM += X[i]*lm; suml2M += lm*lm;
1180
- }
1181
- float sumlxP2 = sumlxP*sumlxP;
1182
- float sumlxM2 = sumlxM*sumlxM;
1183
- if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
- if (sumlxP2 > best*suml2P) {
1185
- best = sumlxP2/suml2P; bestScale = iscale;
1186
- }
1187
- } else {
1188
- if (sumlxM2 > best*suml2M) {
1189
- best = sumlxM2/suml2M; bestScale = -iscale;
1190
- }
1191
- }
1192
- }
1193
- float sumlx = 0; int suml2 = 0;
1194
- for (int i=0; i<n; ++i) {
1195
- int l = nearest_int(bestScale*X[i]);
1196
- l = MAX(nmin, MIN(nmax, l));
1197
- sumlx += X[i]*l; suml2 += l*l;
1198
- L[i] = l;
1199
- }
1200
- float scale = sumlx/suml2;
1201
- return scale;
1202
- }
1203
-
1204
- static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
- #define CANDIDATE_COUNT 8
1206
- static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
- assert(k % QK4_2 == 0);
1208
-
1209
- int8_t L[QK4_2];
1210
-
1211
- const int nb = k / QK4_2;
1212
-
1213
- for (int i = 0; i < nb; i++) {
1214
- float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
- y[i].d = GGML_FP32_TO_FP16(scale);
1216
-
1217
- for (int l = 0; l < QK4_2; l += 2) {
1218
- const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
- const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1275
+ const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
1276
+ const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
1220
1277
 
1221
1278
  assert(vi0 < 16);
1222
1279
  assert(vi1 < 16);
1223
1280
 
1224
1281
  y[i].qs[l/2] = vi0 | (vi1 << 4);
1225
1282
  }
1226
-
1227
- x += QK4_2;
1228
1283
  }
1229
1284
  }
1230
1285
 
@@ -1233,9 +1288,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
1233
1288
 
1234
1289
  block_q4_2 * restrict y = vy;
1235
1290
 
1236
- //quantize_row_q4_2_reference(x, y, k);
1237
- // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
- quantize_row_q4_2_rmse(x, y, k);
1291
+ quantize_row_q4_2_reference(x, y, k);
1239
1292
  }
1240
1293
 
1241
1294
  static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
@@ -1281,6 +1334,103 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int
1281
1334
  quantize_row_q4_3_reference(x, y, k);
1282
1335
  }
1283
1336
 
1337
+ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
1338
+ assert(k % QK5_0 == 0);
1339
+ const int nb = k / QK5_0;
1340
+
1341
+ for (int i = 0; i < nb; i++) {
1342
+ float amax = 0.0f; // absolute max
1343
+ float max = 0.0f;
1344
+
1345
+ for (int l = 0; l < QK5_0; l++) {
1346
+ const float v = x[i*QK5_0 + l];
1347
+ if (amax < fabsf(v)) {
1348
+ amax = fabsf(v);
1349
+ max = v;
1350
+ }
1351
+ }
1352
+
1353
+ const float d = max / -16;
1354
+ const float id = d ? 1.0f/d : 0.0f;
1355
+
1356
+ y[i].d = GGML_FP32_TO_FP16(d);
1357
+
1358
+ uint32_t qh = 0;
1359
+
1360
+ for (int l = 0; l < QK5_0; l += 2) {
1361
+ const float v0 = x[i*QK5_0 + l + 0]*id;
1362
+ const float v1 = x[i*QK5_0 + l + 1]*id;
1363
+
1364
+ const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1365
+ const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1366
+
1367
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1368
+
1369
+ // get the 5-th bit and store it in qh at the right position
1370
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1371
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1372
+ }
1373
+
1374
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1375
+ }
1376
+ }
1377
+
1378
+ static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1379
+ assert(k % QK5_0 == 0);
1380
+
1381
+ block_q5_0 * restrict y = vy;
1382
+
1383
+ quantize_row_q5_0_reference(x, y, k);
1384
+ }
1385
+
1386
+ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1387
+ assert(k % QK5_1 == 0);
1388
+ const int nb = k / QK5_1;
1389
+
1390
+ for (int i = 0; i < nb; i++) {
1391
+ float min = FLT_MAX;
1392
+ float max = -FLT_MAX;
1393
+
1394
+ for (int l = 0; l < QK5_1; l++) {
1395
+ const float v = x[i*QK5_1 + l];
1396
+ if (v < min) min = v;
1397
+ if (v > max) max = v;
1398
+ }
1399
+
1400
+ const float d = (max - min) / ((1 << 5) - 1);
1401
+ const float id = d ? 1.0f/d : 0.0f;
1402
+
1403
+ y[i].d = GGML_FP32_TO_FP16(d);
1404
+ y[i].m = GGML_FP32_TO_FP16(min);
1405
+
1406
+ uint32_t qh = 0;
1407
+
1408
+ for (int l = 0; l < QK5_1; l += 2) {
1409
+ const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1410
+ const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1411
+
1412
+ const uint32_t vi0 = (int) (v0 + 0.5f);
1413
+ const uint32_t vi1 = (int) (v1 + 0.5f);
1414
+
1415
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1416
+
1417
+ // get the 5-th bit and store it in qh at the right position
1418
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1419
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1420
+ }
1421
+
1422
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1423
+ }
1424
+ }
1425
+
1426
+ static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1427
+ assert(k % QK5_1 == 0);
1428
+
1429
+ block_q5_1 * restrict y = vy;
1430
+
1431
+ quantize_row_q5_1_reference(x, y, k);
1432
+ }
1433
+
1284
1434
  // reference implementation for deterministic creation of model files
1285
1435
  static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1286
1436
  assert(k % QK8_0 == 0);
@@ -1300,18 +1450,64 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1300
1450
  y[i].d = d;
1301
1451
 
1302
1452
  for (int l = 0; l < QK8_0; ++l) {
1303
- const float v = x[i*QK8_0 + l]*id;
1304
- y[i].qs[l] = roundf(v);
1453
+ const float v0 = x[i*QK8_0 + l]*id;
1454
+
1455
+ y[i].qs[l] = roundf(v0);
1305
1456
  }
1306
1457
  }
1307
1458
  }
1308
1459
 
1309
1460
  static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1310
1461
  assert(k % QK8_0 == 0);
1311
- const int nb = k / QK8_0;
1312
1462
 
1313
1463
  block_q8_0 * restrict y = vy;
1314
1464
 
1465
+ quantize_row_q8_0_reference(x, y, k);
1466
+ }
1467
+
1468
+ // reference implementation for deterministic creation of model files
1469
+ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1470
+ assert(k % QK8_1 == 0);
1471
+ const int nb = k / QK8_1;
1472
+
1473
+ for (int i = 0; i < nb; i++) {
1474
+ float amax = 0.0f; // absolute max
1475
+
1476
+ for (int l = 0; l < QK8_1; l++) {
1477
+ const float v = x[i*QK8_1 + l];
1478
+ amax = MAX(amax, fabsf(v));
1479
+ }
1480
+
1481
+ const float d = amax / ((1 << 7) - 1);
1482
+ const float id = d ? 1.0f/d : 0.0f;
1483
+
1484
+ y[i].d = d;
1485
+
1486
+ int sum0 = 0;
1487
+ int sum1 = 0;
1488
+
1489
+ for (int l = 0; l < QK8_1/2; ++l) {
1490
+ const float v0 = x[i*QK8_1 + l]*id;
1491
+ const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
1492
+
1493
+ y[i].qs[ l] = roundf(v0);
1494
+ y[i].qs[QK8_1/2 + l] = roundf(v1);
1495
+
1496
+ sum0 += y[i].qs[ l];
1497
+ sum1 += y[i].qs[QK8_1/2 + l];
1498
+ }
1499
+
1500
+ y[i].s0 = d * sum0;
1501
+ y[i].s1 = d * sum1;
1502
+ }
1503
+ }
1504
+
1505
+ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1506
+ assert(k % QK8_1 == 0);
1507
+ const int nb = k / QK8_1;
1508
+
1509
+ block_q8_1 * restrict y = vy;
1510
+
1315
1511
  #if defined(__ARM_NEON)
1316
1512
  for (int i = 0; i < nb; i++) {
1317
1513
  float32x4_t srcv [8];
@@ -1332,7 +1528,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1332
1528
 
1333
1529
  y[i].d = d;
1334
1530
 
1335
- for (int l = 0; l < 8; l++) {
1531
+ int32x4_t accv0 = vdupq_n_s32(0);
1532
+ int32x4_t accv1 = vdupq_n_s32(0);
1533
+
1534
+ // low half
1535
+ for (int l = 0; l < 4; l++) {
1536
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1537
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1538
+
1539
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1540
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1541
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1542
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1543
+
1544
+ accv0 = vaddq_s32(accv0, vi);
1545
+ }
1546
+
1547
+ // high half
1548
+ for (int l = 4; l < 8; l++) {
1336
1549
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
1337
1550
  const int32x4_t vi = vcvtnq_s32_f32(v);
1338
1551
 
@@ -1340,7 +1553,15 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1340
1553
  y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1341
1554
  y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1342
1555
  y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1556
+
1557
+ accv1 = vaddq_s32(accv1, vi);
1343
1558
  }
1559
+
1560
+ const int32_t sum0 = vaddvq_s32(accv0);
1561
+ const int32_t sum1 = vaddvq_s32(accv1);
1562
+
1563
+ y[i].s0 = d * sum0;
1564
+ y[i].s1 = d * sum1;
1344
1565
  }
1345
1566
  #elif defined(__AVX2__) || defined(__AVX__)
1346
1567
  for (int i = 0; i < nb; i++) {
@@ -1388,6 +1609,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1388
1609
  __m256i i3 = _mm256_cvtps_epi32( v3 );
1389
1610
 
1390
1611
  #if defined(__AVX2__)
1612
+ // Compute the sum of the quants and set y[i].s
1613
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1614
+ y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1615
+ y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1616
+
1391
1617
  // Convert int32 to int16
1392
1618
  i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1393
1619
  i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1413,6 +1639,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1413
1639
  __m128i ni6 = _mm256_castsi256_si128( i3 );
1414
1640
  __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1415
1641
 
1642
+ // Compute the sum of the quants and set y[i].s
1643
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1644
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1645
+ y[i].s0 = d * hsum_i32_4(s0);
1646
+ y[i].s1 = d * hsum_i32_4(s1);
1647
+
1416
1648
  // Convert int32 to int16
1417
1649
  ni0 = _mm_packs_epi32( ni0, ni1 );
1418
1650
  ni2 = _mm_packs_epi32( ni2, ni3 );
@@ -1428,7 +1660,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1428
1660
  }
1429
1661
  #else
1430
1662
  // scalar
1431
- quantize_row_q8_0_reference(x, y, k);
1663
+ quantize_row_q8_1_reference(x, y, k);
1432
1664
  #endif
1433
1665
  }
1434
1666
 
@@ -1482,7 +1714,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1482
1714
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1483
1715
 
1484
1716
  // Expand 4-bit qs to 8-bit bytes
1485
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1717
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1486
1718
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1487
1719
 
1488
1720
  // Convert to signed 8-bit integers
@@ -1532,7 +1764,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1532
1764
  for (int l = 0; l < QK4_0; l += 2) {
1533
1765
  const uint8_t vi = pp[l/2];
1534
1766
 
1535
- const int8_t vi0 = vi & 0xf;
1767
+ const int8_t vi0 = vi & 0x0F;
1536
1768
  const int8_t vi1 = vi >> 4;
1537
1769
 
1538
1770
  const float v0 = (vi0 - 8)*d;
@@ -1598,7 +1830,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1598
1830
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1599
1831
 
1600
1832
  // Expand 4-bit qs to 8-bit bytes
1601
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1833
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1602
1834
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1603
1835
 
1604
1836
  // Interleave and combine
@@ -1640,7 +1872,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1640
1872
  for (int l = 0; l < QK4_1; l += 2) {
1641
1873
  const uint8_t vi = pp[l/2];
1642
1874
 
1643
- const int8_t vi0 = vi & 0xf;
1875
+ const int8_t vi0 = vi & 0x0F;
1644
1876
  const int8_t vi1 = vi >> 4;
1645
1877
 
1646
1878
  const float v0 = vi0*d + m;
@@ -1670,7 +1902,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1670
1902
  for (int l = 0; l < QK4_2; l += 2) {
1671
1903
  const uint8_t vi = pp[l/2];
1672
1904
 
1673
- const int8_t vi0 = vi & 0xf;
1905
+ const int8_t vi0 = vi & 0x0F;
1674
1906
  const int8_t vi1 = vi >> 4;
1675
1907
 
1676
1908
  const float v0 = (vi0 - 8)*d;
@@ -1700,7 +1932,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
1700
1932
  for (int l = 0; l < QK4_3; l += 2) {
1701
1933
  const uint8_t vi = pp[l/2];
1702
1934
 
1703
- const int8_t vi0 = vi & 0xf;
1935
+ const int8_t vi0 = vi & 0x0F;
1704
1936
  const int8_t vi1 = vi >> 4;
1705
1937
 
1706
1938
  const float v0 = vi0*d + m;
@@ -1715,54 +1947,176 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
1715
1947
  }
1716
1948
  }
1717
1949
 
1718
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
- static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1722
-
1723
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1724
- [GGML_TYPE_Q4_0] = {
1725
- .dequantize_row_q = dequantize_row_q4_0,
1726
- .quantize_row_q = quantize_row_q4_0,
1727
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1728
- .quantize_row_q_dot = quantize_row_q8_0,
1729
- .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1730
- },
1731
- [GGML_TYPE_Q4_1] = {
1732
- .dequantize_row_q = dequantize_row_q4_1,
1733
- .quantize_row_q = quantize_row_q4_1,
1734
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1735
- .quantize_row_q_dot = quantize_row_q8_0,
1736
- .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
1737
- },
1738
- [GGML_TYPE_Q4_2] = {
1739
- .dequantize_row_q = dequantize_row_q4_2,
1740
- .quantize_row_q = quantize_row_q4_2,
1741
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
1742
- .quantize_row_q_dot = quantize_row_q8_0,
1743
- .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
1744
- },
1745
- [GGML_TYPE_Q4_3] = {
1746
- .dequantize_row_q = dequantize_row_q4_3,
1747
- .quantize_row_q = quantize_row_q4_3,
1748
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
1749
- .quantize_row_q_dot = quantize_row_q8_0,
1750
- .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
1751
- },
1752
- [GGML_TYPE_Q8_0] = {
1753
- .dequantize_row_q = NULL, // TODO
1754
- .quantize_row_q = quantize_row_q8_0,
1755
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
- .quantize_row_q_dot = quantize_row_q8_0,
1757
- .vec_dot_q = NULL, // TODO
1758
- },
1759
- };
1950
+ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
1951
+ assert(k % QK5_0 == 0);
1952
+ const int nb = k / QK5_0;
1760
1953
 
1761
- // For internal test use
1762
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1763
- GGML_ASSERT(i < GGML_TYPE_COUNT);
1764
- return quantize_fns[i];
1765
- }
1954
+ const block_q5_0 * restrict x = vx;
1955
+
1956
+ for (int i = 0; i < nb; i++) {
1957
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1958
+
1959
+ const uint8_t * restrict pp = x[i].qs;
1960
+
1961
+ uint32_t qh;
1962
+ memcpy(&qh, x[i].qh, sizeof(qh));
1963
+
1964
+ for (int l = 0; l < QK5_0; l += 2) {
1965
+ const uint8_t vi = pp[l/2];
1966
+
1967
+ // extract the 5-th bit from qh
1968
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1969
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1970
+
1971
+ const int8_t vi0 = (vi & 0x0F) | vh0;
1972
+ const int8_t vi1 = (vi >> 4) | vh1;
1973
+
1974
+ const float v0 = (vi0 - 16)*d;
1975
+ const float v1 = (vi1 - 16)*d;
1976
+
1977
+ y[i*QK5_0 + l + 0] = v0;
1978
+ y[i*QK5_0 + l + 1] = v1;
1979
+
1980
+ assert(!isnan(y[i*QK5_0 + l + 0]));
1981
+ assert(!isnan(y[i*QK5_0 + l + 1]));
1982
+ }
1983
+ }
1984
+ }
1985
+
1986
+ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
1987
+ assert(k % QK5_1 == 0);
1988
+ const int nb = k / QK5_1;
1989
+
1990
+ const block_q5_1 * restrict x = vx;
1991
+
1992
+ for (int i = 0; i < nb; i++) {
1993
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1994
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1995
+
1996
+ const uint8_t * restrict pp = x[i].qs;
1997
+
1998
+ uint32_t qh;
1999
+ memcpy(&qh, x[i].qh, sizeof(qh));
2000
+
2001
+ for (int l = 0; l < QK5_1; l += 2) {
2002
+ const uint8_t vi = pp[l/2];
2003
+
2004
+ // extract the 5-th bit from qh
2005
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
2006
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
2007
+
2008
+ const uint8_t vi0 = (vi & 0x0F) | vh0;
2009
+ const uint8_t vi1 = (vi >> 4) | vh1;
2010
+
2011
+ const float v0 = vi0*d + m;
2012
+ const float v1 = vi1*d + m;
2013
+
2014
+ y[i*QK5_1 + l + 0] = v0;
2015
+ y[i*QK5_1 + l + 1] = v1;
2016
+
2017
+ assert(!isnan(y[i*QK5_1 + l + 0]));
2018
+ assert(!isnan(y[i*QK5_1 + l + 1]));
2019
+ }
2020
+ }
2021
+ }
2022
+
2023
+ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
2024
+ assert(k % QK8_0 == 0);
2025
+ const int nb = k / QK8_0;
2026
+
2027
+ const block_q8_0 * restrict x = vx;
2028
+
2029
+ for (int i = 0; i < nb; i++) {
2030
+ const float d = x[i].d;
2031
+
2032
+ const int8_t * restrict pp = x[i].qs;
2033
+
2034
+ for (int l = 0; l < QK8_0; ++l) {
2035
+ y[i*QK8_0 + l] = pp[l]*d;
2036
+ }
2037
+ }
2038
+ }
2039
+
2040
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2041
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2042
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2043
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2044
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2045
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2046
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2047
+
2048
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
2049
+ [GGML_TYPE_Q4_0] = {
2050
+ .dequantize_row_q = dequantize_row_q4_0,
2051
+ .quantize_row_q = quantize_row_q4_0,
2052
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
2053
+ .quantize_row_q_dot = quantize_row_q8_0,
2054
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
2055
+ .vec_dot_type = GGML_TYPE_Q8_0,
2056
+ },
2057
+ [GGML_TYPE_Q4_1] = {
2058
+ .dequantize_row_q = dequantize_row_q4_1,
2059
+ .quantize_row_q = quantize_row_q4_1,
2060
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
2061
+ .quantize_row_q_dot = quantize_row_q8_1,
2062
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
2063
+ .vec_dot_type = GGML_TYPE_Q8_1,
2064
+ },
2065
+ [GGML_TYPE_Q4_2] = {
2066
+ .dequantize_row_q = dequantize_row_q4_2,
2067
+ .quantize_row_q = quantize_row_q4_2,
2068
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
2069
+ .quantize_row_q_dot = quantize_row_q8_0,
2070
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
2071
+ .vec_dot_type = GGML_TYPE_Q8_0,
2072
+ },
2073
+ [GGML_TYPE_Q4_3] = {
2074
+ .dequantize_row_q = dequantize_row_q4_3,
2075
+ .quantize_row_q = quantize_row_q4_3,
2076
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
2077
+ .quantize_row_q_dot = quantize_row_q8_1,
2078
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_1,
2079
+ .vec_dot_type = GGML_TYPE_Q8_1,
2080
+ },
2081
+ [GGML_TYPE_Q5_0] = {
2082
+ .dequantize_row_q = dequantize_row_q5_0,
2083
+ .quantize_row_q = quantize_row_q5_0,
2084
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
2085
+ .quantize_row_q_dot = quantize_row_q8_0,
2086
+ .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
2087
+ .vec_dot_type = GGML_TYPE_Q8_0,
2088
+ },
2089
+ [GGML_TYPE_Q5_1] = {
2090
+ .dequantize_row_q = dequantize_row_q5_1,
2091
+ .quantize_row_q = quantize_row_q5_1,
2092
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
2093
+ .quantize_row_q_dot = quantize_row_q8_1,
2094
+ .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
2095
+ .vec_dot_type = GGML_TYPE_Q8_1,
2096
+ },
2097
+ [GGML_TYPE_Q8_0] = {
2098
+ .dequantize_row_q = dequantize_row_q8_0,
2099
+ .quantize_row_q = quantize_row_q8_0,
2100
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
2101
+ .quantize_row_q_dot = quantize_row_q8_0,
2102
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
2103
+ .vec_dot_type = GGML_TYPE_Q8_0,
2104
+ },
2105
+ [GGML_TYPE_Q8_1] = {
2106
+ .dequantize_row_q = NULL, // TODO
2107
+ .quantize_row_q = quantize_row_q8_1,
2108
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
2109
+ .quantize_row_q_dot = quantize_row_q8_1,
2110
+ .vec_dot_q = NULL, // TODO
2111
+ .vec_dot_type = GGML_TYPE_Q8_1,
2112
+ },
2113
+ };
2114
+
2115
+ // For internal test use
2116
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
2117
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
2118
+ return quantize_fns[i];
2119
+ }
1766
2120
 
1767
2121
 
1768
2122
  //
@@ -2366,8 +2720,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2366
2720
  const block_q4_0 * restrict x = vx;
2367
2721
  const block_q8_0 * restrict y = vy;
2368
2722
 
2369
- float sumf = 0.0;
2370
-
2371
2723
  #if defined(__ARM_NEON)
2372
2724
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
2725
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2378,7 +2730,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2378
2730
  const block_q8_0 * restrict y0 = &y[i + 0];
2379
2731
  const block_q8_0 * restrict y1 = &y[i + 1];
2380
2732
 
2381
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2733
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2382
2734
  const int8x16_t s8b = vdupq_n_s8(0x8);
2383
2735
 
2384
2736
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2436,7 +2788,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2788
  #endif
2437
2789
  }
2438
2790
 
2439
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2791
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
2792
  #elif defined(__AVX2__)
2441
2793
  // Initialize accumulator with zeros
2442
2794
  __m256 acc = _mm256_setzero_ps();
@@ -2454,32 +2806,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2454
2806
 
2455
2807
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2456
2808
 
2457
- // Get absolute values of x vectors
2458
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2459
-
2460
- // Sign the values of the y vectors
2461
- const __m256i sy = _mm256_sign_epi8(by, bx);
2462
-
2463
- // Perform multiplication and create 16-bit values
2464
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2465
-
2466
- const __m256i ones = _mm256_set1_epi16(1);
2467
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
2468
-
2469
- /* Convert to vectore of 8 int32_t to 8 floats */
2470
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2809
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2471
2810
 
2472
2811
  /* Multiply q with scale and accumulate */
2473
2812
  acc = _mm256_fmadd_ps( d, q, acc );
2474
2813
  }
2475
2814
 
2476
- // Return horizontal sum of the acc vector
2477
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2478
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2479
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2480
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2481
-
2482
- sumf = _mm_cvtss_f32( res );
2815
+ *s = hsum_float_8(acc);
2483
2816
  #elif defined(__AVX__)
2484
2817
  // Initialize accumulator with zeros
2485
2818
  __m256 acc = _mm256_setzero_ps();
@@ -2518,15 +2851,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2518
2851
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2519
2852
  }
2520
2853
 
2521
- // Return horizontal sum of the acc vector
2522
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2523
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2524
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2525
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2526
-
2527
- sumf = _mm_cvtss_f32( res );
2854
+ *s = hsum_float_8(acc);
2528
2855
  #else
2529
2856
  // scalar
2857
+ float sumf = 0.0;
2530
2858
  for (int i = 0; i < nb; i++) {
2531
2859
  const float d0 = x[i].d;
2532
2860
  const float d1 = y[i].d;
@@ -2538,8 +2866,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2538
2866
  for (int j = 0; j < QK8_0/2; j++) {
2539
2867
  const uint8_t v0 = p0[j];
2540
2868
 
2541
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
- const int i1 = (int8_t) (v0 >> 4) - 8;
2869
+ const int i0 = (int8_t) (v0 & 0x0F) - 8;
2870
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2543
2871
 
2544
2872
  const int i2 = p1[2*j + 0];
2545
2873
  const int i3 = p1[2*j + 1];
@@ -2548,34 +2876,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2548
2876
  }
2549
2877
  sumf += d0*d1*sumi;
2550
2878
  }
2551
- #endif
2552
-
2553
2879
  *s = sumf;
2880
+ #endif
2554
2881
  }
2555
2882
 
2556
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
- const int nb = n / QK8_0;
2883
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2884
+ const int nb = n / QK8_1;
2558
2885
 
2559
- assert(n % QK8_0 == 0);
2886
+ assert(n % QK8_1 == 0);
2560
2887
  assert(nb % 2 == 0);
2561
2888
 
2562
2889
  const block_q4_1 * restrict x = vx;
2563
- const block_q8_0 * restrict y = vy;
2564
-
2565
- float sumf = 0.0;
2890
+ const block_q8_1 * restrict y = vy;
2566
2891
 
2567
2892
  // TODO: add AVX / WASM SIMD / etc
2568
2893
  #if defined(__ARM_NEON)
2569
2894
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
2895
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2571
2896
 
2897
+ float summs = 0;
2898
+
2572
2899
  for (int i = 0; i < nb; i += 2) {
2573
2900
  const block_q4_1 * restrict x0 = &x[i + 0];
2574
2901
  const block_q4_1 * restrict x1 = &x[i + 1];
2575
- const block_q8_0 * restrict y0 = &y[i + 0];
2576
- const block_q8_0 * restrict y1 = &y[i + 1];
2902
+ const block_q8_1 * restrict y0 = &y[i + 0];
2903
+ const block_q8_1 * restrict y1 = &y[i + 1];
2577
2904
 
2578
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2905
+ summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2906
+
2907
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2579
2908
 
2580
2909
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2581
2910
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2586,46 +2915,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2586
2915
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
2916
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
2917
 
2918
+ // interleave
2919
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2920
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2921
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2922
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2923
+
2589
2924
  // load y
2590
2925
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
2926
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
2927
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
2928
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2594
2929
 
2595
- // interleave
2596
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2600
-
2601
- const int16x8_t s0i = vaddq_s16(
2602
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
-
2605
- const int16x8_t s1i = vaddq_s16(
2606
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
-
2609
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
-
2612
2930
  #if defined(__ARM_FEATURE_DOTPROD)
2613
2931
  // dot product into int32x4_t
2614
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2932
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2933
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2616
2934
 
2617
2935
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
2936
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2619
2937
  #else
2620
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2938
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2939
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2940
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2941
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2624
2942
 
2625
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2943
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2944
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2945
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2946
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2629
2947
 
2630
2948
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
2949
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2637,65 +2955,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
2955
  #endif
2638
2956
  }
2639
2957
 
2640
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2958
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2641
2959
  #elif defined(__AVX2__)
2642
2960
  // Initialize accumulator with zeros
2643
2961
  __m256 acc = _mm256_setzero_ps();
2644
2962
 
2963
+ float summs = 0;
2964
+
2645
2965
  // Main loop
2646
2966
  for (int i = 0; i < nb; ++i) {
2647
2967
  const float * d0 = &x[i].d;
2648
2968
  const float * d1 = &y[i].d;
2649
- const float * m0 = &x[i].m;
2969
+
2970
+ summs += x[i].m * (y[i].s0 + y[i].s1);
2650
2971
 
2651
2972
  const __m256 d0v = _mm256_broadcast_ss( d0 );
2652
2973
  const __m256 d1v = _mm256_broadcast_ss( d1 );
2653
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2654
2974
 
2655
2975
  // Compute combined scales
2656
2976
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
- const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2658
2977
 
2659
2978
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
2979
  const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
2980
  const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2662
2981
 
2663
- // Get absolute values of x vectors
2664
- const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
-
2666
- // Sign the values of the y vectors
2667
- const __m256i sy = _mm256_sign_epi8( by, bx );
2668
-
2669
- // Perform multiplication and create 16-bit values
2670
- const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
- const __m256i ones = _mm256_set1_epi16( 1 );
2672
- const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
-
2674
- // Convert to vector of 8 int32_t to 8 floats
2675
- const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2982
+ const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2676
2983
 
2677
2984
  // Accumulate d0*d1*x*y
2678
2985
  acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
-
2680
- // Compute sum of y values
2681
- const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
- const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
-
2686
- // Accumulate d1*m0*y
2687
- acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2688
2986
  }
2689
2987
 
2690
- // Return horizontal sum of the acc vector
2691
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2692
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695
-
2696
- sumf = _mm_cvtss_f32( res );
2988
+ *s = hsum_float_8(acc) + summs;
2697
2989
  #else
2698
2990
  // scalar
2991
+ float sumf = 0.0;
2699
2992
  for (int i = 0; i < nb; i++) {
2700
2993
  const float d0 = x[i].d;
2701
2994
  const float m0 = x[i].m;
@@ -2705,11 +2998,11 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2705
2998
  const int8_t * restrict p1 = y[i].qs;
2706
2999
 
2707
3000
  // TODO: this is very slow ..
2708
- for (int j = 0; j < QK8_0/2; j++) {
3001
+ for (int j = 0; j < QK8_1/2; j++) {
2709
3002
  const uint8_t v0 = p0[j];
2710
3003
 
2711
- const float f0 = d0*(v0 & 0xf) + m0;
2712
- const float f1 = d0*(v0 >> 4) + m0;
3004
+ const float f0 = d0*(v0 & 0x0F) + m0;
3005
+ const float f1 = d0*(v0 >> 4) + m0;
2713
3006
 
2714
3007
  const float f2 = d1*p1[2*j + 0];
2715
3008
  const float f3 = d1*p1[2*j + 1];
@@ -2717,9 +3010,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2717
3010
  sumf += f0*f2 + f1*f3;
2718
3011
  }
2719
3012
  }
2720
- #endif
2721
-
2722
3013
  *s = sumf;
3014
+ #endif
2723
3015
  }
2724
3016
 
2725
3017
  static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -2732,8 +3024,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2732
3024
  const block_q4_2 * restrict x = vx;
2733
3025
  const block_q8_0 * restrict y = vy;
2734
3026
 
2735
- float sumf = 0.0;
2736
-
2737
3027
  #if defined(__ARM_NEON)
2738
3028
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
3029
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2747,7 +3037,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2747
3037
  const block_q8_0 * restrict y0 = &y[i + 0];
2748
3038
  const block_q8_0 * restrict y1 = &y[i + 1];
2749
3039
 
2750
- const uint8x16_t m4b = vdupq_n_u8(0xf);
3040
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2751
3041
  const int8x16_t s8b = vdupq_n_s8(0x8);
2752
3042
 
2753
3043
  const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
@@ -2782,270 +3072,604 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2782
3072
  vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
3073
  vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2784
3074
 
2785
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2788
- #else
2789
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3075
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3076
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3077
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3078
+ #else
3079
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3080
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3081
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3082
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3083
+
3084
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3085
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3086
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3087
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3088
+
3089
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3090
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3091
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3092
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3093
+
3094
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3095
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3096
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3097
+
3098
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3099
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3100
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3101
+ #endif
3102
+ }
3103
+
3104
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3105
+ #elif defined(__AVX2__)
3106
+ // Initialize accumulator with zeros
3107
+ __m256 acc = _mm256_setzero_ps();
3108
+
3109
+ // Main loop
3110
+ for (int i = 0; i < nb; i++) {
3111
+ /* Compute combined scale for the block */
3112
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3113
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3114
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3115
+
3116
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3117
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3118
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
3119
+
3120
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3121
+ const __m256i off = _mm256_set1_epi8(8);
3122
+ bx = _mm256_sub_epi8(bx, off);
3123
+
3124
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3125
+
3126
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3127
+
3128
+ /* Multiply q with scale and accumulate */
3129
+ acc = _mm256_fmadd_ps(d, q, acc);
3130
+ }
3131
+
3132
+ *s = hsum_float_8(acc);
3133
+ #else
3134
+ // scalar
3135
+ float sumf = 0.0;
3136
+ for (int i = 0; i < nb; i++) {
3137
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3138
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3139
+ const int8_t * restrict y0 = y[i].qs;
3140
+
3141
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3142
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3143
+
3144
+ int sumi_0 = 0;
3145
+ int sumi_1 = 0;
3146
+
3147
+ for (int j = 0; j < QK8_0/4; j++) {
3148
+ const uint8_t v0 = x0[j];
3149
+ const uint8_t v1 = x1[j];
3150
+
3151
+ const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3152
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
3153
+
3154
+ const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3155
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
3156
+
3157
+ const int i2_0 = y0[2*j + 0];
3158
+ const int i3_0 = y0[2*j + 1];
3159
+
3160
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
3161
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3162
+
3163
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
3164
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3165
+ }
3166
+
3167
+ sumf += (d0 * y[i].d) * sumi_0;
3168
+ sumf += (d1 * y[i].d) * sumi_1;
3169
+ }
3170
+ *s = sumf;
3171
+ #endif
3172
+ }
3173
+
3174
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3175
+ const int nb = n / QK8_1;
3176
+
3177
+ assert(n % QK8_1 == 0);
3178
+ assert(nb % 2 == 0);
3179
+ assert(QK8_1 == 2*QK4_3);
3180
+
3181
+ const block_q4_3 * restrict x = vx;
3182
+ const block_q8_1 * restrict y = vy;
3183
+
3184
+ #if defined(__ARM_NEON)
3185
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3186
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3187
+
3188
+ float summs0 = 0.0f;
3189
+ float summs1 = 0.0f;
3190
+
3191
+ for (int i = 0; i < nb; ++i) {
3192
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
3193
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
3194
+
3195
+ const block_q8_1 * restrict y0 = &y[i + 0];
3196
+
3197
+ summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
3198
+ summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
3199
+
3200
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3201
+
3202
+ // 4-bit -> 8-bit
3203
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
3204
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3205
+
3206
+ // interleave
3207
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3208
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3209
+
3210
+ // load y
3211
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3212
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3213
+
3214
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
3215
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
3216
+
3217
+ #if defined(__ARM_FEATURE_DOTPROD)
3218
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
3219
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
3220
+ #else
3221
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3222
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3223
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3224
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3225
+
3226
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3227
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3228
+
3229
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
3230
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
3231
+ #endif
3232
+ }
3233
+
3234
+ *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
3235
+ #elif defined(__AVX2__)
3236
+ // Initialize accumulator with zeros
3237
+ __m256 acc = _mm256_setzero_ps();
3238
+ float summs = 0.0f;
3239
+
3240
+ // Main loop
3241
+ for (int i = 0; i < nb; i++) {
3242
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3243
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3244
+ const __m256 dx = _mm256_set_m128(d1, d0);
3245
+
3246
+ summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
3247
+ + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
3248
+
3249
+ const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3250
+ const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3251
+ const __m256i bx = _mm256_set_m128i(bx1, bx0);
3252
+
3253
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3254
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3255
+
3256
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3257
+
3258
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
3259
+ }
3260
+
3261
+ *s = hsum_float_8(acc) + summs;
3262
+ #else
3263
+ // scalar
3264
+ float sumf = 0.0;
3265
+ for (int i = 0; i < nb; i++) {
3266
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3267
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3268
+ const int8_t * restrict y0 = y[i].qs;
3269
+
3270
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3271
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3272
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3273
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3274
+
3275
+ int sxy_0 = 0;
3276
+ int sxy_1 = 0;
3277
+
3278
+ for (int j = 0; j < QK8_1/4; j++) {
3279
+ const uint8_t v0 = x0[j];
3280
+ const uint8_t v1 = x1[j];
3281
+
3282
+ const int x0_0 = v0 & 0x0F;
3283
+ const int x1_0 = v0 >> 4;
3284
+
3285
+ const int x0_1 = v1 & 0x0F;
3286
+ const int x1_1 = v1 >> 4;
3287
+
3288
+ const int y0_0 = y0[2*j + 0];
3289
+ const int y1_0 = y0[2*j + 1];
3290
+
3291
+ const int y0_1 = y0[2*(j + QK8_1/4) + 0];
3292
+ const int y1_1 = y0[2*(j + QK8_1/4) + 1];
3293
+
3294
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3295
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3296
+ }
3297
+
3298
+ sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
3299
+ }
3300
+ *s = sumf;
3301
+ #endif
3302
+ }
3303
+
3304
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3305
+ const int nb = n / QK8_0;
3306
+
3307
+ assert(n % QK8_0 == 0);
3308
+ assert(nb % 2 == 0);
3309
+ assert(QK8_0 == QK5_0);
3310
+
3311
+ const block_q5_0 * restrict x = vx;
3312
+ const block_q8_0 * restrict y = vy;
3313
+
3314
+ #if defined(__ARM_NEON)
3315
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3316
+
3317
+ uint64_t tmp[4];
3318
+
3319
+ for (int i = 0; i < nb; ++i) {
3320
+ const block_q5_0 * restrict x0 = &x[i];
3321
+ const block_q8_0 * restrict y0 = &y[i];
3322
+
3323
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3324
+ const int8x16_t s16b = vdupq_n_s8(0x10);
3325
+
3326
+ // extract the 5th bit
3327
+ uint32_t qh;
3328
+ memcpy(&qh, x0->qh, sizeof(qh));
3329
+
3330
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3331
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3332
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3333
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3334
+
3335
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3336
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3337
+
3338
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3339
+
3340
+ // 4-bit -> 8-bit
3341
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
3342
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3343
+
3344
+ // interleave
3345
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3346
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3347
+
3348
+ // add high bit and sub 16
3349
+ const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
3350
+ const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
3351
+
3352
+ // load y
3353
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3354
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3355
+
3356
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
3357
+
3358
+ #if defined(__ARM_FEATURE_DOTPROD)
3359
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3360
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3361
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3362
+ #else
3363
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3364
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3365
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3366
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
3367
+
3368
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3369
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3370
+
3371
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3372
+ #endif
3373
+ }
3374
+
3375
+ *s = vaddvq_f32(sumv);
3376
+ #elif defined(__AVX2__)
3377
+ // Initialize accumulator with zeros
3378
+ __m256 acc = _mm256_setzero_ps();
3379
+
3380
+ // Main loop
3381
+ for (int i = 0; i < nb; i++) {
3382
+ /* Compute combined scale for the block */
3383
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
3384
+
3385
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3386
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3387
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3388
+ bx = _mm256_or_si256(bx, bxhi);
3389
+
3390
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3391
+
3392
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3393
+
3394
+ /* Multiply q with scale and accumulate */
3395
+ acc = _mm256_fmadd_ps(d, q, acc);
3396
+ }
3397
+
3398
+ *s = hsum_float_8(acc);
3399
+ #else
3400
+ // scalar
3401
+ float sumf = 0.0;
3402
+ for (int i = 0; i < nb; i++) {
3403
+ const uint8_t * restrict x0 = x[i].qs;
3404
+ const int8_t * restrict y0 = y[i].qs;
3405
+
3406
+ uint32_t qh;
3407
+ memcpy(&qh, x[i].qh, sizeof(qh));
3408
+
3409
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3410
+
3411
+ int sxy = 0;
3412
+
3413
+ for (int j = 0; j < QK8_0/2; j++) {
3414
+ const uint8_t v0 = x0[j];
3415
+
3416
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3417
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
3418
+
3419
+ const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
3420
+ const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
3421
+
3422
+ const int y0_0 = y0[2*j + 0];
3423
+ const int y1_0 = y0[2*j + 1];
3424
+
3425
+ sxy += x0_0*y0_0 + x1_0*y1_0;
3426
+ }
3427
+
3428
+ sumf += (d*sxy)*y[i].d;
3429
+ }
3430
+ *s = sumf;
3431
+ #endif
3432
+ }
3433
+
3434
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3435
+ const int nb = n / QK8_1;
3436
+
3437
+ assert(n % QK8_1 == 0);
3438
+ assert(nb % 2 == 0);
3439
+ assert(QK8_1 == QK5_1);
3440
+
3441
+ const block_q5_1 * restrict x = vx;
3442
+ const block_q8_1 * restrict y = vy;
3443
+
3444
+ #if defined(__ARM_NEON)
3445
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3446
+
3447
+ float summs = 0.0f;
3448
+
3449
+ uint64_t tmp[4];
3450
+
3451
+ for (int i = 0; i < nb; ++i) {
3452
+ const block_q5_1 * restrict x0 = &x[i];
3453
+ const block_q8_1 * restrict y0 = &y[i];
3454
+
3455
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
3456
+
3457
+ // extract the 5th bit
3458
+ uint32_t qh;
3459
+ memcpy(&qh, x0->qh, sizeof(qh));
3460
+
3461
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3462
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3463
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3464
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3465
+
3466
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3467
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3468
+
3469
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3470
+
3471
+ // 4-bit -> 8-bit
3472
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3473
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3474
+
3475
+ // interleave
3476
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3477
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3478
+
3479
+ // add
3480
+ const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3481
+ const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
3482
+
3483
+ // load y
3484
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3485
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3486
+
3487
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2793
3488
 
2794
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3489
+ #if defined(__ARM_FEATURE_DOTPROD)
3490
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3491
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3492
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3493
+ #else
3494
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3495
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3496
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3497
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2798
3498
 
2799
3499
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
3500
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2803
3501
 
2804
- sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
- vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
- vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2807
-
2808
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
- vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
- vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3502
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2811
3503
  #endif
2812
3504
  }
2813
3505
 
2814
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3506
+ *s = vaddvq_f32(sumv) + summs;
2815
3507
  #elif defined(__AVX2__)
2816
3508
  // Initialize accumulator with zeros
2817
3509
  __m256 acc = _mm256_setzero_ps();
3510
+ float summs = 0.0f;
2818
3511
 
2819
3512
  // Main loop
2820
3513
  for (int i = 0; i < nb; i++) {
2821
- /* Compute combined scale for the block */
2822
- const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
- const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
- const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2825
-
2826
- __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
- __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
- __m256i bx = _mm256_set_m128i(bx1, bx0);
2829
-
2830
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2831
- const __m256i off = _mm256_set1_epi8(8);
2832
- bx = _mm256_sub_epi8(bx, off);
3514
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2833
3515
 
2834
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3516
+ summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2835
3517
 
2836
- // Get absolute values of x vectors
2837
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2838
- // Sign the values of the y vectors
2839
- const __m256i sy = _mm256_sign_epi8(by, bx);
2840
- // Perform multiplication and create 16-bit values
2841
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
3518
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3519
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3520
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3521
+ bx = _mm256_or_si256(bx, bxhi);
2842
3522
 
2843
- const __m256i ones = _mm256_set1_epi16(1);
2844
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
3523
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3524
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2845
3525
 
2846
- /* Convert to vectore of 8 int32_t to 8 floats */
2847
- __m256 q = _mm256_cvtepi32_ps(xy_q);
3526
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2848
3527
 
2849
- /* Multiply q with scale and accumulate */
2850
- acc = _mm256_fmadd_ps(d, q, acc);
3528
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2851
3529
  }
2852
3530
 
2853
- // Return horizontal sum of the acc vector
2854
- __m128 res = _mm256_extractf128_ps(acc, 1);
2855
- res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858
-
2859
- sumf = _mm_cvtss_f32(res);
3531
+ *s = hsum_float_8(acc) + summs;
2860
3532
  #else
2861
- // scalar
3533
+ float sumf = 0.0;
3534
+
2862
3535
  for (int i = 0; i < nb; i++) {
2863
- const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3536
+ const uint8_t * restrict x0 = x[i].qs;
2865
3537
  const int8_t * restrict y0 = y[i].qs;
2866
3538
 
2867
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3539
+ uint32_t qh;
3540
+ memcpy(&qh, x[i].qh, sizeof(qh));
2869
3541
 
2870
- int sumi_0 = 0;
2871
- int sumi_1 = 0;
3542
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3543
+ const float m = GGML_FP16_TO_FP32(x[i].m);
2872
3544
 
2873
- for (int j = 0; j < QK8_0/4; j++) {
2874
- const uint8_t v0 = x0[j];
2875
- const uint8_t v1 = x1[j];
3545
+ int sxy = 0;
2876
3546
 
2877
- const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
- const int i1_0 = (int8_t) (v0 >> 4) - 8;
3547
+ for (int j = 0; j < QK8_1/2; j++) {
3548
+ const uint8_t v0 = x0[j];
2879
3549
 
2880
- const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
- const int i1_1 = (int8_t) (v1 >> 4) - 8;
3550
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3551
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
2882
3552
 
2883
- const int i2_0 = y0[2*j + 0];
2884
- const int i3_0 = y0[2*j + 1];
3553
+ const int x0_0 = (v0 & 0x0F) | x0_0h;
3554
+ const int x1_0 = (v0 >> 4) | x1_0h;
2885
3555
 
2886
- const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
- const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3556
+ const int y0_0 = y0[2*j + 0];
3557
+ const int y1_0 = y0[2*j + 1];
2888
3558
 
2889
- sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
- sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3559
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2891
3560
  }
2892
3561
 
2893
- sumf += (d0 * y[i].d) * sumi_0;
2894
- sumf += (d1 * y[i].d) * sumi_1;
3562
+ sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2895
3563
  }
2896
- #endif
2897
3564
 
2898
3565
  *s = sumf;
3566
+ #endif
2899
3567
  }
2900
3568
 
2901
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3569
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
3570
  const int nb = n / QK8_0;
2903
3571
 
2904
3572
  assert(n % QK8_0 == 0);
2905
3573
  assert(nb % 2 == 0);
2906
- assert(QK8_0 == 2*QK4_2);
3574
+ assert(QK8_0 == QK8_0);
2907
3575
 
2908
- const block_q4_3 * restrict x = vx;
3576
+ const block_q8_0 * restrict x = vx;
2909
3577
  const block_q8_0 * restrict y = vy;
2910
3578
 
2911
- float sumf = 0.0;
2912
-
2913
3579
  #if defined(__ARM_NEON)
2914
3580
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
3581
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2916
3582
 
2917
3583
  for (int i = 0; i < nb; i += 2) {
2918
- const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
- const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
- const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
- const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2922
-
3584
+ const block_q8_0 * restrict x0 = &x[i + 0];
3585
+ const block_q8_0 * restrict x1 = &x[i + 1];
2923
3586
  const block_q8_0 * restrict y0 = &y[i + 0];
2924
3587
  const block_q8_0 * restrict y1 = &y[i + 1];
2925
3588
 
2926
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
-
2928
- const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
- const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
- const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
- const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
-
2933
- const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
- const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
- const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
- const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
-
2938
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
- const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
-
2941
- // 4-bit -> 8-bit
2942
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2946
-
2947
- // interleave
2948
- const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
- const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
- const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
- const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
3589
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3590
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3591
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3592
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
2952
3593
 
2953
3594
  // load y
2954
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2958
-
2959
- const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
- const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2961
-
2962
- const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
- const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2964
-
2965
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
3595
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3596
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3597
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3598
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
2969
3599
 
2970
3600
  #if defined(__ARM_FEATURE_DOTPROD)
2971
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
- #else
2976
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
-
2981
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3601
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3602
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3603
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
2985
3604
 
2986
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3605
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3606
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3607
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
2990
3608
 
2991
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
3609
+ #else
3610
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3611
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3612
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3613
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3614
+
3615
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3616
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3617
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3618
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3619
+
3620
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3621
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3622
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3623
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3624
+
3625
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3626
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
2995
3627
  #endif
2996
3628
  }
2997
3629
 
2998
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2999
- #else
3000
- // scalar
3001
- for (int i = 0; i < nb; i++) {
3002
- const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
- const int8_t * restrict y0 = y[i].qs;
3005
-
3006
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
- const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
- const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3010
-
3011
- int sy_0 = 0;
3012
- int sy_1 = 0;
3630
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3631
+ #elif defined(__AVX2__)
3632
+ // Initialize accumulator with zeros
3633
+ __m256 acc = _mm256_setzero_ps();
3013
3634
 
3014
- int sxy_0 = 0;
3015
- int sxy_1 = 0;
3635
+ // Main loop
3636
+ for (int i = 0; i < nb; ++i) {
3637
+ // Compute combined scale for the block
3638
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
3639
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3640
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3016
3641
 
3017
- for (int j = 0; j < QK8_0/4; j++) {
3018
- const uint8_t v0 = x0[j];
3019
- const uint8_t v1 = x1[j];
3642
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3020
3643
 
3021
- const int x0_0 = v0 & 0xf;
3022
- const int x1_0 = v0 >> 4;
3644
+ // Multiply q with scale and accumulate
3645
+ acc = _mm256_fmadd_ps( d, q, acc );
3646
+ }
3023
3647
 
3024
- const int x0_1 = v1 & 0xf;
3025
- const int x1_1 = v1 >> 4;
3648
+ *s = hsum_float_8(acc);
3649
+ #else
3650
+ // scalar
3651
+ float sumf = 0.0;
3026
3652
 
3027
- const int y0_0 = y0[2*j + 0];
3028
- const int y1_0 = y0[2*j + 1];
3653
+ for (int i = 0; i < nb; i++) {
3654
+ const int8_t * restrict x0 = x[i].qs;
3655
+ const int8_t * restrict y0 = y[i].qs;
3029
3656
 
3030
- const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
- const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3657
+ int sumi = 0;
3032
3658
 
3033
- sy_0 += y0_0 + y1_0;
3034
- sy_1 += y0_1 + y1_1;
3659
+ for (int j = 0; j < QK8_0; j++) {
3660
+ const int v0 = x0[j];
3661
+ const int v1 = y0[j];
3035
3662
 
3036
- sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
- sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3663
+ sumi += v0*v1;
3038
3664
  }
3039
3665
 
3040
- sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
- sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
3666
+ sumf += (x[i].d*y[i].d)*sumi;
3042
3667
  }
3043
- #endif
3044
3668
 
3045
3669
  *s = sumf;
3670
+ #endif
3046
3671
  }
3047
3672
 
3048
-
3049
3673
  // compute GGML_VEC_DOT_UNROLL dot products at once
3050
3674
  // xs - x row stride in bytes
3051
3675
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3242,6 +3866,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
3242
3866
  #endif
3243
3867
  }
3244
3868
 
3869
+ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
3870
+ ggml_float sum = 0.0;
3871
+ for (int i = 0; i < n; ++i) {
3872
+ sum += (ggml_float)x[i];
3873
+ }
3874
+ *s = sum;
3875
+ }
3876
+
3245
3877
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
3246
3878
  #ifndef GGML_USE_ACCELERATE
3247
3879
  float max = -INFINITY;
@@ -3294,12 +3926,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3294
3926
  [GGML_TYPE_Q4_1] = QK4_1,
3295
3927
  [GGML_TYPE_Q4_2] = QK4_2,
3296
3928
  [GGML_TYPE_Q4_3] = QK4_3,
3929
+ [GGML_TYPE_Q5_0] = QK5_0,
3930
+ [GGML_TYPE_Q5_1] = QK5_1,
3297
3931
  [GGML_TYPE_Q8_0] = QK8_0,
3932
+ [GGML_TYPE_Q8_1] = QK8_1,
3298
3933
  [GGML_TYPE_I8] = 1,
3299
3934
  [GGML_TYPE_I16] = 1,
3300
3935
  [GGML_TYPE_I32] = 1,
3301
3936
  };
3302
- static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
3937
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3303
3938
 
3304
3939
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3305
3940
  [GGML_TYPE_F32] = sizeof(float),
@@ -3308,12 +3943,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3308
3943
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
3944
  [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
3945
  [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3946
+ [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
3947
+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3311
3948
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3949
+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3312
3950
  [GGML_TYPE_I8] = sizeof(int8_t),
3313
3951
  [GGML_TYPE_I16] = sizeof(int16_t),
3314
3952
  [GGML_TYPE_I32] = sizeof(int32_t),
3315
3953
  };
3316
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
3954
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3317
3955
 
3318
3956
 
3319
3957
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3323,12 +3961,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3323
3961
  [GGML_TYPE_Q4_1] = "q4_1",
3324
3962
  [GGML_TYPE_Q4_2] = "q4_2",
3325
3963
  [GGML_TYPE_Q4_3] = "q4_3",
3964
+ [GGML_TYPE_Q5_0] = "q5_0",
3965
+ [GGML_TYPE_Q5_1] = "q5_1",
3326
3966
  [GGML_TYPE_Q8_0] = "q8_0",
3967
+ [GGML_TYPE_Q8_1] = "q8_1",
3327
3968
  [GGML_TYPE_I8] = "i8",
3328
3969
  [GGML_TYPE_I16] = "i16",
3329
3970
  [GGML_TYPE_I32] = "i32",
3330
3971
  };
3331
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
3972
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3332
3973
 
3333
3974
  static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
3975
  [GGML_TYPE_F32] = false,
@@ -3337,12 +3978,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3337
3978
  [GGML_TYPE_Q4_1] = true,
3338
3979
  [GGML_TYPE_Q4_2] = true,
3339
3980
  [GGML_TYPE_Q4_3] = true,
3981
+ [GGML_TYPE_Q5_0] = true,
3982
+ [GGML_TYPE_Q5_1] = true,
3340
3983
  [GGML_TYPE_Q8_0] = true,
3984
+ [GGML_TYPE_Q8_1] = true,
3341
3985
  [GGML_TYPE_I8] = false,
3342
3986
  [GGML_TYPE_I16] = false,
3343
3987
  [GGML_TYPE_I32] = false,
3344
3988
  };
3345
- static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
3989
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3346
3990
 
3347
3991
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3348
3992
  "NONE",
@@ -3720,7 +4364,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3720
4364
 
3721
4365
  // initialize cuBLAS
3722
4366
  #if defined(GGML_USE_CUBLAS)
3723
- init_cublas();
4367
+ ggml_init_cublas();
4368
+ #elif defined(GGML_USE_CLBLAST)
4369
+ ggml_cl_init();
3724
4370
  #endif
3725
4371
 
3726
4372
  is_first_call = false;
@@ -6554,6 +7200,9 @@ static void ggml_compute_forward_add(
6554
7200
  case GGML_TYPE_Q4_1:
6555
7201
  case GGML_TYPE_Q4_2:
6556
7202
  case GGML_TYPE_Q4_3:
7203
+ case GGML_TYPE_Q5_0:
7204
+ case GGML_TYPE_Q5_1:
7205
+ case GGML_TYPE_Q8_0:
6557
7206
  {
6558
7207
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6559
7208
  } break;
@@ -6811,15 +7460,20 @@ static void ggml_compute_forward_sum_f32(
6811
7460
  const size_t nb02 = src0->nb[2];
6812
7461
  const size_t nb03 = src0->nb[3];
6813
7462
 
7463
+ ggml_float sum = 0;
7464
+ ggml_float row_sum = 0;
7465
+
6814
7466
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6815
7467
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6816
7468
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6817
- ggml_vec_sum_f32(ne00,
6818
- (float *) (dst->data),
7469
+ ggml_vec_sum_ggf(ne00,
7470
+ &row_sum,
6819
7471
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
7472
+ sum += row_sum;
6820
7473
  }
6821
7474
  }
6822
7475
  }
7476
+ ((float *) dst->data)[0] = sum;
6823
7477
  }
6824
7478
 
6825
7479
  static void ggml_compute_forward_sum(
@@ -7454,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
7454
8108
 
7455
8109
  // ggml_compute_forward_mul_mat
7456
8110
 
7457
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8111
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7458
8112
  // helper function to determine if it is better to use BLAS or not
7459
8113
  // for large matrices, BLAS is faster
7460
8114
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7479,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
7479
8133
 
7480
8134
  return false;
7481
8135
  }
8136
+
7482
8137
  #endif
7483
8138
 
7484
8139
  static void ggml_compute_forward_mul_mat_f32(
@@ -7494,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
7494
8149
  const int64_t ne02 = src0->ne[2];
7495
8150
  const int64_t ne03 = src0->ne[3];
7496
8151
 
7497
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8152
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7498
8153
  const int64_t ne10 = src1->ne[0];
7499
8154
  #endif
7500
8155
  const int64_t ne11 = src1->ne[1];
@@ -7551,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
7551
8206
  // nb01 >= nb00 - src0 is not transposed
7552
8207
  // compute by src0 rows
7553
8208
 
7554
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8209
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7555
8210
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7556
8211
  if (params->ith != 0) {
7557
8212
  return;
@@ -7566,18 +8221,16 @@ static void ggml_compute_forward_mul_mat_f32(
7566
8221
  }
7567
8222
 
7568
8223
  #if defined(GGML_USE_CUBLAS)
7569
- float *d_X = NULL;
7570
- float *d_Y = NULL;
7571
- float *d_D = NULL;
7572
8224
  const float alpha = 1.0f;
7573
8225
  const float beta = 0.0f;
7574
8226
  const int x_ne = ne01 * ne10;
7575
8227
  const int y_ne = ne11 * ne10;
7576
8228
  const int d_ne = ne11 * ne01;
7577
8229
 
7578
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8230
+ size_t x_size, y_size, d_size;
8231
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8232
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8233
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
7581
8234
  #endif
7582
8235
 
7583
8236
  for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7589,21 +8242,28 @@ static void ggml_compute_forward_mul_mat_f32(
7589
8242
 
7590
8243
  #if defined(GGML_USE_CUBLAS)
7591
8244
  // copy data to device
7592
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8245
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8246
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
7594
8247
 
7595
8248
  // compute
7596
8249
  CUBLAS_CHECK(
7597
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8250
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
8251
  ne01, ne11, ne10,
7599
8252
  &alpha, d_X, ne00,
7600
8253
  d_Y, ne10,
7601
8254
  &beta, d_D, ne01));
7602
8255
 
7603
8256
  // copy data to host
7604
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
- #else
8257
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8258
+ #elif defined(GGML_USE_CLBLAST)
7606
8259
  // zT = y * xT
8260
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8261
+ ne11, ne01, ne10,
8262
+ 1.0f, y, ne10,
8263
+ x, ne10,
8264
+ 0.0f, d, ne01,
8265
+ GGML_TYPE_F32);
8266
+ #else
7607
8267
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7608
8268
  ne11, ne01, ne10,
7609
8269
  1.0f, y, ne10,
@@ -7613,10 +8273,10 @@ static void ggml_compute_forward_mul_mat_f32(
7613
8273
  }
7614
8274
  }
7615
8275
  #if defined(GGML_USE_CUBLAS)
7616
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
- CUDA_CHECK(cudaFree(d_X));
7618
- CUDA_CHECK(cudaFree(d_Y));
7619
- CUDA_CHECK(cudaFree(d_D));
8276
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8277
+ ggml_cuda_pool_free(d_X, x_size);
8278
+ ggml_cuda_pool_free(d_Y, y_size);
8279
+ ggml_cuda_pool_free(d_D, d_size);
7620
8280
  #endif
7621
8281
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7622
8282
 
@@ -7747,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7747
8407
  // nb01 >= nb00 - src0 is not transposed
7748
8408
  // compute by src0 rows
7749
8409
 
7750
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8410
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7751
8411
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7752
8412
  GGML_ASSERT(nb10 == sizeof(float));
7753
8413
 
@@ -7766,18 +8426,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7766
8426
  #if defined(GGML_USE_CUBLAS)
7767
8427
  ggml_fp16_t * const wdata = params->wdata;
7768
8428
 
7769
- float *d_X = NULL;
7770
- float *d_Y = NULL;
7771
- float *d_D = NULL;
7772
8429
  const float alpha = 1.0f;
7773
8430
  const float beta = 0.0f;
7774
8431
  const int x_ne = ne01 * ne10;
7775
8432
  const int y_ne = ne11 * ne10;
7776
8433
  const int d_ne = ne11 * ne01;
7777
8434
 
7778
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8435
+ size_t x_size, y_size, d_size;
8436
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8437
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8438
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
7781
8439
  #else
7782
8440
  float * const wdata = params->wdata;
7783
8441
  #endif
@@ -7811,12 +8469,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7811
8469
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7812
8470
 
7813
8471
  // copy data to device
7814
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8472
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8473
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
7816
8474
 
7817
8475
  // compute
7818
8476
  CUBLAS_CHECK(
7819
- cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8477
+ cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
8478
  ne01, ne11, ne10,
7821
8479
  &alpha, d_X, CUDA_R_16F, ne00,
7822
8480
  d_Y, CUDA_R_16F, ne10,
@@ -7825,7 +8483,20 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7825
8483
  CUBLAS_GEMM_DEFAULT));
7826
8484
 
7827
8485
  // copy data to host
7828
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8486
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8487
+ #elif defined(GGML_USE_CLBLAST)
8488
+ const float * x = wdata;
8489
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8490
+
8491
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8492
+
8493
+ // zT = y * xT
8494
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8495
+ ne11, ne01, ne10,
8496
+ 1.0f, y, ne10,
8497
+ x, ne10,
8498
+ 0.0f, d, ne01,
8499
+ GGML_TYPE_F32);
7829
8500
  #else
7830
8501
  const float * x = wdata;
7831
8502
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7843,10 +8514,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7843
8514
  }
7844
8515
 
7845
8516
  #if defined(GGML_USE_CUBLAS)
7846
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
- CUDA_CHECK(cudaFree(d_X));
7848
- CUDA_CHECK(cudaFree(d_Y));
7849
- CUDA_CHECK(cudaFree(d_D));
8517
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8518
+ ggml_cuda_pool_free(d_X, x_size);
8519
+ ggml_cuda_pool_free(d_Y, y_size);
8520
+ ggml_cuda_pool_free(d_D, d_size);
7850
8521
  #endif
7851
8522
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7852
8523
 
@@ -7980,6 +8651,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7980
8651
  const enum ggml_type type = src0->type;
7981
8652
  quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7982
8653
  vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
8654
+ enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
7983
8655
 
7984
8656
  // we don't support permuted src0 or src1
7985
8657
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7999,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7999
8671
  // nb01 >= nb00 - src0 is not transposed
8000
8672
  // compute by src0 rows
8001
8673
 
8002
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8674
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
8003
8675
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
8004
8676
  if (params->ith != 0) {
8005
8677
  return;
@@ -8014,20 +8686,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
8014
8686
  }
8015
8687
 
8016
8688
  #if defined(GGML_USE_CUBLAS)
8017
- float *d_X = NULL;
8018
- float *d_Y = NULL;
8019
- float *d_D = NULL;
8020
- float *d_Q = NULL;
8021
8689
  const float alpha = 1.0f;
8022
8690
  const float beta = 0.0f;
8023
8691
  const int x_ne = ne01 * ne10;
8024
8692
  const int y_ne = ne11 * ne10;
8025
8693
  const int d_ne = ne11 * ne01;
8026
8694
 
8027
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
- CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8695
+ size_t x_size, y_size, d_size, q_size;
8696
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8697
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8698
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8699
+ float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8031
8700
 
8032
8701
  void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
8702
  if (type == GGML_TYPE_Q4_0) {
@@ -8039,10 +8708,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
8039
8708
  else if (type == GGML_TYPE_Q4_2) {
8040
8709
  dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
8710
  }
8711
+ else if (type == GGML_TYPE_Q4_3) {
8712
+ dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8713
+ }
8714
+ else if (type == GGML_TYPE_Q5_0) {
8715
+ dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8716
+ }
8717
+ else if (type == GGML_TYPE_Q5_1) {
8718
+ dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8719
+ }
8720
+ else if (type == GGML_TYPE_Q8_0) {
8721
+ dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8722
+ }
8042
8723
  else {
8043
8724
  GGML_ASSERT(false);
8044
8725
  }
8045
- #else
8726
+ #elif !defined(GGML_USE_CLBLAST)
8046
8727
  float * const wdata = params->wdata;
8047
8728
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
8729
  #endif
@@ -8057,10 +8738,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
8057
8738
  // copy and dequantize on device
8058
8739
  CUDA_CHECK(
8059
8740
  cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
- GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8741
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8061
8742
 
8062
- dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8743
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8063
8744
  CUDA_CHECK(cudaGetLastError());
8745
+ #elif defined(GGML_USE_CLBLAST)
8746
+ const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
8064
8747
  #else
8065
8748
  {
8066
8749
  size_t id = 0;
@@ -8075,20 +8758,27 @@ static void ggml_compute_forward_mul_mat_q_f32(
8075
8758
 
8076
8759
  #if defined(GGML_USE_CUBLAS)
8077
8760
  // copy data to device
8078
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8761
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8079
8762
 
8080
8763
  // compute
8081
8764
  CUBLAS_CHECK(
8082
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8765
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
8766
  ne01, ne11, ne10,
8084
8767
  &alpha, d_X, ne00,
8085
8768
  d_Y, ne10,
8086
8769
  &beta, d_D, ne01));
8087
8770
 
8088
8771
  // copy data to host
8089
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
- #else
8772
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8773
+ #elif defined(GGML_USE_CLBLAST)
8091
8774
  // zT = y * xT
8775
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8776
+ ne11, ne01, ne10,
8777
+ 1.0f, y, ne10,
8778
+ x, ne10,
8779
+ 0.0f, d, ne01,
8780
+ type);
8781
+ #else
8092
8782
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
8093
8783
  ne11, ne01, ne10,
8094
8784
  1.0f, y, ne10,
@@ -8099,11 +8789,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
8099
8789
  }
8100
8790
 
8101
8791
  #if defined(GGML_USE_CUBLAS)
8102
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
- CUDA_CHECK(cudaFree(d_X));
8104
- CUDA_CHECK(cudaFree(d_Y));
8105
- CUDA_CHECK(cudaFree(d_D));
8106
- CUDA_CHECK(cudaFree(d_Q));
8792
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8793
+ ggml_cuda_pool_free(d_X, x_size);
8794
+ ggml_cuda_pool_free(d_Y, y_size);
8795
+ ggml_cuda_pool_free(d_D, d_size);
8796
+ ggml_cuda_pool_free(d_Q, q_size);
8107
8797
  #endif
8108
8798
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
8109
8799
 
@@ -8113,7 +8803,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8113
8803
 
8114
8804
  if (params->type == GGML_TASK_INIT) {
8115
8805
  char * wdata = params->wdata;
8116
- const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8806
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8117
8807
 
8118
8808
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
8119
8809
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -8144,7 +8834,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8144
8834
  const int ir1 = MIN(ir0 + dr, nr);
8145
8835
 
8146
8836
  void * wdata = params->wdata;
8147
- const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8837
+ const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8148
8838
 
8149
8839
  for (int ir = ir0; ir < ir1; ++ir) {
8150
8840
  // src0 indices
@@ -8194,7 +8884,10 @@ static void ggml_compute_forward_mul_mat(
8194
8884
  case GGML_TYPE_Q4_1:
8195
8885
  case GGML_TYPE_Q4_2:
8196
8886
  case GGML_TYPE_Q4_3:
8887
+ case GGML_TYPE_Q5_0:
8888
+ case GGML_TYPE_Q5_1:
8197
8889
  case GGML_TYPE_Q8_0:
8890
+ case GGML_TYPE_Q8_1:
8198
8891
  {
8199
8892
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
8200
8893
  } break;
@@ -8423,7 +9116,10 @@ static void ggml_compute_forward_get_rows(
8423
9116
  case GGML_TYPE_Q4_1:
8424
9117
  case GGML_TYPE_Q4_2:
8425
9118
  case GGML_TYPE_Q4_3:
9119
+ case GGML_TYPE_Q5_0:
9120
+ case GGML_TYPE_Q5_1:
8426
9121
  case GGML_TYPE_Q8_0:
9122
+ case GGML_TYPE_Q8_1:
8427
9123
  {
8428
9124
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
8429
9125
  } break;
@@ -8561,6 +9257,7 @@ static void ggml_compute_forward_soft_max_f32(
8561
9257
 
8562
9258
  uint16_t scvt;
8563
9259
  for (int i = 0; i < nc; i++) {
9260
+ //printf("p[%3d] = %8.4f\n", i, p[i]);
8564
9261
  if (p[i] == -INFINITY) {
8565
9262
  p[i] = 0.0f;
8566
9263
  } else {
@@ -10921,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10921
11618
  size_t cur = 0;
10922
11619
 
10923
11620
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10924
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11621
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10925
11622
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10926
11623
  node->n_tasks = 1; // TODO: this actually is doing nothing
10927
11624
  // the threads are still spinning
@@ -10938,14 +11635,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10938
11635
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10939
11636
  cur = 0;
10940
11637
  } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11638
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10942
11639
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10943
11640
  node->n_tasks = 1;
10944
11641
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10945
11642
  } else
10946
11643
  #endif
10947
11644
  {
10948
- cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
11645
+ const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
11646
+ cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
10949
11647
  }
10950
11648
  } else {
10951
11649
  GGML_ASSERT(false);
@@ -11273,9 +11971,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11273
11971
  for (int i = 0; i < cgraph->n_nodes; i++) {
11274
11972
  struct ggml_tensor * node = cgraph->nodes[i];
11275
11973
 
11276
- perf_total_per_op_us[node->op] += node->perf_time_us;
11974
+ perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
11277
11975
 
11278
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11976
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11279
11977
  i,
11280
11978
  node->ne[0], node->ne[1], node->ne[2],
11281
11979
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -11289,13 +11987,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11289
11987
  for (int i = 0; i < cgraph->n_leafs; i++) {
11290
11988
  struct ggml_tensor * node = cgraph->leafs[i];
11291
11989
 
11292
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
11990
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
11293
11991
  i,
11294
11992
  node->ne[0], node->ne[1],
11295
11993
  GGML_OP_LABEL[node->op]);
11296
11994
  }
11297
11995
 
11298
11996
  for (int i = 0; i < GGML_OP_COUNT; i++) {
11997
+ if (perf_total_per_op_us[i] == 0) {
11998
+ continue;
11999
+ }
12000
+
11299
12001
  GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
11300
12002
  }
11301
12003
 
@@ -12129,7 +12831,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
12129
12831
 
12130
12832
  for (int i = 0; i < nb; i++) {
12131
12833
  for (int l = 0; l < QK4_0; l += 2) {
12132
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12834
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12133
12835
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12134
12836
 
12135
12837
  hist[vi0]++;
@@ -12152,7 +12854,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
12152
12854
 
12153
12855
  for (int i = 0; i < nb; i++) {
12154
12856
  for (int l = 0; l < QK4_1; l += 2) {
12155
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12857
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12156
12858
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12157
12859
 
12158
12860
  hist[vi0]++;
@@ -12171,12 +12873,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12171
12873
  for (int j = 0; j < n; j += k) {
12172
12874
  block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
12875
 
12174
- //quantize_row_q4_2_reference(src + j, y, k);
12175
- quantize_row_q4_2_rmse(src + j, y, k);
12876
+ quantize_row_q4_2_reference(src + j, y, k);
12176
12877
 
12177
12878
  for (int i = 0; i < nb; i++) {
12178
12879
  for (int l = 0; l < QK4_2; l += 2) {
12179
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12880
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12180
12881
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
12882
 
12182
12883
  hist[vi0]++;
@@ -12199,7 +12900,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
12199
12900
 
12200
12901
  for (int i = 0; i < nb; i++) {
12201
12902
  for (int l = 0; l < QK4_3; l += 2) {
12202
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12903
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12203
12904
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12204
12905
 
12205
12906
  hist[vi0]++;
@@ -12211,6 +12912,87 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
12211
12912
  return (n/QK4_3*sizeof(block_q4_3));
12212
12913
  }
12213
12914
 
12915
+ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12916
+ assert(k % QK5_0 == 0);
12917
+ const int nb = k / QK5_0;
12918
+
12919
+ for (int j = 0; j < n; j += k) {
12920
+ block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
12921
+
12922
+ quantize_row_q5_0_reference(src + j, y, k);
12923
+
12924
+ for (int i = 0; i < nb; i++) {
12925
+ uint32_t qh;
12926
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12927
+
12928
+ for (int l = 0; l < QK5_0; l += 2) {
12929
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12930
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12931
+
12932
+ // cast to 16 bins
12933
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12934
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12935
+
12936
+ hist[vi0]++;
12937
+ hist[vi1]++;
12938
+ }
12939
+ }
12940
+ }
12941
+
12942
+ return (n/QK5_0*sizeof(block_q5_0));
12943
+ }
12944
+
12945
+ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
12946
+ assert(k % QK5_1 == 0);
12947
+ const int nb = k / QK5_1;
12948
+
12949
+ for (int j = 0; j < n; j += k) {
12950
+ block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
12951
+
12952
+ quantize_row_q5_1_reference(src + j, y, k);
12953
+
12954
+ for (int i = 0; i < nb; i++) {
12955
+ uint32_t qh;
12956
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12957
+
12958
+ for (int l = 0; l < QK5_1; l += 2) {
12959
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12960
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12961
+
12962
+ // cast to 16 bins
12963
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12964
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12965
+
12966
+ hist[vi0]++;
12967
+ hist[vi1]++;
12968
+ }
12969
+ }
12970
+ }
12971
+
12972
+ return (n/QK5_1*sizeof(block_q5_1));
12973
+ }
12974
+
12975
+ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12976
+ assert(k % QK8_0 == 0);
12977
+ const int nb = k / QK8_0;
12978
+
12979
+ for (int j = 0; j < n; j += k) {
12980
+ block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
12981
+
12982
+ quantize_row_q8_0_reference(src + j, y, k);
12983
+
12984
+ for (int i = 0; i < nb; i++) {
12985
+ for (int l = 0; l < QK8_0; ++l) {
12986
+ const int8_t vi = y[i].qs[l];
12987
+
12988
+ hist[vi/16 + 8]++;
12989
+ }
12990
+ }
12991
+ }
12992
+
12993
+ return (n/QK8_0*sizeof(block_q8_0));
12994
+ }
12995
+
12214
12996
  size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12215
12997
  size_t result = 0;
12216
12998
  switch (type) {
@@ -12238,6 +13020,24 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
12238
13020
  block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
13021
  result = ggml_quantize_q4_3(src + start, block, n, n, hist);
12240
13022
  } break;
13023
+ case GGML_TYPE_Q5_0:
13024
+ {
13025
+ GGML_ASSERT(start % QK5_0 == 0);
13026
+ block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
13027
+ result = ggml_quantize_q5_0(src + start, block, n, n, hist);
13028
+ } break;
13029
+ case GGML_TYPE_Q5_1:
13030
+ {
13031
+ GGML_ASSERT(start % QK5_1 == 0);
13032
+ block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
13033
+ result = ggml_quantize_q5_1(src + start, block, n, n, hist);
13034
+ } break;
13035
+ case GGML_TYPE_Q8_0:
13036
+ {
13037
+ GGML_ASSERT(start % QK8_0 == 0);
13038
+ block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
13039
+ result = ggml_quantize_q8_0(src + start, block, n, n, hist);
13040
+ } break;
12241
13041
  default:
12242
13042
  assert(false);
12243
13043
  }
@@ -12335,7 +13135,7 @@ int ggml_cpu_has_wasm_simd(void) {
12335
13135
  }
12336
13136
 
12337
13137
  int ggml_cpu_has_blas(void) {
12338
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
13138
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
12339
13139
  return 1;
12340
13140
  #else
12341
13141
  return 0;
@@ -12350,6 +13150,18 @@ int ggml_cpu_has_cublas(void) {
12350
13150
  #endif
12351
13151
  }
12352
13152
 
13153
+ int ggml_cpu_has_clblast(void) {
13154
+ #if defined(GGML_USE_CLBLAST)
13155
+ return 1;
13156
+ #else
13157
+ return 0;
13158
+ #endif
13159
+ }
13160
+
13161
+ int ggml_cpu_has_gpublas(void) {
13162
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
13163
+ }
13164
+
12353
13165
  int ggml_cpu_has_sse3(void) {
12354
13166
  #if defined(__SSE3__)
12355
13167
  return 1;