llama_cpp 0.0.6 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -135,57 +135,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
135
135
  #define UNUSED(x) (void)(x)
136
136
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
137
137
 
138
- #define GGML_ASSERT(x) \
139
- do { \
140
- if (!(x)) { \
141
- fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
142
- abort(); \
143
- } \
144
- } while (0)
145
-
146
138
  #if defined(GGML_USE_ACCELERATE)
147
139
  #include <Accelerate/Accelerate.h>
148
140
  #elif defined(GGML_USE_OPENBLAS)
149
141
  #include <cblas.h>
150
142
  #elif defined(GGML_USE_CUBLAS)
151
- #include <cublas_v2.h>
152
- #include <cuda_runtime.h>
153
143
  #include "ggml-cuda.h"
154
-
155
- #define CUDA_CHECK(err) \
156
- do { \
157
- cudaError_t err_ = (err); \
158
- if (err_ != cudaSuccess) { \
159
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
- cudaGetErrorString(err_)); \
161
- exit(1); \
162
- } \
163
- } while (0)
164
-
165
- #define CUBLAS_CHECK(err) \
166
- do { \
167
- cublasStatus_t err_ = (err); \
168
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
- exit(1); \
171
- } \
172
- } while (0)
173
-
174
- static cublasHandle_t cublasH = NULL;
175
- static cudaStream_t cudaStream = NULL;
176
- static void init_cublas(void) {
177
- if (cublasH == NULL) {
178
- // create cublas handle, bind a stream
179
- CUBLAS_CHECK(cublasCreate(&cublasH));
180
-
181
- CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
-
183
- CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
-
185
- // configure logging to stdout
186
- // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
- }
188
- }
144
+ #elif defined(GGML_USE_CLBLAST)
145
+ #include "ggml-opencl.h"
189
146
  #endif
190
147
 
191
148
  #undef MIN
@@ -223,9 +180,13 @@ typedef double ggml_float;
223
180
  #undef bool
224
181
  #define bool _Bool
225
182
  #else
183
+ #if defined(_MSC_VER) || defined(__MINGW32__)
184
+ #include <intrin.h>
185
+ #else
226
186
  #include <immintrin.h>
227
187
  #endif
228
188
  #endif
189
+ #endif
229
190
 
230
191
  #ifdef __F16C__
231
192
 
@@ -365,6 +326,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
365
326
  // precomputed f32 table for f16 (256 KB)
366
327
  static float table_f32_f16[1 << 16];
367
328
 
329
+ #if defined(__ARM_NEON) || defined(__wasm_simd128__)
330
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
331
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
332
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
333
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
334
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
335
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
336
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
337
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
338
+
339
+ // precomputed tables for expanding 8bits to 8 bytes (shl 4)
340
+ static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
341
+ #endif
342
+
368
343
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
369
344
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
370
345
  // This is also true for POWER9.
@@ -391,6 +366,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
391
366
  return GGML_FP32_TO_FP16(x);
392
367
  }
393
368
 
369
+ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
370
+ for (size_t i = 0; i < n; i++) {
371
+ y[i] = GGML_FP16_TO_FP32(x[i]);
372
+ }
373
+ }
374
+
375
+ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
376
+ size_t i = 0;
377
+ #if defined(__F16C__)
378
+ for (; i + 7 < n; i += 8) {
379
+ __m256 x_vec = _mm256_loadu_ps(x + i);
380
+ __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
381
+ _mm_storeu_si128((__m128i *)(y + i), y_vec);
382
+ }
383
+ for(; i + 3 < n; i += 4) {
384
+ __m128 x_vec = _mm_loadu_ps(x + i);
385
+ __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
386
+ _mm_storel_epi64((__m128i *)(y + i), y_vec);
387
+ }
388
+ #endif
389
+ for (; i < n; i++) {
390
+ y[i] = GGML_FP32_TO_FP16(x[i]);
391
+ }
392
+ }
393
+
394
+
394
395
  //
395
396
  // timing
396
397
  //
@@ -473,7 +474,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
473
474
  static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
475
  {
475
476
  // Load 8 bytes from memory
476
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477
+ __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
477
478
 
478
479
  // Expand bytes into uint16_t values
479
480
  __m128i bytes = _mm_cvtepu8_epi16( tmp );
@@ -487,7 +488,46 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
487
488
  return bytes;
488
489
  }
489
490
 
491
+ // horizontally add 8 floats
492
+ static inline float hsum_float_8(const __m256 x) {
493
+ __m128 res = _mm256_extractf128_ps(x, 1);
494
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
495
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
496
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
497
+ return _mm_cvtss_f32(res);
498
+ }
499
+
500
+ // horizontally add 8 int32_t
501
+ static inline int hsum_i32_8(const __m256i a) {
502
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
503
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
504
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
505
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
506
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
507
+ }
508
+
509
+ // horizontally add 4 int32_t
510
+ static inline int hsum_i32_4(const __m128i a) {
511
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
512
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
513
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
514
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
515
+ }
516
+
490
517
  #if __AVX2__ || __AVX512F__
518
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
519
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
520
+ uint32_t x32;
521
+ memcpy(&x32, x, sizeof(uint32_t));
522
+ const __m256i shuf_mask = _mm256_set_epi64x(
523
+ 0x0303030303030303, 0x0202020202020202,
524
+ 0x0101010101010101, 0x0000000000000000);
525
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
526
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
527
+ bytes = _mm256_or_si256(bytes, bit_mask);
528
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
529
+ }
530
+
491
531
  // Unpack 32 4-bit fields into 32 bytes
492
532
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
493
533
  static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
@@ -507,9 +547,38 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
507
547
  return bytes;
508
548
  }
509
549
 
550
+ // add int16_t pairwise and return as float vector
551
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
552
+ const __m256i ones = _mm256_set1_epi16(1);
553
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
554
+ return _mm256_cvtepi32_ps(summed_pairs);
555
+ }
556
+
557
+ // multiply int8_t, add results pairwise twice and return as float vector
558
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
559
+ // Get absolute values of x vectors
560
+ const __m256i ax = _mm256_sign_epi8(x, x);
561
+ // Sign the values of the y vectors
562
+ const __m256i sy = _mm256_sign_epi8(y, x);
563
+ #if __AVXVNNI__
564
+ const __m256i zero = _mm256_setzero_si256();
565
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
566
+ return _mm256_cvtepi32_ps(summed_pairs);
567
+ #else
568
+ // Perform multiplication and create 16-bit values
569
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
570
+ return sum_i16_pairs_float(dot);
571
+ #endif
572
+ }
573
+
510
574
  static inline __m128i packNibbles( __m256i bytes )
511
575
  {
512
576
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
577
+ #if __AVX512F__
578
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
579
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
580
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
581
+ #else
513
582
  const __m256i lowByte = _mm256_set1_epi16( 0xFF );
514
583
  __m256i high = _mm256_andnot_si256( lowByte, bytes );
515
584
  __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -520,6 +589,7 @@ static inline __m128i packNibbles( __m256i bytes )
520
589
  __m128i r0 = _mm256_castsi256_si128( bytes );
521
590
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
522
591
  return _mm_packus_epi16( r0, r1 );
592
+ #endif
523
593
  }
524
594
  #else
525
595
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
@@ -605,19 +675,102 @@ float vmaxvq_f32(float32x4_t v) {
605
675
  }
606
676
 
607
677
  int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
608
- return vget_low_s8(vcombine_s8(a, b));
678
+ int8x8_t res;
679
+
680
+ res[0] = a[0]; res[1] = b[0];
681
+ res[2] = a[1]; res[3] = b[1];
682
+ res[4] = a[2]; res[5] = b[2];
683
+ res[6] = a[3]; res[7] = b[3];
684
+
685
+ return res;
609
686
  }
610
687
 
611
688
  int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
612
- return vget_high_s8(vcombine_s8(a, b));
689
+ int8x8_t res;
690
+
691
+ res[0] = a[4]; res[1] = b[4];
692
+ res[2] = a[5]; res[3] = b[5];
693
+ res[4] = a[6]; res[5] = b[6];
694
+ res[6] = a[7]; res[7] = b[7];
695
+
696
+ return res;
613
697
  }
614
698
 
615
699
  uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
616
- return vget_low_u8(vcombine_u8(a, b));
700
+ uint8x8_t res;
701
+
702
+ res[0] = a[0]; res[1] = b[0];
703
+ res[2] = a[1]; res[3] = b[1];
704
+ res[4] = a[2]; res[5] = b[2];
705
+ res[6] = a[3]; res[7] = b[3];
706
+
707
+ return res;
617
708
  }
618
709
 
619
710
  uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
620
- return vget_high_u8(vcombine_u8(a, b));
711
+ uint8x8_t res;
712
+
713
+ res[0] = a[4]; res[1] = b[4];
714
+ res[2] = a[5]; res[3] = b[5];
715
+ res[4] = a[6]; res[5] = b[6];
716
+ res[6] = a[7]; res[7] = b[7];
717
+
718
+ return res;
719
+ }
720
+
721
+ int8x16_t vzip1q_s8(int8x16_t a, int8x16_t b) {
722
+ int8x16_t res;
723
+
724
+ res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
725
+ res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
726
+ res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
727
+ res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
728
+
729
+ return res;
730
+ }
731
+
732
+ int8x16_t vzip2q_s8(int8x16_t a, int8x16_t b) {
733
+ int8x16_t res;
734
+
735
+ res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
736
+ res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
737
+ res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
738
+ res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
739
+
740
+ return res;
741
+ }
742
+
743
+ uint8x16_t vzip1q_u8(uint8x16_t a, uint8x16_t b) {
744
+ uint8x16_t res;
745
+
746
+ res[0] = a[0]; res[1] = b[0]; res[2] = a[1]; res[3] = b[1];
747
+ res[4] = a[2]; res[5] = b[2]; res[6] = a[3]; res[7] = b[3];
748
+ res[8] = a[4]; res[9] = b[4]; res[10] = a[5]; res[11] = b[5];
749
+ res[12] = a[6]; res[13] = b[6]; res[14] = a[7]; res[15] = b[7];
750
+
751
+ return res;
752
+ }
753
+
754
+ uint8x16_t vzip2q_u8(uint8x16_t a, uint8x16_t b) {
755
+ uint8x16_t res;
756
+
757
+ res[0] = a[8]; res[1] = b[8]; res[2] = a[9]; res[3] = b[9];
758
+ res[4] = a[10]; res[5] = b[10]; res[6] = a[11]; res[7] = b[11];
759
+ res[8] = a[12]; res[9] = b[12]; res[10] = a[13]; res[11] = b[13];
760
+ res[12] = a[14]; res[13] = b[14]; res[14] = a[15]; res[15] = b[15];
761
+
762
+ return res;
763
+ }
764
+
765
+ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
766
+ int32x4_t res;
767
+
768
+ res[0] = roundf(vgetq_lane_f32(v, 0));
769
+ res[1] = roundf(vgetq_lane_f32(v, 1));
770
+ res[2] = roundf(vgetq_lane_f32(v, 2));
771
+ res[3] = roundf(vgetq_lane_f32(v, 3));
772
+
773
+ return res;
621
774
  }
622
775
 
623
776
  #endif
@@ -646,13 +799,22 @@ typedef struct {
646
799
  } block_q4_2;
647
800
  static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
648
801
 
649
- #define QK4_3 16
802
+ #define QK5_0 32
803
+ typedef struct {
804
+ ggml_fp16_t d; // delta
805
+ uint8_t qh[4]; // 5-th bit of quants
806
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
807
+ } block_q5_0;
808
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
809
+
810
+ #define QK5_1 32
650
811
  typedef struct {
651
812
  ggml_fp16_t d; // delta
652
813
  ggml_fp16_t m; // min
653
- uint8_t qs[QK4_3 / 2]; // nibbles / quants
654
- } block_q4_3;
655
- static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
814
+ uint8_t qh[4]; // 5-th bit of quants
815
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
816
+ } block_q5_1;
817
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
656
818
 
657
819
  #define QK8_0 32
