llama_cpp 0.0.6 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -135,57 +135,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
135
135
  #define UNUSED(x) (void)(x)
136
136
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
137
137
 
138
- #define GGML_ASSERT(x) \
139
- do { \
140
- if (!(x)) { \
141
- fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
142
- abort(); \
143
- } \
144
- } while (0)
145
-
146
138
  #if defined(GGML_USE_ACCELERATE)
147
139
  #include <Accelerate/Accelerate.h>
148
140
  #elif defined(GGML_USE_OPENBLAS)
149
141
  #include <cblas.h>
150
142
  #elif defined(GGML_USE_CUBLAS)
151
- #include <cublas_v2.h>
152
- #include <cuda_runtime.h>
153
143
  #include "ggml-cuda.h"
154
-
155
- #define CUDA_CHECK(err) \
156
- do { \
157
- cudaError_t err_ = (err); \
158
- if (err_ != cudaSuccess) { \
159
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
- cudaGetErrorString(err_)); \
161
- exit(1); \
162
- } \
163
- } while (0)
164
-
165
- #define CUBLAS_CHECK(err) \
166
- do { \
167
- cublasStatus_t err_ = (err); \
168
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
- exit(1); \
171
- } \
172
- } while (0)
173
-
174
- static cublasHandle_t cublasH = NULL;
175
- static cudaStream_t cudaStream = NULL;
176
- static void init_cublas(void) {
177
- if (cublasH == NULL) {
178
- // create cublas handle, bind a stream
179
- CUBLAS_CHECK(cublasCreate(&cublasH));
180
-
181
- CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
-
183
- CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
-
185
- // configure logging to stdout
186
- // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
- }
188
- }
144
+ #elif defined(GGML_USE_CLBLAST)
145
+ #include "ggml-opencl.h"
189
146
  #endif
190
147
 
191
148
  #undef MIN
@@ -223,9 +180,13 @@ typedef double ggml_float;
223
180
  #undef bool
224
181
  #define bool _Bool
225
182
  #else
183
+ #if defined(_MSC_VER) || defined(__MINGW32__)
184
+ #include <intrin.h>
185
+ #else
226
186
  #include <immintrin.h>
227
187
  #endif
228
188
  #endif
189
+ #endif
229
190
 
230
191
  #ifdef __F16C__
231
192
 
@@ -365,6 +326,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
365
326
  // precomputed f32 table for f16 (256 KB)
366
327
  static float table_f32_f16[1 << 16];
367
328
 
329
+ #if defined(__ARM_NEON) || defined(__wasm_simd128__)
330
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
331
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
332
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
333
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
334
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
335
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
336
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
337
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
338
+
339
+ // precomputed tables for expanding 8bits to 8 bytes (shl 4)
340
+ static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
341
+ #endif
342
+
368
343
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
369
344
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
370
345
  // This is also true for POWER9.
@@ -391,6 +366,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
391
366
  return GGML_FP32_TO_FP16(x);
392
367
  }
393
368
 
369
+ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
370
+ for (size_t i = 0; i < n; i++) {
371
+ y[i] = GGML_FP16_TO_FP32(x[i]);
372
+ }
373
+ }
374
+
375
+ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
376
+ size_t i = 0;
377
+ #if defined(__F16C__)
378
+ for (; i + 7 < n; i += 8) {
379
+ __m256 x_vec = _mm256_loadu_ps(x + i);
380
+ __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
381
+ _mm_storeu_si128((__m128i *)(y + i), y_vec);
382
+ }
383
+ for(; i + 3 < n; i += 4) {
384
+ __m128 x_vec = _mm_loadu_ps(x + i);
385
+ __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
386
+ _mm_storel_epi64((__m128i *)(y + i), y_vec);
387
+ }
388
+ #endif
389
+ for (; i < n; i++) {
390
+ y[i] = GGML_FP32_TO_FP16(x[i]);
391
+ }
392
+ }
393
+
394
+
394
395
  //
395
396
  // timing
396
397
  //
@@ -473,7 +474,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
473
474
  static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
475
  {
475
476
  // Load 8 bytes from memory
476
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477
+ __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
477
478
 
478
479
  // Expand bytes into uint16_t values
479
480
  __m128i bytes = _mm_cvtepu8_epi16( tmp );
@@ -487,7 +488,46 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487
488
  return bytes;
488
489
  }
489
490
 
491
+ // horizontally add 8 floats
492
+ static inline float hsum_float_8(const __m256 x) {
493
+ __m128 res = _mm256_extractf128_ps(x, 1);
494
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
495
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
496
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
497
+ return _mm_cvtss_f32(res);
498
+ }
499
+
500
+ // horizontally add 8 int32_t
501
+ static inline int hsum_i32_8(const __m256i a) {
502
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
503
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
504
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
505
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
506
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
507
+ }
508
+
509
+ // horizontally add 4 int32_t
510
+ static inline int hsum_i32_4(const __m128i a) {
511
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
512
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
513
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
514
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
515
+ }
516
+
490
517
  #if __AVX2__ || __AVX512F__
518
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
519
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
520
+ uint32_t x32;
521
+ memcpy(&x32, x, sizeof(uint32_t));
522
+ const __m256i shuf_mask = _mm256_set_epi64x(
523
+ 0x0303030303030303, 0x0202020202020202,
524
+ 0x0101010101010101, 0x0000000000000000);
525
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
526
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
527
+ bytes = _mm256_or_si256(bytes, bit_mask);
528
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
529
+ }
530
+
491
531
  // Unpack 32 4-bit fields into 32 bytes
492
532
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
493
533
  static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
@@ -507,9 +547,38 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507
547
  return bytes;
508
548
  }
509
549
 
550
+ // add int16_t pairwise and return as float vector
551
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
552
+ const __m256i ones = _mm256_set1_epi16(1);
553
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
554
+ return _mm256_cvtepi32_ps(summed_pairs);
555
+ }
556
+
557
+ // multiply int8_t, add results pairwise twice and return as float vector
558
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
559
+ // Get absolute values of x vectors
560
+ const __m256i ax = _mm256_sign_epi8(x, x);
561
+ // Sign the values of the y vectors
562
+ const __m256i sy = _mm256_sign_epi8(y, x);
563
+ #if __AVXVNNI__
564
+ const __m256i zero = _mm256_setzero_si256();
565
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
566
+ return _mm256_cvtepi32_ps(summed_pairs);
567
+ #else
568
+ // Perform multiplication and create 16-bit values
569
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
570
+ return sum_i16_pairs_float(dot);
571
+ #endif
572
+ }
573
+
510
574
  static inline __m128i packNibbles( __m256i bytes )
511
575
  {
512
576
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
577
+ #if __AVX512F__
578
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
579
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
580
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
581
+ #else
513
582
  const __m256i lowByte = _mm256_set1_epi16( 0xFF );
514
583
  __m256i high = _mm256_andnot_si256( lowByte, bytes );
515
584
  __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -520,6 +589,7 @@ static inline __m128i packNibbles( __m256i bytes )
520
589
  __m128i r0 = _mm256_castsi256_si128( bytes );
521
590
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
522
591
  return _mm_packus_epi16( r0, r1 );
592
+ #endif
523
593
  }
524
594
  #else
525
595
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
@@ -605,19 +675,102 @@ float vmaxvq_f32(float32x4_t v) {
605
675
  }
606
676
 
607
677
  int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
608
- return vget_low_s8(vcombine_s8(a, b));
678
+ int8x8_t res;
679
+
680
+ res[0] = a[0]; res[1] = b[0];
681
+ res[2] = a[1]; res[3] = b[1];
682
+ res[4] = a[2]; res[5] = b[2];
683
+ res[6] = a[3]; res[7] = b[3];
684
+
685
+ return res;
609
686
  }
610
687
 
611
688
  int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
612
- return vget_high_s8(vcombine_s8(a, b));
689
+ int8x8_t res;
690
+
691
+ res[0] = a[4]; res[1] = b[4];
692
+ res[2] = a[5]; res[3] = b[5];
693
+ res[4] = a[6]; res[5] = b[6];
694
+ res[6] = a[7]; res[7] = b[7];
695
+
696
+ return res;
613
697
  }
614
698
 
615
699
  uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
616
- return vget_low_u8(vcombine_u8(a, b));
700
+ uint8x8_t res;
701
+
702
+ res[0] = a[0]; res[1] = b[0];
703
+ res[2] = a[1]; res[3] = b[1];
704
+ res[4] = a[2]; res[5] = b[2];
705
+ res[6] = a[3]; res[7] = b[3];
706
+
707
+ return res;
617
708
  }
618
709
 
619
710
  uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
620
- return vget_high_u8(vcombine_u8(a, b));
711
+ uint8x8_t res;
712
+
713
+ res[0] = a[4]; res[1] = b[4];
714
+ res[2] = a[5]; res[3] = b[5];
715
+ res[4] = a[6]; res[5] = b[6];
716
+ res[6] = a[7]; res[7] = b[7];
717
+
718
+ return res;
719
+ }
720
+
721
+ int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
722
+ int8x16_t res;
723
+
724
+ res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
725
+ res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
726
+ res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
727
+ res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
728
+
729
+ return res;
730
+ }
731
+
732
+ int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
733
+ int8x16_t res;
734
+
735
+ res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
736
+ res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
737
+ res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
738
+ res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
739
+
740
+ return res;
741
+ }
742
+
743
+ uint8x16_t vzip1q_u8(uint8x16_t a, uint8x16_t b) {
744
+ uint8x16_t res;
745
+
746
+ res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
747
+ res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
748
+ res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
749
+ res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
750
+
751
+ return res;
752
+ }
753
+
754
+ uint8x16_t vzip2q_u8(uint8x16_t a, uint8x16_t b) {
755
+ uint8x16_t res;
756
+
757
+ res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
758
+ res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
759
+ res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
760
+ res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
761
+
762
+ return res;
763
+ }
764
+
765
+ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
766
+ int32x4_t res;
767
+
768
+ res[0] = roundf(vgetq_lane_f32(v, 0));
769
+ res[1] = roundf(vgetq_lane_f32(v, 1));
770
+ res[2] = roundf(vgetq_lane_f32(v, 2));
771
+ res[3] = roundf(vgetq_lane_f32(v, 3));
772
+
773
+ return res;
621
774
  }
622
775
 
623
776
  #endif
@@ -646,13 +799,22 @@ typedef struct {
646
799
  } block_q4_2;
647
800
  static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
648
801
 
649
- #define QK4_3 16
802
+ #define QK5_0 32
803
+ typedef struct {
804
+ ggml_fp16_t d; // delta
805
+ uint8_t qh[4]; // 5-th bit of quants
806
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
807
+ } block_q5_0;
808
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
809
+
810
+ #define QK5_1 32
650
811
  typedef struct {
651
812
  ggml_fp16_t d; // delta
652
813
  ggml_fp16_t m; // min
653
- uint8_t qs[QK4_3 / 2]; // nibbles / quants
654
- } block_q4_3;
655
- static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
814
+ uint8_t qh[4]; // 5-th bit of quants
815
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
816
+ } block_q5_1;
817
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
656
818
 
657
819
  #define QK8_0 32
658
820
  typedef struct {
@@ -661,6 +823,14 @@ typedef struct {
661
823
  } block_q8_0;
662
824
  static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663
825
 
826
+ #define QK8_1 32
827
+ typedef struct {
828
+ float d; // delta
829
+ float s0; // d * sum(qs[i]) low
830
+ float s1; // d * sum(qs[i]) high
831
+ int8_t qs[QK8_1]; // quants
832
+ } block_q8_1;
833
+ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
664
834
 
665
835
  // reference implementation for deterministic creation of model files
666
836
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -671,13 +841,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
671
841
 
672
842
  for (int i = 0; i < nb; i++) {
673
843
  float amax = 0.0f; // absolute max
844
+ float max = 0.0f;
674
845
 
675
846
  for (int l = 0; l < QK4_0; l++) {
676
847
  const float v = x[i*QK4_0 + l];
677
- amax = MAX(amax, fabsf(v));
848
+ if (amax < fabsf(v)) {
849
+ amax = fabsf(v);
850
+ max = v;
851
+ }
678
852
  }
679
853
 
680
- const float d = amax / ((1 << 3) - 1);
854
+ const float d = max / -8;
681
855
  const float id = d ? 1.0f/d : 0.0f;
682
856
 
683
857
  y[i].d = d;
@@ -686,8 +860,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
686
860
  const float v0 = x[i*QK4_0 + l + 0]*id;
687
861
  const float v1 = x[i*QK4_0 + l + 1]*id;
688
862
 
689
- const uint8_t vi0 = (int8_t)roundf(v0) + 8;
690
- const uint8_t vi1 = (int8_t)roundf(v1) + 8;
863
+ const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
864
+ const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
691
865
 
692
866
  assert(vi0 < 16);
693
867
  assert(vi1 < 16);
@@ -707,28 +881,43 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
707
881
 
708
882
  #if defined(__POWER9_VECTOR__)
709
883
  const vector float v85 = vec_splats(8.5f);
884
+ const vector signed int v15 = vec_splats(15);
710
885
  for (int i = 0; i < nb; i++) {
711
- float amax = 0.0f; // absolute max
886
+ float max = 0.0f;
887
+ float min = 0.0f;
712
888
 
889
+ vector float asrcv [8];
713
890
  vector float srcv [8];
714
- vector float asrcv[8];
715
- vector float amaxv[8];
891
+ vector float maxv[8];
892
+ vector float minv[8];
716
893
 
717
894
  for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
718
- for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
719
-
720
- for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
721
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
722
- amaxv[0] = vec_max(amaxv[0], amaxv[2]);
723
- amaxv[4] = vec_max(amaxv[4], amaxv[6]);
724
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
725
- amaxv[0] = vec_max(amaxv[0], amaxv[4]);
726
-
727
- amax = MAX(
728
- MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
729
- MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
730
-
731
- const float d = amax / ((1 << 3) - 1);
895
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
896
+
897
+ for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
898
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
899
+ maxv[0] = vec_max(maxv[0], maxv[2]);
900
+ maxv[4] = vec_max(maxv[4], maxv[6]);
901
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
902
+ maxv[0] = vec_max(maxv[0], maxv[4]);
903
+
904
+ for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
905
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
906
+ minv[0] = vec_min(minv[0], minv[2]);
907
+ minv[4] = vec_min(minv[4], minv[6]);
908
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
909
+ minv[0] = vec_min(minv[0], minv[4]);
910
+
911
+
912
+ max = MAX(
913
+ MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
914
+ MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
915
+ min = MIN(
916
+ MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
917
+ MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
918
+
919
+ const float magnitude = max >= fabsf(min) ? max : min;
920
+ const float d = magnitude / -8;
732
921
  const float id = d ? 1.0/d : 0.0;
733
922
 
734
923
  y[i].d = d;
@@ -738,27 +927,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
738
927
  for (int l = 0; l < 8; l++) {
739
928
  const vector float vf = vec_madd(srcv[l], vid, v85);
740
929
  const vector signed int vi = vec_signed(vf);
930
+ const vector signed int vc = vec_min(vi, v15);
741
931
 
742
- pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
743
- pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
932
+ pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
933
+ pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
744
934
  }
745
935
  }
746
936
  #elif __ARM_NEON
747
937
  for (int i = 0; i < nb; i++) {
748
938
  float32x4_t srcv [8];
749
- float32x4_t asrcv[8];
750
- float32x4_t amaxv[8];
939
+ float32x4_t maxv[8];
940
+ float32x4_t minv[8];
751
941
 
752
942
  for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
753
- for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
754
943
 
755
- for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
756
- for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
757
- for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
944
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
945
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
946
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
758
947
 
759
- const float amax = vmaxvq_f32(amaxv[0]);
948
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
949
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
950
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
951
+
952
+ const float max = vmaxvq_f32(maxv[0]);
953
+ const float min = vminvq_f32(minv[0]);
760
954
 
761
- const float d = amax / ((1 << 3) - 1);
955
+ const float magnitude = max >= fabsf(min) ? max : min;
956
+ const float d = magnitude / -8;
762
957
  const float id = d ? 1.0f/d : 0.0f;
763
958
 
764
959
  y[i].d = d;
@@ -767,9 +962,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
767
962
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
768
963
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
769
964
  const int32x4_t vi = vcvtq_s32_f32(vf);
965
+ const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
770
966
 
771
- y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
772
- y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
967
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
968
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
773
969
  }
774
970
  }
775
971
  #elif defined(__AVX2__)
@@ -781,22 +977,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
781
977
  __m256 v3 = _mm256_loadu_ps( x + 24 );
782
978
  x += 32;
783
979
 
784
- // Compute max(abs(e)) for the block
785
- const __m256 signBit = _mm256_set1_ps( -0.0f );
786
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
787
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
788
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
789
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
980
+ // Compute max for the block
981
+ __m256 max = _mm256_max_ps( v0, v1 );
982
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
983
+ max = _mm256_max_ps( max, maxTmp );
790
984
 
791
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
985
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
792
986
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
793
987
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
794
988
  const float maxScalar = _mm_cvtss_f32( max4 );
795
989
 
990
+ // Compute min for the block
991
+ __m256 min = _mm256_min_ps( v0, v1 );
992
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
993
+ min = _mm256_min_ps( min, minTmp );
994
+
995
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
996
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
997
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
998
+ const float minScalar = _mm_cvtss_f32( min4 );
999
+
796
1000
  // Quantize these floats
797
- const float d = maxScalar / 7.0f;
1001
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
1002
+ const float d = magnitude / -8.0f;
798
1003
  y[i].d = d;
799
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
1004
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
800
1005
  const __m256 mul = _mm256_set1_ps( id );
801
1006
 
802
1007
  // Apply the multiplier
@@ -829,9 +1034,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
829
1034
  const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
830
1035
  i0 = _mm256_permutevar8x32_epi32( i0, perm );
831
1036
 
832
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
1037
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
833
1038
  const __m256i off = _mm256_set1_epi8( 8 );
834
1039
  i0 = _mm256_add_epi8( i0, off );
1040
+ const __m256i maxNibble = _mm256_set1_epi8( 15 );
1041
+ i0 = _mm256_min_epi8( i0, maxNibble );
835
1042
 
836
1043
  // Compress the vector into 4 bit/value, and store
837
1044
  __m128i res = packNibbles( i0 );
@@ -846,22 +1053,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
846
1053
  __m256 v3 = _mm256_loadu_ps( x + 24 );
847
1054
  x += 32;
848
1055
 
849
- // Compute max(abs(e)) for the block
850
- const __m256 signBit = _mm256_set1_ps( -0.0f );
851
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
852
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
853
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
854
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1056
+ // Compute max for the block
1057
+ __m256 max = _mm256_max_ps( v0, v1 );
1058
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
1059
+ max = _mm256_max_ps( max, maxTmp );
855
1060
 
856
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1061
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
857
1062
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
858
1063
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
859
1064
  const float maxScalar = _mm_cvtss_f32( max4 );
860
1065
 
1066
+ // Compute min for the block
1067
+ __m256 min = _mm256_min_ps( v0, v1 );
1068
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
1069
+ min = _mm256_min_ps( min, minTmp );
1070
+
1071
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
1072
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
1073
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
1074
+ const float minScalar = _mm_cvtss_f32( min4 );
1075
+
861
1076
  // Quantize these floats
862
- const float d = maxScalar / 7.0f;
1077
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
1078
+ const float d = magnitude / -8.0f;
863
1079
  y[i].d = d;
864
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
1080
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
865
1081
  const __m256 mul = _mm256_set1_ps( id );
866
1082
 
867
1083
  // Apply the multiplier
@@ -902,10 +1118,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
902
1118
  ni0 = _mm_packs_epi16( ni0, ni2 );
903
1119
  ni4 = _mm_packs_epi16( ni4, ni6 );
904
1120
 
905
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
906
- const __m128i off = _mm_set1_epi8( 8);
1121
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
1122
+ const __m128i off = _mm_set1_epi8( 8 );
907
1123
  ni0 = _mm_add_epi8( ni0, off );
908
1124
  ni4 = _mm_add_epi8( ni4, off );
1125
+ const __m128i maxNibble = _mm_set1_epi8( 15 );
1126
+ ni0 = _mm_min_epi8( ni0, maxNibble );
1127
+ ni4 = _mm_min_epi8( ni4, maxNibble );
909
1128
 
910
1129
  // Compress the vector into 4 bit/value, and store
911
1130
  __m128i res = packNibbles( ni0, ni4 );
@@ -913,24 +1132,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
913
1132
  }
914
1133
  #elif defined(__wasm_simd128__)
915
1134
  for (int i = 0; i < nb; i++) {
916
- float amax = 0.0f; // absolute max
1135
+ float max = 0.0f;
1136
+ float min = 0.0f;
917
1137
 
918
1138
  v128_t srcv [8];
919
- v128_t asrcv[8];
920
- v128_t amaxv[8];
1139
+ v128_t maxv[8];
1140
+ v128_t minv[8];
921
1141
 
922
1142
  for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
923
- for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
924
1143
 
925
- for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
926
- for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
927
- for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
1144
+ for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
1145
+ for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
1146
+ for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
1147
+
1148
+ for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
1149
+ for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
1150
+ for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
928
1151
 
929
- amax = MAX(
930
- MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
931
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
1152
+ max = MAX(
1153
+ MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
1154
+ MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
1155
+ min = MIN(
1156
+ MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
1157
+ MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
932
1158
 
933
- const float d = amax / ((1 << 3) - 1);
1159
+ const float magnitude = max >= fabsf(min) ? max : min;
1160
+ const float d = magnitude / -8;
934
1161
  const float id = d ? 1.0/d : 0.0;
935
1162
 
936
1163
  y[i].d = d;
@@ -939,9 +1166,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
939
1166
  const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
940
1167
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
941
1168
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1169
+ const v128_t vc = wasm_i32x4_min(vi, wasm_i32x4_splat(15));
942
1170
 
943
- y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
944
- y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
1171
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
1172
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
945
1173
  }
946
1174
  }
947
1175
  #else
@@ -1122,13 +1350,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1122
1350
 
1123
1351
  for (int i = 0; i < nb; i++) {
1124
1352
  float amax = 0.0f; // absolute max
1353
+ float max = 0.0f;
1125
1354
 
1126
1355
  for (int l = 0; l < QK4_2; l++) {
1127
1356
  const float v = x[i*QK4_2 + l];
1128
- amax = MAX(amax, fabsf(v));
1357
+ if (amax < fabsf(v)) {
1358
+ amax = fabsf(v);
1359
+ max = v;
1360
+ }
1129
1361
  }
1130
1362
 
1131
- const float d = amax / ((1 << 3) - 1);
1363
+ const float d = max / -8;
1132
1364
 
1133
1365
  const float id = d ? 1.0f/d : 0.0f;
1134
1366
 
@@ -1138,8 +1370,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1138
1370
  const float v0 = x[i*QK4_2 + l + 0]*id;
1139
1371
  const float v1 = x[i*QK4_2 + l + 1]*id;
1140
1372
 
1141
- const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
- const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1373
+ const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
1374
+ const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
1143
1375
 
1144
1376
  assert(vi0 < 16);
1145
1377
  assert(vi1 < 16);
@@ -1149,136 +1381,109 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1149
1381
  }
1150
1382
  }
1151
1383
 
1152
- static inline int nearest_int(float fval) {
1153
- assert(fval <= 4194303.f);
1154
- float val = fval + 12582912.f;
1155
- int i; memcpy(&i, &val, sizeof(int));
1156
- return (i & 0x007fffff) - 0x00400000;
1157
- }
1158
-
1159
- static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
- const float * restrict candidates, int8_t * restrict L) {
1161
- assert (nmin >= INT8_MIN);
1162
- assert (nmax <= INT8_MAX);
1163
- float amax = 0;
1164
- for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
- if (!amax) { // all zero
1166
- for (int i=0; i<n; ++i) L[i] = 0;
1167
- return 1.f;
1168
- }
1169
- float best = 0, bestScale = 0;
1170
- for (int si=0; si<nCandidates; ++si) {
1171
- float iscale = candidates[si]/amax;
1172
- float sumlxP = 0; int suml2P = 0;
1173
- float sumlxM = 0; int suml2M = 0;
1174
- for (int i=0; i<n; ++i) {
1175
- int l = nearest_int(iscale*X[i]);
1176
- int lp = MAX(nmin, MIN(nmax, +l));
1177
- int lm = MAX(nmin, MIN(nmax, -l));
1178
- sumlxP += X[i]*lp; suml2P += lp*lp;
1179
- sumlxM += X[i]*lm; suml2M += lm*lm;
1180
- }
1181
- float sumlxP2 = sumlxP*sumlxP;
1182
- float sumlxM2 = sumlxM*sumlxM;
1183
- if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
- if (sumlxP2 > best*suml2P) {
1185
- best = sumlxP2/suml2P; bestScale = iscale;
1186
- }
1187
- } else {
1188
- if (sumlxM2 > best*suml2M) {
1189
- best = sumlxM2/suml2M; bestScale = -iscale;
1384
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1385
+ assert(k % QK4_2 == 0);
1386
+
1387
+ block_q4_2 * restrict y = vy;
1388
+
1389
+ quantize_row_q4_2_reference(x, y, k);
1390
+ }
1391
+
1392
+ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
1393
+ assert(k % QK5_0 == 0);
1394
+ const int nb = k / QK5_0;
1395
+
1396
+ for (int i = 0; i < nb; i++) {
1397
+ float amax = 0.0f; // absolute max
1398
+ float max = 0.0f;
1399
+
1400
+ for (int l = 0; l < QK5_0; l++) {
1401
+ const float v = x[i*QK5_0 + l];
1402
+ if (amax < fabsf(v)) {
1403
+ amax = fabsf(v);
1404
+ max = v;
1190
1405
  }
1191
1406
  }
1192
- }
1193
- float sumlx = 0; int suml2 = 0;
1194
- for (int i=0; i<n; ++i) {
1195
- int l = nearest_int(bestScale*X[i]);
1196
- l = MAX(nmin, MIN(nmax, l));
1197
- sumlx += X[i]*l; suml2 += l*l;
1198
- L[i] = l;
1199
- }
1200
- float scale = sumlx/suml2;
1201
- return scale;
1202
- }
1203
1407
 
1204
- static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
- #define CANDIDATE_COUNT 8
1206
- static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
- assert(k % QK4_2 == 0);
1408
+ const float d = max / -16;
1409
+ const float id = d ? 1.0f/d : 0.0f;
1208
1410
 
1209
- int8_t L[QK4_2];
1411
+ y[i].d = GGML_FP32_TO_FP16(d);
1210
1412
 
1211
- const int nb = k / QK4_2;
1413
+ uint32_t qh = 0;
1212
1414
 
1213
- for (int i = 0; i < nb; i++) {
1214
- float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
- y[i].d = GGML_FP32_TO_FP16(scale);
1415
+ for (int l = 0; l < QK5_0; l += 2) {
1416
+ const float v0 = x[i*QK5_0 + l + 0]*id;
1417
+ const float v1 = x[i*QK5_0 + l + 1]*id;
1216
1418
 
1217
- for (int l = 0; l < QK4_2; l += 2) {
1218
- const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
- const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1419
+ const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1420
+ const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1220
1421
 
1221
- assert(vi0 < 16);
1222
- assert(vi1 < 16);
1422
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1223
1423
 
1224
- y[i].qs[l/2] = vi0 | (vi1 << 4);
1424
+ // get the 5-th bit and store it in qh at the right position
1425
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1426
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1225
1427
  }
1226
1428
 
1227
- x += QK4_2;
1429
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1228
1430
  }
1229
1431
  }
1230
1432
 
1231
- static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1232
- assert(k % QK4_2 == 0);
1433
+ static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1434
+ assert(k % QK5_0 == 0);
1233
1435
 
1234
- block_q4_2 * restrict y = vy;
1436
+ block_q5_0 * restrict y = vy;
1235
1437
 
1236
- //quantize_row_q4_2_reference(x, y, k);
1237
- // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
- quantize_row_q4_2_rmse(x, y, k);
1438
+ quantize_row_q5_0_reference(x, y, k);
1239
1439
  }
1240
1440
 
1241
- static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1242
- assert(k % QK4_3 == 0);
1243
- const int nb = k / QK4_3;
1441
+ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1442
+ assert(k % QK5_1 == 0);
1443
+ const int nb = k / QK5_1;
1244
1444
 
1245
1445
  for (int i = 0; i < nb; i++) {
1246
1446
  float min = FLT_MAX;
1247
1447
  float max = -FLT_MAX;
1248
1448
 
1249
- for (int l = 0; l < QK4_3; l++) {
1250
- const float v = x[i*QK4_3 + l];
1449
+ for (int l = 0; l < QK5_1; l++) {
1450
+ const float v = x[i*QK5_1 + l];
1251
1451
  if (v < min) min = v;
1252
1452
  if (v > max) max = v;
1253
1453
  }
1254
1454
 
1255
- const float d = (max - min) / ((1 << 4) - 1);
1455
+ const float d = (max - min) / ((1 << 5) - 1);
1256
1456
  const float id = d ? 1.0f/d : 0.0f;
1257
1457
 
1258
1458
  y[i].d = GGML_FP32_TO_FP16(d);
1259
1459
  y[i].m = GGML_FP32_TO_FP16(min);
1260
1460
 
1261
- for (int l = 0; l < QK4_3; l += 2) {
1262
- const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1263
- const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1461
+ uint32_t qh = 0;
1264
1462
 
1265
- const uint8_t vi0 = (int) (v0 + 0.5f);
1266
- const uint8_t vi1 = (int) (v1 + 0.5f);
1463
+ for (int l = 0; l < QK5_1; l += 2) {
1464
+ const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1465
+ const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1267
1466
 
1268
- assert(vi0 < 16);
1269
- assert(vi1 < 16);
1467
+ const uint32_t vi0 = (int) (v0 + 0.5f);
1468
+ const uint32_t vi1 = (int) (v1 + 0.5f);
1270
1469
 
1271
- y[i].qs[l/2] = vi0 | (vi1 << 4);
1470
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1471
+
1472
+ // get the 5-th bit and store it in qh at the right position
1473
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1474
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1272
1475
  }
1476
+
1477
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1273
1478
  }
1274
1479
  }
1275
1480
 
1276
- static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1277
- assert(k % QK4_3 == 0);
1481
+ static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1482
+ assert(k % QK5_1 == 0);
1278
1483
 
1279
- block_q4_3 * restrict y = vy;
1484
+ block_q5_1 * restrict y = vy;
1280
1485
 
1281
- quantize_row_q4_3_reference(x, y, k);
1486
+ quantize_row_q5_1_reference(x, y, k);
1282
1487
  }
1283
1488
 
1284
1489
  // reference implementation for deterministic creation of model files
@@ -1300,13 +1505,15 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1300
1505
  y[i].d = d;
1301
1506
 
1302
1507
  for (int l = 0; l < QK8_0; ++l) {
1303
- const float v = x[i*QK8_0 + l]*id;
1304
- y[i].qs[l] = roundf(v);
1508
+ const float v0 = x[i*QK8_0 + l]*id;
1509
+
1510
+ y[i].qs[l] = roundf(v0);
1305
1511
  }
1306
1512
  }
1307
1513
  }
1308
1514
 
1309
1515
  static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1516
+ assert(QK8_0 == 32);
1310
1517
  assert(k % QK8_0 == 0);
1311
1518
  const int nb = k / QK8_0;
1312
1519
 
@@ -1432,95 +1639,295 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1432
1639
  #endif
1433
1640
  }
1434
1641
 
1435
- static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1436
- assert(k % QK4_0 == 0);
1437
- const int nb = k / QK4_0;
1438
-
1439
- const block_q4_0 * restrict x = vx;
1642
+ // reference implementation for deterministic creation of model files
1643
+ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1644
+ assert(QK8_1 == 32);
1645
+ assert(k % QK8_1 == 0);
1646
+ const int nb = k / QK8_1;
1440
1647
 