658
820
  typedef struct {
@@ -661,6 +823,14 @@ typedef struct {
661
823
  } block_q8_0;
662
824
  static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663
825
 
826
+ #define QK8_1 32
827
+ typedef struct {
828
+ float d; // delta
829
+ float s0; // d * sum(qs[i]) low
830
+ float s1; // d * sum(qs[i]) high
831
+ int8_t qs[QK8_1]; // quants
832
+ } block_q8_1;
833
+ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
664
834
 
665
835
  // reference implementation for deterministic creation of model files
666
836
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -671,13 +841,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
671
841
 
672
842
  for (int i = 0; i < nb; i++) {
673
843
  float amax = 0.0f; // absolute max
844
+ float max = 0.0f;
674
845
 
675
846
  for (int l = 0; l < QK4_0; l++) {
676
847
  const float v = x[i*QK4_0 + l];
677
- amax = MAX(amax, fabsf(v));
848
+ if (amax < fabsf(v)) {
849
+ amax = fabsf(v);
850
+ max = v;
851
+ }
678
852
  }
679
853
 
680
- const float d = amax / ((1 << 3) - 1);
854
+ const float d = max / -8;
681
855
  const float id = d ? 1.0f/d : 0.0f;
682
856
 
683
857
  y[i].d = d;
@@ -686,8 +860,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
686
860
  const float v0 = x[i*QK4_0 + l + 0]*id;
687
861
  const float v1 = x[i*QK4_0 + l + 1]*id;
688
862
 
689
- const uint8_t vi0 = (int8_t)roundf(v0) + 8;
690
- const uint8_t vi1 = (int8_t)roundf(v1) + 8;
863
+ const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
864
+ const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
691
865
 
692
866
  assert(vi0 < 16);
693
867
  assert(vi1 < 16);
@@ -707,28 +881,43 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
707
881
 
708
882
  #if defined(__POWER9_VECTOR__)
709
883
  const vector float v85 = vec_splats(8.5f);
884
+ const vector signed int v15 = vec_splats(15);
710
885
  for (int i = 0; i < nb; i++) {
711
- float amax = 0.0f; // absolute max
886
+ float max = 0.0f;
887
+ float min = 0.0f;
712
888
 
889
+ vector float asrcv [8];
713
890
  vector float srcv [8];
714
- vector float asrcv[8];
715
- vector float amaxv[8];
891
+ vector float maxv[8];
892
+ vector float minv[8];
716
893
 
717
894
  for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
718
- for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
719
-
720
- for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
721
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
722
- amaxv[0] = vec_max(amaxv[0], amaxv[2]);
723
- amaxv[4] = vec_max(amaxv[4], amaxv[6]);
724
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
725
- amaxv[0] = vec_max(amaxv[0], amaxv[4]);
726
-
727
- amax = MAX(
728
- MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
729
- MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
730
-
731
- const float d = amax / ((1 << 3) - 1);
895
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
896
+
897
+ for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
898
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
899
+ maxv[0] = vec_max(maxv[0], maxv[2]);
900
+ maxv[4] = vec_max(maxv[4], maxv[6]);
901
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
902
+ maxv[0] = vec_max(maxv[0], maxv[4]);
903
+
904
+ for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
905
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
906
+ minv[0] = vec_min(minv[0], minv[2]);
907
+ minv[4] = vec_min(minv[4], minv[6]);
908
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
909
+ minv[0] = vec_min(minv[0], minv[4]);
910
+
911
+
912
+ max = MAX(
913
+ MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
914
+ MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
915
+ min = MIN(
916
+ MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
917
+ MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
918
+
919
+ const float magnitude = max >= fabsf(min) ? max : min;
920
+ const float d = magnitude / -8;
732
921
  const float id = d ? 1.0/d : 0.0;
733
922
 
734
923
  y[i].d = d;
@@ -738,27 +927,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
738
927
  for (int l = 0; l < 8; l++) {
739
928
  const vector float vf = vec_madd(srcv[l], vid, v85);
740
929
  const vector signed int vi = vec_signed(vf);
930
+ const vector signed int vc = vec_min(vi, v15);
741
931
 
742
- pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
743
- pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
932
+ pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
933
+ pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
744
934
  }
745
935
  }
746
936
  #elif __ARM_NEON
747
937
  for (int i = 0; i < nb; i++) {
748
938
  float32x4_t srcv [8];
749
- float32x4_t asrcv[8];
750
- float32x4_t amaxv[8];
939
+ float32x4_t maxv[8];
940
+ float32x4_t minv[8];
751
941
 
752
942
  for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
753
- for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
754
943
 
755
- for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
756
- for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
757
- for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
944
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
945
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
946
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
758
947
 
759
- const float amax = vmaxvq_f32(amaxv[0]);
948
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
949
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
950
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
951
+
952
+ const float max = vmaxvq_f32(maxv[0]);
953
+ const float min = vminvq_f32(minv[0]);
760
954
 
761
- const float d = amax / ((1 << 3) - 1);
955
+ const float magnitude = max >= fabsf(min) ? max : min;
956
+ const float d = magnitude / -8;
762
957
  const float id = d ? 1.0f/d : 0.0f;
763
958
 
764
959
  y[i].d = d;
@@ -767,9 +962,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
767
962
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
768
963
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
769
964
  const int32x4_t vi = vcvtq_s32_f32(vf);
965
+ const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
770
966
 
771
- y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
772
- y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
967
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
968
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
773
969
  }
774
970
  }
775
971
  #elif defined(__AVX2__)
@@ -781,22 +977,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
781
977
  __m256 v3 = _mm256_loadu_ps( x + 24 );
782
978
  x += 32;
783
979
 
784
- // Compute max(abs(e)) for the block
785
- const __m256 signBit = _mm256_set1_ps( -0.0f );
786
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
787
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
788
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
789
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
980
+ // Compute max for the block
981
+ __m256 max = _mm256_max_ps( v0, v1 );
982
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
983
+ max = _mm256_max_ps( max, maxTmp );
790
984
 
791
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
985
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
792
986
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
793
987
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
794
988
  const float maxScalar = _mm_cvtss_f32( max4 );
795
989
 
990
+ // Compute min for the block
991
+ __m256 min = _mm256_min_ps( v0, v1 );
992
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
993
+ min = _mm256_min_ps( min, minTmp );
994
+
995
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
996
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
997
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
998
+ const float minScalar = _mm_cvtss_f32( min4 );
999
+
796
1000
  // Quantize these floats
797
- const float d = maxScalar / 7.0f;
1001
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
1002
+ const float d = magnitude / -8.0f;
798
1003
  y[i].d = d;
799
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
1004
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
800
1005
  const __m256 mul = _mm256_set1_ps( id );
801
1006
 
802
1007
  // Apply the multiplier
@@ -829,9 +1034,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
829
1034
  const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
830
1035
  i0 = _mm256_permutevar8x32_epi32( i0, perm );
831
1036
 
832
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
1037
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
833
1038
  const __m256i off = _mm256_set1_epi8( 8 );
834
1039
  i0 = _mm256_add_epi8( i0, off );
1040
+ const __m256i maxNibble = _mm256_set1_epi8( 15 );
1041
+ i0 = _mm256_min_epi8( i0, maxNibble );
835
1042
 
836
1043
  // Compress the vector into 4 bit/value, and store
837
1044
  __m128i res = packNibbles( i0 );
@@ -846,22 +1053,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
846
1053
  __m256 v3 = _mm256_loadu_ps( x + 24 );
847
1054
  x += 32;
848
1055
 
849
- // Compute max(abs(e)) for the block
850
- const __m256 signBit = _mm256_set1_ps( -0.0f );
851
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
852
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
853
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
854
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1056
+ // Compute max for the block
1057
+ __m256 max = _mm256_max_ps( v0, v1 );
1058
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
1059
+ max = _mm256_max_ps( max, maxTmp );
855
1060
 
856
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1061
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
857
1062
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
858
1063
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
859
1064
  const float maxScalar = _mm_cvtss_f32( max4 );
860
1065
 
1066
+ // Compute min for the block
1067
+ __m256 min = _mm256_min_ps( v0, v1 );
1068
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
1069
+ min = _mm256_min_ps( min, minTmp );
1070
+
1071
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
1072
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
1073
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
1074
+ const float minScalar = _mm_cvtss_f32( min4 );
1075
+
861
1076
  // Quantize these floats
862
- const float d = maxScalar / 7.0f;
1077
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
1078
+ const float d = magnitude / -8.0f;
863
1079
  y[i].d = d;
864
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
1080
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
865
1081
  const __m256 mul = _mm256_set1_ps( id );
866
1082
 
867
1083
  // Apply the multiplier
@@ -902,10 +1118,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
902
1118
  ni0 = _mm_packs_epi16( ni0, ni2 );
903
1119
  ni4 = _mm_packs_epi16( ni4, ni6 );
904
1120
 
905
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
906
- const __m128i off = _mm_set1_epi8( 8);
1121
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
1122
+ const __m128i off = _mm_set1_epi8( 8 );
907
1123
  ni0 = _mm_add_epi8( ni0, off );
908
1124
  ni4 = _mm_add_epi8( ni4, off );
1125
+ const __m128i maxNibble = _mm_set1_epi8( 15 );
1126
+ ni0 = _mm_min_epi8( ni0, maxNibble );
1127
+ ni4 = _mm_min_epi8( ni4, maxNibble );
909
1128
 
910
1129
  // Compress the vector into 4 bit/value, and store
911
1130
  __m128i res = packNibbles( ni0, ni4 );
@@ -913,24 +1132,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
913
1132
  }
914
1133
  #elif defined(__wasm_simd128__)
915
1134
  for (int i = 0; i < nb; i++) {
916
- float amax = 0.0f; // absolute max
1135
+ float max = 0.0f;
1136
+ float min = 0.0f;
917
1137
 
918
1138
  v128_t srcv [8];
919
- v128_t asrcv[8];
920
- v128_t amaxv[8];
1139
+ v128_t maxv[8];
1140
+ v128_t minv[8];
921
1141
 
922
1142
  for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
923
- for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
924
1143
 
925
- for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
926
- for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
927
- for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
1144
+ for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
1145
+ for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
1146
+ for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
1147
+
1148
+ for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
1149
+ for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
1150
+ for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
928
1151
 
929
- amax = MAX(
930
- MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
931
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
1152
+ max = MAX(
1153
+ MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
1154
+ MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
1155
+ min = MIN(
1156
+ MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
1157
+ MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
932
1158
 
933
- const float d = amax / ((1 << 3) - 1);
1159
+ const float magnitude = max >= fabsf(min) ? max : min;
1160
+ const float d = magnitude / -8;
934
1161
  const float id = d ? 1.0/d : 0.0;
935
1162
 
936
1163
  y[i].d = d;
@@ -939,9 +1166,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
939
1166
  const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
940
1167
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
941
1168
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1169
+ const v128_t vc = wasm_i32x4_min(vi, wasm_i32x4_splat(15));
942
1170
 
943
- y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
944
- y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
1171
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
1172
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
945
1173
  }
946
1174
  }
947
1175
  #else
@@ -1122,13 +1350,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1122
1350
 
1123
1351
  for (int i = 0; i < nb; i++) {
1124
1352
  float amax = 0.0f; // absolute max
1353
+ float max = 0.0f;
1125
1354
 
1126
1355
  for (int l = 0; l < QK4_2; l++) {
1127
1356
  const float v = x[i*QK4_2 + l];
1128
- amax = MAX(amax, fabsf(v));
1357
+ if (amax < fabsf(v)) {
1358
+ amax = fabsf(v);
1359
+ max = v;
1360
+ }
1129
1361
  }
1130
1362
 
1131
- const float d = amax / ((1 << 3) - 1);
1363
+ const float d = max / -8;
1132
1364
 
1133
1365
  const float id = d ? 1.0f/d : 0.0f;
1134
1366
 
@@ -1138,8 +1370,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1138
1370
  const float v0 = x[i*QK4_2 + l + 0]*id;
1139
1371
  const float v1 = x[i*QK4_2 + l + 1]*id;
1140
1372
 
1141
- const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
- const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1373
+ const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
1374
+ const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
1143
1375
 
1144
1376
  assert(vi0 < 16);
1145
1377
  assert(vi1 < 16);
@@ -1149,136 +1381,109 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
1149
1381
  }
1150
1382
  }
1151
1383
 
1152
- static inline int nearest_int(float fval) {
1153
- assert(fval <= 4194303.f);
1154
- float val = fval + 12582912.f;
1155
- int i; memcpy(&i, &val, sizeof(int));
1156
- return (i & 0x007fffff) - 0x00400000;
1157
- }
1158
-
1159
- static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
- const float * restrict candidates, int8_t * restrict L) {
1161
- assert (nmin >= INT8_MIN);
1162
- assert (nmax <= INT8_MAX);
1163
- float amax = 0;
1164
- for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
- if (!amax) { // all zero
1166
- for (int i=0; i<n; ++i) L[i] = 0;
1167
- return 1.f;
1168
- }
1169
- float best = 0, bestScale = 0;
1170
- for (int si=0; si<nCandidates; ++si) {
1171
- float iscale = candidates[si]/amax;
1172
- float sumlxP = 0; int suml2P = 0;
1173
- float sumlxM = 0; int suml2M = 0;
1174
- for (int i=0; i<n; ++i) {
1175
- int l = nearest_int(iscale*X[i]);
1176
- int lp = MAX(nmin, MIN(nmax, +l));
1177
- int lm = MAX(nmin, MIN(nmax, -l));
1178
- sumlxP += X[i]*lp; suml2P += lp*lp;
1179
- sumlxM += X[i]*lm; suml2M += lm*lm;
1180
- }
1181
- float sumlxP2 = sumlxP*sumlxP;
1182
- float sumlxM2 = sumlxM*sumlxM;
1183
- if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
- if (sumlxP2 > best*suml2P) {
1185
- best = sumlxP2/suml2P; bestScale = iscale;
1186
- }
1187
- } else {
1188
- if (sumlxM2 > best*suml2M) {
1189
- best = sumlxM2/suml2M; bestScale = -iscale;
1384
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1385
+ assert(k % QK4_2 == 0);
1386
+
1387
+ block_q4_2 * restrict y = vy;
1388
+
1389
+ quantize_row_q4_2_reference(x, y, k);
1390
+ }
1391
+
1392
+ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
1393
+ assert(k % QK5_0 == 0);
1394
+ const int nb = k / QK5_0;
1395
+
1396
+ for (int i = 0; i < nb; i++) {
1397
+ float amax = 0.0f; // absolute max
1398
+ float max = 0.0f;
1399
+
1400
+ for (int l = 0; l < QK5_0; l++) {
1401
+ const float v = x[i*QK5_0 + l];
1402
+ if (amax < fabsf(v)) {
1403
+ amax = fabsf(v);
1404
+ max = v;
1190
1405
  }
1191
1406
  }
1192
- }
1193
- float sumlx = 0; int suml2 = 0;
1194
- for (int i=0; i<n; ++i) {
1195
- int l = nearest_int(bestScale*X[i]);
1196
- l = MAX(nmin, MIN(nmax, l));
1197
- sumlx += X[i]*l; suml2 += l*l;
1198
- L[i] = l;
1199
- }
1200
- float scale = sumlx/suml2;
1201
- return scale;
1202
- }
1203
1407
 
1204
- static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
- #define CANDIDATE_COUNT 8
1206
- static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
- assert(k % QK4_2 == 0);
1408
+ const float d = max / -16;
1409
+ const float id = d ? 1.0f/d : 0.0f;
1208
1410
 
1209
- int8_t L[QK4_2];
1411
+ y[i].d = GGML_FP32_TO_FP16(d);
1210
1412
 
1211
- const int nb = k / QK4_2;
1413
+ uint32_t qh = 0;
1212
1414
 
1213
- for (int i = 0; i < nb; i++) {
1214
- float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
- y[i].d = GGML_FP32_TO_FP16(scale);
1415
+ for (int l = 0; l < QK5_0; l += 2) {
1416
+ const float v0 = x[i*QK5_0 + l + 0]*id;
1417
+ const float v1 = x[i*QK5_0 + l + 1]*id;
1216
1418
 
1217
- for (int l = 0; l < QK4_2; l += 2) {
1218
- const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
- const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1419
+ const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1420
+ const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1220
1421
 
1221
- assert(vi0 < 16);
1222
- assert(vi1 < 16);
1422
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1223
1423
 
1224
- y[i].qs[l/2] = vi0 | (vi1 << 4);
1424
+ // get the 5-th bit and store it in qh at the right position
1425
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1426
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1225
1427
  }
1226
1428
 
1227
- x += QK4_2;
1429
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1228
1430
  }
1229
1431
  }
1230
1432
 
1231
- static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1232
- assert(k % QK4_2 == 0);
1433
+ static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1434
+ assert(k % QK5_0 == 0);
1233
1435
 
1234
- block_q4_2 * restrict y = vy;
1436
+ block_q5_0 * restrict y = vy;
1235
1437
 
1236
- //quantize_row_q4_2_reference(x, y, k);
1237
- // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
- quantize_row_q4_2_rmse(x, y, k);
1438
+ quantize_row_q5_0_reference(x, y, k);
1239
1439
  }
1240
1440
 
1241
- static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1242
- assert(k % QK4_3 == 0);
1243
- const int nb = k / QK4_3;
1441
+ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1442
+ assert(k % QK5_1 == 0);
1443
+ const int nb = k / QK5_1;
1244
1444
 
1245
1445
  for (int i = 0; i < nb; i++) {
1246
1446
  float min = FLT_MAX;
1247
1447
  float max = -FLT_MAX;
1248
1448
 
1249
- for (int l = 0; l < QK4_3; l++) {
1250
- const float v = x[i*QK4_3 + l];
1449
+ for (int l = 0; l < QK5_1; l++) {
1450
+ const float v = x[i*QK5_1 + l];
1251
1451
  if (v < min) min = v;
1252
1452
  if (v > max) max = v;
1253
1453
  }
1254
1454
 
1255
- const float d = (max - min) / ((1 << 4) - 1);
1455
+ const float d = (max - min) / ((1 << 5) - 1);
1256
1456
  const float id = d ? 1.0f/d : 0.0f;
1257
1457
 
1258
1458
  y[i].d = GGML_FP32_TO_FP16(d);
1259
1459
  y[i].m = GGML_FP32_TO_FP16(min);
1260
1460
 
1261
- for (int l = 0; l < QK4_3; l += 2) {
1262
- const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1263
- const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1461
+ uint32_t qh = 0;
1264
1462
 
1265
- const uint8_t vi0 = (int) (v0 + 0.5f);
1266
- const uint8_t vi1 = (int) (v1 + 0.5f);
1463
+ for (int l = 0; l < QK5_1; l += 2) {
1464
+ const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1465
+ const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1267
1466
 
1268
- assert(vi0 < 16);
1269
- assert(vi1 < 16);
1467
+ const uint32_t vi0 = (int) (v0 + 0.5f);
1468
+ const uint32_t vi1 = (int) (v1 + 0.5f);
1270
1469
 
1271
- y[i].qs[l/2] = vi0 | (vi1 << 4);
1470
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1471
+
1472
+ // get the 5-th bit and store it in qh at the right position
1473
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1474
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1272
1475
  }
1476
+
1477
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1273
1478
  }
1274
1479
  }
1275
1480
 
1276
- static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1277
- assert(k % QK4_3 == 0);
1481
+ static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1482
+ assert(k % QK5_1 == 0);
1278
1483
 
1279
- block_q4_3 * restrict y = vy;
1484
+ block_q5_1 * restrict y = vy;
1280
1485
 
1281
- quantize_row_q4_3_reference(x, y, k);
1486
+ quantize_row_q5_1_reference(x, y, k);
1282
1487
  }
1283
1488
 
1284
1489
  // reference implementation for deterministic creation of model files
@@ -1300,13 +1505,15 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1300
1505
  y[i].d = d;
1301
1506
 
1302
1507
  for (int l = 0; l < QK8_0; ++l) {
1303
- const float v = x[i*QK8_0 + l]*id;
1304
- y[i].qs[l] = roundf(v);
1508
+ const float v0 = x[i*QK8_0 + l]*id;
1509
+
1510
+ y[i].qs[l] = roundf(v0);
1305
1511
  }
1306
1512
  }
1307
1513
  }
1308
1514
 
1309
1515
  static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1516
+ assert(QK8_0 == 32);
1310
1517
  assert(k % QK8_0 == 0);
1311
1518
  const int nb = k / QK8_0;
1312
1519
 
@@ -1432,95 +1639,295 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1432
1639
  #endif
1433
1640
  }
1434
1641
 
1435
- static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1436
- assert(k % QK4_0 == 0);
1437
- const int nb = k / QK4_0;
1438
-
1439
- const block_q4_0 * restrict x = vx;
1642
+ // reference implementation for deterministic creation of model files
1643
+ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1644
+ assert(QK8_1 == 32);
1645
+ assert(k % QK8_1 == 0);
1646
+ const int nb = k / QK8_1;
1440
1647
 
1441
- #if defined(__AVX2__)
1442
1648
  for (int i = 0; i < nb; i++) {
1443
- // scale factor
1444
- const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
1649
+ float amax = 0.0f; // absolute max
1445
1650
 
1446
- const uint8_t * restrict pp = x[i].qs;
1651
+ for (int l = 0; l < QK8_1; l++) {
1652
+ const float v = x[i*QK8_1 + l];
1653
+ amax = MAX(amax, fabsf(v));
1654
+ }
1447
1655
 
1448
- for (int l = 0; l < QK4_0; l += 32) {
1449
- // Load 32x4-bit integers into 32x8-bit integers
1450
- __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1656
+ const float d = amax / ((1 << 7) - 1);
1657
+ const float id = d ? 1.0f/d : 0.0f;
1451
1658
 
1452
- // Subtract 8 from the integers
1453
- vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
1659
+ y[i].d = d;
1454
1660
 
1455
- // Convert to 16-bit int
1456
- const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
1457
- const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
1661
+ int sum0 = 0;
1662
+ int sum1 = 0;
1458
1663
 
1459
- // Convert to 32-bit int -> float 32
1460
- const __m256 vf[4] = {
1461
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
1462
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
1463
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
1464
- _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
1465
- };
1664
+ for (int l = 0; l < QK8_1/2; ++l) {
1665
+ const float v0 = x[i*QK8_1 + l]*id;
1666
+ const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
1466
1667
 
1467
- // Scale and store
1468
- for (int j = 0; j < 4; j++) {
1469
- const __m256 result = _mm256_mul_ps(vf[j], d_v);
1470
- _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1471
- }
1668
+ y[i].qs[ l] = roundf(v0);
1669
+ y[i].qs[QK8_1/2 + l] = roundf(v1);
1670
+
1671
+ sum0 += y[i].qs[ l];
1672
+ sum1 += y[i].qs[QK8_1/2 + l];
1472
1673
  }
1674
+
1675
+ y[i].s0 = d * sum0;
1676
+ y[i].s1 = d * sum1;
1473
1677
  }
1474
- #elif defined(__ARM_NEON)
1475
- for (int i = 0; i < nb; i++) {
1476
- const float32x4_t vd = vdupq_n_f32(x[i].d);
1678
+ }
1477
1679
 
1478
- const uint8_t * restrict pp = x[i].qs;
1680
+ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1681
+ assert(k % QK8_1 == 0);
1682
+ const int nb = k / QK8_1;
1479
1683
 
1480
- for (int l = 0; l < QK4_0; l += 16) {
1481
- // Load 16x4-bit integers into 8x8-bit integers
1482
- const uint8x8_t v8 = vld1_u8(pp + l/2);
1684
+ block_q8_1 * restrict y = vy;
1483
1685
 
1484
- // Expand 4-bit qs to 8-bit bytes
1485
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1486
- const uint8x8_t v1 = vshr_n_u8(v8, 4);
1686
+ #if defined(__ARM_NEON)
1687
+ for (int i = 0; i < nb; i++) {
1688
+ float32x4_t srcv [8];
1689
+ float32x4_t asrcv[8];
1690
+ float32x4_t amaxv[8];
1487
1691
 
1488
- // Convert to signed 8-bit integers
1489
- const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
1490
- const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
1692
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1693
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1491
1694
 
1492
- // Subtract 8 from each byte
1493
- const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
1494
- const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
1695
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1696
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1697
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1495
1698
 
1496
- // Interleave and combine
1497
- const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
1498
- const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
1699
+ const float amax = vmaxvq_f32(amaxv[0]);
1499
1700
 
1500
- const int8x16_t vq = vcombine_s8(vx_0, vx_1);
1701
+ const float d = amax / ((1 << 7) - 1);
1702
+ const float id = d ? 1.0f/d : 0.0f;
1501
1703
 
1502
- // convert to 2x int16x8_t
1503
- const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
1504
- const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
1704
+ y[i].d = d;
1505
1705
 
1506
- // convert to 4x float32x4_t
1507
- const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
1508
- const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
1509
- const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
1510
- const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
1706
+ int32x4_t accv0 = vdupq_n_s32(0);
1707
+ int32x4_t accv1 = vdupq_n_s32(0);
1511
1708
 
1512
- // Multiply by d
1513
- const float32x4_t r0 = vmulq_f32(vf_0, vd);
1514
- const float32x4_t r1 = vmulq_f32(vf_1, vd);
1515
- const float32x4_t r2 = vmulq_f32(vf_2, vd);
1516
- const float32x4_t r3 = vmulq_f32(vf_3, vd);
1709
+ // low half
1710
+ for (int l = 0; l < 4; l++) {
1711
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1712
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1517
1713
 
1518
- // Store
1519
- vst1q_f32(y + i*QK4_0 + l + 0, r0);
1520
- vst1q_f32(y + i*QK4_0 + l + 4, r1);
1521
- vst1q_f32(y + i*QK4_0 + l + 8, r2);
1522
- vst1q_f32(y + i*QK4_0 + l + 12, r3);
1523
- }
1714
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1715
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1716
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1717
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1718
+
1719
+ accv0 = vaddq_s32(accv0, vi);
1720
+ }
1721
+
1722
+ // high half
1723
+ for (int l = 4; l < 8; l++) {
1724
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1725
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1726
+
1727
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1728
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1729
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1730
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1731
+
1732
+ accv1 = vaddq_s32(accv1, vi);
1733
+ }
1734
+
1735
+ const int32_t sum0 = vaddvq_s32(accv0);
1736
+ const int32_t sum1 = vaddvq_s32(accv1);
1737
+
1738
+ y[i].s0 = d * sum0;
1739
+ y[i].s1 = d * sum1;
1740
+ }
1741
+ #elif defined(__AVX2__) || defined(__AVX__)
1742
+ for (int i = 0; i < nb; i++) {
1743
+ // Load elements into 4 AVX vectors
1744
+ __m256 v0 = _mm256_loadu_ps( x );
1745
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1746
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1747
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1748
+ x += 32;
1749
+
1750
+ // Compute max(abs(e)) for the block
1751
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1752
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1753
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1754
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1755
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1756
+
1757
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1758
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1759
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1760
+ const float maxScalar = _mm_cvtss_f32( max4 );
1761
+
1762
+ // Quantize these floats
1763
+ const float d = maxScalar / 127.f;
1764
+ y[i].d = d;
1765
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1766
+ const __m256 mul = _mm256_set1_ps( id );
1767
+
1768
+ // Apply the multiplier
1769
+ v0 = _mm256_mul_ps( v0, mul );
1770
+ v1 = _mm256_mul_ps( v1, mul );
1771
+ v2 = _mm256_mul_ps( v2, mul );
1772
+ v3 = _mm256_mul_ps( v3, mul );
1773
+
1774
+ // Round to nearest integer
1775
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1776
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1777
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1778
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1779
+
1780
+ // Convert floats to integers
1781
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1782
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1783
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1784
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1785
+
1786
+ #if defined(__AVX2__)
1787
+ // Compute the sum of the quants and set y[i].s
1788
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1789
+ y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1790
+ y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1791
+
1792
+ // Convert int32 to int16
1793
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1794
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1795
+ // Convert int16 to int8
1796
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1797
+
1798
+ // We got our precious signed bytes, but the order is now wrong
1799
+ // These AVX2 pack instructions process 16-byte pieces independently
1800
+ // The following instruction is fixing the order
1801
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1802
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1803
+
1804
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1805
+ #else
1806
+ // Since we don't have in AVX some necessary functions,
1807
+ // we split the registers in half and call AVX2 analogs from SSE
1808
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1809
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1810
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1811
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1812
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1813
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1814
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1815
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1816
+
1817
+ // Compute the sum of the quants and set y[i].s
1818
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1819
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1820
+ y[i].s0 = d * hsum_i32_4(s0);
1821
+ y[i].s1 = d * hsum_i32_4(s1);
1822
+
1823
+ // Convert int32 to int16
1824
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1825
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1826
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1827
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1828
+ // Convert int16 to int8
1829
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1830
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1831
+
1832
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1833
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1834
+ #endif
1835
+ }
1836
+ #else
1837
+ // scalar
1838
+ quantize_row_q8_1_reference(x, y, k);
1839
+ #endif
1840
+ }
1841
+
1842
+ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1843
+ assert(k % QK4_0 == 0);
1844
+ const int nb = k / QK4_0;
1845
+
1846
+ const block_q4_0 * restrict x = vx;
1847
+
1848
+ #if defined(__AVX2__)
1849
+ for (int i = 0; i < nb; i++) {
1850
+ // scale factor
1851
+ const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
1852
+
1853
+ const uint8_t * restrict pp = x[i].qs;
1854
+
1855
+ for (int l = 0; l < QK4_0; l += 32) {
1856
+ // Load 32x4-bit integers into 32x8-bit integers
1857
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1858
+
1859
+ // Subtract 8 from the integers
1860
+ vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
1861
+
1862
+ // Convert to 16-bit int
1863
+ const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
1864
+ const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
1865
+
1866
+ // Convert to 32-bit int -> float 32
1867
+ const __m256 vf[4] = {
1868
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
1869
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
1870
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
1871
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
1872
+ };
1873
+
1874
+ // Scale and store
1875
+ for (int j = 0; j < 4; j++) {
1876
+ const __m256 result = _mm256_mul_ps(vf[j], d_v);
1877
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1878
+ }
1879
+ }
1880
+ }
1881
+ #elif defined(__ARM_NEON)
1882
+ for (int i = 0; i < nb; i++) {
1883
+ const float32x4_t vd = vdupq_n_f32(x[i].d);
1884
+
1885
+ const uint8_t * restrict pp = x[i].qs;
1886
+
1887
+ for (int l = 0; l < QK4_0; l += 16) {
1888
+ // Load 16x4-bit integers into 8x8-bit integers
1889
+ const uint8x8_t v8 = vld1_u8(pp + l/2);
1890
+
1891
+ // Expand 4-bit qs to 8-bit bytes
1892
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1893
+ const uint8x8_t v1 = vshr_n_u8(v8, 4);
1894
+
1895
+ // Convert to signed 8-bit integers
1896
+ const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
1897
+ const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
1898
+
1899
+ // Subtract 8 from each byte
1900
+ const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
1901
+ const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
1902
+
1903
+ // Interleave and combine
1904
+ const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
1905
+ const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
1906
+
1907
+ const int8x16_t vq = vcombine_s8(vx_0, vx_1);
1908
+
1909
+ // convert to 2x int16x8_t
1910
+ const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
1911
+ const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
1912
+
1913
+ // convert to 4x float32x4_t
1914
+ const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
1915
+ const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
1916
+ const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
1917
+ const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
1918
+
1919
+ // Multiply by d
1920
+ const float32x4_t r0 = vmulq_f32(vf_0, vd);
1921
+ const float32x4_t r1 = vmulq_f32(vf_1, vd);
1922
+ const float32x4_t r2 = vmulq_f32(vf_2, vd);
1923
+ const float32x4_t r3 = vmulq_f32(vf_3, vd);
1924
+
1925
+ // Store
1926
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1927
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1928
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1929
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1930
+ }
1524
1931
  }
1525
1932
  #else
1526
1933
  // scalar
@@ -1532,7 +1939,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1532
1939
  for (int l = 0; l < QK4_0; l += 2) {
1533
1940
  const uint8_t vi = pp[l/2];
1534
1941
 
1535
- const int8_t vi0 = vi & 0xf;
1942
+ const int8_t vi0 = vi & 0x0F;
1536
1943
  const int8_t vi1 = vi >> 4;
1537
1944
 
1538
1945
  const float v0 = (vi0 - 8)*d;
@@ -1598,7 +2005,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1598
2005
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1599
2006
 
1600
2007
  // Expand 4-bit qs to 8-bit bytes
1601
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
2008
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1602
2009
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1603
2010
 
1604
2011
  // Interleave and combine
@@ -1640,7 +2047,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1640
2047
  for (int l = 0; l < QK4_1; l += 2) {
1641
2048
  const uint8_t vi = pp[l/2];
1642
2049
 
1643
- const int8_t vi0 = vi & 0xf;
2050
+ const int8_t vi0 = vi & 0x0F;
1644
2051
  const int8_t vi1 = vi >> 4;
1645
2052
 
1646
2053
  const float v0 = vi0*d + m;
@@ -1670,7 +2077,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1670
2077
  for (int l = 0; l < QK4_2; l += 2) {
1671
2078
  const uint8_t vi = pp[l/2];
1672
2079
 
1673
- const int8_t vi0 = vi & 0xf;
2080
+ const int8_t vi0 = vi & 0x0F;
1674
2081
  const int8_t vi1 = vi >> 4;
1675
2082
 
1676
2083
  const float v0 = (vi0 - 8)*d;
@@ -1685,11 +2092,47 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
1685
2092
  }
1686
2093
  }
1687
2094
 
1688
- static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1689
- assert(k % QK4_3 == 0);
1690
- const int nb = k / QK4_3;
2095
+ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
2096
+ assert(k % QK5_0 == 0);
2097
+ const int nb = k / QK5_0;
2098
+
2099
+ const block_q5_0 * restrict x = vx;
2100
+
2101
+ for (int i = 0; i < nb; i++) {
2102
+ const float d = GGML_FP16_TO_FP32(x[i].d);
2103
+
2104
+ const uint8_t * restrict pp = x[i].qs;
2105
+
2106
+ uint32_t qh;
2107
+ memcpy(&qh, x[i].qh, sizeof(qh));
2108
+
2109
+ for (int l = 0; l < QK5_0; l += 2) {
2110
+ const uint8_t vi = pp[l/2];
2111
+
2112
+ // extract the 5-th bit from qh
2113
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
2114
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
2115
+
2116
+ const int8_t vi0 = (vi & 0x0F) | vh0;
2117
+ const int8_t vi1 = (vi >> 4) | vh1;
2118
+
2119
+ const float v0 = (vi0 - 16)*d;
2120
+ const float v1 = (vi1 - 16)*d;
2121
+
2122
+ y[i*QK5_0 + l + 0] = v0;
2123
+ y[i*QK5_0 + l + 1] = v1;
2124
+
2125
+ assert(!isnan(y[i*QK5_0 + l + 0]));
2126
+ assert(!isnan(y[i*QK5_0 + l + 1]));
2127
+ }
2128
+ }
2129
+ }
2130
+
2131
+ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
2132
+ assert(k % QK5_1 == 0);
2133
+ const int nb = k / QK5_1;
1691
2134
 
1692
- const block_q4_3 * restrict x = vx;
2135
+ const block_q5_1 * restrict x = vx;
1693
2136
 
1694
2137
  for (int i = 0; i < nb; i++) {
1695
2138
  const float d = GGML_FP16_TO_FP32(x[i].d);
@@ -1697,28 +2140,54 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
1697
2140
 
1698
2141
  const uint8_t * restrict pp = x[i].qs;
1699
2142
 
1700
- for (int l = 0; l < QK4_3; l += 2) {
2143
+ uint32_t qh;
2144
+ memcpy(&qh, x[i].qh, sizeof(qh));
2145
+
2146
+ for (int l = 0; l < QK5_1; l += 2) {
1701
2147
  const uint8_t vi = pp[l/2];
1702
2148
 
1703
- const int8_t vi0 = vi & 0xf;
1704
- const int8_t vi1 = vi >> 4;
2149
+ // extract the 5-th bit from qh
2150
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
2151
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
2152
+
2153
+ const uint8_t vi0 = (vi & 0x0F) | vh0;
2154
+ const uint8_t vi1 = (vi >> 4) | vh1;
1705
2155
 
1706
2156
  const float v0 = vi0*d + m;
1707
2157
  const float v1 = vi1*d + m;
1708
2158
 
1709
- y[i*QK4_3 + l + 0] = v0;
1710
- y[i*QK4_3 + l + 1] = v1;
2159
+ y[i*QK5_1 + l + 0] = v0;
2160
+ y[i*QK5_1 + l + 1] = v1;
2161
+
2162
+ assert(!isnan(y[i*QK5_1 + l + 0]));
2163
+ assert(!isnan(y[i*QK5_1 + l + 1]));
2164
+ }
2165
+ }
2166
+ }
2167
+
2168
+ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
2169
+ assert(k % QK8_0 == 0);
2170
+ const int nb = k / QK8_0;
2171
+
2172
+ const block_q8_0 * restrict x = vx;
2173
+
2174
+ for (int i = 0; i < nb; i++) {
2175
+ const float d = x[i].d;
1711
2176
 
1712
- assert(!isnan(y[i*QK4_3 + l + 0]));
1713
- assert(!isnan(y[i*QK4_3 + l + 1]));
2177
+ const int8_t * restrict pp = x[i].qs;
2178
+
2179
+ for (int l = 0; l < QK8_0; ++l) {
2180
+ y[i*QK8_0 + l] = pp[l]*d;
1714
2181
  }
1715
2182
  }