1441
- #if defined(__AVX2__)
1442
1648
  for (int i = 0; i < nb; i++) {
1443
- // scale factor
1444
- const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
1649
+ float amax = 0.0f; // absolute max
1445
1650
 
1446
- const uint8_t * restrict pp = x[i].qs;
1651
+ for (int l = 0; l < QK8_1; l++) {
1652
+ const float v = x[i*QK8_1 + l];
1653
+ amax = MAX(amax, fabsf(v));
1654
+ }
1447
1655
 
1448
- for (int l = 0; l < QK4_0; l += 32) {
1449
- // Load 32x4-bit integers into 32x8-bit integers
1450
- __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1656
+ const float d = amax / ((1 << 7) - 1);
1657
+ const float id = d ? 1.0f/d : 0.0f;
1451
1658
 
1452
- // Subtract 8 from the integers
1453
- vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
1659
+ y[i].d = d;
1454
1660
 
1455
- // Convert to 16-bit int
1456
- const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
1457
- const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
1661
+ int sum0 = 0;
1662
+ int sum1 = 0;
1458
1663
 
1459
- // Convert to 32-bit int -> float 32
1460
- const __m256 vf[4] = {
1461
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
1462
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
1463
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
1464
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
1465
- };
1664
+ for (int l = 0; l < QK8_1/2; ++l) {
1665
+ const float v0 = x[i*QK8_1 + l]*id;
1666
+ const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
1466
1667
 
1467
- // Scale and store
1468
- for (int j = 0; j < 4; j++) {
1469
- const __m256 result = _mm256_mul_ps(vf[j], d_v);
1470
- _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1471
- }
1668
+ y[i].qs[ l] = roundf(v0);
1669
+ y[i].qs[QK8_1/2 + l] = roundf(v1);
1670
+
1671
+ sum0 += y[i].qs[ l];
1672
+ sum1 += y[i].qs[QK8_1/2 + l];
1472
1673
  }
1674
+
1675
+ y[i].s0 = d * sum0;
1676
+ y[i].s1 = d * sum1;
1473
1677
  }
1474
- #elif defined(__ARM_NEON)
1475
- for (int i = 0; i < nb; i++) {
1476
- const float32x4_t vd = vdupq_n_f32(x[i].d);
1678
+ }
1477
1679
 
1478
- const uint8_t * restrict pp = x[i].qs;
1680
+ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1681
+ assert(k % QK8_1 == 0);
1682
+ const int nb = k / QK8_1;
1479
1683
 
1480
- for (int l = 0; l < QK4_0; l += 16) {
1481
- // Load 16x4-bit integers into 8x8-bit integers
1482
- const uint8x8_t v8 = vld1_u8(pp + l/2);
1684
+ block_q8_1 * restrict y = vy;
1483
1685
 
1484
- // Expand 4-bit qs to 8-bit bytes
1485
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1486
- const uint8x8_t v1 = vshr_n_u8(v8, 4);
1686
+ #if defined(__ARM_NEON)
1687
+ for (int i = 0; i < nb; i++) {
1688
+ float32x4_t srcv [8];
1689
+ float32x4_t asrcv[8];
1690
+ float32x4_t amaxv[8];
1487
1691
 
1488
- // Convert to signed 8-bit integers
1489
- const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
1490
- const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
1692
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1693
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1491
1694
 
1492
- // Subtract 8 from each byte
1493
- const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
1494
- const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
1695
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1696
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1697
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1495
1698
 
1496
- // Interleave and combine
1497
- const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
1498
- const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
1699
+ const float amax = vmaxvq_f32(amaxv[0]);
1499
1700
 
1500
- const int8x16_t vq = vcombine_s8(vx_0, vx_1);
1701
+ const float d = amax / ((1 << 7) - 1);
1702
+ const float id = d ? 1.0f/d : 0.0f;
1501
1703
 
1502
- // convert to 2x int16x8_t
1503
- const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
1504
- const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
1704
+ y[i].d = d;
1505
1705
 
1506
- // convert to 4x float32x4_t
1507
- const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
1508
- const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
1509
- const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
1510
- const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
1706
+ int32x4_t accv0 = vdupq_n_s32(0);
1707
+ int32x4_t accv1 = vdupq_n_s32(0);
1511
1708
 
1512
- // Multiply by d
1513
- const float32x4_t r0 = vmulq_f32(vf_0, vd);
1514
- const float32x4_t r1 = vmulq_f32(vf_1, vd);
1515
- const float32x4_t r2 = vmulq_f32(vf_2, vd);
1516
- const float32x4_t r3 = vmulq_f32(vf_3, vd);
1709
+ // low half
1710
+ for (int l = 0; l < 4; l++) {
1711
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1712
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1517
1713
 
1518
- // Store
1519
- vst1q_f32(y + i*QK4_0 + l + 0, r0);
1520
- vst1q_f32(y + i*QK4_0 + l + 4, r1);
1521
- vst1q_f32(y + i*QK4_0 + l + 8, r2);
1522
- vst1q_f32(y + i*QK4_0 + l + 12, r3);
1523
- }
1714
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1715
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1716
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1717
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1718
+
1719
+ accv0 = vaddq_s32(accv0, vi);
1720
+ }
1721
+
1722
+ // high half
1723
+ for (int l = 4; l < 8; l++) {
1724
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1725
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1726
+
1727
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1728
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1729
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1730
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1731
+
1732
+ accv1 = vaddq_s32(accv1, vi);
1733
+ }
1734
+
1735
+ const int32_t sum0 = vaddvq_s32(accv0);
1736
+ const int32_t sum1 = vaddvq_s32(accv1);
1737
+
1738
+ y[i].s0 = d * sum0;
1739
+ y[i].s1 = d * sum1;
1740
+ }
1741
+ #elif defined(__AVX2__) || defined(__AVX__)
1742
+ for (int i = 0; i < nb; i++) {
1743
+ // Load elements into 4 AVX vectors
1744
+ __m256 v0 = _mm256_loadu_ps( x );
1745
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1746
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1747
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1748
+ x += 32;
1749
+
1750
+ // Compute max(abs(e)) for the block
1751
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1752
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1753
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1754
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1755
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1756
+
1757
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1758
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1759
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1760
+ const float maxScalar = _mm_cvtss_f32( max4 );
1761
+
1762
+ // Quantize these floats
1763
+ const float d = maxScalar / 127.f;
1764
+ y[i].d = d;
1765
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1766
+ const __m256 mul = _mm256_set1_ps( id );
1767
+
1768
+ // Apply the multiplier
1769
+ v0 = _mm256_mul_ps( v0, mul );
1770
+ v1 = _mm256_mul_ps( v1, mul );
1771
+ v2 = _mm256_mul_ps( v2, mul );
1772
+ v3 = _mm256_mul_ps( v3, mul );
1773
+
1774
+ // Round to nearest integer
1775
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1776
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1777
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1778
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1779
+
1780
+ // Convert floats to integers
1781
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1782
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1783
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1784
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1785
+
1786
+ #if defined(__AVX2__)
1787
+ // Compute the sum of the quants and set y[i].s
1788
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1789
+ y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1790
+ y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1791
+
1792
+ // Convert int32 to int16
1793
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1794
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1795
+ // Convert int16 to int8
1796
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1797
+
1798
+ // We got our precious signed bytes, but the order is now wrong
1799
+ // These AVX2 pack instructions process 16-byte pieces independently
1800
+ // The following instruction is fixing the order
1801
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1802
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1803
+
1804
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1805
+ #else
1806
+ // Since we don't have in AVX some necessary functions,
1807
+ // we split the registers in half and call AVX2 analogs from SSE
1808
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1809
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1810
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1811
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1812
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1813
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1814
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1815
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1816
+
1817
+ // Compute the sum of the quants and set y[i].s
1818
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1819
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1820
+ y[i].s0 = d * hsum_i32_4(s0);
1821
+ y[i].s1 = d * hsum_i32_4(s1);
1822
+
1823
+ // Convert int32 to int16
1824
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1825
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1826
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1827
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1828
+ // Convert int16 to int8
1829
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1830
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1831
+
1832
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1833
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1834
+ #endif
1835
+ }
1836
+ #else
1837
+ // scalar
1838
+ quantize_row_q8_1_reference(x, y, k);
1839
+ #endif
1840
+ }
1841
+
1842
+ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1843
+ assert(k % QK4_0 == 0);
1844
+ const int nb = k / QK4_0;
1845
+
1846
+ const block_q4_0 * restrict x = vx;
1847
+
1848
+ #if defined(__AVX2__)
1849
+ for (int i = 0; i < nb; i++) {
1850
+ // scale factor
1851
+ const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
1852
+
1853
+ const uint8_t * restrict pp = x[i].qs;
1854
+
1855
+ for (int l = 0; l < QK4_0; l += 32) {
1856
+ // Load 32x4-bit integers into 32x8-bit integers
1857
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1858
+
1859
+ // Subtract 8 from the integers
1860
+ vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
1861
+
1862
+ // Convert to 16-bit int
1863
+ const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
1864
+ const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
1865
+
1866
+ // Convert to 32-bit int -> float 32
1867
+ const __m256 vf[4] = {
1868
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
1869
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
1870
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
1871
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
1872
+ };
1873
+
1874
+ // Scale and store
1875
+ for (int j = 0; j < 4; j++) {
1876
+ const __m256 result = _mm256_mul_ps(vf[j], d_v);
1877
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1878
+ }
1879
+ }
1880
+ }
1881
+ #elif defined(__ARM_NEON)
1882
+ for (int i = 0; i < nb; i++) {
1883
+ const float32x4_t vd = vdupq_n_f32(x[i].d);
1884
+
1885
+ const uint8_t * restrict pp = x[i].qs;
1886
+
1887
+ for (int l = 0; l < QK4_0; l += 16) {
1888
+ // Load 16x4-bit integers into 8x8-bit integers
1889
+ const uint8x8_t v8 = vld1_u8(pp + l/2);
1890
+
1891
+ // Expand 4-bit qs to 8-bit bytes
1892
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1893
+ const uint8x8_t v1 = vshr_n_u8(v8, 4);
1894
+
1895
+ // Convert to signed 8-bit integers
1896
+ const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
1897
+ const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
1898
+
1899
+ // Subtract 8 from each byte
1900
+ const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
1901
+ const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
1902
+
1903
+ // Interleave and combine
1904
+ const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
1905
+ const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
1906
+
1907
+ const int8x16_t vq = vcombine_s8(vx_0, vx_1);
1908
+
1909
+ // convert to 2x int16x8_t
1910
+ const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
1911
+ const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
1912
+
1913
+ // convert to 4x float32x4_t
1914
+ const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
1915
+ const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
1916
+ const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
1917
+ const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
1918
+
1919
+ // Multiply by d
1920
+ const float32x4_t r0 = vmulq_f32(vf_0, vd);
1921
+ const float32x4_t r1 = vmulq_f32(vf_1, vd);
1922
+ const float32x4_t r2 = vmulq_f32(vf_2, vd);
1923
+ const float32x4_t r3 = vmulq_f32(vf_3, vd);
1924
+
1925
+ // Store
1926
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1927
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1928
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1929
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1930
+ }
1524
1931
  }
1525
1932
  #else
1526
1933
  // scalar
@@ -1532,7 +1939,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1532
1939
  for (int l = 0; l < QK4_0; l += 2) {
1533
1940
  const uint8_t vi = pp[l/2];
1534
1941
 
1535
- const int8_t vi0 = vi & 0xf;
1942
+ const int8_t vi0 = vi & 0x0F;
1536
1943
  const int8_t vi1 = vi >> 4;
1537
1944
 
1538
1945
  const float v0 = (vi0 - 8)*d;
@@ -1598,7 +2005,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1598
2005
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1599
2006
 
1600
2007
  // Expand 4-bit qs to 8-bit bytes
1601
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
2008
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1602
2009
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1603
2010
 
1604
2011
  // Interleave and combine
@@ -1640,7 +2047,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1640
2047
  for (int l = 0; l < QK4_1; l += 2) {
1641
2048
  const uint8_t vi = pp[l/2];
1642
2049
 
1643
- const int8_t vi0 = vi & 0xf;
2050
+ const int8_t vi0 = vi & 0x0F;
1644
2051
  const int8_t vi1 = vi >> 4;
1645
2052
 
1646
2053
  const float v0 = vi0*d + m;
@@ -1670,7 +2077,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1670
2077
  for (int l = 0; l < QK4_2; l += 2) {
1671
2078
  const uint8_t vi = pp[l/2];
1672
2079
 
1673
- const int8_t vi0 = vi & 0xf;
2080
+ const int8_t vi0 = vi & 0x0F;
1674
2081
  const int8_t vi1 = vi >> 4;
1675
2082
 
1676
2083
  const float v0 = (vi0 - 8)*d;
@@ -1685,11 +2092,47 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1685
2092
  }
1686
2093
  }
1687
2094
 
1688
- static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1689
- assert(k % QK4_3 == 0);
1690
- const int nb = k / QK4_3;
2095
+ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
2096
+ assert(k % QK5_0 == 0);
2097
+ const int nb = k / QK5_0;
2098
+
2099
+ const block_q5_0 * restrict x = vx;
2100
+
2101
+ for (int i = 0; i < nb; i++) {
2102
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2103
+
2104
+ const uint8_t * restrict pp = x[i].qs;
2105
+
2106
+ uint32_t qh;
2107
+ memcpy(&qh, x[i].qh, sizeof(qh));
2108
+
2109
+ for (int l = 0; l < QK5_0; l += 2) {
2110
+ const uint8_t vi = pp[l/2];
2111
+
2112
+ // extract the 5-th bit from qh
2113
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
2114
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
2115
+
2116
+ const int8_t vi0 = (vi & 0x0F) | vh0;
2117
+ const int8_t vi1 = (vi >> 4) | vh1;
2118
+
2119
+ const float v0 = (vi0 - 16)*d;
2120
+ const float v1 = (vi1 - 16)*d;
2121
+
2122
+ y[i*QK5_0 + l + 0] = v0;
2123
+ y[i*QK5_0 + l + 1] = v1;
2124
+
2125
+ assert(!isnan(y[i*QK5_0 + l + 0]));
2126
+ assert(!isnan(y[i*QK5_0 + l + 1]));
2127
+ }
2128
+ }
2129
+ }
2130
+
2131
+ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
2132
+ assert(k % QK5_1 == 0);
2133
+ const int nb = k / QK5_1;
1691
2134
 
1692
- const block_q4_3 * restrict x = vx;
2135
+ const block_q5_1 * restrict x = vx;
1693
2136
 
1694
2137
  for (int i = 0; i < nb; i++) {
1695
2138
  const float d = GGML_FP16_TO_FP32(x[i].d);
@@ -1697,28 +2140,54 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
1697
2140
 
1698
2141
  const uint8_t * restrict pp = x[i].qs;
1699
2142
 
1700
- for (int l = 0; l < QK4_3; l += 2) {
2143
+ uint32_t qh;
2144
+ memcpy(&qh, x[i].qh, sizeof(qh));
2145
+
2146
+ for (int l = 0; l < QK5_1; l += 2) {
1701
2147
  const uint8_t vi = pp[l/2];
1702
2148
 
1703
- const int8_t vi0 = vi & 0xf;
1704
- const int8_t vi1 = vi >> 4;
2149
+ // extract the 5-th bit from qh
2150
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
2151
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
2152
+
2153
+ const uint8_t vi0 = (vi & 0x0F) | vh0;
2154
+ const uint8_t vi1 = (vi >> 4) | vh1;
1705
2155
 
1706
2156
  const float v0 = vi0*d + m;
1707
2157
  const float v1 = vi1*d + m;
1708
2158
 
1709
- y[i*QK4_3 + l + 0] = v0;
1710
- y[i*QK4_3 + l + 1] = v1;
2159
+ y[i*QK5_1 + l + 0] = v0;
2160
+ y[i*QK5_1 + l + 1] = v1;
2161
+
2162
+ assert(!isnan(y[i*QK5_1 + l + 0]));
2163
+ assert(!isnan(y[i*QK5_1 + l + 1]));
2164
+ }
2165
+ }
2166
+ }
2167
+
2168
+ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
2169
+ assert(k % QK8_0 == 0);
2170
+ const int nb = k / QK8_0;
2171
+
2172
+ const block_q8_0 * restrict x = vx;
2173
+
2174
+ for (int i = 0; i < nb; i++) {
2175
+ const float d = x[i].d;
1711
2176
 
1712
- assert(!isnan(y[i*QK4_3 + l + 0]));
1713
- assert(!isnan(y[i*QK4_3 + l + 1]));
2177
+ const int8_t * restrict pp = x[i].qs;
2178
+
2179
+ for (int l = 0; l < QK8_0; ++l) {
2180
+ y[i*QK8_0 + l] = pp[l]*d;
1714
2181
  }
1715
2182
  }