1716
2183
  }
1717
2184
 
1718
2185
  static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2186
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
2187
  static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2188
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2189
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2190
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1722
2191
 
1723
2192
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1724
2193
  [GGML_TYPE_Q4_0] = {
@@ -1727,34 +2196,55 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1727
2196
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1728
2197
  .quantize_row_q_dot = quantize_row_q8_0,
1729
2198
  .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
2199
+ .vec_dot_type = GGML_TYPE_Q8_0,
1730
2200
  },
1731
2201
  [GGML_TYPE_Q4_1] = {
1732
2202
  .dequantize_row_q = dequantize_row_q4_1,
1733
2203
  .quantize_row_q = quantize_row_q4_1,
1734
2204
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1735
- .quantize_row_q_dot = quantize_row_q8_0,
1736
- .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
2205
+ .quantize_row_q_dot = quantize_row_q8_1,
2206
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
2207
+ .vec_dot_type = GGML_TYPE_Q8_1,
1737
2208
  },
1738
2209
  [GGML_TYPE_Q4_2] = {
1739
2210
  .dequantize_row_q = dequantize_row_q4_2,
1740
2211
  .quantize_row_q = quantize_row_q4_2,
1741
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
2212
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
1742
2213
  .quantize_row_q_dot = quantize_row_q8_0,
1743
2214
  .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
2215
+ .vec_dot_type = GGML_TYPE_Q8_0,
1744
2216
  },
1745
- [GGML_TYPE_Q4_3] = {
1746
- .dequantize_row_q = dequantize_row_q4_3,
1747
- .quantize_row_q = quantize_row_q4_3,
1748
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
2217
+ [GGML_TYPE_Q5_0] = {
2218
+ .dequantize_row_q = dequantize_row_q5_0,
2219
+ .quantize_row_q = quantize_row_q5_0,
2220
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
1749
2221
  .quantize_row_q_dot = quantize_row_q8_0,
1750
- .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
2222
+ .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
2223
+ .vec_dot_type = GGML_TYPE_Q8_0,
2224
+ },
2225
+ [GGML_TYPE_Q5_1] = {
2226
+ .dequantize_row_q = dequantize_row_q5_1,
2227
+ .quantize_row_q = quantize_row_q5_1,
2228
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
2229
+ .quantize_row_q_dot = quantize_row_q8_1,
2230
+ .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
2231
+ .vec_dot_type = GGML_TYPE_Q8_1,
1751
2232
  },
1752
2233
  [GGML_TYPE_Q8_0] = {
1753
- .dequantize_row_q = NULL, // TODO
2234
+ .dequantize_row_q = dequantize_row_q8_0,
1754
2235
  .quantize_row_q = quantize_row_q8_0,
1755
2236
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
2237
  .quantize_row_q_dot = quantize_row_q8_0,
2238
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
2239
+ .vec_dot_type = GGML_TYPE_Q8_0,
2240
+ },
2241
+ [GGML_TYPE_Q8_1] = {
2242
+ .dequantize_row_q = NULL, // TODO
2243
+ .quantize_row_q = quantize_row_q8_1,
2244
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
2245
+ .quantize_row_q_dot = quantize_row_q8_1,
1757
2246
  .vec_dot_q = NULL, // TODO
2247
+ .vec_dot_type = GGML_TYPE_Q8_1,
1758
2248
  },
1759
2249
  };
1760
2250
 
@@ -2366,8 +2856,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2366
2856
  const block_q4_0 * restrict x = vx;
2367
2857
  const block_q8_0 * restrict y = vy;
2368
2858
 
2369
- float sumf = 0.0;
2370
-
2371
2859
  #if defined(__ARM_NEON)
2372
2860
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
2861
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
@@ -2378,7 +2866,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2378
2866
  const block_q8_0 * restrict y0 = &y[i + 0];
2379
2867
  const block_q8_0 * restrict y1 = &y[i + 1];
2380
2868
 
2381
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2869
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2382
2870
  const int8x16_t s8b = vdupq_n_s8(0x8);
2383
2871
 
2384
2872
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2396,35 +2884,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2396
2884
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2397
2885
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2398
2886
 
2887
+ // interleave
2888
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2889
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2890
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2891
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2892
+
2399
2893
  // load y
2400
2894
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2401
2895
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2402
2896
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2403
2897
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2404
2898
 
2405
- // interleave
2406
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2407
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2408
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2409
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2410
-
2411
2899
  #if defined(__ARM_FEATURE_DOTPROD)
2412
2900
  // dot product into int32x4_t
2413
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2901
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2902
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2415
2903
 
2416
2904
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2417
2905
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2418
2906
  #else
2419
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2420
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2421
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2422
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2907
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2908
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2909
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2910
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2423
2911
 
2424
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2425
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2426
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2427
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2912
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2913
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2914
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2915
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2428
2916
 
2429
2917
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2430
2918
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2436,7 +2924,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2436
2924
  #endif
2437
2925
  }
2438
2926
 
2439
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2927
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
2928
  #elif defined(__AVX2__)
2441
2929
  // Initialize accumulator with zeros
2442
2930
  __m256 acc = _mm256_setzero_ps();
@@ -2454,32 +2942,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2454
2942
 
2455
2943
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2456
2944
 
2457
- // Get absolute values of x vectors
2458
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2459
-
2460
- // Sign the values of the y vectors
2461
- const __m256i sy = _mm256_sign_epi8(by, bx);
2462
-
2463
- // Perform multiplication and create 16-bit values
2464
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2465
-
2466
- const __m256i ones = _mm256_set1_epi16(1);
2467
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
2468
-
2469
- /* Convert to vectore of 8 int32_t to 8 floats */
2470
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2945
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2471
2946
 
2472
2947
  /* Multiply q with scale and accumulate */
2473
2948
  acc = _mm256_fmadd_ps( d, q, acc );
2474
2949
  }
2475
2950
 
2476
- // Return horizontal sum of the acc vector
2477
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2478
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2479
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2480
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2481
-
2482
- sumf = _mm_cvtss_f32( res );
2951
+ *s = hsum_float_8(acc);
2483
2952
  #elif defined(__AVX__)
2484
2953
  // Initialize accumulator with zeros
2485
2954
  __m256 acc = _mm256_setzero_ps();
@@ -2518,15 +2987,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2518
2987
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2519
2988
  }
2520
2989
 
2521
- // Return horizontal sum of the acc vector
2522
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2523
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2524
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2525
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2526
-
2527
- sumf = _mm_cvtss_f32( res );
2990
+ *s = hsum_float_8(acc);
2528
2991
  #else
2529
2992
  // scalar
2993
+ float sumf = 0.0;
2530
2994
  for (int i = 0; i < nb; i++) {
2531
2995
  const float d0 = x[i].d;
2532
2996
  const float d1 = y[i].d;
@@ -2538,8 +3002,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2538
3002
  for (int j = 0; j < QK8_0/2; j++) {
2539
3003
  const uint8_t v0 = p0[j];
2540
3004
 
2541
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
- const int i1 = (int8_t) (v0 >> 4) - 8;
3005
+ const int i0 = (int8_t) (v0 & 0x0F) - 8;
3006
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2543
3007
 
2544
3008
  const int i2 = p1[2*j + 0];
2545
3009
  const int i3 = p1[2*j + 1];
@@ -2548,34 +3012,35 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2548
3012
  }
2549
3013
  sumf += d0*d1*sumi;
2550
3014
  }
2551
- #endif
2552
-
2553
3015
  *s = sumf;
3016
+ #endif
2554
3017
  }
2555
3018
 
2556
- static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
- const int nb = n / QK8_0;
3019
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3020
+ const int nb = n / QK8_1;
2558
3021
 
2559
- assert(n % QK8_0 == 0);
3022
+ assert(n % QK8_1 == 0);
2560
3023
  assert(nb % 2 == 0);
2561
3024
 
2562
3025
  const block_q4_1 * restrict x = vx;
2563
- const block_q8_0 * restrict y = vy;
2564
-
2565
- float sumf = 0.0;
3026
+ const block_q8_1 * restrict y = vy;
2566
3027
 
2567
3028
  // TODO: add AVX / WASM SIMD / etc
2568
3029
  #if defined(__ARM_NEON)
2569
3030
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
3031
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2571
3032
 