1716
2183
  }
1717
2184
 
1718
2185
  static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2186
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
2187
  static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2188
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2189
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2190
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1722
2191
 
1723
2192
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1724
2193
  [GGML_TYPE_Q4_0] = {
@@ -1727,34 +2196,55 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1727
2196
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1728
2197
  .quantize_row_q_dot = quantize_row_q8_0,
1729
2198
  .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
2199
+ .vec_dot_type = GGML_TYPE_Q8_0,
1730
2200
  },
1731
2201
  [GGML_TYPE_Q4_1] = {
1732
2202
  .dequantize_row_q = dequantize_row_q4_1,
1733
2203
  .quantize_row_q = quantize_row_q4_1,
1734
2204
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1735
- .quantize_row_q_dot = quantize_row_q8_0,
1736
- .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
2205
+ .quantize_row_q_dot = quantize_row_q8_1,
2206
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
2207
+ .vec_dot_type = GGML_TYPE_Q8_1,
1737
2208
  },
1738
2209
  [GGML_TYPE_Q4_2] = {
1739
2210
  .dequantize_row_q = dequantize_row_q4_2,
1740
2211
  .quantize_row_q = quantize_row_q4_2,
1741
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
2212
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
1742
2213
  .quantize_row_q_dot = quantize_row_q8_0,
1743
2214
  .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
2215
+ .vec_dot_type = GGML_TYPE_Q8_0,
1744
2216
  },
1745
- [GGML_TYPE_Q4_3] = {
1746
- .dequantize_row_q = dequantize_row_q4_3,
1747
- .quantize_row_q = quantize_row_q4_3,
1748
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
2217
+ [GGML_TYPE_Q5_0] = {
2218
+ .dequantize_row_q = dequantize_row_q5_0,
2219
+ .quantize_row_q = quantize_row_q5_0,
2220
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
1749
2221
  .quantize_row_q_dot = quantize_row_q8_0,
1750
- .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
2222
+ .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
2223
+ .vec_dot_type = GGML_TYPE_Q8_0,
2224
+ },
2225
+ [GGML_TYPE_Q5_1] = {
2226
+ .dequantize_row_q = dequantize_row_q5_1,
2227
+ .quantize_row_q = quantize_row_q5_1,
2228
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
2229
+ .quantize_row_q_dot = quantize_row_q8_1,
2230
+ .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
2231
+ .vec_dot_type = GGML_TYPE_Q8_1,
1751
2232
  },
1752
2233
  [GGML_TYPE_Q8_0] = {
1753
- .dequantize_row_q = NULL, // TODO
2234
+ .dequantize_row_q = dequantize_row_q8_0,
1754
2235
  .quantize_row_q = quantize_row_q8_0,
1755
2236
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
2237
  .quantize_row_q_dot = quantize_row_q8_0,
2238
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
2239
+ .vec_dot_type = GGML_TYPE_Q8_0,
2240
+ },
2241
+ [GGML_TYPE_Q8_1] = {
2242
+ .dequantize_row_q = NULL, // TODO
2243
+ .quantize_row_q = quantize_row_q8_1,
2244
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
2245
+ .quantize_row_q_dot = quantize_row_q8_1,
1757
2246
  .vec_dot_q = NULL, // TODO
2247
+ .vec_dot_type = GGML_TYPE_Q8_1,
1758
2248
  },
1759
2249
  };
1760
2250
 
@@ -2366,8 +2856,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2366
2856
  const block_q4_0 * restrict x = vx;
2367
2857
  const block_q8_0 * restrict y = vy;
2368
2858
 
2369
- float sumf = 0.0;
2370
-
2371
2859
  #if defined(__ARM_NEON)
2372
2860
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
2861
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2378,7 +2866,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2378
2866
  const block_q8_0 * restrict y0 = &y[i + 0];
2379
2867
  const block_q8_0 * restrict y1 = &y[i + 1];
2380
2868
 
2381
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2869
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2382
2870
  const int8x16_t s8b = vdupq_n_s8(0x8);
2383
2871
 
2384
2872
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2396,35 +2884,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2396
2884
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2397
2885
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2398
2886
 
2887
+ // interleave
2888
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2889
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2890
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2891
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2892
+
2399
2893
  // load y
2400
2894
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2401
2895
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2402
2896
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2403
2897
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2404
2898
 
2405
- // interleave
2406
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2407
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2408
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2409
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2410
-
2411
2899
  #if defined(__ARM_FEATURE_DOTPROD)
2412
2900
  // dot product into int32x4_t
2413
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2901
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2902
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2415
2903
 
2416
2904
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2417
2905
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2418
2906
  #else
2419
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2420
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2421
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2422
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2907
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2908
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2909
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2910
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2423
2911
 
2424
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2425
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2426
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2427
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2912
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2913
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2914
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2915
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2428
2916
 
2429
2917
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2430
2918
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2436,7 +2924,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2924
  #endif
2437
2925
  }
2438
2926
 
2439
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2927
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
2928
  #elif defined(__AVX2__)
2441
2929
  // Initialize accumulator with zeros
2442
2930
  __m256 acc = _mm256_setzero_ps();
@@ -2454,32 +2942,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2454
2942
 
2455
2943
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2456
2944
 
2457
- // Get absolute values of x vectors
2458
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2459
-
2460
- // Sign the values of the y vectors
2461
- const __m256i sy = _mm256_sign_epi8(by, bx);
2462
-
2463
- // Perform multiplication and create 16-bit values
2464
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2465
-
2466
- const __m256i ones = _mm256_set1_epi16(1);
2467
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
2468
-
2469
- /* Convert to vectore of 8 int32_t to 8 floats */
2470
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2945
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2471
2946
 
2472
2947
  /* Multiply q with scale and accumulate */
2473
2948
  acc = _mm256_fmadd_ps( d, q, acc );
2474
2949
  }
2475
2950
 
2476
- // Return horizontal sum of the acc vector
2477
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2478
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2479
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2480
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2481
-
2482
- sumf = _mm_cvtss_f32( res );
2951
+ *s = hsum_float_8(acc);
2483
2952
  #elif defined(__AVX__)
2484
2953
  // Initialize accumulator with zeros
2485
2954
  __m256 acc = _mm256_setzero_ps();
@@ -2518,15 +2987,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2518
2987
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2519
2988
  }
2520
2989
 
2521
- // Return horizontal sum of the acc vector
2522
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2523
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2524
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2525
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2526
-
2527
- sumf = _mm_cvtss_f32( res );
2990
+ *s = hsum_float_8(acc);
2528
2991
  #else
2529
2992
  // scalar
2993
+ float sumf = 0.0;
2530
2994
  for (int i = 0; i < nb; i++) {
2531
2995
  const float d0 = x[i].d;
2532
2996
  const float d1 = y[i].d;
@@ -2538,8 +3002,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2538
3002
  for (int j = 0; j < QK8_0/2; j++) {
2539
3003
  const uint8_t v0 = p0[j];
2540
3004
 
2541
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
- const int i1 = (int8_t) (v0 >> 4) - 8;
3005
+ const int i0 = (int8_t) (v0 & 0x0F) - 8;
3006
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2543
3007
 
2544
3008
  const int i2 = p1[2*j + 0];
2545
3009
  const int i3 = p1[2*j + 1];
@@ -2548,34 +3012,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2548
3012
  }
2549
3013
  sumf += d0*d1*sumi;
2550
3014
  }
2551
- #endif
2552
-
2553
3015
  *s = sumf;
3016
+ #endif
2554
3017
  }
2555
3018
 
2556
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
- const int nb = n / QK8_0;
3019
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3020
+ const int nb = n / QK8_1;
2558
3021
 
2559
- assert(n % QK8_0 == 0);
3022
+ assert(n % QK8_1 == 0);
2560
3023
  assert(nb % 2 == 0);
2561
3024
 
2562
3025
  const block_q4_1 * restrict x = vx;
2563
- const block_q8_0 * restrict y = vy;
2564
-
2565
- float sumf = 0.0;
3026
+ const block_q8_1 * restrict y = vy;
2566
3027
 
2567
3028
  // TODO: add AVX / WASM SIMD / etc
2568
3029
  #if defined(__ARM_NEON)
2569
3030
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
3031
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2571
3032
 