3033
+ float summs = 0;
3034
+
2572
3035
  for (int i = 0; i < nb; i += 2) {
2573
3036
  const block_q4_1 * restrict x0 = &x[i + 0];
2574
3037
  const block_q4_1 * restrict x1 = &x[i + 1];
2575
- const block_q8_0 * restrict y0 = &y[i + 0];
2576
- const block_q8_0 * restrict y1 = &y[i + 1];
3038
+ const block_q8_1 * restrict y0 = &y[i + 0];
3039
+ const block_q8_1 * restrict y1 = &y[i + 1];
3040
+
3041
+ summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2577
3042
 
2578
- const uint8x16_t m4b = vdupq_n_u8(0xf);
3043
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2579
3044
 
2580
3045
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2581
3046
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2586,46 +3051,35 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2586
3051
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
3052
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
3053
 
3054
+ // interleave
3055
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3056
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3057
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
3058
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
3059
+
2589
3060
  // load y
2590
3061
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
3062
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
3063
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
3064
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2594
3065
 
2595
- // interleave
2596
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2600
-
2601
- const int16x8_t s0i = vaddq_s16(
2602
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
- vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
-
2605
- const int16x8_t s1i = vaddq_s16(
2606
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
- vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
-
2609
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
-
2612
3066
  #if defined(__ARM_FEATURE_DOTPROD)
2613
3067
  // dot product into int32x4_t
2614
- const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
- const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
3068
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
3069
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2616
3070
 
2617
3071
  sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
3072
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2619
3073
  #else
2620
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
3074
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3075
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3076
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3077
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2624
3078
 
2625
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
3079
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3080
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3081
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3082
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2629
3083
 
2630
3084
  const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
3085
  const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2637,65 +3091,40 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2637
3091
  #endif
2638
3092
  }
2639
3093
 
2640
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3094
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2641
3095
  #elif defined(__AVX2__)
2642
3096
  // Initialize accumulator with zeros
2643
3097
  __m256 acc = _mm256_setzero_ps();
2644
3098
 
3099
+ float summs = 0;
3100
+
2645
3101
  // Main loop
2646
3102
  for (int i = 0; i < nb; ++i) {
2647
3103
  const float * d0 = &x[i].d;
2648
3104
  const float * d1 = &y[i].d;
2649
- const float * m0 = &x[i].m;
3105
+
3106
+ summs += x[i].m * (y[i].s0 + y[i].s1);
2650
3107
 
2651
3108
  const __m256 d0v = _mm256_broadcast_ss( d0 );
2652
3109
  const __m256 d1v = _mm256_broadcast_ss( d1 );
2653
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2654
3110
 
2655
3111
  // Compute combined scales
2656
3112
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
- const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2658
3113
 
2659
3114
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
3115
  const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
3116
  const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2662
3117
 
2663
- // Get absolute values of x vectors
2664
- const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
-
2666
- // Sign the values of the y vectors
2667
- const __m256i sy = _mm256_sign_epi8( by, bx );
2668
-
2669
- // Perform multiplication and create 16-bit values
2670
- const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
- const __m256i ones = _mm256_set1_epi16( 1 );
2672
- const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
-
2674
- // Convert to vector of 8 int32_t to 8 floats
2675
- const __m256 xy = _mm256_cvtepi32_ps( xy_q );
3118
+ const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2676
3119
 
2677
3120
  // Accumulate d0*d1*x*y
2678
3121
  acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
-
2680
- // Compute sum of y values
2681
- const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
- const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
- const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
- const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
-
2686
- // Accumulate d1*m0*y
2687
- acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2688
3122
  }
2689
3123
 
2690
- // Return horizontal sum of the acc vector
2691
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2692
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695
-
2696
- sumf = _mm_cvtss_f32( res );
3124
+ *s = hsum_float_8(acc) + summs;
2697
3125
  #else
2698
3126
  // scalar
3127
+ float sumf = 0.0;
2699
3128
  for (int i = 0; i < nb; i++) {
2700
3129
  const float d0 = x[i].d;
2701
3130
  const float m0 = x[i].m;
@@ -2705,347 +3134,685 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2705
3134
  const int8_t * restrict p1 = y[i].qs;
2706
3135
 
2707
3136
  // TODO: this is very slow ..
2708
- for (int j = 0; j < QK8_0/2; j++) {
3137
+ for (int j = 0; j < QK8_1/2; j++) {
2709
3138
  const uint8_t v0 = p0[j];
2710
3139
 
2711
- const float f0 = d0*(v0 & 0xf) + m0;
2712
- const float f1 = d0*(v0 >> 4) + m0;
3140
+ const float f0 = d0*(v0 & 0x0F) + m0;
3141
+ const float f1 = d0*(v0 >> 4) + m0;
3142
+
3143
+ const float f2 = d1*p1[2*j + 0];
3144
+ const float f3 = d1*p1[2*j + 1];
3145
+
3146
+ sumf += f0*f2 + f1*f3;
3147
+ }
3148
+ }
3149
+ *s = sumf;
3150
+ #endif
3151
+ }
3152
+
3153
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3154
+ const int nb = n / QK8_0;
3155
+
3156
+ assert(n % QK8_0 == 0);
3157
+ assert(nb % 2 == 0);
3158
+ assert(QK8_0 == 2*QK4_2);
3159
+
3160
+ const block_q4_2 * restrict x = vx;
3161
+ const block_q8_0 * restrict y = vy;
3162
+
3163
+ #if defined(__ARM_NEON)
3164
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3165
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3166
+
3167
+ for (int i = 0; i < nb; i += 2) {
3168
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
3169
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
3170
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
3171
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
3172
+
3173
+ const block_q8_0 * restrict y0 = &y[i + 0];
3174
+ const block_q8_0 * restrict y1 = &y[i + 1];
3175
+
3176
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3177
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3178
+
3179
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3180
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
3181
+
3182
+ // 4-bit -> 8-bit
3183
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3184
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3185
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3186
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3187
+
3188
+ // sub 8
3189
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
3190
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
3191
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
3192
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3193
+
3194
+ // interleave
3195
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
3196
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
3197
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
3198
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
3199
+
3200
+ // load y
3201
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3202
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3203
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
3204
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3205
+
3206
+ #if defined(__ARM_FEATURE_DOTPROD)
3207
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3208
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
3209
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3210
+
3211
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3212
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3213
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3214
+ #else
3215
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3216
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3217
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3218
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3219
+
3220
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3221
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3222
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3223
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3224
+
3225
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3226
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3227
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3228
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3229
+
3230
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3231
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3232
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3233
+
3234
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3235
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3236
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3237
+ #endif
3238
+ }
3239
+
3240
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3241
+ #elif defined(__AVX2__)
3242
+ // Initialize accumulator with zeros
3243
+ __m256 acc = _mm256_setzero_ps();
3244
+
3245
+ // Main loop
3246
+ for (int i = 0; i < nb; i++) {
3247
+ /* Compute combined scale for the block */
3248
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3249
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3250
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3251
+
3252
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3253
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3254
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
3255
+
3256
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3257
+ const __m256i off = _mm256_set1_epi8(8);
3258
+ bx = _mm256_sub_epi8(bx, off);
3259
+
3260
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3261
+
3262
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3263
+
3264
+ /* Multiply q with scale and accumulate */
3265
+ acc = _mm256_fmadd_ps(d, q, acc);
3266
+ }
3267
+
3268
+ *s = hsum_float_8(acc);
3269
+ #else
3270
+ // scalar
3271
+ float sumf = 0.0;
3272
+ for (int i = 0; i < nb; i++) {
3273
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3274
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3275
+ const int8_t * restrict y0 = y[i].qs;
3276
+
3277
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3278
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3279
+
3280
+ int sumi_0 = 0;
3281
+ int sumi_1 = 0;
3282
+
3283
+ for (int j = 0; j < QK8_0/4; j++) {
3284
+ const uint8_t v0 = x0[j];
3285
+ const uint8_t v1 = x1[j];
3286
+
3287
+ const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3288
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
3289
+
3290
+ const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3291
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
3292
+
3293
+ const int i2_0 = y0[2*j + 0];
3294
+ const int i3_0 = y0[2*j + 1];
3295
+
3296
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
3297
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3298
+
3299
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
3300
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3301
+ }
3302
+
3303
+ sumf += (d0 * y[i].d) * sumi_0;
3304
+ sumf += (d1 * y[i].d) * sumi_1;
3305
+ }
3306
+ *s = sumf;
3307
+ #endif
3308
+ }
3309
+
3310
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3311
+ const int nb = n / QK8_0;
3312
+
3313
+ assert(n % QK8_0 == 0);
3314
+ assert(nb % 2 == 0);
3315
+ assert(QK8_0 == QK5_0);
3316
+
3317
+ const block_q5_0 * restrict x = vx;
3318
+ const block_q8_0 * restrict y = vy;
3319
+
3320
+ #if defined(__ARM_NEON)
3321
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3322
+
3323
+ uint64_t tmp[4];
3324
+
3325
+ for (int i = 0; i < nb; ++i) {
3326
+ const block_q5_0 * restrict x0 = &x[i];
3327
+ const block_q8_0 * restrict y0 = &y[i];
3328
+
3329
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3330
+ const int8x16_t s16b = vdupq_n_s8(0x10);
3331
+
3332
+ // extract the 5th bit
3333
+ uint32_t qh;
3334
+ memcpy(&qh, x0->qh, sizeof(qh));
3335
+
3336
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3337
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3338
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3339
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3340
+
3341
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3342
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3343
+
3344
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3345
+
3346
+ // 4-bit -> 8-bit
3347
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
3348
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3349
+
3350
+ // interleave
3351
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3352
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3353
+
3354
+ // add high bit and sub 16
3355
+ const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
3356
+ const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
3357
+
3358
+ // load y
3359
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3360
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3361
+
3362
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
3363
+
3364
+ #if defined(__ARM_FEATURE_DOTPROD)
3365
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3366
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3367
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3368
+ #else
3369
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3370
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3371
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3372
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
3373
+
3374
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3375
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3376
+
3377
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3378
+ #endif
3379
+ }
3380
+
3381
+ *s = vaddvq_f32(sumv);
3382
+ #elif defined(__wasm_simd128__)
3383
+ v128_t sumv = wasm_f32x4_splat(0.0f);
3384
+
3385
+ uint64_t tmp[4];
3386
+
3387
+ for (int i = 0; i < nb; ++i) {
3388
+ const block_q5_0 * restrict x0 = &x[i];
3389
+ const block_q8_0 * restrict y0 = &y[i];
3390
+
3391
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
3392
+ const v128_t s16b = wasm_i8x16_splat(0x10);
3393
+
3394
+ // extract the 5th bit
3395
+ uint32_t qh;
3396
+ memcpy(&qh, x0->qh, sizeof(qh));
3397
+
3398
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3399
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3400
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3401
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3402
+
3403
+ const v128_t qhl = wasm_v128_load(tmp + 0);
3404
+ const v128_t qhh = wasm_v128_load(tmp + 2);
3405
+
3406
+ const v128_t v0 = wasm_v128_load(x0->qs);
3407
+
3408
+ // 4-bit -> 8-bit
3409
+ const v128_t v0l = wasm_v128_and (v0, m4b);
3410
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
3411
+
3412
+ // interleave
3413
+ const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
3414
+ const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
3415
+
3416
+ // add high bit and sub 16
3417
+ const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0lz, qhl), s16b);
3418
+ const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0hz, qhh), s16b);
3419
+
3420
+ // load y
3421
+ const v128_t v1l = wasm_v128_load(y0->qs);
3422
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
3423
+
3424
+ // int8x16 -> int16x8
3425
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3426
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3427
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3428
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
3429
+
3430
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3431
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3432
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3433
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3434
+
3435
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
3436
+
3437
+ // dot product
3438
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3439
+ wasm_i32x4_add(
3440
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3441
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3442
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3443
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
3444
+ }
3445
+
3446
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3447
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
3448
+ #elif defined(__AVX2__)
3449
+ // Initialize accumulator with zeros
3450
+ __m256 acc = _mm256_setzero_ps();
3451
+
3452
+ // Main loop
3453
+ for (int i = 0; i < nb; i++) {
3454
+ /* Compute combined scale for the block */
3455
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
3456
+
3457
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3458
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3459
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3460
+ bx = _mm256_or_si256(bx, bxhi);
3461
+
3462
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3463
+
3464
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3465
+
3466
+ /* Multiply q with scale and accumulate */
3467
+ acc = _mm256_fmadd_ps(d, q, acc);
3468
+ }
3469
+
3470
+ *s = hsum_float_8(acc);
3471
+ #else
3472
+ // scalar
3473
+ float sumf = 0.0;
3474
+ for (int i = 0; i < nb; i++) {
3475
+ const uint8_t * restrict x0 = x[i].qs;
3476
+ const int8_t * restrict y0 = y[i].qs;
3477
+
3478
+ uint32_t qh;
3479
+ memcpy(&qh, x[i].qh, sizeof(qh));
3480
+
3481
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3482
+
3483
+ int sxy = 0;
3484
+
3485
+ for (int j = 0; j < QK8_0/2; j++) {
3486
+ const uint8_t v0 = x0[j];
3487
+
3488
+ const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
3489
+ const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
3490
+
3491
+ const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
3492
+ const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
3493
+
3494
+ const int y0_0 = y0[2*j + 0];
3495
+ const int y1_0 = y0[2*j + 1];
3496
+
3497
+ sxy += x0_0*y0_0 + x1_0*y1_0;
3498
+ }
3499
+
3500
+ sumf += (d*sxy)*y[i].d;
3501
+ }
3502
+ *s = sumf;
3503
+ #endif
3504
+ }
3505
+
3506
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3507
+ const int nb = n / QK8_1;
3508
+
3509
+ assert(n % QK8_1 == 0);
3510
+ assert(nb % 2 == 0);
3511
+ assert(QK8_1 == QK5_1);
3512
+
3513
+ const block_q5_1 * restrict x = vx;
3514
+ const block_q8_1 * restrict y = vy;
3515
+
3516
+ #if defined(__ARM_NEON)
3517
+ float32x4_t sumv = vdupq_n_f32(0.0f);
3518
+
3519
+ float summs = 0.0f;
3520
+
3521
+ uint64_t tmp[4];
3522
+
3523
+ for (int i = 0; i < nb; ++i) {
3524
+ const block_q5_1 * restrict x0 = &x[i];
3525
+ const block_q8_1 * restrict y0 = &y[i];
3526
+
3527
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
3528
+
3529
+ // extract the 5th bit
3530
+ uint32_t qh;
3531
+ memcpy(&qh, x0->qh, sizeof(qh));
3532
+
3533
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3534
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3535
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3536
+ tmp[3] = table_b2b_u[(qh >> 24) ];
3537
+
3538
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3539
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
3540
+
3541
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
3542
+
3543
+ // 4-bit -> 8-bit
3544
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3545
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
3546
+
3547
+ // interleave
3548
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3549
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
3550
+
3551
+ // add
3552
+ const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3553
+ const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
3554
+
3555
+ // load y
3556
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3557
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
3558
+
3559
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2713
3560
 
2714
- const float f2 = d1*p1[2*j + 0];
2715
- const float f3 = d1*p1[2*j + 1];
3561
+ #if defined(__ARM_FEATURE_DOTPROD)
3562
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3563
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3564
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3565
+ #else
3566
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3567
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3568
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3569
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2716
3570
 
2717
- sumf += f0*f2 + f1*f3;
2718
- }
2719
- }
3571
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3572
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3573
+
3574
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2720
3575
  #endif
3576
+ }
2721
3577
 
2722
- *s = sumf;
2723
- }
3578
+ *s = vaddvq_f32(sumv) + summs;
3579
+ #elif defined(__wasm_simd128__)
3580
+ v128_t sumv = wasm_f32x4_splat(0.0f);
2724
3581
 
2725
- static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2726
- const int nb = n / QK8_0;
3582
+ float summs = 0.0f;
2727
3583
 
2728
- assert(n % QK8_0 == 0);
2729
- assert(nb % 2 == 0);
2730
- assert(QK8_0 == 2*QK4_2);
3584
+ uint64_t tmp[4];
2731
3585
 
2732
- const block_q4_2 * restrict x = vx;
2733
- const block_q8_0 * restrict y = vy;
3586
+ for (int i = 0; i < nb; ++i) {
3587
+ const block_q5_1 * restrict x0 = &x[i];
3588
+ const block_q8_1 * restrict y0 = &y[i];
2734
3589
 
2735
- float sumf = 0.0;
3590
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2736
3591
 
2737
- #if defined(__ARM_NEON)
2738
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
3592
+ const v128_t m4b = wasm_i8x16_splat(0x0F);
2740
3593
 
2741
- for (int i = 0; i < nb; i += 2) {
2742
- const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
2743
- const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
2744
- const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
2745
- const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
3594
+ // extract the 5th bit
3595
+ uint32_t qh;
3596
+ memcpy(&qh, x0->qh, sizeof(qh));
2746
3597
 
2747
- const block_q8_0 * restrict y0 = &y[i + 0];
2748
- const block_q8_0 * restrict y1 = &y[i + 1];
3598
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3599
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3600
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3601
+ tmp[3] = table_b2b_u[(qh >> 24) ];
2749
3602
 
2750
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2751
- const int8x16_t s8b = vdupq_n_s8(0x8);
3603
+ const v128_t qhl = wasm_v128_load(tmp + 0);
3604
+ const v128_t qhh = wasm_v128_load(tmp + 2);
2752
3605
 
2753
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2754
- const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
3606
+ const v128_t v0 = wasm_v128_load(x0->qs);
2755
3607
 
2756
3608
  // 4-bit -> 8-bit
2757
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2758
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2759
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2760
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3609
+ const v128_t v0l = wasm_v128_and (v0, m4b);
3610
+ const v128_t v0h = wasm_u8x16_shr(v0, 4);
2761
3611
 
2762
- // sub 8
2763
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2764
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2765
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2766
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3612
+ static bool x = true;
2767
3613
 
2768
3614
  // interleave
2769
- const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2770
- const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2771
- const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2772
- const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2773
-
2774
- // load y
2775
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2776
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2777
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2778
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3615
+ const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
3616
+ const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
2779
3617
 
2780
- #if defined(__ARM_FEATURE_DOTPROD)
2781
- sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2782
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3618
+ // add high bit
3619
+ const v128_t v0lf = wasm_v128_or(v0lz, qhl);
3620
+ const v128_t v0hf = wasm_v128_or(v0hz, qhh);
2784
3621
 
2785
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
- vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2788
- #else
2789
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3622
+ // load y
3623
+ const v128_t v1l = wasm_v128_load(y0->qs);
3624
+ const v128_t v1h = wasm_v128_load(y0->qs + 16);
2793
3625
 
2794
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3626
+ // int8x16 -> int16x8
3627
+ const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3628
+ const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3629
+ const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3630
+ const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
2798
3631
 
2799
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3632
+ const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3633
+ const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3634
+ const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3635
+ const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
2803
3636
 
2804
- sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
- vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
- vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3637
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2807
3638
 
2808
- sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
- vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
- vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2811
- #endif
3639
+ // dot product
3640
+ sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3641
+ wasm_i32x4_add(
3642
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3643
+ wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3644
+ wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3645
+ wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
2812
3646
  }
2813
3647
 
2814
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3648
+ *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3649
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
2815
3650
  #elif defined(__AVX2__)
2816
3651
  // Initialize accumulator with zeros
2817
3652
  __m256 acc = _mm256_setzero_ps();
3653
+ float summs = 0.0f;
2818
3654
 
2819
3655
  // Main loop
2820
3656
  for (int i = 0; i < nb; i++) {
2821
- /* Compute combined scale for the block */
2822
- const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
- const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
- const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2825
-
2826
- __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
- __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
- __m256i bx = _mm256_set_m128i(bx1, bx0);
2829
-
2830
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2831
- const __m256i off = _mm256_set1_epi8(8);
2832
- bx = _mm256_sub_epi8(bx, off);
3657
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2833
3658
 
2834
- __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3659
+ summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2835
3660
 
2836
- // Get absolute values of x vectors
2837
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2838
- // Sign the values of the y vectors
2839
- const __m256i sy = _mm256_sign_epi8(by, bx);
2840
- // Perform multiplication and create 16-bit values
2841
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
3661
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3662
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3663
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3664
+ bx = _mm256_or_si256(bx, bxhi);
2842
3665
 
2843
- const __m256i ones = _mm256_set1_epi16(1);
2844
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
3666
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3667
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2845
3668
 
2846
- /* Convert to vectore of 8 int32_t to 8 floats */
2847
- __m256 q = _mm256_cvtepi32_ps(xy_q);
3669
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2848
3670
 
2849
- /* Multiply q with scale and accumulate */
2850
- acc = _mm256_fmadd_ps(d, q, acc);
3671
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2851
3672
  }
2852
3673
 
2853
- // Return horizontal sum of the acc vector
2854
- __m128 res = _mm256_extractf128_ps(acc, 1);
2855
- res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
- res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
- res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858
-
2859
- sumf = _mm_cvtss_f32(res);
3674
+ *s = hsum_float_8(acc) + summs;
2860
3675
  #else
2861
- // scalar
3676
+ float sumf = 0.0;
3677
+
2862
3678
  for (int i = 0; i < nb; i++) {
2863
- const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3679
+ const uint8_t * restrict x0 = x[i].qs;
2865
3680
  const int8_t * restrict y0 = y[i].qs;
2866
3681
 
2867
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3682
+ uint32_t qh;
3683
+ memcpy(&qh, x[i].qh, sizeof(qh));
2869
3684
 
2870
- int sumi_0 = 0;
2871
- int sumi_1 = 0;
3685
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3686
+ const float m = GGML_FP16_TO_FP32(x[i].m);
2872
3687
 
2873
- for (int j = 0; j < QK8_0/4; j++) {
2874
- const uint8_t v0 = x0[j];
2875
- const uint8_t v1 = x1[j];
3688
+ int sxy = 0;
2876
3689
 
2877
- const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
- const int i1_0 = (int8_t) (v0 >> 4) - 8;
3690
+ for (int j = 0; j < QK8_1/2; j++) {
3691
+ const uint8_t v0 = x0[j];
2879
3692
 
2880
- const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
- const int i1_1 = (int8_t) (v1 >> 4) - 8;
3693
+ const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
3694
+ const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
2882
3695
 
2883
- const int i2_0 = y0[2*j + 0];
2884
- const int i3_0 = y0[2*j + 1];
3696
+ const int x0_0 = (v0 & 0x0F) | x0_0h;
3697
+ const int x1_0 = (v0 >> 4) | x1_0h;
2885
3698
 
2886
- const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
- const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3699
+ const int y0_0 = y0[2*j + 0];
3700
+ const int y1_0 = y0[2*j + 1];
2888
3701
 
2889
- sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
- sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3702
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2891
3703
  }
2892
3704
 
2893
- sumf += (d0 * y[i].d) * sumi_0;
2894
- sumf += (d1 * y[i].d) * sumi_1;
3705
+ sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2895
3706
  }
2896
- #endif
2897
3707
 
2898
3708
  *s = sumf;
3709
+ #endif
2899
3710
  }
2900
3711
 
2901
- static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3712
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
3713
  const int nb = n / QK8_0;
2903
3714
 
2904
3715
  assert(n % QK8_0 == 0);
2905
3716
  assert(nb % 2 == 0);
2906
- assert(QK8_0 == 2*QK4_2);
3717
+ assert(QK8_0 == QK8_0);
2907
3718
 
2908
- const block_q4_3 * restrict x = vx;
3719
+ const block_q8_0 * restrict x = vx;
2909
3720
  const block_q8_0 * restrict y = vy;
2910
3721
 
2911
- float sumf = 0.0;
2912
-
2913
3722
  #if defined(__ARM_NEON)
2914
3723
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
3724
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
2916
3725
 
2917
3726
  for (int i = 0; i < nb; i += 2) {
2918
- const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
- const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
- const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
- const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2922
-
3727
+ const block_q8_0 * restrict x0 = &x[i + 0];
3728
+ const block_q8_0 * restrict x1 = &x[i + 1];
2923
3729
  const block_q8_0 * restrict y0 = &y[i + 0];
2924
3730
  const block_q8_0 * restrict y1 = &y[i + 1];
2925
3731
 
2926
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
-
2928
- const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
- const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
- const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
- const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
-
2933
- const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
- const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
- const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
- const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
-
2938
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
- const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
-
2941
- // 4-bit -> 8-bit
2942
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2946
-
2947
- // interleave
2948
- const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
- const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
- const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
- const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
3732
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3733
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3734
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3735
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
2952
3736
 
2953
3737
  // load y
2954
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2958
-
2959
- const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
- const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2961
-
2962
- const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
- const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2964
-
2965
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
3738
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3739
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3740
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3741
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
2969
3742
 
2970
3743
  #if defined(__ARM_FEATURE_DOTPROD)
2971
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
- #else
2976
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
-
2981
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3744
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3745
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3746
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
2985
3747
 
2986
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
- const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
- const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3748
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3749
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3750
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
2990
3751
 
2991
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
3752
+ #else
3753
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3754
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3755
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3756
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3757
+
3758
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3759
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3760
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3761
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3762
+
3763
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3764
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3765
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3766
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3767
+
3768
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3769
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
2995
3770
  #endif
2996
3771
  }
2997
3772
 
2998
- sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2999
- #else
3000
- // scalar
3001
- for (int i = 0; i < nb; i++) {
3002
- const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
- const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
- const int8_t * restrict y0 = y[i].qs;
3005
-
3006
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
- const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
- const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3010
-
3011
- int sy_0 = 0;
3012
- int sy_1 = 0;
3773
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3774
+ #elif defined(__AVX2__)
3775
+ // Initialize accumulator with zeros
3776
+ __m256 acc = _mm256_setzero_ps();
3013
3777
 
3014
- int sxy_0 = 0;
3015
- int sxy_1 = 0;
3778
+ // Main loop
3779
+ for (int i = 0; i < nb; ++i) {
3780
+ // Compute combined scale for the block
3781
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
3782
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
3783
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3016
3784
 
3017
- for (int j = 0; j < QK8_0/4; j++) {
3018
- const uint8_t v0 = x0[j];
3019
- const uint8_t v1 = x1[j];
3785
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3020
3786
 
3021
- const int x0_0 = v0 & 0xf;
3022
- const int x1_0 = v0 >> 4;
3787
+ // Multiply q with scale and accumulate
3788
+ acc = _mm256_fmadd_ps( d, q, acc );
3789
+ }
3023
3790
 
3024
- const int x0_1 = v1 & 0xf;
3025
- const int x1_1 = v1 >> 4;
3791
+ *s = hsum_float_8(acc);
3792
+ #else
3793
+ // scalar
3794
+ float sumf = 0.0;
3026
3795
 
3027
- const int y0_0 = y0[2*j + 0];
3028
- const int y1_0 = y0[2*j + 1];
3796
+ for (int i = 0; i < nb; i++) {
3797
+ const int8_t * restrict x0 = x[i].qs;
3798
+ const int8_t * restrict y0 = y[i].qs;
3029
3799
 
3030
- const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
- const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3800
+ int sumi = 0;
3032
3801
 
3033
- sy_0 += y0_0 + y1_0;
3034
- sy_1 += y0_1 + y1_1;
3802
+ for (int j = 0; j < QK8_0; j++) {
3803
+ const int v0 = x0[j];
3804
+ const int v1 = y0[j];
3035
3805
 
3036
- sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
- sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3806
+ sumi += v0*v1;
3038
3807
  }
3039
3808
 
3040
- sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
- sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
3809
+ sumf += (x[i].d*y[i].d)*sumi;
3042
3810
  }
3043
- #endif
3044
3811
 
3045
3812
  *s = sumf;
3813
+ #endif
3046
3814
  }
3047
3815
 
3048
-
3049
3816
  // compute GGML_VEC_DOT_UNROLL dot products at once
3050
3817
  // xs - x row stride in bytes
3051
3818
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3242,6 +4009,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
3242
4009
  #endif
3243
4010
  }
3244
4011
 
4012
+ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
4013
+ ggml_float sum = 0.0;
4014
+ for (int i = 0; i < n; ++i) {
4015
+ sum += (ggml_float)x[i];
4016
+ }
4017
+ *s = sum;
4018
+ }
4019
+
3245
4020
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
3246
4021
  #ifndef GGML_USE_ACCELERATE
3247
4022
  float max = -INFINITY;
@@ -3293,13 +4068,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3293
4068
  [GGML_TYPE_Q4_0] = QK4_0,
3294
4069
  [GGML_TYPE_Q4_1] = QK4_1,
3295
4070
  [GGML_TYPE_Q4_2] = QK4_2,
3296
- [GGML_TYPE_Q4_3] = QK4_3,
4071
+ [GGML_TYPE_Q5_0] = QK5_0,
4072
+ [GGML_TYPE_Q5_1] = QK5_1,
3297
4073
  [GGML_TYPE_Q8_0] = QK8_0,
4074
+ [GGML_TYPE_Q8_1] = QK8_1,
3298
4075
  [GGML_TYPE_I8] = 1,
3299
4076
  [GGML_TYPE_I16] = 1,
3300
4077
  [GGML_TYPE_I32] = 1,
3301
4078
  };
3302
- static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
4079
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3303
4080
 
3304
4081
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3305
4082
  [GGML_TYPE_F32] = sizeof(float),
@@ -3307,13 +4084,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3307
4084
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3308
4085
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
4086
  [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
- [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
4087
+ [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
4088
+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3311
4089
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
4090
+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3312
4091
  [GGML_TYPE_I8] = sizeof(int8_t),
3313
4092
  [GGML_TYPE_I16] = sizeof(int16_t),
3314
4093
  [GGML_TYPE_I32] = sizeof(int32_t),
3315
4094
  };
3316
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
4095
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3317
4096
 
3318
4097
 
3319
4098
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3322,13 +4101,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3322
4101
  [GGML_TYPE_Q4_0] = "q4_0",
3323
4102
  [GGML_TYPE_Q4_1] = "q4_1",
3324
4103
  [GGML_TYPE_Q4_2] = "q4_2",
3325
- [GGML_TYPE_Q4_3] = "q4_3",
4104
+ [GGML_TYPE_Q5_0] = "q5_0",
4105
+ [GGML_TYPE_Q5_1] = "q5_1",
3326
4106
  [GGML_TYPE_Q8_0] = "q8_0",
4107
+ [GGML_TYPE_Q8_1] = "q8_1",
3327
4108
  [GGML_TYPE_I8] = "i8",
3328
4109
  [GGML_TYPE_I16] = "i16",
3329
4110
  [GGML_TYPE_I32] = "i32",
3330
4111
  };
3331
- static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
4112
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3332
4113
 
3333
4114
  static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
4115
  [GGML_TYPE_F32] = false,
@@ -3336,13 +4117,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3336
4117
  [GGML_TYPE_Q4_0] = true,
3337
4118
  [GGML_TYPE_Q4_1] = true,
3338
4119
  [GGML_TYPE_Q4_2] = true,
3339
- [GGML_TYPE_Q4_3] = true,
4120
+ [GGML_TYPE_Q5_0] = true,
4121
+ [GGML_TYPE_Q5_1] = true,
3340
4122
  [GGML_TYPE_Q8_0] = true,
4123
+ [GGML_TYPE_Q8_1] = true,
3341
4124
  [GGML_TYPE_I8] = false,
3342
4125
  [GGML_TYPE_I16] = false,
3343
4126
  [GGML_TYPE_I32] = false,
3344
4127
  };
3345
- static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
4128
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3346
4129
 
3347
4130
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3348
4131
  "NONE",
@@ -3380,6 +4163,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3380
4163
  "DIAG_MASK_INF",
3381
4164
  "SOFT_MAX",
3382
4165
  "ROPE",
4166
+ "ALIBI",
3383
4167
  "CONV_1D_1S",
3384
4168
  "CONV_1D_2S",
3385
4169
 
@@ -3390,7 +4174,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3390
4174
  "MAP_BINARY",
3391
4175
  };
3392
4176
 
3393
- static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
4177
+ static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
3394
4178
 
3395
4179
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3396
4180
  "none",
@@ -3428,6 +4212,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3428
4212
  "diag_mask_inf(x)",
3429
4213
  "soft_max(x)",
3430
4214
  "rope(x)",
4215
+ "alibi(x)",
3431
4216
  "conv_1d_1s(x)",
3432
4217
  "conv_1d_2s(x)",
3433
4218
 
@@ -3438,7 +4223,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3438
4223
  "f(x,y)",
3439
4224
  };
3440
4225
 
3441
- static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
4226
+ static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
3442
4227
 
3443
4228
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3444
4229
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3608,6 +4393,27 @@ bool ggml_is_quantized(enum ggml_type type) {
3608
4393
  return GGML_IS_QUANTIZED[type];
3609
4394
  }
3610
4395
 
4396
+ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
4397
+ enum ggml_type wtype = GGML_TYPE_COUNT;
4398
+
4399
+ switch (ftype) {
4400
+ case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
4401
+ case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
4402
+ case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
4403
+ case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
4404
+ case GGML_FTYPE_MOSTLY_Q4_2: wtype = GGML_TYPE_Q4_2; break;
4405
+ case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
4406
+ case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
4407
+ case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
4408
+ case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
4409
+ case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
4410
+ }
4411
+
4412
+ GGML_ASSERT(wtype != GGML_TYPE_COUNT);
4413
+
4414
+ return wtype;
4415
+ }
4416
+
3611
4417
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3612
4418
  return tensor->nb[0] > tensor->nb[1];
3613
4419
  }
@@ -3718,10 +4524,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3718
4524
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3719
4525
  }
3720
4526
 
3721
- // initialize cuBLAS
3722
- #if defined(GGML_USE_CUBLAS)
3723
- init_cublas();
3724
- #endif
4527
+ #if defined(GGML_USE_CUBLAS)
4528
+ ggml_init_cublas();
4529
+ #elif defined(GGML_USE_CLBLAST)
4530
+ ggml_cl_init();
4531
+ #endif
3725
4532
 
3726
4533
  is_first_call = false;
3727
4534
  }
@@ -3802,7 +4609,7 @@ void ggml_free(struct ggml_context * ctx) {
3802
4609
  }
3803
4610
 
3804
4611
  size_t ggml_used_mem(const struct ggml_context * ctx) {
3805
- return ctx->objects_end->offs + ctx->objects_end->size;
4612
+ return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
3806
4613
  }
3807
4614
 
3808
4615
  size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) {
@@ -3915,6 +4722,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
3915
4722
  /*.perf_cycles =*/ 0,
3916
4723
  /*.perf_time_us =*/ 0,
3917
4724
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4725
+ /*.name =*/ { 0 },
3918
4726
  /*.pad =*/ { 0 },
3919
4727
  };
3920
4728
 
@@ -4269,6 +5077,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
4269
5077
  return (float *)(tensor->data);
4270
5078
  }
4271
5079
 