3033
+ float summs = 0;
3034
+
2572
3035
  for (int i = 0; i < nb; i += 2) {
2573
3036
  const block_q4_1 * restrict x0 = &x[i + 0];
2574
3037
  const block_q4_1 * restrict x1 = &x[i + 1];
2575
- const block_q8_0 * restrict y0 = &y[i + 0];
2576
- const block_q8_0 * restrict y1 = &y[i + 1];
3038
+ const block_q8_1 * restrict y0 = &y[i + 0];
3039
+ const block_q8_1 * restrict y1 = &y[i + 1];
3040
+
3041
+ summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2577
3042
 
2578
- const uint8x16_t m4b = vdupq_n_u8(0xf);
3043
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2579
3044
 
2580
3045
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2581
3046
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2586,46 +3051,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2586
3051
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
3052
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
3053
 
3054
+ // interleave
3055
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3056
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3057
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
3058
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
3059
+
2589
3060
  // load y
2590
3061
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
3062
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
3063
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
3064
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2594
3065
 
2595
- // interleave
2596
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2600
-
2601
- const int16x8_t s0i = vaddq_s16(
2602
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
-
2605
- const int16x8_t s1i = vaddq_s16(
2606
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
-
2609
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
-
2612
3066
  #if defined(__ARM_FEATURE_DOTPROD)
2613
3067
  // dot product into int32x4_t
2614
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
3068
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
3069
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2616
3070
 
2617
3071
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
3072
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2619
3073
  #else
2620
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
3074
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3075
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3076
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3077
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2624
3078
 
2625
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
3079
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3080
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3081
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3082
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2629
3083
 
2630
3084
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
3085
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2637,65 +3091,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
3091
  #endif
2638
3092
  }
2639
3093
 
2640
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3094
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2641
3095
  #elif defined(__AVX2__)
2642
3096
  // Initialize accumulator with zeros
2643
3097
  __m256 acc = _mm256_setzero_ps();
2644
3098
 
3099
+ float summs = 0;
3100
+
2645
3101
  // Main loop
2646
3102
  for (int i = 0; i < nb; ++i) {
2647
3103
  const float * d0 = &x[i].d;
2648
3104
  const float * d1 = &y[i].d;
2649
- const float * m0 = &x[i].m;
3105
+
3106
+ summs += x[i].m * (y[i].s0 + y[i].s1);
2650
3107
 
2651
3108
  const __m256 d0v = _mm256_broadcast_ss( d0 );
2652
3109
  const __m256 d1v = _mm256_broadcast_ss( d1 );
2653
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2654
3110
 
2655
3111
  // Compute combined scales
2656
3112
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
- const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2658
3113
 
2659
3114
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
3115
  const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
3116
  const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2662
3117
 
2663
- // Get absolute values of x vectors
2664
- const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
-
2666
- // Sign the values of the y vectors
2667
- const __m256i sy = _mm256_sign_epi8( by, bx );
2668
-
2669
- // Perform multiplication and create 16-bit values
2670
- const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
- const __m256i ones = _mm256_set1_epi16( 1 );
2672
- const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
-
2674
- // Convert to vector of 8 int32_t to 8 floats
2675
- const __m256 xy = _mm256_cvtepi32_ps( xy_q );
3118
+ const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2676
3119
 
2677
3120
  // Accumulate d0*d1*x*y
2678
3121
  acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
-
2680
- // Compute sum of y values
2681
- const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
- const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
-
2686
- // Accumulate d1*m0*y
2687
- acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2688
3122
  }
2689
3123
 
2690
- // Return horizontal sum of the acc vector
2691
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2692
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695
-
2696
- sumf = _mm_cvtss_f32( res );
3124
+ *s = hsum_float_8(acc) + summs;
2697
3125
  #else
2698
3126
  // scalar
3127
+ float sumf = 0.0;
2699
3128
  for (int i = 0; i < nb; i++) {
2700
3129
  const float d0 = x[i].d;
2701
3130
  const float m0 = x[i].m;
@@ -2705,347 +3134,685 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2705
3134
  const int8_t * restrict p1 = y[i].qs;
2706
3135
 
2707
3136
  // TODO: this is very slow ..
2708
- for (int j = 0; j < QK8_0/2; j++) {
3137
+ for (int j = 0; j < QK8_1/2; j++) {
2709
3138
  const uint8_t v0 = p0[j];
2710
3139
 
2711
- const float f0 = d0*(v0 & 0xf) + m0;
2712
- const float f1 = d0*(v0 >> 4) + m0;
3140
+ const float f0 = d0*(v0 & 0x0F) + m0;
3141
+ const float f1 = d0*(v0 >> 4) + m0;
3142
+
3143
+ const float f2 = d1*p1[2*j + 0];
3144
+ const float f3 = d1*p1[2*j + 1];
3145
+
3146
+ sumf += f0*f2 + f1*f3;
3147
+ }
3148
+ }
3149
+ *s = sumf;
3150
+ #endif
3151
+ }
3152
+
3153
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3154
+ const int nb = n / QK8_0;
3155
+
3156
+ assert(n % QK8_0 == 0);
3157
+ assert(nb % 2 == 0);
3158
+ assert(QK8_0 == 2*QK4_2);
3159
+
3160
+ const block_q4_2 * restrict x = vx;
3161
+ const block_q8_0 * restrict y = vy;
3162
+
3163
+ #if defined(__ARM_NEON)
3164
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3165
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3166
+
3167
+ for (int i = 0; i < nb; i += 2) {
3168
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
3169
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
3170
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
3171
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
3172
+
3173
+ const block_q8_0 * restrict y0 = &y[i + 0];
3174
+ const block_q8_0 * restrict y1 = &y[i + 1];
3175
+
3176
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3177
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3178
+
3179
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3180
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
3181
+
3182
+ // 4-bit -> 8-bit
3183
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3184
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3185
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3186
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3187
+
3188
+ // sub 8
3189
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
3190
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
3191
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
3192
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3193
+
3194
+ // interleave
3195
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
3196
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
3197
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
3198
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
3199
+
3200
+ // load y
3201
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3202
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3203
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
3204
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3205
+
3206
+ #if defined(__ARM_FEATURE_DOTPROD)
3207
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3208
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
3209
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3210
+
3211
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3212
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3213
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3214
+ #else
3215
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3216
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3217
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3218
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3219
+
3220
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3221
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3222
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3223
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3224
+
3225
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3226
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3227
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3228
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3229
+
3230
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3231
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3232
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3233
+
3234
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3235
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3236
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3237
+ #endif
3238
+ }
3239
+
3240
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3241
+ #elif defined(__AVX2__)
3242
+ // Initialize accumulator with zeros
3243
+ __m256 acc = _mm256_setzero_ps();
3244
+
3245
+ // Main loop
3246
+ for (int i = 0; i < nb; i++) {
3247
+ /* Compute combined scale for the block */
3248
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3249
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3250
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3251
+
3252
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3253
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3254
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
3255
+
3256
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3257
+ const __m256i off = _mm256_set1_epi8(8);
3258
+ bx = _mm256_sub_epi8(bx, off);
3259
+
3260
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3261
+
3262
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3263
+
3264
+ /* Multiply q with scale and accumulate */
3265
+ acc = _mm256_fmadd_ps(d, q, acc);
3266
+ }
3267
+
3268
+ *s = hsum_float_8(acc);
3269
+ #else
3270
+ // scalar
3271
+ float sumf = 0.0;
3272
+ for (int i = 0; i < nb; i++) {
3273
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3274
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3275
+ const int8_t * restrict y0 = y[i].qs;
3276
+
3277
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3278
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3279
+
3280
+ int sumi_0 = 0;
3281
+ int sumi_1 = 0;
3282
+
3283
+ for (int j = 0; j < QK8_0/4; j++) {
3284
+ const uint8_t v0 = x0[j];
3285
+ const uint8_t v1 = x1[j];
3286
+
3287
+ const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3288
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
3289
+
3290
+ const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3291
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
3292
+
3293
+ const int i2_0 = y0[2*j + 0];
3294
+ const int i3_0 = y0[2*j + 1];
3295
+
3296
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
3297
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3298
+
3299
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
3300
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3301
+ }
3302
+
3303
+ sumf += (d0 * y[i].d) * sumi_0;
3304
+ sumf += (d1 * y[i].d) * sumi_1;
3305
+ }
3306
+ *s = sumf;
3307
+ #endif
3308
+ }
3309
+
3310
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3311
+ const int nb = n / QK8_0;
3312
+
3313
+ assert(n % QK8_0 == 0);
3314
+ assert(nb % 2 == 0);
3315
+ assert(QK8_0 == QK5_0);
3316
+
3317
+ const block_q5_0 * restrict x = vx;
3318
+ const block_q8_0 * restrict y = vy;
3319
+
3320
+ #if defined(__ARM_NEON)
3321
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3322
+
3323
+ uint64_t tmp[4];
3324
+
3325
+ for (int i = 0; i < nb; ++i) {
3326
+ const block_q5_0 * restrict x0 = &x[i];
3327
+ const block_q8_0 * restrict y0 = &y[i];
3328
+
3329
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3330
+ const int8x16_t s16b = vdupq_n_s8(0x10);
3331
+
3332
+ // extract the 5th bit
3333
+ uint32_t qh;
3334
+ memcpy(&qh, x0->qh, sizeof(qh));
3335
+
3336
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3337
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3338
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3339
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3340
+
3341
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3342
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3343
+
3344
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3345
+
3346
+ // 4-bit -> 8-bit
3347
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
3348
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3349
+
3350
+ // interleave
3351
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3352
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3353
+
3354
+ // add high bit and sub 16
3355
+ const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
3356
+ const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
3357
+
3358
+ // load y
3359
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3360
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3361
+
3362
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
3363
+
3364
+ #if defined(__ARM_FEATURE_DOTPROD)
3365
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3366
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3367
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3368
+ #else
3369
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3370
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3371
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3372
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
3373
+
3374
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3375
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3376
+
3377
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3378
+ #endif
3379
+ }
3380
+
3381
+ *s = vaddvq_f32(sumv);
3382
+ #elif defined(__wasm_simd128__)
3383
+ v128_t sumv = wasm_f32x4_splat(0.0f);
3384
+
3385
+ uint64_t tmp[4];
3386
+
3387
+ for (int i = 0; i < nb; ++i) {
3388
+ const block_q5_0 * restrict x0 = &x[i];
3389
+ const block_q8_0 * restrict y0 = &y[i];
3390
+
3391
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
3392
+ const v128_t s16b = wasm_i8x16_splat(0x10);
3393
+
3394
+ // extract the 5th bit
3395
+ uint32_t qh;
3396
+ memcpy(&qh, x0->qh, sizeof(qh));
3397
+
3398
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3399
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3400
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3401
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3402
+
3403
+ const v128_t qhl = wasm_v128_load(tmp + 0);
3404
+ const v128_t qhh = wasm_v128_load(tmp + 2);
3405
+
3406
+ const v128_t v0 = wasm_v128_load(x0->qs);
3407
+
3408
+ // 4-bit -> 8-bit
3409
+ const v128_t v0l = wasm_v128_and (v0, m4b);
3410
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
3411
+
3412
+ // interleave
3413
+ const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
3414
+ const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
3415
+
3416
+ // add high bit and sub 16
3417
+ const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0lz, qhl), s16b);
3418
+ const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0hz, qhh), s16b);
3419
+
3420
+ // load y
3421
+ const v128_t v1l = wasm_v128_load(y0->qs);
3422
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
3423
+
3424
+ // int8x16 -> int16x8
3425
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3426
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3427
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3428
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
3429
+
3430
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3431
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3432
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3433
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3434
+
3435
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
3436
+
3437
+ // dot product
3438
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3439
+ wasm_i32x4_add(
3440
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3441
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3442
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3443
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
3444
+ }
3445
+
3446
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3447
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3448
+ #elif defined(__AVX2__)
3449
+ // Initialize accumulator with zeros
3450
+ __m256 acc = _mm256_setzero_ps();
3451
+
3452
+ // Main loop
3453
+ for (int i = 0; i < nb; i++) {
3454
+ /* Compute combined scale for the block */
3455
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
3456
+
3457
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3458
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3459
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3460
+ bx = _mm256_or_si256(bx, bxhi);
3461
+
3462
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3463
+
3464
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3465
+
3466
+ /* Multiply q with scale and accumulate */
3467
+ acc = _mm256_fmadd_ps(d, q, acc);
3468
+ }
3469
+
3470
+ *s = hsum_float_8(acc);
3471
+ #else
3472
+ // scalar
3473
+ float sumf = 0.0;
3474
+ for (int i = 0; i < nb; i++) {
3475
+ const uint8_t * restrict x0 = x[i].qs;
3476
+ const int8_t * restrict y0 = y[i].qs;
3477
+
3478
+ uint32_t qh;
3479
+ memcpy(&qh, x[i].qh, sizeof(qh));
3480
+
3481
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3482
+
3483
+ int sxy = 0;
3484
+
3485
+ for (int j = 0; j < QK8_0/2; j++) {
3486
+ const uint8_t v0 = x0[j];
3487
+
3488
+ const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
3489
+ const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
3490
+
3491
+ const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
3492
+ const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
3493
+
3494
+ const int y0_0 = y0[2*j + 0];
3495
+ const int y1_0 = y0[2*j + 1];
3496
+
3497
+ sxy += x0_0*y0_0 + x1_0*y1_0;
3498
+ }
3499
+
3500
+ sumf += (d*sxy)*y[i].d;
3501
+ }
3502
+ *s = sumf;
3503
+ #endif
3504
+ }
3505
+
3506
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3507
+ const int nb = n / QK8_1;
3508
+
3509
+ assert(n % QK8_1 == 0);
3510
+ assert(nb % 2 == 0);
3511
+ assert(QK8_1 == QK5_1);
3512
+
3513
+ const block_q5_1 * restrict x = vx;
3514
+ const block_q8_1 * restrict y = vy;
3515
+
3516
+ #if defined(__ARM_NEON)
3517
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3518
+
3519
+ float summs = 0.0f;
3520
+
3521
+ uint64_t tmp[4];
3522
+
3523
+ for (int i = 0; i < nb; ++i) {
3524
+ const block_q5_1 * restrict x0 = &x[i];
3525
+ const block_q8_1 * restrict y0 = &y[i];
3526
+
3527
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
3528
+
3529
+ // extract the 5th bit
3530
+ uint32_t qh;
3531
+ memcpy(&qh, x0->qh, sizeof(qh));
3532
+
3533
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3534
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3535
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3536
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3537
+
3538
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3539
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3540
+
3541
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3542
+
3543
+ // 4-bit -> 8-bit
3544
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3545
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3546
+
3547
+ // interleave
3548
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3549
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3550
+
3551
+ // add
3552
+ const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3553
+ const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
3554
+
3555
+ // load y
3556
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3557
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3558
+
3559
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2713
3560
 
2714
- const float f2 = d1*p1[2*j + 0];
2715
- const float f3 = d1*p1[2*j + 1];
3561
+ #if defined(__ARM_FEATURE_DOTPROD)
3562
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3563
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3564
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3565
+ #else
3566
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3567
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3568
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3569
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2716
3570
 
2717
- sumf += f0*f2 + f1*f3;
2718
- }
2719
- }
3571
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3572
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3573
+
3574
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2720
3575
  #endif
3576
+ }
2721
3577
 
2722
- *s = sumf;
2723
- }
3578
+ *s = vaddvq_f32(sumv) + summs;
3579
+ #elif defined(__wasm_simd128__)
3580
+ v128_t sumv = wasm_f32x4_splat(0.0f);
2724
3581
 
2725
- static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2726
- const int nb = n / QK8_0;
3582
+ float summs = 0.0f;
2727
3583
 
2728
- assert(n % QK8_0 == 0);
2729
- assert(nb % 2 == 0);
2730
- assert(QK8_0 == 2*QK4_2);
3584
+ uint64_t tmp[4];
2731
3585
 
2732
- const block_q4_2 * restrict x = vx;
2733
- const block_q8_0 * restrict y = vy;
3586
+ for (int i = 0; i < nb; ++i) {
3587
+ const block_q5_1 * restrict x0 = &x[i];
3588
+ const block_q8_1 * restrict y0 = &y[i];
2734
3589
 
2735
- float sumf = 0.0;
3590
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2736
3591
 
2737
- #if defined(__ARM_NEON)
2738
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
3592
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2740
3593
 
2741
- for (int i = 0; i < nb; i += 2) {
2742
- const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
2743
- const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
2744
- const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
2745
- const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
3594
+ // extract the 5th bit
3595
+ uint32_t qh;
3596
+ memcpy(&qh, x0->qh, sizeof(qh));
2746
3597
 
2747
- const block_q8_0 * restrict y0 = &y[i + 0];
2748
- const block_q8_0 * restrict y1 = &y[i + 1];
3598
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3599
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3600
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3601
+ tmp[3] = table_b2b_u[(qh >> 24) ];
2749
3602
 
2750
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2751
- const int8x16_t s8b = vdupq_n_s8(0x8);
3603
+ const v128_t qhl = wasm_v128_load(tmp + 0);
3604
+ const v128_t qhh = wasm_v128_load(tmp + 2);
2752
3605
 
2753
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2754
- const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
3606
+ const v128_t v0 = wasm_v128_load(x0->qs);
2755
3607
 
2756
3608
  // 4-bit -> 8-bit
2757
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2758
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2759
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2760
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3609
+ const v128_t v0l = wasm_v128_and (v0, m4b);
3610
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
2761
3611
 
2762
- // sub 8
2763
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2764
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2765
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2766
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3612
+ static bool x = true;
2767
3613
 
2768
3614
  // interleave
2769
- const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2770
- const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2771
- const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2772
- const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2773
-
2774
- // load y
2775
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2776
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2777
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2778
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3615
+ const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
3616
+ const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
2779
3617
 
2780
- #if defined(__ARM_FEATURE_DOTPROD)
2781
- sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2782
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3618
+ // add high bit
3619
+ const v128_t v0lf = wasm_v128_or(v0lz, qhl);
3620
+ const v128_t v0hf = wasm_v128_or(v0hz, qhh);
2784
3621
 
2785
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2788
- #else
2789
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3622
+ // load y
3623
+ const v128_t v1l = wasm_v128_load(y0->qs);
3624
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
2793
3625
 
2794
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3626
+ // int8x16 -> int16x8
3627
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3628
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3629
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3630
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
2798
3631
 
2799
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3632
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3633
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3634
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3635
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
2803
3636
 
2804
- sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
- vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
- vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3637
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2807
3638
 
2808
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
- vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
- vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2811
- #endif
3639
+ // dot product
3640
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3641
+ wasm_i32x4_add(
3642
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3643
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3644
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3645
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
2812
3646
  }
2813
3647
 
2814
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3648
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3649
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
2815
3650
  #elif defined(__AVX2__)
2816
3651
  // Initialize accumulator with zeros
2817
3652
  __m256 acc = _mm256_setzero_ps();
3653
+ float summs = 0.0f;
2818
3654
 
2819
3655
  // Main loop
2820
3656
  for (int i = 0; i < nb; i++) {
2821
- /* Compute combined scale for the block */
2822
- const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
- const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
- const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2825
-
2826
- __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
- __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
- __m256i bx = _mm256_set_m128i(bx1, bx0);
2829
-
2830
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2831
- const __m256i off = _mm256_set1_epi8(8);
2832
- bx = _mm256_sub_epi8(bx, off);
3657
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2833
3658
 
2834
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3659
+ summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2835
3660
 
2836
- // Get absolute values of x vectors
2837
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2838
- // Sign the values of the y vectors
2839
- const __m256i sy = _mm256_sign_epi8(by, bx);
2840
- // Perform multiplication and create 16-bit values
2841
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
3661
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3662
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3663
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3664
+ bx = _mm256_or_si256(bx, bxhi);
2842
3665
 
2843
- const __m256i ones = _mm256_set1_epi16(1);
2844
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
3666
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3667
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2845
3668
 
2846
- /* Convert to vectore of 8 int32_t to 8 floats */
2847
- __m256 q = _mm256_cvtepi32_ps(xy_q);
3669
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2848
3670
 
2849
- /* Multiply q with scale and accumulate */
2850
- acc = _mm256_fmadd_ps(d, q, acc);
3671
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2851
3672
  }
2852
3673
 
2853
- // Return horizontal sum of the acc vector
2854
- __m128 res = _mm256_extractf128_ps(acc, 1);
2855
- res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858
-
2859
- sumf = _mm_cvtss_f32(res);
3674
+ *s = hsum_float_8(acc) + summs;
2860
3675
  #else
2861
- // scalar
3676
+ float sumf = 0.0;
3677
+
2862
3678
  for (int i = 0; i < nb; i++) {
2863
- const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3679
+ const uint8_t * restrict x0 = x[i].qs;
2865
3680
  const int8_t * restrict y0 = y[i].qs;
2866
3681
 
2867
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3682
+ uint32_t qh;
3683
+ memcpy(&qh, x[i].qh, sizeof(qh));
2869
3684
 
2870
- int sumi_0 = 0;
2871
- int sumi_1 = 0;
3685
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3686
+ const float m = GGML_FP16_TO_FP32(x[i].m);
2872
3687
 
2873
- for (int j = 0; j < QK8_0/4; j++) {
2874
- const uint8_t v0 = x0[j];
2875
- const uint8_t v1 = x1[j];
3688
+ int sxy = 0;
2876
3689
 
2877
- const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
- const int i1_0 = (int8_t) (v0 >> 4) - 8;
3690
+ for (int j = 0; j < QK8_1/2; j++) {
3691
+ const uint8_t v0 = x0[j];
2879
3692
 
2880
- const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
- const int i1_1 = (int8_t) (v1 >> 4) - 8;
3693
+ const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
3694
+ const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
2882
3695
 
2883
- const int i2_0 = y0[2*j + 0];
2884
- const int i3_0 = y0[2*j + 1];
3696
+ const int x0_0 = (v0 & 0x0F) | x0_0h;
3697
+ const int x1_0 = (v0 >> 4) | x1_0h;
2885
3698
 
2886
- const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
- const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3699
+ const int y0_0 = y0[2*j + 0];
3700
+ const int y1_0 = y0[2*j + 1];
2888
3701
 
2889
- sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
- sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3702
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2891
3703
  }
2892
3704
 
2893
- sumf += (d0 * y[i].d) * sumi_0;
2894
- sumf += (d1 * y[i].d) * sumi_1;
3705
+ sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2895
3706
  }
2896
- #endif
2897
3707
 
2898
3708
  *s = sumf;
3709
+ #endif
2899
3710
  }
2900
3711
 
2901
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3712
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
3713
  const int nb = n / QK8_0;
2903
3714
 
2904
3715
  assert(n % QK8_0 == 0);
2905
3716
  assert(nb % 2 == 0);
2906
- assert(QK8_0 == 2*QK4_2);
3717
+ assert(QK8_0 == QK8_0);
2907
3718
 
2908
- const block_q4_3 * restrict x = vx;
3719
+ const block_q8_0 * restrict x = vx;
2909
3720
  const block_q8_0 * restrict y = vy;
2910
3721
 
2911
- float sumf = 0.0;
2912
-
2913
3722
  #if defined(__ARM_NEON)
2914
3723
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
3724
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2916
3725
 
2917
3726
  for (int i = 0; i < nb; i += 2) {
2918
- const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
- const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
- const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
- const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2922
-
3727
+ const block_q8_0 * restrict x0 = &x[i + 0];
3728
+ const block_q8_0 * restrict x1 = &x[i + 1];
2923
3729
  const block_q8_0 * restrict y0 = &y[i + 0];
2924
3730
  const block_q8_0 * restrict y1 = &y[i + 1];
2925
3731
 
2926
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
-
2928
- const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
- const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
- const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
- const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
-
2933
- const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
- const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
- const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
- const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
-
2938
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
- const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
-
2941
- // 4-bit -> 8-bit
2942
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2946
-
2947
- // interleave
2948
- const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
- const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
- const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
- const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
3732
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3733
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3734
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3735
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
2952
3736
 
2953
3737
  // load y
2954
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2958
-
2959
- const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
- const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2961
-
2962
- const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
- const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2964
-
2965
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
3738
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3739
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3740
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3741
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
2969
3742
 
2970
3743
  #if defined(__ARM_FEATURE_DOTPROD)
2971
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
- #else
2976
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
-
2981
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3744
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3745
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3746
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
2985
3747
 
2986
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3748
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3749
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3750
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
2990
3751
 
2991
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
3752
+ #else
3753
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3754
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3755
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3756
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3757
+
3758
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3759
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3760
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3761
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3762
+
3763
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3764
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3765
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3766
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3767
+
3768
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3769
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
2995
3770
  #endif
2996
3771
  }
2997
3772
 
2998
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2999
- #else
3000
- // scalar
3001
- for (int i = 0; i < nb; i++) {
3002
- const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
- const int8_t * restrict y0 = y[i].qs;
3005
-
3006
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
- const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
- const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3010
-
3011
- int sy_0 = 0;
3012
- int sy_1 = 0;
3773
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3774
+ #elif defined(__AVX2__)
3775
+ // Initialize accumulator with zeros
3776
+ __m256 acc = _mm256_setzero_ps();
3013
3777
 
3014
- int sxy_0 = 0;
3015
- int sxy_1 = 0;
3778
+ // Main loop
3779
+ for (int i = 0; i < nb; ++i) {
3780
+ // Compute combined scale for the block
3781
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
3782
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3783
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3016
3784
 
3017
- for (int j = 0; j < QK8_0/4; j++) {
3018
- const uint8_t v0 = x0[j];
3019
- const uint8_t v1 = x1[j];
3785
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3020
3786
 
3021
- const int x0_0 = v0 & 0xf;
3022
- const int x1_0 = v0 >> 4;
3787
+ // Multiply q with scale and accumulate
3788
+ acc = _mm256_fmadd_ps( d, q, acc );
3789
+ }
3023
3790
 
3024
- const int x0_1 = v1 & 0xf;
3025
- const int x1_1 = v1 >> 4;
3791
+ *s = hsum_float_8(acc);
3792
+ #else
3793
+ // scalar
3794
+ float sumf = 0.0;
3026
3795
 
3027
- const int y0_0 = y0[2*j + 0];
3028
- const int y1_0 = y0[2*j + 1];
3796
+ for (int i = 0; i < nb; i++) {
3797
+ const int8_t * restrict x0 = x[i].qs;
3798
+ const int8_t * restrict y0 = y[i].qs;
3029
3799
 
3030
- const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
- const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3800
+ int sumi = 0;
3032
3801
 
3033
- sy_0 += y0_0 + y1_0;
3034
- sy_1 += y0_1 + y1_1;
3802
+ for (int j = 0; j < QK8_0; j++) {
3803
+ const int v0 = x0[j];
3804
+ const int v1 = y0[j];
3035
3805
 
3036
- sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
- sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3806
+ sumi += v0*v1;
3038
3807
  }
3039
3808
 
3040
- sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
- sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
3809
+ sumf += (x[i].d*y[i].d)*sumi;
3042
3810
  }
3043
- #endif
3044
3811
 
3045
3812
  *s = sumf;
3813
+ #endif
3046
3814
  }
3047
3815
 
3048
-
3049
3816
  // compute GGML_VEC_DOT_UNROLL dot products at once
3050
3817
  // xs - x row stride in bytes
3051
3818
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3242,6 +4009,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
3242
4009
  #endif
3243
4010
  }
3244
4011
 
4012
+ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
4013
+ ggml_float sum = 0.0;
4014
+ for (int i = 0; i < n; ++i) {
4015
+ sum += (ggml_float)x[i];
4016
+ }
4017
+ *s = sum;
4018
+ }
4019
+
3245
4020
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
3246
4021
  #ifndef GGML_USE_ACCELERATE
3247
4022
  float max = -INFINITY;
@@ -3293,13 +4068,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3293
4068
  [GGML_TYPE_Q4_0] = QK4_0,
3294
4069
  [GGML_TYPE_Q4_1] = QK4_1,
3295
4070
  [GGML_TYPE_Q4_2] = QK4_2,
3296
- [GGML_TYPE_Q4_3] = QK4_3,
4071
+ [GGML_TYPE_Q5_0] = QK5_0,
4072
+ [GGML_TYPE_Q5_1] = QK5_1,
3297
4073
  [GGML_TYPE_Q8_0] = QK8_0,
4074
+ [GGML_TYPE_Q8_1] = QK8_1,
3298
4075
  [GGML_TYPE_I8] = 1,
3299
4076
  [GGML_TYPE_I16] = 1,
3300
4077
  [GGML_TYPE_I32] = 1,
3301
4078
  };