5080
+ const char * ggml_get_name(const struct ggml_tensor * tensor) {
5081
+ return tensor->name;
5082
+ }
5083
+
5084
+ void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
5085
+ strncpy(tensor->name, name, sizeof(tensor->name));
5086
+ tensor->name[sizeof(tensor->name) - 1] = '\0';
5087
+ }
5088
+
4272
5089
  struct ggml_tensor * ggml_view_tensor(
4273
5090
  struct ggml_context * ctx,
4274
5091
  const struct ggml_tensor * src) {
@@ -5368,6 +6185,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
5368
6185
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5369
6186
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5370
6187
  struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
6188
+ ggml_set_name(b, "n_past");
5371
6189
 
5372
6190
  result->op = GGML_OP_DIAG_MASK_INF;
5373
6191
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5393,22 +6211,55 @@ struct ggml_tensor * ggml_soft_max(
5393
6211
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5394
6212
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5395
6213
 
5396
- result->op = GGML_OP_SOFT_MAX;
6214
+ result->op = GGML_OP_SOFT_MAX;
6215
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6216
+ result->src0 = a;
6217
+ result->src1 = NULL;
6218
+
6219
+ return result;
6220
+ }
6221
+
6222
+ // ggml_rope
6223
+
6224
+ struct ggml_tensor * ggml_rope(
6225
+ struct ggml_context * ctx,
6226
+ struct ggml_tensor * a,
6227
+ int n_past,
6228
+ int n_dims,
6229
+ int mode) {
6230
+ GGML_ASSERT(n_past >= 0);
6231
+ bool is_node = false;
6232
+
6233
+ if (a->grad) {
6234
+ GGML_ASSERT(false); // TODO: implement backward
6235
+ is_node = true;
6236
+ }
6237
+
6238
+ // TODO: when implement backward, fix this:
6239
+ //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
6240
+ struct ggml_tensor * result = ggml_view_tensor(ctx, a);
6241
+
6242
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6243
+ ((int32_t *) b->data)[0] = n_past;
6244
+ ((int32_t *) b->data)[1] = n_dims;
6245
+ ((int32_t *) b->data)[2] = mode;
6246
+ ggml_set_name(b, "n_past, n_dims, mode");
6247
+
6248
+ result->op = GGML_OP_ROPE;
5397
6249
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5398
6250
  result->src0 = a;
5399
- result->src1 = NULL;
6251
+ result->src1 = b;
5400
6252
 
5401
6253
  return result;
5402
6254
  }
5403
6255
 
5404
- // ggml_rope
6256
+ // ggml_alibi
5405
6257
 
5406
- struct ggml_tensor * ggml_rope(
6258
+ struct ggml_tensor * ggml_alibi(
5407
6259
  struct ggml_context * ctx,
5408
6260
  struct ggml_tensor * a,
5409
6261
  int n_past,
5410
- int n_dims,
5411
- int mode) {
6262
+ int n_head) {
5412
6263
  GGML_ASSERT(n_past >= 0);
5413
6264
  bool is_node = false;
5414
6265
 
@@ -5421,12 +6272,11 @@ struct ggml_tensor * ggml_rope(
5421
6272
  //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5422
6273
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5423
6274
 
5424
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
6275
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
5425
6276
  ((int32_t *) b->data)[0] = n_past;
5426
- ((int32_t *) b->data)[1] = n_dims;
5427
- ((int32_t *) b->data)[2] = mode;
6277
+ ((int32_t *) b->data)[1] = n_head;
5428
6278
 
5429
- result->op = GGML_OP_ROPE;
6279
+ result->op = GGML_OP_ALIBI;
5430
6280
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5431
6281
  result->src0 = a;
5432
6282
  result->src1 = b;
@@ -6553,7 +7403,9 @@ static void ggml_compute_forward_add(
6553
7403
  case GGML_TYPE_Q4_0:
6554
7404
  case GGML_TYPE_Q4_1:
6555
7405
  case GGML_TYPE_Q4_2:
6556
- case GGML_TYPE_Q4_3:
7406
+ case GGML_TYPE_Q5_0:
7407
+ case GGML_TYPE_Q5_1:
7408
+ case GGML_TYPE_Q8_0:
6557
7409
  {
6558
7410
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6559
7411
  } break;
@@ -6811,15 +7663,20 @@ static void ggml_compute_forward_sum_f32(
6811
7663
  const size_t nb02 = src0->nb[2];
6812
7664
  const size_t nb03 = src0->nb[3];
6813
7665
 
7666
+ ggml_float sum = 0;
7667
+ ggml_float row_sum = 0;
7668
+
6814
7669
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6815
7670
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6816
7671
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6817
- ggml_vec_sum_f32(ne00,
6818
- (float *) (dst->data),
7672
+ ggml_vec_sum_ggf(ne00,
7673
+ &row_sum,
6819
7674
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
7675
+ sum += row_sum;
6820
7676
  }
6821
7677
  }
6822
7678
  }
7679
+ ((float *) dst->data)[0] = sum;
6823
7680
  }
6824
7681
 
6825
7682
  static void ggml_compute_forward_sum(
@@ -7454,7 +8311,7 @@ static void ggml_compute_forward_rms_norm(
7454
8311
 
7455
8312
  // ggml_compute_forward_mul_mat
7456
8313
 
7457
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8314
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7458
8315
  // helper function to determine if it is better to use BLAS or not
7459
8316
  // for large matrices, BLAS is faster
7460
8317
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7471,7 +8328,8 @@ static bool ggml_compute_forward_mul_mat_use_blas(
7471
8328
 
7472
8329
  // TODO: find the optimal values for these
7473
8330
  if (ggml_is_contiguous(src0) &&
7474
- ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
8331
+ ggml_is_contiguous(src1) &&
8332
+ (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
7475
8333
 
7476
8334
  /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
7477
8335
  return true;
@@ -7494,7 +8352,7 @@ static void ggml_compute_forward_mul_mat_f32(
7494
8352
  const int64_t ne02 = src0->ne[2];
7495
8353
  const int64_t ne03 = src0->ne[3];
7496
8354
 
7497
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8355
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7498
8356
  const int64_t ne10 = src1->ne[0];
7499
8357
  #endif
7500
8358
  const int64_t ne11 = src1->ne[1];
@@ -7551,7 +8409,16 @@ static void ggml_compute_forward_mul_mat_f32(
7551
8409
  // nb01 >= nb00 - src0 is not transposed
7552
8410
  // compute by src0 rows
7553
8411
 
7554
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8412
+ #if defined(GGML_USE_CUBLAS)
8413
+ if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
8414
+ if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
8415
+ ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
8416
+ }
8417
+ return;
8418
+ }
8419
+ #endif
8420
+
8421
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7555
8422
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7556
8423
  if (params->ith != 0) {
7557
8424
  return;
@@ -7565,45 +8432,21 @@ static void ggml_compute_forward_mul_mat_f32(
7565
8432
  return;
7566
8433
  }
7567
8434
 
7568
- #if defined(GGML_USE_CUBLAS)
7569
- float *d_X = NULL;
7570
- float *d_Y = NULL;
7571
- float *d_D = NULL;
7572
- const float alpha = 1.0f;
7573
- const float beta = 0.0f;
7574
- const int x_ne = ne01 * ne10;
7575
- const int y_ne = ne11 * ne10;
7576
- const int d_ne = ne11 * ne01;
7577
-
7578
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7581
- #endif
7582
-
7583
8435
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7584
8436
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7585
8437
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
7586
8438
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7587
-
7588
8439
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7589
8440
 
7590
- #if defined(GGML_USE_CUBLAS)
7591
- // copy data to device
7592
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7594
-
7595
- // compute
7596
- CUBLAS_CHECK(
7597
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
- ne01, ne11, ne10,
7599
- &alpha, d_X, ne00,
7600
- d_Y, ne10,
7601
- &beta, d_D, ne01));
7602
-
7603
- // copy data to host
7604
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
- #else
8441
+ #if defined(GGML_USE_CLBLAST)
7606
8442
  // zT = y * xT
8443
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8444
+ ne11, ne01, ne10,
8445
+ 1.0f, y, ne10,
8446
+ x, ne10,
8447
+ 0.0f, d, ne01,
8448
+ GGML_TYPE_F32);
8449
+ #else
7607
8450
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7608
8451
  ne11, ne01, ne10,
7609
8452
  1.0f, y, ne10,
@@ -7612,12 +8455,6 @@ static void ggml_compute_forward_mul_mat_f32(
7612
8455
  #endif
7613
8456
  }
7614
8457
  }
7615
- #if defined(GGML_USE_CUBLAS)
7616
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
- CUDA_CHECK(cudaFree(d_X));
7618
- CUDA_CHECK(cudaFree(d_Y));
7619
- CUDA_CHECK(cudaFree(d_D));
7620
- #endif
7621
8458
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7622
8459
 
7623
8460
  return;
@@ -7747,7 +8584,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7747
8584
  // nb01 >= nb00 - src0 is not transposed
7748
8585
  // compute by src0 rows
7749
8586
 
7750
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8587
+ #if defined(GGML_USE_CUBLAS)
8588
+ if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
8589
+ if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
8590
+ ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
8591
+ }
8592
+ return;
8593
+ }
8594
+ #endif
8595
+
8596
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
7751
8597
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7752
8598
  GGML_ASSERT(nb10 == sizeof(float));
7753
8599
 
@@ -7763,37 +8609,9 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7763
8609
  return;
7764
8610
  }
7765
8611
 
7766
- #if defined(GGML_USE_CUBLAS)
7767
- ggml_fp16_t * const wdata = params->wdata;
7768
-
7769
- float *d_X = NULL;
7770
- float *d_Y = NULL;
7771
- float *d_D = NULL;
7772
- const float alpha = 1.0f;
7773
- const float beta = 0.0f;
7774
- const int x_ne = ne01 * ne10;
7775
- const int y_ne = ne11 * ne10;
7776
- const int d_ne = ne11 * ne01;
7777
-
7778
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7781
- #else
7782
- float * const wdata = params->wdata;
7783
- #endif
7784
8612
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7785
8613
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7786
- #if defined(GGML_USE_CUBLAS)
7787
- // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7788
- {
7789
- size_t id = 0;
7790
- for (int64_t i01 = 0; i01 < ne11; ++i01) {
7791
- for (int64_t i00 = 0; i00 < ne10; ++i00) {
7792
- wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
7793
- }
7794
- }
7795
- }
7796
- #else
8614
+ float * const wdata = params->wdata;
7797
8615
  {
7798
8616
  size_t id = 0;
7799
8617
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7801,31 +8619,23 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7801
8619
  wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
7802
8620
  }
7803
8621
  }
8622
+
8623
+ assert(id*sizeof(float) <= params->wsize);
7804
8624
  }
7805
- #endif
7806
8625
 
7807
- #if defined(GGML_USE_CUBLAS)
7808
- const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
7809
- const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
8626
+ #if defined(GGML_USE_CLBLAST)
8627
+ const float * x = wdata;
8628
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7810
8629
 
7811
8630
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7812
8631
 
7813
- // copy data to device
7814
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7816
-
7817
- // compute
7818
- CUBLAS_CHECK(
7819
- cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
- ne01, ne11, ne10,
7821
- &alpha, d_X, CUDA_R_16F, ne00,
7822
- d_Y, CUDA_R_16F, ne10,
7823
- &beta, d_D, CUDA_R_32F, ne01,
7824
- CUBLAS_COMPUTE_32F,
7825
- CUBLAS_GEMM_DEFAULT));
7826
-
7827
- // copy data to host
7828
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8632
+ // zT = y * xT
8633
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8634
+ ne11, ne01, ne10,
8635
+ 1.0f, y, ne10,
8636
+ x, ne10,
8637
+ 0.0f, d, ne01,
8638
+ GGML_TYPE_F32);
7829
8639
  #else
7830
8640
  const float * x = wdata;
7831
8641
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7842,12 +8652,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7842
8652
  }
7843
8653
  }
7844
8654
 
7845
- #if defined(GGML_USE_CUBLAS)
7846
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
- CUDA_CHECK(cudaFree(d_X));
7848
- CUDA_CHECK(cudaFree(d_Y));
7849
- CUDA_CHECK(cudaFree(d_D));
7850
- #endif
7851
8655
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7852
8656
 
7853
8657
  return;
@@ -7980,6 +8784,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7980
8784
  const enum ggml_type type = src0->type;
7981
8785
  quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7982
8786
  vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
8787
+ enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
7983
8788
 
7984
8789
  // we don't support permuted src0 or src1
7985
8790
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7999,7 +8804,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
7999
8804
  // nb01 >= nb00 - src0 is not transposed
8000
8805
  // compute by src0 rows
8001
8806
 
8002
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
8807
+ #if defined(GGML_USE_CUBLAS)
8808
+ if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
8809
+ if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
8810
+ ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
8811
+ }
8812
+ return;
8813
+ }
8814
+ #endif
8815
+
8816
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
8003
8817
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
8004
8818
  if (params->ith != 0) {
8005
8819
  return;
@@ -8013,39 +8827,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
8013
8827
  return;
8014
8828
  }
8015
8829
 
8016
- #if defined(GGML_USE_CUBLAS)
8017
- float *d_X = NULL;
8018
- float *d_Y = NULL;
8019
- float *d_D = NULL;
8020
- float *d_Q = NULL;
8021
- const float alpha = 1.0f;
8022
- const float beta = 0.0f;
8023
- const int x_ne = ne01 * ne10;
8024
- const int y_ne = ne11 * ne10;
8025
- const int d_ne = ne11 * ne01;
8026
-
8027
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
- CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8031
-
8032
- void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
- if (type == GGML_TYPE_Q4_0) {
8034
- dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8035
- }
8036
- else if (type == GGML_TYPE_Q4_1) {
8037
- dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8038
- }
8039
- else if (type == GGML_TYPE_Q4_2) {
8040
- dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
- }
8042
- else {
8043
- GGML_ASSERT(false);
8044
- }
8045
- #else
8046
8830
  float * const wdata = params->wdata;
8047
8831
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
- #endif
8049
8832
 
8050
8833
  for (int64_t i03 = 0; i03 < ne03; i03++) {
8051
8834
  for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -8053,14 +8836,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
8053
8836
 
8054
8837
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8055
8838
 
8056
- #if defined(GGML_USE_CUBLAS)
8057
- // copy and dequantize on device
8058
- CUDA_CHECK(
8059
- cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
- GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8061
-
8062
- dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8063
- CUDA_CHECK(cudaGetLastError());
8839
+ #if defined(GGML_USE_CLBLAST)
8840
+ const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
8064
8841
  #else
8065
8842
  {
8066
8843
  size_t id = 0;
@@ -8068,27 +8845,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
8068
8845
  dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
8069
8846
  id += ne00;
8070
8847
  }
8848
+
8849
+ assert(id*sizeof(float) <= params->wsize);
8071
8850
  }
8851
+
8072
8852
  const float * x = wdata;
8073
8853
  #endif
8074
8854
 
8075
-
8076
- #if defined(GGML_USE_CUBLAS)
8077
- // copy data to device
8078
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8079
-
8080
- // compute
8081
- CUBLAS_CHECK(
8082
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
- ne01, ne11, ne10,
8084
- &alpha, d_X, ne00,
8085
- d_Y, ne10,
8086
- &beta, d_D, ne01));
8087
-
8088
- // copy data to host
8089
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
- #else
8855
+ #if defined(GGML_USE_CLBLAST)
8091
8856
  // zT = y * xT
8857
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8858
+ ne11, ne01, ne10,
8859
+ 1.0f, y, ne10,
8860
+ x, ne10,
8861
+ 0.0f, d, ne01,
8862
+ type);
8863
+ #else
8092
8864
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
8093
8865
  ne11, ne01, ne10,
8094
8866
  1.0f, y, ne10,
@@ -8098,13 +8870,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
8098
8870
  }
8099
8871
  }
8100
8872
 
8101
- #if defined(GGML_USE_CUBLAS)
8102
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
- CUDA_CHECK(cudaFree(d_X));
8104
- CUDA_CHECK(cudaFree(d_Y));
8105
- CUDA_CHECK(cudaFree(d_D));
8106
- CUDA_CHECK(cudaFree(d_Q));
8107
- #endif
8108
8873
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
8109
8874
 
8110
8875
  return;
@@ -8113,7 +8878,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8113
8878
 