3302
- static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
4079
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3303
4080
 
3304
4081
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3305
4082
  [GGML_TYPE_F32] = sizeof(float),
@@ -3307,13 +4084,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3307
4084
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3308
4085
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
4086
  [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
- [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
4087
+ [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
4088
+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3311
4089
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
4090
+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3312
4091
  [GGML_TYPE_I8] = sizeof(int8_t),
3313
4092
  [GGML_TYPE_I16] = sizeof(int16_t),
3314
4093
  [GGML_TYPE_I32] = sizeof(int32_t),
3315
4094
  };
3316
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
4095
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3317
4096
 
3318
4097
 
3319
4098
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3322,13 +4101,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3322
4101
  [GGML_TYPE_Q4_0] = "q4_0",
3323
4102
  [GGML_TYPE_Q4_1] = "q4_1",
3324
4103
  [GGML_TYPE_Q4_2] = "q4_2",
3325
- [GGML_TYPE_Q4_3] = "q4_3",
4104
+ [GGML_TYPE_Q5_0] = "q5_0",
4105
+ [GGML_TYPE_Q5_1] = "q5_1",
3326
4106
  [GGML_TYPE_Q8_0] = "q8_0",
4107
+ [GGML_TYPE_Q8_1] = "q8_1",
3327
4108
  [GGML_TYPE_I8] = "i8",
3328
4109
  [GGML_TYPE_I16] = "i16",
3329
4110
  [GGML_TYPE_I32] = "i32",
3330
4111
  };
3331
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
4112
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3332
4113
 
3333
4114
  static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
4115
  [GGML_TYPE_F32] = false,
@@ -3336,13 +4117,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3336
4117
  [GGML_TYPE_Q4_0] = true,
3337
4118
  [GGML_TYPE_Q4_1] = true,
3338
4119
  [GGML_TYPE_Q4_2] = true,
3339
- [GGML_TYPE_Q4_3] = true,
4120
+ [GGML_TYPE_Q5_0] = true,
4121
+ [GGML_TYPE_Q5_1] = true,
3340
4122
  [GGML_TYPE_Q8_0] = true,
4123
+ [GGML_TYPE_Q8_1] = true,
3341
4124
  [GGML_TYPE_I8] = false,
3342
4125
  [GGML_TYPE_I16] = false,
3343
4126
  [GGML_TYPE_I32] = false,
3344
4127
  };
3345
- static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
4128
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3346
4129
 
3347
4130
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3348
4131
  "NONE",
@@ -3380,6 +4163,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3380
4163
  "DIAG_MASK_INF",
3381
4164
  "SOFT_MAX",
3382
4165
  "ROPE",
4166
+ "ALIBI",
3383
4167
  "CONV_1D_1S",
3384
4168
  "CONV_1D_2S",
3385
4169
 
@@ -3390,7 +4174,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3390
4174
  "MAP_BINARY",
3391
4175
  };
3392
4176
 
3393
- static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
4177
+ static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
3394
4178
 
3395
4179
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3396
4180
  "none",
@@ -3428,6 +4212,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3428
4212
  "diag_mask_inf(x)",
3429
4213
  "soft_max(x)",
3430
4214
  "rope(x)",
4215
+ "alibi(x)",
3431
4216
  "conv_1d_1s(x)",
3432
4217
  "conv_1d_2s(x)",
3433
4218
 
@@ -3438,7 +4223,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3438
4223
  "f(x,y)",
3439
4224
  };
3440
4225
 
3441
- static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
4226
+ static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
3442
4227
 
3443
4228
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3444
4229
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3608,6 +4393,27 @@ bool ggml_is_quantized(enum ggml_type type) {
3608
4393
  return GGML_IS_QUANTIZED[type];
3609
4394
  }
3610
4395
 
4396
+ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
4397
+ enum ggml_type wtype = GGML_TYPE_COUNT;
4398
+
4399
+ switch (ftype) {
4400
+ case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
4401
+ case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
4402
+ case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
4403
+ case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
4404
+ case GGML_FTYPE_MOSTLY_Q4_2: wtype = GGML_TYPE_Q4_2; break;
4405
+ case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
4406
+ case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
4407
+ case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
4408
+ case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
4409
+ case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
4410
+ }
4411
+
4412
+ GGML_ASSERT(wtype != GGML_TYPE_COUNT);
4413
+
4414
+ return wtype;
4415
+ }
4416
+
3611
4417
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3612
4418
  return tensor->nb[0] > tensor->nb[1];
3613
4419
  }
@@ -3718,10 +4524,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3718
4524
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3719
4525
  }
3720
4526
 
3721
- // initialize cuBLAS
3722
- #if defined(GGML_USE_CUBLAS)
3723
- init_cublas();
3724
- #endif
4527
+ #if defined(GGML_USE_CUBLAS)
4528
+ ggml_init_cublas();
4529
+ #elif defined(GGML_USE_CLBLAST)
4530
+ ggml_cl_init();
4531
+ #endif
3725
4532
 
3726
4533
  is_first_call = false;
3727
4534
  }
@@ -3802,7 +4609,7 @@ void ggml_free(struct ggml_context * ctx) {
3802
4609
  }
3803
4610
 
3804
4611
  size_t ggml_used_mem(const struct ggml_context * ctx) {
3805
- return ctx->objects_end->offs + ctx->objects_end->size;
4612
+ return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
3806
4613
  }
3807
4614
 
3808
4615
  size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
@@ -3915,6 +4722,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
3915
4722
  /*.perf_cycles =*/ 0,
3916
4723
  /*.perf_time_us =*/ 0,
3917
4724
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4725
+ /*.name =*/ { 0 },
3918
4726
  /*.pad =*/ { 0 },
3919
4727
  };
3920
4728
 
@@ -4269,6 +5077,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
4269
5077
  return (float *)(tensor->data);
4270
5078
  }
4271
5079
 
5080
+ const char * ggml_get_name(const struct ggml_tensor * tensor) {
5081
+ return tensor->name;
5082
+ }
5083
+
5084
+ void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
5085
+ strncpy(tensor->name, name, sizeof(tensor->name));
5086
+ tensor->name[sizeof(tensor->name) - 1] = '\0';
5087
+ }
5088
+
4272
5089
  struct ggml_tensor * ggml_view_tensor(
4273
5090
  struct ggml_context * ctx,
4274
5091
  const struct ggml_tensor * src) {
@@ -5368,6 +6185,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
5368
6185
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5369
6186
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5370
6187
  struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
6188
+ ggml_set_name(b, "n_past");
5371
6189
 
5372
6190
  result->op = GGML_OP_DIAG_MASK_INF;
5373
6191
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5393,22 +6211,55 @@ struct ggml_tensor * ggml_soft_max(
5393
6211
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5394
6212
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5395
6213
 
5396
- result->op = GGML_OP_SOFT_MAX;
6214
+ result->op = GGML_OP_SOFT_MAX;
6215
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6216
+ result->src0 = a;
6217
+ result->src1 = NULL;
6218
+
6219
+ return result;
6220
+ }
6221
+
6222
+ // ggml_rope
6223
+
6224
+ struct ggml_tensor * ggml_rope(
6225
+ struct ggml_context * ctx,
6226
+ struct ggml_tensor * a,
6227
+ int n_past,
6228
+ int n_dims,
6229
+ int mode) {
6230
+ GGML_ASSERT(n_past >= 0);
6231
+ bool is_node = false;
6232
+
6233
+ if (a->grad) {
6234
+ GGML_ASSERT(false); // TODO: implement backward
6235
+ is_node = true;
6236
+ }
6237
+
6238
+ // TODO: when implement backward, fix this:
6239
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6240
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
6241
+
6242
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6243
+ ((int32_t *) b->data)[0] = n_past;
6244
+ ((int32_t *) b->data)[1] = n_dims;
6245
+ ((int32_t *) b->data)[2] = mode;
6246
+ ggml_set_name(b, "n_past, n_dims, mode");
6247
+
6248
+ result->op = GGML_OP_ROPE;
5397
6249
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5398
6250
  result->src0 = a;
5399
- result->src1 = NULL;
6251
+ result->src1 = b;
5400
6252
 
5401
6253
  return result;
5402
6254
  }
5403
6255
 
5404
- // ggml_rope
6256
+ // ggml_alibi
5405
6257
 
5406
- struct ggml_tensor * ggml_rope(
6258
+ struct ggml_tensor * ggml_alibi(
5407
6259
  struct ggml_context * ctx,
5408
6260
  struct ggml_tensor * a,
5409
6261
  int n_past,
5410
- int n_dims,
5411
- int mode) {
6262
+ int n_head) {
5412
6263
  GGML_ASSERT(n_past >= 0);
5413
6264
  bool is_node = false;
5414
6265
 
@@ -5421,12 +6272,11 @@ struct ggml_tensor * ggml_rope(
5421
6272
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5422
6273
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5423
6274
 
5424
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6275
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
5425
6276
  ((int32_t *) b->data)[0] = n_past;
5426
- ((int32_t *) b->data)[1] = n_dims;
5427
- ((int32_t *) b->data)[2] = mode;
6277
+ ((int32_t *) b->data)[1] = n_head;
5428
6278
 
5429
- result->op = GGML_OP_ROPE;
6279
+ result->op = GGML_OP_ALIBI;
5430
6280
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5431
6281
  result->src0 = a;
5432
6282
  result->src1 = b;
@@ -6553,7 +7403,9 @@ static void ggml_compute_forward_add(
6553
7403
  case GGML_TYPE_Q4_0:
6554
7404
  case GGML_TYPE_Q4_1:
6555
7405
  case GGML_TYPE_Q4_2:
6556
- case GGML_TYPE_Q4_3:
7406
+ case GGML_TYPE_Q5_0:
7407
+ case GGML_TYPE_Q5_1:
7408
+ case GGML_TYPE_Q8_0:
6557
7409
  {
6558
7410
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6559
7411
  } break;
@@ -6811,15 +7663,20 @@ static void ggml_compute_forward_sum_f32(
6811
7663
  const size_t nb02 = src0->nb[2];
6812
7664
  const size_t nb03 = src0->nb[3];
6813
7665
 
7666
+ ggml_float sum = 0;
7667
+ ggml_float row_sum = 0;
7668
+
6814
7669
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6815
7670
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6816
7671
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6817
- ggml_vec_sum_f32(ne00,
6818
- (float *) (dst->data),
7672
+ ggml_vec_sum_ggf(ne00,
7673
+ &row_sum,
6819
7674
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
7675
+ sum += row_sum;
6820
7676
  }
6821
7677
  }
6822
7678
  }
7679
+ ((float *) dst->data)[0] = sum;
6823
7680
  }
6824
7681
 
6825
7682
  static void ggml_compute_forward_sum(
@@ -7454,7 +8311,7 @@ static void ggml_compute_forward_rms_norm(
7454
8311
 
7455
8312
  // ggml_compute_forward_mul_mat
7456
8313
 
7457
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8314
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7458
8315
  // helper function to determine if it is better to use BLAS or not
7459
8316
  // for large matrices, BLAS is faster
7460
8317
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7471,7 +8328,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
7471
8328
 
7472
8329
  // TODO: find the optimal values for these
7473
8330
  if (ggml_is_contiguous(src0) &&
7474
- ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
8331
+ ggml_is_contiguous(src1) &&
8332
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
7475
8333
 
7476
8334
  /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
7477
8335
  return true;
@@ -7494,7 +8352,7 @@ static void ggml_compute_forward_mul_mat_f32(
7494
8352
  const int64_t ne02 = src0->ne[2];
7495
8353
  const int64_t ne03 = src0->ne[3];
7496
8354
 
7497
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8355
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7498
8356
  const int64_t ne10 = src1->ne[0];
7499
8357
  #endif
7500
8358
  const int64_t ne11 = src1->ne[1];
@@ -7551,7 +8409,16 @@ static void ggml_compute_forward_mul_mat_f32(
7551
8409
  // nb01 >= nb00 - src0 is not transposed
7552
8410
  // compute by src0 rows
7553
8411
 
7554
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8412
+ #if defined(GGML_USE_CUBLAS)
8413
+ if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
8414
+ if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
8415
+ ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
8416
+ }
8417
+ return;
8418
+ }
8419
+ #endif
8420
+
8421
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7555
8422
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7556
8423
  if (params->ith != 0) {
7557
8424
  return;
@@ -7565,45 +8432,21 @@ static void ggml_compute_forward_mul_mat_f32(
7565
8432
  return;
7566
8433
  }
7567
8434
 
7568
- #if defined(GGML_USE_CUBLAS)
7569
- float *d_X = NULL;
7570
- float *d_Y = NULL;
7571
- float *d_D = NULL;
7572
- const float alpha = 1.0f;
7573
- const float beta = 0.0f;
7574
- const int x_ne = ne01 * ne10;
7575
- const int y_ne = ne11 * ne10;
7576
- const int d_ne = ne11 * ne01;
7577
-
7578
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7581
- #endif
7582
-
7583
8435
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7584
8436
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7585
8437
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
7586
8438
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7587
-
7588
8439
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7589
8440
 
7590
- #if defined(GGML_USE_CUBLAS)
7591
- // copy data to device
7592
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7594
-
7595
- // compute
7596
- CUBLAS_CHECK(
7597
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
- ne01, ne11, ne10,
7599
- &alpha, d_X, ne00,
7600
- d_Y, ne10,
7601
- &beta, d_D, ne01));
7602
-
7603
- // copy data to host
7604
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
- #else
8441
+ #if defined(GGML_USE_CLBLAST)
7606
8442
  // zT = y * xT
8443
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8444
+ ne11, ne01, ne10,
8445
+ 1.0f, y, ne10,
8446
+ x, ne10,
8447
+ 0.0f, d, ne01,
8448
+ GGML_TYPE_F32);
8449
+ #else
7607
8450
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7608
8451
  ne11, ne01, ne10,
7609
8452
  1.0f, y, ne10,
@@ -7612,12 +8455,6 @@ static void ggml_compute_forward_mul_mat_f32(
7612
8455
  #endif
7613
8456
  }
7614
8457
  }
7615
- #if defined(GGML_USE_CUBLAS)
7616
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
- CUDA_CHECK(cudaFree(d_X));
7618
- CUDA_CHECK(cudaFree(d_Y));
7619
- CUDA_CHECK(cudaFree(d_D));
7620
- #endif
7621
8458
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7622
8459
 
7623
8460
  return;
@@ -7747,7 +8584,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7747
8584
  // nb01 >= nb00 - src0 is not transposed
7748
8585
  // compute by src0 rows
7749
8586
 
7750
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8587
+ #if defined(GGML_USE_CUBLAS)
8588
+ if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
8589
+ if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
8590
+ ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
8591
+ }
8592
+ return;
8593
+ }
8594
+ #endif
8595
+
8596
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7751
8597
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7752
8598
  GGML_ASSERT(nb10 == sizeof(float));
7753
8599
 
@@ -7763,37 +8609,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7763
8609
  return;
7764
8610
  }
7765
8611
 
7766
- #if defined(GGML_USE_CUBLAS)
7767
- ggml_fp16_t * const wdata = params->wdata;
7768
-
7769
- float *d_X = NULL;
7770
- float *d_Y = NULL;
7771
- float *d_D = NULL;
7772
- const float alpha = 1.0f;
7773
- const float beta = 0.0f;
7774
- const int x_ne = ne01 * ne10;
7775
- const int y_ne = ne11 * ne10;
7776
- const int d_ne = ne11 * ne01;
7777
-
7778
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7781
- #else
7782
- float * const wdata = params->wdata;
7783
- #endif
7784
8612
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7785
8613
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7786
- #if defined(GGML_USE_CUBLAS)
7787
- // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7788
- {
7789
- size_t id = 0;
7790
- for (int64_t i01 = 0; i01 < ne11; ++i01) {
7791
- for (int64_t i00 = 0; i00 < ne10; ++i00) {
7792
- wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
7793
- }
7794
- }
7795
- }
7796
- #else
8614
+ float * const wdata = params->wdata;
7797
8615
  {
7798
8616
  size_t id = 0;
7799
8617
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7801,31 +8619,23 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7801
8619
  wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
7802
8620
  }
7803
8621
  }
8622
+
8623
+ assert(id*sizeof(float) <= params->wsize);
7804
8624
  }
7805
- #endif
7806
8625
 
7807
- #if defined(GGML_USE_CUBLAS)
7808
- const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
7809
- const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
8626
+ #if defined(GGML_USE_CLBLAST)
8627
+ const float * x = wdata;
8628
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7810
8629
 
7811
8630
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7812
8631
 
7813
- // copy data to device
7814
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7816
-
7817
- // compute
7818
- CUBLAS_CHECK(
7819
- cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
- ne01, ne11, ne10,
7821
- &alpha, d_X, CUDA_R_16F, ne00,
7822
- d_Y, CUDA_R_16F, ne10,
7823
- &beta, d_D, CUDA_R_32F, ne01,
7824
- CUBLAS_COMPUTE_32F,
7825
- CUBLAS_GEMM_DEFAULT));
7826
-
7827
- // copy data to host
7828
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8632
+ // zT = y * xT
8633
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8634
+ ne11, ne01, ne10,
8635
+ 1.0f, y, ne10,
8636
+ x, ne10,
8637
+ 0.0f, d, ne01,
8638
+ GGML_TYPE_F32);
7829
8639
  #else
7830
8640
  const float * x = wdata;
7831
8641
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7842,12 +8652,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7842
8652
  }
7843
8653
  }
7844
8654
 
7845
- #if defined(GGML_USE_CUBLAS)
7846
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
- CUDA_CHECK(cudaFree(d_X));
7848
- CUDA_CHECK(cudaFree(d_Y));
7849
- CUDA_CHECK(cudaFree(d_D));
7850
- #endif
7851
8655
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7852
8656
 
7853
8657
  return;
@@ -7980,6 +8784,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7980
8784
  const enum ggml_type type = src0->type;
7981
8785
  quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7982
8786
  vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
8787
+ enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
7983
8788
 
7984
8789
  // we don't support permuted src0 or src1
7985
8790
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7999,7 +8804,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
7999
8804
  // nb01 >= nb00 - src0 is not transposed
8000
8805
  // compute by src0 rows
8001
8806
 
8002
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8807
+ #if defined(GGML_USE_CUBLAS)
8808
+ if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
8809
+ if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
8810
+ ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
8811
+ }
8812
+ return;
8813
+ }
8814
+ #endif
8815
+
8816
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
8003
8817
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
8004
8818
  if (params->ith != 0) {
8005
8819
  return;
@@ -8013,39 +8827,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
8013
8827
  return;
8014
8828
  }
8015
8829
 
8016
- #if defined(GGML_USE_CUBLAS)
8017
- float *d_X = NULL;
8018
- float *d_Y = NULL;
8019
- float *d_D = NULL;
8020
- float *d_Q = NULL;
8021
- const float alpha = 1.0f;
8022
- const float beta = 0.0f;
8023
- const int x_ne = ne01 * ne10;
8024
- const int y_ne = ne11 * ne10;
8025
- const int d_ne = ne11 * ne01;
8026
-
8027
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
- CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8031
-
8032
- void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
- if (type == GGML_TYPE_Q4_0) {
8034
- dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8035
- }
8036
- else if (type == GGML_TYPE_Q4_1) {
8037
- dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8038
- }
8039
- else if (type == GGML_TYPE_Q4_2) {
8040
- dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
- }
8042
- else {
8043
- GGML_ASSERT(false);
8044
- }
8045
- #else
8046
8830
  float * const wdata = params->wdata;
8047
8831
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
- #endif
8049
8832
 
8050
8833
  for (int64_t i03 = 0; i03 < ne03; i03++) {
8051
8834
  for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -8053,14 +8836,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
8053
8836
 
8054
8837
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8055
8838
 
8056
- #if defined(GGML_USE_CUBLAS)
8057
- // copy and dequantize on device
8058
- CUDA_CHECK(
8059
- cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
- GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8061
-
8062
- dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8063
- CUDA_CHECK(cudaGetLastError());
8839
+ #if defined(GGML_USE_CLBLAST)
8840
+ const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
8064
8841
  #else
8065
8842
  {
8066
8843
  size_t id = 0;
@@ -8068,27 +8845,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
8068
8845
  dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
8069
8846
  id += ne00;
8070
8847
  }
8848
+
8849
+ assert(id*sizeof(float) <= params->wsize);
8071
8850
  }
8851
+
8072
8852
  const float * x = wdata;
8073
8853
  #endif
8074
8854
 
8075
-
8076
- #if defined(GGML_USE_CUBLAS)
8077
- // copy data to device
8078
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8079
-
8080
- // compute
8081
- CUBLAS_CHECK(
8082
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
- ne01, ne11, ne10,
8084
- &alpha, d_X, ne00,
8085
- d_Y, ne10,
8086
- &beta, d_D, ne01));
8087
-
8088
- // copy data to host
8089
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
- #else
8855
+ #if defined(GGML_USE_CLBLAST)
8091
8856
  // zT = y * xT
8857
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8858
+ ne11, ne01, ne10,
8859
+ 1.0f, y, ne10,
8860
+ x, ne10,
8861
+ 0.0f, d, ne01,
8862
+ type);
8863
+ #else
8092
8864
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
8093
8865
  ne11, ne01, ne10,
8094
8866
  1.0f, y, ne10,
@@ -8098,13 +8870,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
8098
8870
  }
8099
8871
  }
8100
8872
 
8101
- #if defined(GGML_USE_CUBLAS)
8102
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
- CUDA_CHECK(cudaFree(d_X));
8104
- CUDA_CHECK(cudaFree(d_Y));
8105
- CUDA_CHECK(cudaFree(d_D));
8106
- CUDA_CHECK(cudaFree(d_Q));
8107
- #endif
8108
8873
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
8109
8874
 
8110
8875
  return;
@@ -8113,7 +8878,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8113
8878
 