8114
8879
  if (params->type == GGML_TASK_INIT) {
8115
8880
  char * wdata = params->wdata;
8116
- const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8881
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8117
8882
 
8118
8883
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
8119
8884
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -8144,7 +8909,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
8144
8909
  const int ir1 = MIN(ir0 + dr, nr);
8145
8910
 
8146
8911
  void * wdata = params->wdata;
8147
- const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8912
+ const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
8148
8913
 
8149
8914
  for (int ir = ir0; ir < ir1; ++ir) {
8150
8915
  // src0 indices
@@ -8193,8 +8958,10 @@ static void ggml_compute_forward_mul_mat(
8193
8958
  case GGML_TYPE_Q4_0:
8194
8959
  case GGML_TYPE_Q4_1:
8195
8960
  case GGML_TYPE_Q4_2:
8196
- case GGML_TYPE_Q4_3:
8961
+ case GGML_TYPE_Q5_0:
8962
+ case GGML_TYPE_Q5_1:
8197
8963
  case GGML_TYPE_Q8_0:
8964
+ case GGML_TYPE_Q8_1:
8198
8965
  {
8199
8966
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
8200
8967
  } break;
@@ -8422,8 +9189,10 @@ static void ggml_compute_forward_get_rows(
8422
9189
  case GGML_TYPE_Q4_0:
8423
9190
  case GGML_TYPE_Q4_1:
8424
9191
  case GGML_TYPE_Q4_2:
8425
- case GGML_TYPE_Q4_3:
9192
+ case GGML_TYPE_Q5_0:
9193
+ case GGML_TYPE_Q5_1:
8426
9194
  case GGML_TYPE_Q8_0:
9195
+ case GGML_TYPE_Q8_1:
8427
9196
  {
8428
9197
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
8429
9198
  } break;
@@ -8561,6 +9330,7 @@ static void ggml_compute_forward_soft_max_f32(
8561
9330
 
8562
9331
  uint16_t scvt;
8563
9332
  for (int i = 0; i < nc; i++) {
9333
+ //printf("p[%3d] = %8.4f\n", i, p[i]);
8564
9334
  if (p[i] == -INFINITY) {
8565
9335
  p[i] = 0.0f;
8566
9336
  } else {
@@ -8603,6 +9373,161 @@ static void ggml_compute_forward_soft_max(
8603
9373
  }
8604
9374
  }
8605
9375
 
9376
+ // ggml_compute_forward_alibi
9377
+
9378
+ static void ggml_compute_forward_alibi_f32(
9379
+ const struct ggml_compute_params * params,
9380
+ const struct ggml_tensor * src0,
9381
+ const struct ggml_tensor * src1,
9382
+ struct ggml_tensor * dst) {
9383
+ assert(params->ith == 0);
9384
+ assert(src1->type == GGML_TYPE_I32);
9385
+ assert(ggml_nelements(src1) == 2);
9386
+
9387
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9388
+ return;
9389
+ }
9390
+
9391
+ const int n_past = ((int32_t *) src1->data)[0];
9392
+ const int n_head = ((int32_t *) src1->data)[1];
9393
+
9394
+ const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
9395
+ const int ne1 = src0->ne[1]; // seq_len_without_past
9396
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
9397
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
9398
+
9399
+ const int n = ggml_nrows(src0);
9400
+ const int ne2_ne3 = n/ne1; // ne2*ne3
9401
+
9402
+ const int nb0 = src0->nb[0];
9403
+ const int nb1 = src0->nb[1];
9404
+ const int nb2 = src0->nb[2];
9405
+ //const int nb3 = src0->nb[3];
9406
+
9407
+ assert(nb0 == sizeof(float));
9408
+ assert(ne1 + n_past == ne0); (void) n_past;
9409
+
9410
+ // add alibi to src0 (KQ_scaled)
9411
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
9412
+
9413
+ const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
9414
+ const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
9415
+
9416
+ for (int i = 0; i < ne0; i++) {
9417
+ for (int j = 0; j < ne1; j++) {
9418
+ for (int k = 0; k < ne2_ne3; k++) {
9419
+ float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
9420
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
9421
+
9422
+ // TODO: k*nb2 or k*nb3
9423
+
9424
+ float m_k;
9425
+
9426
+ if (k < n_heads_log2_floor) {
9427
+ m_k = powf(m0, k + 1);
9428
+ } else {
9429
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
9430
+ }
9431
+
9432
+ pdst[0] = (j+1) * m_k + src[0];
9433
+ }
9434
+ }
9435
+ }
9436
+ }
9437
+
9438
+
9439
+ static void ggml_compute_forward_alibi_f16(
9440
+ const struct ggml_compute_params * params,
9441
+ const struct ggml_tensor * src0,
9442
+ const struct ggml_tensor * src1,
9443
+ struct ggml_tensor * dst) {
9444
+ assert(params->ith == 0);
9445
+ assert(src1->type == GGML_TYPE_I32);
9446
+ assert(ggml_nelements(src1) == 2);
9447
+
9448
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9449
+ return;
9450
+ }
9451
+
9452
+ const int n_past = ((int32_t *) src1->data)[0];
9453
+ const int n_head = ((int32_t *) src1->data)[1];
9454
+
9455
+ const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
9456
+ const int ne1 = src0->ne[1]; // seq_len_without_past
9457
+ //const int ne2 = src0->ne[2]; // n_head -> this is k
9458
+ //const int ne3 = src0->ne[3]; // 1 -> bsz
9459
+
9460
+ const int n = ggml_nrows(src0);
9461
+ const int ne2_ne3 = n/ne1; // ne2*ne3
9462
+
9463
+ const int nb0 = src0->nb[0];
9464
+ const int nb1 = src0->nb[1];
9465
+ const int nb2 = src0->nb[2];
9466
+ //const int nb3 = src0->nb[3];
9467
+
9468
+ assert(nb0 == sizeof(ggml_fp16_t));
9469
+ assert(ne1 + n_past == ne0); (void) n_past;
9470
+
9471
+ // add alibi to src0 (KQ_scaled)
9472
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
9473
+
9474
+ const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
9475
+ const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
9476
+
9477
+ for (int i = 0; i < ne0; i++) {
9478
+ for (int j = 0; j < ne1; j++) {
9479
+ for (int k = 0; k < ne2_ne3; k++) {
9480
+ ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
9481
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
9482
+
9483
+ // TODO: k*nb2 or k*nb3
9484
+
9485
+ float m_k;
9486
+
9487
+ if (k < n_heads_log2_floor) {
9488
+ m_k = powf(m0, k + 1);
9489
+ } else {
9490
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
9491
+ }
9492
+
9493
+ // we return F32
9494
+ pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
9495
+ }
9496
+ }
9497
+ }
9498
+ }
9499
+
9500
+ static void ggml_compute_forward_alibi(
9501
+ const struct ggml_compute_params * params,
9502
+ const struct ggml_tensor * src0,
9503
+ const struct ggml_tensor * src1,
9504
+ struct ggml_tensor * dst) {
9505
+ switch (src0->type) {
9506
+ case GGML_TYPE_F16:
9507
+ {
9508
+ ggml_compute_forward_alibi_f16(params, src0, src1, dst);
9509
+ } break;
9510
+ case GGML_TYPE_F32:
9511
+ {
9512
+ ggml_compute_forward_alibi_f32(params, src0, src1, dst);
9513
+ } break;
9514
+ case GGML_TYPE_Q4_0:
9515
+ case GGML_TYPE_Q4_1:
9516
+ case GGML_TYPE_Q4_2:
9517
+ case GGML_TYPE_Q5_0:
9518
+ case GGML_TYPE_Q5_1:
9519
+ case GGML_TYPE_Q8_0:
9520
+ case GGML_TYPE_Q8_1:
9521
+ case GGML_TYPE_I8:
9522
+ case GGML_TYPE_I16:
9523
+ case GGML_TYPE_I32:
9524
+ case GGML_TYPE_COUNT:
9525
+ {
9526
+ GGML_ASSERT(false);
9527
+ } break;
9528
+ }
9529
+ }
9530
+
8606
9531
  // ggml_compute_forward_rope
8607
9532
 
8608
9533
  static void ggml_compute_forward_rope_f32(
@@ -10241,6 +11166,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
10241
11166
  {
10242
11167
  ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
10243
11168
  } break;
11169
+ case GGML_OP_ALIBI:
11170
+ {
11171
+ ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
11172
+ } break;
10244
11173
  case GGML_OP_CONV_1D_1S:
10245
11174
  {
10246
11175
  ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -10443,6 +11372,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
10443
11372
  {
10444
11373
  GGML_ASSERT(false); // TODO: not implemented
10445
11374
  } break;
11375
+ case GGML_OP_ALIBI:
11376
+ {
11377
+ GGML_ASSERT(false); // TODO: not implemented
11378
+ } break;
10446
11379
  case GGML_OP_SILU:
10447
11380
  {
10448
11381
  GGML_ASSERT(false); // TODO: not implemented
@@ -10920,15 +11853,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10920
11853
 
10921
11854
  size_t cur = 0;
10922
11855
 
11856
+ #if defined(GGML_USE_CUBLAS)
11857
+ if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
11858
+ node->n_tasks = 1; // TODO: this actually is doing nothing
11859
+ // the threads are still spinning
11860
+ cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
11861
+ }
11862
+ else
11863
+ #endif
10923
11864
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10924
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11865
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10925
11866
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10926
11867
  node->n_tasks = 1; // TODO: this actually is doing nothing
10927
11868
  // the threads are still spinning
11869
+ // here we need memory just for single 2D matrix from src0
10928
11870
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10929
- //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
10930
- //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
10931
- //printf("cur = %zu\n", cur);
10932
11871
  } else {
10933
11872
  cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
10934
11873
  }
@@ -10937,15 +11876,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10937
11876
  #endif
10938
11877
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10939
11878
  cur = 0;
11879
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
11880
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
11881
+ node->n_tasks = 1;
11882
+ }
11883
+ #endif
10940
11884
  } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
11885
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
10942
11886
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10943
11887
  node->n_tasks = 1;
10944
11888
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10945
11889
  } else
10946
11890
  #endif
10947
11891
  {
10948
- cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
11892
+ const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
11893
+ cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
10949
11894
  }
10950
11895
  } else {
10951
11896
  GGML_ASSERT(false);
@@ -10975,6 +11920,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10975
11920
  {
10976
11921
  node->n_tasks = n_threads;
10977
11922
  } break;
11923
+ case GGML_OP_ALIBI:
11924
+ {
11925
+ node->n_tasks = 1; //TODO
11926
+ } break;
10978
11927
  case GGML_OP_CONV_1D_1S:
10979
11928
  case GGML_OP_CONV_1D_2S:
10980
11929
  {
@@ -11273,9 +12222,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11273
12222
  for (int i = 0; i < cgraph->n_nodes; i++) {
11274
12223
  struct ggml_tensor * node = cgraph->nodes[i];
11275
12224
 
11276
- perf_total_per_op_us[node->op] += node->perf_time_us;
12225
+ perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
11277
12226
 
11278
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
12227
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11279
12228
  i,
11280
12229
  node->ne[0], node->ne[1], node->ne[2],
11281
12230
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -11289,13 +12238,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
11289
12238
  for (int i = 0; i < cgraph->n_leafs; i++) {
11290
12239
  struct ggml_tensor * node = cgraph->leafs[i];
11291
12240
 
11292
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
12241
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
11293
12242
  i,
11294
12243
  node->ne[0], node->ne[1],
11295
12244
  GGML_OP_LABEL[node->op]);
11296
12245
  }
11297
12246
 
11298
12247
  for (int i = 0; i < GGML_OP_COUNT; i++) {
12248
+ if (perf_total_per_op_us[i] == 0) {
12249
+ continue;
12250
+ }
12251
+
11299
12252
  GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
11300
12253
  }
11301
12254
 
@@ -11358,10 +12311,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
11358
12311
  snprintf(color, sizeof(color), "white");
11359
12312
  }
11360
12313
 
11361
- fprintf(fp, " \"%p\" [ \
11362
- style = filled; fillcolor = %s; shape = record; \
11363
- label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
11364
- (void *) node, color,
12314
+ fprintf(fp, " \"%p\" [ "
12315
+ "style = filled; fillcolor = %s; shape = record; "
12316
+ "label=\"",
12317
+ (void *) node, color);
12318
+
12319
+ if (strlen(node->name) > 0) {
12320
+ fprintf(fp, "%s |", node->name);
12321
+ }
12322
+
12323
+ fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
11365
12324
  i, node->ne[0], node->ne[1],
11366
12325
  GGML_OP_SYMBOL[node->op]);
11367
12326
 
@@ -11377,18 +12336,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
11377
12336
 
11378
12337
  snprintf(color, sizeof(color), "pink");
11379
12338
 
12339
+ fprintf(fp, " \"%p\" [ "
12340
+ "style = filled; fillcolor = %s; shape = record; "
12341
+ "label=\"<x>",
12342
+ (void *) node, color);
12343
+
12344
+ if (strlen(node->name) > 0) {
12345
+ fprintf(fp, "%s | ", node->name);
12346
+ }
11380
12347
  if (ggml_nelements(node) == 1) {
11381
- fprintf(fp, " \"%p\" [ \
11382
- style = filled; fillcolor = %s; shape = record; \
11383
- label=\"<x>%.1e\"; ]\n",
11384
- (void *) node, color, (double)ggml_get_f32_1d(node, 0));
11385
- } else {
11386
- fprintf(fp, " \"%p\" [ \
11387
- style = filled; fillcolor = %s; shape = record; \
11388
- label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
11389
- (void *) node, color,
11390
- i, node->ne[0], node->ne[1]);
12348
+ if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
12349
+ fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
12350
+ }
12351
+ else {
12352
+ fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
12353
+ }
11391
12354
  }
12355
+ else {
12356
+ fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
12357
+ }
12358
+ fprintf(fp, "\"; ]\n");
11392
12359
  }
11393
12360
 
11394
12361
  for (int i = 0; i < gb->n_nodes; i++) {
@@ -12129,7 +13096,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
12129
13096
 
12130
13097
  for (int i = 0; i < nb; i++) {
12131
13098
  for (int l = 0; l < QK4_0; l += 2) {
12132
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
13099
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12133
13100
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12134
13101
 
12135
13102
  hist[vi0]++;
@@ -12152,7 +13119,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
12152
13119
 
12153
13120
  for (int i = 0; i < nb; i++) {
12154
13121
  for (int l = 0; l < QK4_1; l += 2) {
12155
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
13122
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12156
13123
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12157
13124
 
12158
13125
  hist[vi0]++;
@@ -12171,12 +13138,11 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12171
13138
  for (int j = 0; j < n; j += k) {
12172
13139
  block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
13140
 
12174
- //quantize_row_q4_2_reference(src + j, y, k);
12175
- quantize_row_q4_2_rmse(src + j, y, k);
13141
+ quantize_row_q4_2_reference(src + j, y, k);
12176
13142
 
12177
13143
  for (int i = 0; i < nb; i++) {
12178
13144
  for (int l = 0; l < QK4_2; l += 2) {
12179
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
13145
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12180
13146
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
13147
 
12182
13148
  hist[vi0]++;
@@ -12188,19 +13154,56 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
12188
13154
  return (n/QK4_2*sizeof(block_q4_2));
12189
13155
  }
12190
13156
 
12191
- size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12192
- assert(k % QK4_3 == 0);
12193
- const int nb = k / QK4_3;
13157
+ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
13158
+ assert(k % QK5_0 == 0);
13159
+ const int nb = k / QK5_0;
12194
13160
 
12195
13161
  for (int j = 0; j < n; j += k) {
12196
- block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
13162
+ block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
12197
13163
 
12198
- quantize_row_q4_3_reference(src + j, y, k);
13164
+ quantize_row_q5_0_reference(src + j, y, k);
12199
13165
 
12200
13166
  for (int i = 0; i < nb; i++) {
12201
- for (int l = 0; l < QK4_3; l += 2) {
12202
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12203
- const uint8_t vi1 = y[i].qs[l/2] >> 4;
13167
+ uint32_t qh;
13168
+ memcpy(&qh, &y[i].qh, sizeof(qh));
13169
+
13170
+ for (int l = 0; l < QK5_0; l += 2) {
13171
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
13172
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
13173
+
13174
+ // cast to 16 bins
13175
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
13176
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
13177
+
13178
+ hist[vi0]++;
13179
+ hist[vi1]++;
13180
+ }
13181
+ }
13182
+ }
13183
+
13184
+ return (n/QK5_0*sizeof(block_q5_0));
13185
+ }
13186
+
13187
+ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
13188
+ assert(k % QK5_1 == 0);
13189
+ const int nb = k / QK5_1;
13190
+
13191
+ for (int j = 0; j < n; j += k) {
13192
+ block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
13193
+
13194
+ quantize_row_q5_1_reference(src + j, y, k);
13195
+
13196
+ for (int i = 0; i < nb; i++) {
13197
+ uint32_t qh;
13198
+ memcpy(&qh, &y[i].qh, sizeof(qh));
13199
+
13200
+ for (int l = 0; l < QK5_1; l += 2) {
13201
+ const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
13202
+ const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
13203
+
13204
+ // cast to 16 bins
13205
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
13206
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12204
13207
 
12205
13208
  hist[vi0]++;
12206
13209
  hist[vi1]++;
@@ -12208,7 +13211,28 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
12208
13211
  }
12209
13212
  }
12210
13213
 
12211
- return (n/QK4_3*sizeof(block_q4_3));
13214
+ return (n/QK5_1*sizeof(block_q5_1));
13215
+ }
13216
+
13217
+ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
13218
+ assert(k % QK8_0 == 0);
13219
+ const int nb = k / QK8_0;
13220
+
13221
+ for (int j = 0; j < n; j += k) {
13222
+ block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
13223
+
13224
+ quantize_row_q8_0_reference(src + j, y, k);
13225
+
13226
+ for (int i = 0; i < nb; i++) {
13227
+ for (int l = 0; l < QK8_0; ++l) {
13228
+ const int8_t vi = y[i].qs[l];
13229
+
13230
+ hist[vi/16 + 8]++;
13231
+ }
13232
+ }
13233
+ }
13234
+
13235
+ return (n/QK8_0*sizeof(block_q8_0));
12212
13236
  }
12213
13237
 
12214
13238
  size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
@@ -12232,11 +13256,23 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
12232
13256
  block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
12233
13257
  result = ggml_quantize_q4_2(src + start, block, n, n, hist);
12234
13258
  } break;
12235
- case GGML_TYPE_Q4_3:
13259
+ case GGML_TYPE_Q5_0:
13260
+ {
13261
+ GGML_ASSERT(start % QK5_0 == 0);
13262
+ block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
13263
+ result = ggml_quantize_q5_0(src + start, block, n, n, hist);
13264
+ } break;
13265
+ case GGML_TYPE_Q5_1:
13266
+ {
13267
+ GGML_ASSERT(start % QK5_1 == 0);
13268
+ block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
13269
+ result = ggml_quantize_q5_1(src + start, block, n, n, hist);
13270
+ } break;
13271
+ case GGML_TYPE_Q8_0:
12236
13272
  {
12237
- GGML_ASSERT(start % QK4_3 == 0);
12238
- block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
- result = ggml_quantize_q4_3(src + start, block, n, n, hist);
13273
+ GGML_ASSERT(start % QK8_0 == 0);
13274
+ block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
13275
+ result = ggml_quantize_q8_0(src + start, block, n, n, hist);
12240
13276
  } break;
12241
13277
  default:
12242
13278
  assert(false);
@@ -12335,7 +13371,7 @@ int ggml_cpu_has_wasm_simd(void) {
12335
13371
  }
12336
13372
 
12337
13373
  int ggml_cpu_has_blas(void) {
12338
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
13374
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
12339
13375
  return 1;
12340
13376
  #else
12341
13377
  return 0;
@@ -12350,6 +13386,18 @@ int ggml_cpu_has_cublas(void) {
12350
13386
  #endif
12351
13387
  }
12352
13388
 
13389
+ int ggml_cpu_has_clblast(void) {
13390
+ #if defined(GGML_USE_CLBLAST)
13391
+ return 1;
13392
+ #else
13393
+ return 0;
13394
+ #endif
13395
+ }
13396
+
13397
+ int ggml_cpu_has_gpublas(void) {
13398
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
13399
+ }
13400
+
12353
13401
  int ggml_cpu_has_sse3(void) {
12354
13402
  #if defined(__SSE3__)
12355
13403
  return 1;