8114
8879
  if (params->type == GGML_TASK_INIT) {
8115
8880
  char * wdata = params->wdata;
8116
- const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8881
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8117
8882
 
8118
8883
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
8119
8884
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -8144,7 +8909,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8144
8909
  const int ir1 = MIN(ir0 + dr, nr);
8145
8910
 
8146
8911
  void * wdata = params->wdata;
8147
- const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8912
+ const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8148
8913
 
8149
8914
  for (int ir = ir0; ir < ir1; ++ir) {
8150
8915
  // src0 indices
@@ -8193,8 +8958,10 @@ static void ggml_compute_forward_mul_mat(
8193
8958
  case GGML_TYPE_Q4_0:
8194
8959
  case GGML_TYPE_Q4_1:
8195
8960
  case GGML_TYPE_Q4_2:
8196
- case GGML_TYPE_Q4_3:
8961
+ case GGML_TYPE_Q5_0:
8962
+ case GGML_TYPE_Q5_1:
8197
8963
  case GGML_TYPE_Q8_0:
8964
+ case GGML_TYPE_Q8_1:
8198
8965
  {
8199
8966
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
8200
8967
  } break;
@@ -8422,8 +9189,10 @@ static void ggml_compute_forward_get_rows(
8422
9189
  case GGML_TYPE_Q4_0:
8423
9190
  case GGML_TYPE_Q4_1:
8424
9191
  case GGML_TYPE_Q4_2:
8425
- case GGML_TYPE_Q4_3:
9192
+ case GGML_TYPE_Q5_0:
9193
+ case GGML_TYPE_Q5_1:
8426
9194
  case GGML_TYPE_Q8_0:
9195
+ case GGML_TYPE_Q8_1:
8427
9196
  {
8428
9197
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
8429
9198
  } break;
@@ -8561,6 +9330,7 @@ static void ggml_compute_forward_soft_max_f32(
8561
9330
 
8562
9331
  uint16_t scvt;
8563
9332
  for (int i = 0; i < nc; i++) {
9333
+ //printf("p[%3d] = %8.4f\n", i, p[i]);
8564
9334
  if (p[i] == -INFINITY) {
8565
9335
  p[i] = 0.0f;
8566
9336
  } else {
@@ -8603,6 +9373,161 @@ static void ggml_compute_forward_soft_max(
8603
9373
  }
8604
9374
  }
8605
9375
 
9376
+ // ggml_compute_forward_alibi
9377
+
9378
+ static void ggml_compute_forward_alibi_f32(
9379
+ const struct ggml_compute_params * params,
9380
+ const struct ggml_tensor * src0,
9381
+ const struct ggml_tensor * src1,
9382
+ struct ggml_tensor * dst) {
9383
+ assert(params->ith == 0);
9384
+ assert(src1->type == GGML_TYPE_I32);
9385
+ assert(ggml_nelements(src1) == 2);
9386
+
9387
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9388
+ return;
9389
+ }
9390
+
9391
+ const int n_past = ((int32_t *) src1->data)[0];
9392
+ const int n_head = ((int32_t *) src1->data)[1];
9393
+
9394
+ const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
9395
+ const int ne1 = src0->ne[1]; // seq_len_without_past
9396
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
9397
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
9398
+
9399
+ const int n = ggml_nrows(src0);
9400
+ const int ne2_ne3 = n/ne1; // ne2*ne3
9401
+
9402
+ const int nb0 = src0->nb[0];
9403
+ const int nb1 = src0->nb[1];
9404
+ const int nb2 = src0->nb[2];
9405
+ //const int nb3 = src0->nb[3];
9406
+
9407
+ assert(nb0 == sizeof(float));
9408
+ assert(ne1 + n_past == ne0); (void) n_past;
9409
+
9410
+ // add alibi to src0 (KQ_scaled)
9411
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
9412
+
9413
+ const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
9414
+ const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
9415
+
9416
+ for (int i = 0; i < ne0; i++) {
9417
+ for (int j = 0; j < ne1; j++) {
9418
+ for (int k = 0; k < ne2_ne3; k++) {
9419
+ float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
9420
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
9421
+
9422
+ // TODO: k*nb2 or k*nb3
9423
+
9424
+ float m_k;
9425
+
9426
+ if (k < n_heads_log2_floor) {
9427
+ m_k = powf(m0, k + 1);
9428
+ } else {
9429
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
9430
+ }
9431
+
9432
+ pdst[0] = (j+1) * m_k + src[0];
9433
+ }
9434
+ }
9435
+ }
9436
+ }
9437
+
9438
+
9439
+ static void ggml_compute_forward_alibi_f16(
9440
+ const struct ggml_compute_params * params,
9441
+ const struct ggml_tensor * src0,
9442
+ const struct ggml_tensor * src1,
9443
+ struct ggml_tensor * dst) {
9444
+ assert(params->ith == 0);
9445
+ assert(src1->type == GGML_TYPE_I32);
9446
+ assert(ggml_nelements(src1) == 2);
9447
+
9448
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9449
+ return;
9450
+ }
9451
+
9452
+ const int n_past = ((int32_t *) src1->data)[0];
9453
+ const int n_head = ((int32_t *) src1->data)[1];
9454
+
9455
+ const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
9456
+ const int ne1 = src0->ne[1]; // seq_len_without_past
9457
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
9458
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
9459
+
9460
+ const int n = ggml_nrows(src0);
9461
+ const int ne2_ne3 = n/ne1; // ne2*ne3
9462
+
9463
+ const int nb0 = src0->nb[0];
9464
+ const int nb1 = src0->nb[1];
9465
+ const int nb2 = src0->nb[2];
9466
+ //const int nb3 = src0->nb[3];
9467
+
9468
+ assert(nb0 == sizeof(ggml_fp16_t));
9469
+ assert(ne1 + n_past == ne0); (void) n_past;
9470
+
9471
+ // add alibi to src0 (KQ_scaled)
9472
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
9473
+
9474
+ const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
9475
+ const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
9476
+
9477
+ for (int i = 0; i < ne0; i++) {
9478
+ for (int j = 0; j < ne1; j++) {
9479
+ for (int k = 0; k < ne2_ne3; k++) {
9480
+ ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
9481
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
9482
+
9483
+ // TODO: k*nb2 or k*nb3
9484
+
9485
+ float m_k;
9486
+
9487
+ if (k < n_heads_log2_floor) {
9488
+ m_k = powf(m0, k + 1);
9489
+ } else {
9490
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
9491
+ }
9492
+
9493
+ // we return F32
9494
+ pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
9495
+ }
9496
+ }
9497
+ }
9498
+ }
9499
+
9500
+ static void ggml_compute_forward_alibi(
9501
+ const struct ggml_compute_params * params,
9502
+ const struct ggml_tensor * src0,
9503
+ const struct ggml_tensor * src1,
9504
+ struct ggml_tensor * dst) {
9505
+ switch (src0->type) {
9506
+ case GGML_TYPE_F16:
9507
+ {
9508
+ ggml_compute_forward_alibi_f16(params, src0, src1, dst);
9509
+ } break;
9510
+ case GGML_TYPE_F32:
9511
+ {
9512
+ ggml_compute_forward_alibi_f32(params, src0, src1, dst);
9513
+ } break;
9514
+ case GGML_TYPE_Q4_0:
9515
+ case GGML_TYPE_Q4_1:
9516
+ case GGML_TYPE_Q4_2:
9517
+ case GGML_TYPE_Q5_0:
9518
+ case GGML_TYPE_Q5_1:
9519
+ case GGML_TYPE_Q8_0:
9520
+ case GGML_TYPE_Q8_1:
9521
+ case GGML_TYPE_I8:
9522
+ case GGML_TYPE_I16:
9523
+ case GGML_TYPE_I32:
9524
+ case GGML_TYPE_COUNT:
9525
+ {
9526
+ GGML_ASSERT(false);
9527
+ } break;
9528
+ }
9529
+ }
9530
+
8606
9531
  // ggml_compute_forward_rope
8607
9532
 
8608
9533
  static void ggml_compute_forward_rope_f32(
@@ -10241,6 +11166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
10241
11166
  {
10242
11167
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
10243
11168
  } break;
11169
+ case GGML_OP_ALIBI:
11170
+ {
11171
+ ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
11172
+ } break;
10244
11173
  case GGML_OP_CONV_1D_1S:
10245
11174
  {
10246
11175
  ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -10443,6 +11372,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
10443
11372
  {
10444
11373
  GGML_ASSERT(false); // TODO: not implemented
10445
11374
  } break;
11375
+ case GGML_OP_ALIBI:
11376
+ {
11377
+ GGML_ASSERT(false); // TODO: not implemented
11378
+ } break;
10446
11379
  case GGML_OP_SILU:
10447
11380
  {
10448
11381
  GGML_ASSERT(false); // TODO: not implemented
@@ -10920,15 +11853,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10920
11853
 
10921
11854
  size_t cur = 0;
10922
11855
 
11856
+ #if defined(GGML_USE_CUBLAS)
11857
+ if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
11858
+ node->n_tasks = 1; // TODO: this actually is doing nothing
11859
+ // the threads are still spinning
11860
+ cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
11861
+ }
11862
+ else
11863
+ #endif
10923
11864
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10924
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11865
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10925
11866
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10926
11867
  node->n_tasks = 1; // TODO: this actually is doing nothing
10927
11868
  // the threads are still spinning
11869
+ // here we need memory just for single 2D matrix from src0
10928
11870
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10929
- //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
10930
- //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
10931
- //printf("cur = %zu\n", cur);
10932
11871
  } else {
10933
11872
  cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
10934
11873
  }
@@ -10937,15 +11876,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10937
11876
  #endif
10938
11877
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10939
11878
  cur = 0;
11879
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
11880
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
11881
+ node->n_tasks = 1;
11882
+ }
11883
+ #endif
10940
11884
  } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11885
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10942
11886
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10943
11887
  node->n_tasks = 1;
10944
11888
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10945
11889
  } else
10946
11890
  #endif
10947
11891
  {
10948
- cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
11892
+ const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
11893
+ cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
10949
11894
  }
10950
11895
  } else {
10951
11896
  GGML_ASSERT(false);
@@ -10975,6 +11920,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10975
11920
  {
10976
11921
  node->n_tasks = n_threads;
10977
11922
  } break;
11923
+ case GGML_OP_ALIBI:
11924
+ {
11925
+ node->n_tasks = 1; //TODO
11926
+ } break;
10978
11927
  case GGML_OP_CONV_1D_1S:
10979
11928
  case GGML_OP_CONV_1D_2S:
10980
11929
  {
@@ -11273,9 +12222,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11273
12222
  for (int i = 0; i < cgraph->n_nodes; i++) {
11274
12223
  struct ggml_tensor * node = cgraph->nodes[i];
11275
12224
 
11276
- perf_total_per_op_us[node->op] += node->perf_time_us;
12225
+ perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
11277
12226
 
11278
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
12227
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11279
12228
  i,
11280
12229
  node->ne[0], node->ne[1], node->ne[2],
11281
12230
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -11289,13 +12238,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11289
12238
  for (int i = 0; i < cgraph->n_leafs; i++) {
11290
12239
  struct ggml_tensor * node = cgraph->leafs[i];
11291
12240
 
11292
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
12241
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
11293
12242
  i,
11294
12243
  node->ne[0], node->ne[1],
11295
12244
  GGML_OP_LABEL[node->op]);
11296
12245
  }
11297
12246
 
11298
12247
  for (int i = 0; i < GGML_OP_COUNT; i++) {
12248
+ if (perf_total_per_op_us[i] == 0) {
12249
+ continue;
12250
+ }
12251
+
11299
12252
  GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
11300
12253
  }
11301
12254
 
@@ -11358,10 +12311,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
11358
12311
  snprintf(color, sizeof(color), "white");
11359
12312
  }
11360
12313
 
11361
- fprintf(fp, " \"%p\" [ \
11362
- style = filled; fillcolor = %s; shape = record; \
11363
- label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
11364
- (void *) node, color,
12314
+ fprintf(fp, " \"%p\" [ "
12315
+ "style = filled; fillcolor = %s; shape = record; "
12316
+ "label=\"",
12317
+ (void *) node, color);
12318
+
12319
+ if (strlen(node->name) > 0) {
12320
+ fprintf(fp, "%s |", node->name);
12321
+ }
12322
+
12323
+ fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
11365
12324
  i, node->ne[0], node->ne[1],
11366
12325
  GGML_OP_SYMBOL[node->op]);
11367
12326
 
@@ -11377,18 +12336,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
11377
12336
 
11378
12337
  snprintf(color, sizeof(color), "pink");
11379
12338
 
12339
+ fprintf(fp, " \"%p\" [ "
12340
+ "style = filled; fillcolor = %s; shape = record; "
12341
+ "label=\"<x>",
12342
+ (void *) node, color);
12343
+
12344
+ if (strlen(node->name) > 0) {
12345
+ fprintf(fp, "%s | ", node->name);
12346
+ }
11380
12347
  if (ggml_nelements(node) == 1) {
11381
- fprintf(fp, " \"%p\" [ \
11382
- style = filled; fillcolor = %s; shape = record; \
11383
- label=\"<x>%.1e\"; ]\n",
11384
- (void *) node, color, (double)ggml_get_f32_1d(node, 0));
11385
- } else {
11386
- fprintf(fp, " \"%p\" [ \
11387
- style = filled; fillcolor = %s; shape = record; \
11388
- label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
11389
- (void *) node, color,
11390
- i, node->ne[0], node->ne[1]);
12348
+ if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
12349
+ fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
12350
+ }
12351
+ else {
12352
+ fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
12353
+ }
11391
12354
  }
12355
+ else {
12356
+ fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
12357
+ }
12358
+ fprintf(fp, "\"; ]\n");
11392
12359
  }
11393
12360
 
11394
12361
  for (int i = 0; i < gb->n_nodes; i++) {
@@ -12129,7 +13096,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
12129
13096
 
12130
13097
  for (int i = 0; i < nb; i++) {
12131
13098
  for (int l = 0; l < QK4_0; l += 2) {
12132
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
13099
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12133
13100
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12134
13101
 
12135
13102
  hist[vi0]++;
@@ -12152,7 +13119,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
12152
13119
 
12153
13120
  for (int i = 0; i < nb; i++) {
12154
13121
  for (int l = 0; l < QK4_1; l += 2) {
12155
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
13122
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12156
13123
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12157
13124
 
12158
13125
  hist[vi0]++;
@@ -12171,12 +13138,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12171
13138
  for (int j = 0; j < n; j += k) {
12172
13139
  block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
13140
 
12174
- //quantize_row_q4_2_reference(src + j, y, k);
12175
- quantize_row_q4_2_rmse(src + j, y, k);
13141
+ quantize_row_q4_2_reference(src + j, y, k);
12176
13142
 
12177
13143
  for (int i = 0; i < nb; i++) {
12178
13144
  for (int l = 0; l < QK4_2; l += 2) {
12179
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
13145
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12180
13146
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
13147
 
12182
13148
  hist[vi0]++;
@@ -12188,19 +13154,56 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12188
13154
  return (n/QK4_2*sizeof(block_q4_2));
12189
13155
  }
12190
13156
 
12191
- size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12192
- assert(k % QK4_3 == 0);
12193
- const int nb = k / QK4_3;
13157
+ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
13158
+ assert(k % QK5_0 == 0);
13159
+ const int nb = k / QK5_0;
12194
13160
 
12195
13161
  for (int j = 0; j < n; j += k) {
12196
- block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
13162
+ block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
12197
13163
 
12198
- quantize_row_q4_3_reference(src + j, y, k);
13164
+ quantize_row_q5_0_reference(src + j, y, k);
12199
13165
 
12200
13166
  for (int i = 0; i < nb; i++) {
12201
- for (int l = 0; l < QK4_3; l += 2) {
12202
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12203
- const uint8_t vi1 = y[i].qs[l/2] >> 4;
13167
+ uint32_t qh;
13168
+ memcpy(&qh, &y[i].qh, sizeof(qh));
13169
+
13170
+ for (int l = 0; l < QK5_0; l += 2) {
13171
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
13172
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
13173
+
13174
+ // cast to 16 bins
13175
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
13176
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
13177
+
13178
+ hist[vi0]++;
13179
+ hist[vi1]++;
13180
+ }
13181
+ }
13182
+ }
13183
+
13184
+ return (n/QK5_0*sizeof(block_q5_0));
13185
+ }
13186
+
13187
+ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
13188
+ assert(k % QK5_1 == 0);
13189
+ const int nb = k / QK5_1;
13190
+
13191
+ for (int j = 0; j < n; j += k) {
13192
+ block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
13193
+
13194
+ quantize_row_q5_1_reference(src + j, y, k);
13195
+
13196
+ for (int i = 0; i < nb; i++) {
13197
+ uint32_t qh;
13198
+ memcpy(&qh, &y[i].qh, sizeof(qh));
13199
+
13200
+ for (int l = 0; l < QK5_1; l += 2) {
13201
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
13202
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
13203
+
13204
+ // cast to 16 bins
13205
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
13206
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12204
13207
 
12205
13208
  hist[vi0]++;
12206
13209
  hist[vi1]++;
@@ -12208,7 +13211,28 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
12208
13211
  }
12209
13212
  }
12210
13213
 
12211
- return (n/QK4_3*sizeof(block_q4_3));
13214
+ return (n/QK5_1*sizeof(block_q5_1));
13215
+ }
13216
+
13217
+ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
13218
+ assert(k % QK8_0 == 0);
13219
+ const int nb = k / QK8_0;
13220
+
13221
+ for (int j = 0; j < n; j += k) {
13222
+ block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
13223
+
13224
+ quantize_row_q8_0_reference(src + j, y, k);
13225
+
13226
+ for (int i = 0; i < nb; i++) {
13227
+ for (int l = 0; l < QK8_0; ++l) {
13228
+ const int8_t vi = y[i].qs[l];
13229
+
13230
+ hist[vi/16 + 8]++;
13231
+ }
13232
+ }
13233
+ }
13234
+
13235
+ return (n/QK8_0*sizeof(block_q8_0));
12212
13236
  }
12213
13237
 
12214
13238
  size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
@@ -12232,11 +13256,23 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
12232
13256
  block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
12233
13257
  result = ggml_quantize_q4_2(src + start, block, n, n, hist);
12234
13258
  } break;
12235
- case GGML_TYPE_Q4_3:
13259
+ case GGML_TYPE_Q5_0:
13260
+ {
13261
+ GGML_ASSERT(start % QK5_0 == 0);
13262
+ block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
13263
+ result = ggml_quantize_q5_0(src + start, block, n, n, hist);
13264
+ } break;
13265
+ case GGML_TYPE_Q5_1:
13266
+ {
13267
+ GGML_ASSERT(start % QK5_1 == 0);
13268
+ block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
13269
+ result = ggml_quantize_q5_1(src + start, block, n, n, hist);
13270
+ } break;
13271
+ case GGML_TYPE_Q8_0:
12236
13272
  {
12237
- GGML_ASSERT(start % QK4_3 == 0);
12238
- block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
- result = ggml_quantize_q4_3(src + start, block, n, n, hist);
13273
+ GGML_ASSERT(start % QK8_0 == 0);
13274
+ block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
13275
+ result = ggml_quantize_q8_0(src + start, block, n, n, hist);
12240
13276
  } break;
12241
13277
  default:
12242
13278
  assert(false);
@@ -12335,7 +13371,7 @@ int ggml_cpu_has_wasm_simd(void) {
12335
13371
  }
12336
13372
 
12337
13373
  int ggml_cpu_has_blas(void) {
12338
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
13374
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
12339
13375
  return 1;
12340
13376
  #else
12341
13377
  return 0;
@@ -12350,6 +13386,18 @@ int ggml_cpu_has_cublas(void) {
12350
13386
  #endif
12351
13387
  }
12352
13388
 
13389
+ int ggml_cpu_has_clblast(void) {
13390
+ #if defined(GGML_USE_CLBLAST)
13391
+ return 1;
13392
+ #else
13393
+ return 0;
13394
+ #endif
13395
+ }
13396
+
13397
+ int ggml_cpu_has_gpublas(void) {
13398
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
13399
+ }
13400
+
12353
13401
  int ggml_cpu_has_sse3(void) {
12354
13402
  #if defined(__SSE3__)
12355
13403
  return 1;