llama_cpp 0.0.5 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@
19
19
  #include <inttypes.h>
20
20
  #include <stdio.h>
21
21
  #include <float.h>
22
+ #include <limits.h>
22
23
 
23
24
  // if C99 - static_assert is noop
24
25
  // ref: https://stackoverflow.com/a/53923785/4039976
@@ -142,10 +143,14 @@ inline static void* ggml_aligned_malloc(size_t size) {
142
143
  } \
143
144
  } while (0)
144
145
 
145
- #ifdef GGML_USE_ACCELERATE
146
+ #if defined(GGML_USE_ACCELERATE)
146
147
  #include <Accelerate/Accelerate.h>
147
- #elif GGML_USE_OPENBLAS
148
+ #elif defined(GGML_USE_OPENBLAS)
148
149
  #include <cblas.h>
150
+ #elif defined(GGML_USE_CUBLAS)
151
+ #include "ggml-cuda.h"
152
+ #elif defined(GGML_USE_CLBLAST)
153
+ #include "ggml-opencl.h"
149
154
  #endif
150
155
 
151
156
  #undef MIN
@@ -325,6 +330,20 @@ static ggml_fp16_t table_exp_f16[1 << 16];
325
330
  // precomputed f32 table for f16 (256 KB)
326
331
  static float table_f32_f16[1 << 16];
327
332
 
333
+ #if defined(__ARM_NEON)
334
+ #define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
335
+ #define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
336
+ #define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
337
+ #define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
338
+ #define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
339
+ #define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
340
+ #define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
341
+ #define B8(c,s ) B7(c,s, c), B7(c,s, s)
342
+
343
+ // precomputed tables for expanding 8bits to 8 bytes (shl 4)
344
+ static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
345
+ #endif
346
+
328
347
  // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
329
348
  // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
330
349
  // This is also true for POWER9.
@@ -427,12 +446,69 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427
446
  // quantization
428
447
  //
429
448
 
430
- // AVX routines provided by GH user Const-me
431
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
449
+ #if __AVX__ || __AVX2__ || __AVX512F__
450
+ // Unpack 16 4-bit fields into 16 bytes
451
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
452
+ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
453
+ {
454
+ // Load 8 bytes from memory
455
+ __m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
456
+
457
+ // Expand bytes into uint16_t values
458
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
459
+
460
+ // Unpack values into individual bytes
461
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
462
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
463
+ __m128i low = _mm_and_si128( lowMask, bytes );
464
+ high = _mm_slli_epi16( high, 4 );
465
+ bytes = _mm_or_si128( low, high );
466
+ return bytes;
467
+ }
468
+
469
+ // horizontally add 8 floats
470
+ static inline float hsum_float_8(const __m256 x) {
471
+ __m128 res = _mm256_extractf128_ps(x, 1);
472
+ res = _mm_add_ps(res, _mm256_castps256_ps128(x));
473
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
474
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
475
+ return _mm_cvtss_f32(res);
476
+ }
477
+
478
+ // horizontally add 8 int32_t
479
+ static inline int hsum_i32_8(const __m256i a) {
480
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
481
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
482
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
483
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
484
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
485
+ }
486
+
487
+ // horizontally add 4 int32_t
488
+ static inline int hsum_i32_4(const __m128i a) {
489
+ const __m128i hi64 = _mm_unpackhi_epi64(a, a);
490
+ const __m128i sum64 = _mm_add_epi32(hi64, a);
491
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
492
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
493
+ }
494
+
432
495
  #if __AVX2__ || __AVX512F__
496
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
497
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
498
+ uint32_t x32;
499
+ memcpy(&x32, x, sizeof(uint32_t));
500
+ const __m256i shuf_mask = _mm256_set_epi64x(
501
+ 0x0303030303030303, 0x0202020202020202,
502
+ 0x0101010101010101, 0x0000000000000000);
503
+ __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
504
+ const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
505
+ bytes = _mm256_or_si256(bytes, bit_mask);
506
+ return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
507
+ }
508
+
433
509
  // Unpack 32 4-bit fields into 32 bytes
434
510
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435
- static inline __m256i bytesFromNibbles( const uint8_t* rsi )
511
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
436
512
  {
437
513
  // Load 16 bytes from memory
438
514
  __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -449,9 +525,38 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
449
525
  return bytes;
450
526
  }
451
527
 
528
+ // add int16_t pairwise and return as float vector
529
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
530
+ const __m256i ones = _mm256_set1_epi16(1);
531
+ const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
532
+ return _mm256_cvtepi32_ps(summed_pairs);
533
+ }
534
+
535
+ // multiply int8_t, add results pairwise twice and return as float vector
536
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
537
+ // Get absolute values of x vectors
538
+ const __m256i ax = _mm256_sign_epi8(x, x);
539
+ // Sign the values of the y vectors
540
+ const __m256i sy = _mm256_sign_epi8(y, x);
541
+ #if __AVXVNNI__
542
+ const __m256i zero = _mm256_setzero_si256();
543
+ const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
544
+ return _mm256_cvtepi32_ps(summed_pairs);
545
+ #else
546
+ // Perform multiplication and create 16-bit values
547
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
548
+ return sum_i16_pairs_float(dot);
549
+ #endif
550
+ }
551
+
452
552
  static inline __m128i packNibbles( __m256i bytes )
453
553
  {
454
554
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
555
+ #if __AVX512F__
556
+ const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
557
+ bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
558
+ return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
559
+ #else
455
560
  const __m256i lowByte = _mm256_set1_epi16( 0xFF );
456
561
  __m256i high = _mm256_andnot_si256( lowByte, bytes );
457
562
  __m256i low = _mm256_and_si256( lowByte, bytes );
@@ -462,25 +567,9 @@ static inline __m128i packNibbles( __m256i bytes )
462
567
  __m128i r0 = _mm256_castsi256_si128( bytes );
463
568
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
464
569
  return _mm_packus_epi16( r0, r1 );
570
+ #endif
465
571
  }
466
- #elif __AVX__
467
- static inline __m128i bytesFromNibbles( const uint8_t* rsi )
468
- {
469
- // Load 8 bytes from memory
470
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
471
-
472
- // Expand bytes into uint16_t values
473
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
474
-
475
- // Unpack values into individual bytes
476
- const __m128i lowMask = _mm_set1_epi8( 0xF );
477
- __m128i high = _mm_andnot_si128( lowMask, bytes );
478
- __m128i low = _mm_and_si128( lowMask, bytes );
479
- high = _mm_slli_epi16( high, 4 );
480
- bytes = _mm_or_si128( low, high );
481
- return bytes;
482
- }
483
-
572
+ #else
484
573
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
485
574
  {
486
575
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -497,6 +586,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
586
  return _mm_packus_epi16( bytes1, bytes2);
498
587
  }
499
588
  #endif
589
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
500
590
 
501
591
  #if __ARM_NEON
502
592
 
@@ -514,6 +604,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
514
604
  (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
515
605
  }
516
606
 
607
+ inline static int16_t vaddvq_s8(int8x16_t v) {
608
+ return
609
+ (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
610
+ (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
611
+ (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
612
+ (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
613
+ (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
614
+ (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
615
+ (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
616
+ (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
617
+ }
618
+
517
619
  inline static int32_t vaddvq_s16(int16x8_t v) {
518
620
  return
519
621
  (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -583,7 +685,39 @@ typedef struct {
583
685
  float m; // min
584
686
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
585
687
  } block_q4_1;
586
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
688
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
689
+
690
+ #define QK4_2 16
691
+ typedef struct {
692
+ ggml_fp16_t d; // delta
693
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
694
+ } block_q4_2;
695
+ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
696
+
697
+ #define QK4_3 16
698
+ typedef struct {
699
+ ggml_fp16_t d; // delta
700
+ ggml_fp16_t m; // min
701
+ uint8_t qs[QK4_3 / 2]; // nibbles / quants
702
+ } block_q4_3;
703
+ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
704
+
705
+ #define QK5_0 32
706
+ typedef struct {
707
+ ggml_fp16_t d; // delta
708
+ uint8_t qh[4]; // 5-th bit of quants
709
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
710
+ } block_q5_0;
711
+ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
712
+
713
+ #define QK5_1 32
714
+ typedef struct {
715
+ ggml_fp16_t d; // delta
716
+ ggml_fp16_t m; // min
717
+ uint8_t qh[4]; // 5-th bit of quants
718
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
719
+ } block_q5_1;
720
+ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
587
721
 
588
722
  #define QK8_0 32
589
723
  typedef struct {
@@ -592,6 +726,14 @@ typedef struct {
592
726
  } block_q8_0;
593
727
  static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
594
728
 
729
+ #define QK8_1 32
730
+ typedef struct {
731
+ float d; // delta
732
+ float s0; // d * sum(qs[i]) low
733
+ float s1; // d * sum(qs[i]) high
734
+ int8_t qs[QK8_1]; // quants
735
+ } block_q8_1;
736
+ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
595
737
 
596
738
  // reference implementation for deterministic creation of model files
597
739
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -602,13 +744,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
602
744
 
603
745
  for (int i = 0; i < nb; i++) {
604
746
  float amax = 0.0f; // absolute max
747
+ float max = 0.0f;
605
748
 
606
749
  for (int l = 0; l < QK4_0; l++) {
607
750
  const float v = x[i*QK4_0 + l];
608
- amax = MAX(amax, fabsf(v));
751
+ if (amax < fabsf(v)) {
752
+ amax = fabsf(v);
753
+ max = v;
754
+ }
609
755
  }
610
756
 
611
- const float d = amax / ((1 << 3) - 1);
757
+ const float d = max / -8;
612
758
  const float id = d ? 1.0f/d : 0.0f;
613
759
 
614
760
  y[i].d = d;
@@ -617,8 +763,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
617
763
  const float v0 = x[i*QK4_0 + l + 0]*id;
618
764
  const float v1 = x[i*QK4_0 + l + 1]*id;
619
765
 
620
- const uint8_t vi0 = (int8_t)roundf(v0) + 8;
621
- const uint8_t vi1 = (int8_t)roundf(v1) + 8;
766
+ const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
767
+ const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
622
768
 
623
769
  assert(vi0 < 16);
624
770
  assert(vi1 < 16);
@@ -638,28 +784,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
638
784
 
639
785
  #if defined(__POWER9_VECTOR__)
640
786
  const vector float v85 = vec_splats(8.5f);
787
+ const vector signed int v15 = vec_splats(15);
641
788
  for (int i = 0; i < nb; i++) {
642
- float amax = 0.0f; // absolute max
789
+ float max = 0.0f;
790
+ float min = 0.0f;
643
791
 
644
792
  vector float srcv [8];
645
- vector float asrcv[8];
646
- vector float amaxv[8];
793
+ vector float maxv[8];
794
+ vector float minv[8];
647
795
 
648
796
  for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
649
- for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
650
-
651
- for (int l = 0; l < 4; l++) amaxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
652
- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
653
- amaxv[0] = vec_max(amaxv[0], amaxv[2]);
654
- amaxv[4] = vec_max(amaxv[4], amaxv[6]);
655
- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
656
- amaxv[0] = vec_max(amaxv[0], amaxv[4]);
657
-
658
- amax = MAX(
659
- MAX(vec_extract(amaxv[0], 0), vec_extract(amaxv[0], 1)),
660
- MAX(vec_extract(amaxv[0], 2), vec_extract(amaxv[0], 3)));
661
-
662
- const float d = amax / ((1 << 3) - 1);
797
+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
798
+
799
+ for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
800
+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
801
+ maxv[0] = vec_max(maxv[0], maxv[2]);
802
+ maxv[4] = vec_max(maxv[4], maxv[6]);
803
+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
804
+ maxv[0] = vec_max(maxv[0], maxv[4]);
805
+
806
+ for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
807
+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
808
+ minv[0] = vec_min(minv[0], minv[2]);
809
+ minv[4] = vec_min(minv[4], minv[6]);
810
+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
811
+ minv[0] = vec_min(minv[0], minv[4]);
812
+
813
+
814
+ max = MAX(
815
+ MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
816
+ MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
817
+ min = MIN(
818
+ MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
819
+ MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
820
+
821
+ const float magnitude = max >= fabsf(min) ? max : min;
822
+ const float d = magnitude / -8;
663
823
  const float id = d ? 1.0/d : 0.0;
664
824
 
665
825
  y[i].d = d;
@@ -669,27 +829,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
669
829
  for (int l = 0; l < 8; l++) {
670
830
  const vector float vf = vec_madd(srcv[l], vid, v85);
671
831
  const vector signed int vi = vec_signed(vf);
832
+ const vector signed int vc = vec_min(vi, v15);
672
833
 
673
- pb[2*l + 0] = vec_extract(vi, 0) | (vec_extract(vi, 1) << 4);
674
- pb[2*l + 1] = vec_extract(vi, 2) | (vec_extract(vi, 3) << 4);
834
+ pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
835
+ pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
675
836
  }
676
837
  }
677
838
  #elif __ARM_NEON
678
839
  for (int i = 0; i < nb; i++) {
679
840
  float32x4_t srcv [8];
680
- float32x4_t asrcv[8];
681
- float32x4_t amaxv[8];
841
+ float32x4_t maxv[8];
842
+ float32x4_t minv[8];
682
843
 
683
844
  for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
684
- for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
685
845
 
686
- for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
687
- for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
688
- for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
846
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
847
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
848
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
689
849
 
690
- const float amax = vmaxvq_f32(amaxv[0]);
850
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
851
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
852
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
853
+
854
+ const float max = vmaxvq_f32(maxv[0]);
855
+ const float min = vminvq_f32(minv[0]);
691
856
 
692
- const float d = amax / ((1 << 3) - 1);
857
+ const float magnitude = max >= fabsf(min) ? max : min;
858
+ const float d = magnitude / -8;
693
859
  const float id = d ? 1.0f/d : 0.0f;
694
860
 
695
861
  y[i].d = d;
@@ -698,9 +864,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
698
864
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
699
865
  const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
700
866
  const int32x4_t vi = vcvtq_s32_f32(vf);
867
+ const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
701
868
 
702
- y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
703
- y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
869
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
870
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
704
871
  }
705
872
  }
706
873
  #elif defined(__AVX2__)
@@ -712,22 +879,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
712
879
  __m256 v3 = _mm256_loadu_ps( x + 24 );
713
880
  x += 32;
714
881
 
715
- // Compute max(abs(e)) for the block
716
- const __m256 signBit = _mm256_set1_ps( -0.0f );
717
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
718
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
719
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
720
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
882
+ // Compute max for the block
883
+ __m256 max = _mm256_max_ps( v0, v1 );
884
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
885
+ max = _mm256_max_ps( max, maxTmp );
721
886
 
722
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
887
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
723
888
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
724
889
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
725
890
  const float maxScalar = _mm_cvtss_f32( max4 );
726
891
 
892
+ // Compute min for the block
893
+ __m256 min = _mm256_min_ps( v0, v1 );
894
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
895
+ min = _mm256_min_ps( min, minTmp );
896
+
897
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
898
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
899
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
900
+ const float minScalar = _mm_cvtss_f32( min4 );
901
+
727
902
  // Quantize these floats
728
- const float d = maxScalar / 7.0f;
903
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
904
+ const float d = magnitude / -8.0f;
729
905
  y[i].d = d;
730
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
906
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
731
907
  const __m256 mul = _mm256_set1_ps( id );
732
908
 
733
909
  // Apply the multiplier
@@ -760,9 +936,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
760
936
  const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
761
937
  i0 = _mm256_permutevar8x32_epi32( i0, perm );
762
938
 
763
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
939
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
764
940
  const __m256i off = _mm256_set1_epi8( 8 );
765
941
  i0 = _mm256_add_epi8( i0, off );
942
+ const __m256i maxNibble = _mm256_set1_epi8( 15 );
943
+ i0 = _mm256_min_epi8( i0, maxNibble );
766
944
 
767
945
  // Compress the vector into 4 bit/value, and store
768
946
  __m128i res = packNibbles( i0 );
@@ -777,22 +955,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
777
955
  __m256 v3 = _mm256_loadu_ps( x + 24 );
778
956
  x += 32;
779
957
 
780
- // Compute max(abs(e)) for the block
781
- const __m256 signBit = _mm256_set1_ps( -0.0f );
782
- __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
783
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
784
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
785
- maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
958
+ // Compute max for the block
959
+ __m256 max = _mm256_max_ps( v0, v1 );
960
+ __m256 maxTmp = _mm256_max_ps( v2, v3 );
961
+ max = _mm256_max_ps( max, maxTmp );
786
962
 
787
- __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
963
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
788
964
  max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
789
965
  max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
790
966
  const float maxScalar = _mm_cvtss_f32( max4 );
791
967
 
968
+ // Compute min for the block
969
+ __m256 min = _mm256_min_ps( v0, v1 );
970
+ __m256 minTmp = _mm256_min_ps( v2, v3 );
971
+ min = _mm256_min_ps( min, minTmp );
972
+
973
+ __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
974
+ min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
975
+ min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
976
+ const float minScalar = _mm_cvtss_f32( min4 );
977
+
792
978
  // Quantize these floats
793
- const float d = maxScalar / 7.0f;
979
+ const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
980
+ const float d = magnitude / -8.0f;
794
981
  y[i].d = d;
795
- const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
982
+ const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
796
983
  const __m256 mul = _mm256_set1_ps( id );
797
984
 
798
985
  // Apply the multiplier
@@ -833,10 +1020,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
833
1020
  ni0 = _mm_packs_epi16( ni0, ni2 );
834
1021
  ni4 = _mm_packs_epi16( ni4, ni6 );
835
1022
 
836
- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
837
- const __m128i off = _mm_set1_epi8( 8);
1023
+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
1024
+ const __m128i off = _mm_set1_epi8( 8 );
838
1025
  ni0 = _mm_add_epi8( ni0, off );
839
1026
  ni4 = _mm_add_epi8( ni4, off );
1027
+ const __m128i maxNibble = _mm_set1_epi8( 15 );
1028
+ ni0 = _mm_min_epi8( ni0, maxNibble );
1029
+ ni4 = _mm_min_epi8( ni4, maxNibble );
840
1030
 
841
1031
  // Compress the vector into 4 bit/value, and store
842
1032
  __m128i res = packNibbles( ni0, ni4 );
@@ -844,24 +1034,32 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
844
1034
  }
845
1035
  #elif defined(__wasm_simd128__)
846
1036
  for (int i = 0; i < nb; i++) {
847
- float amax = 0.0f; // absolute max
1037
+ float max = 0.0f;
1038
+ float min = 0.0f;
848
1039
 
849
1040
  v128_t srcv [8];
850
- v128_t asrcv[8];
851
- v128_t amaxv[8];
1041
+ v128_t maxv[8];
1042
+ v128_t minv[8];
852
1043
 
853
1044
  for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
854
- for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);
855
1045
 
856
- for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
857
- for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
858
- for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);
1046
+ for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
1047
+ for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
1048
+ for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
1049
+
1050
+ for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
1051
+ for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
1052
+ for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
859
1053
 
860
- amax = MAX(
861
- MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
862
- MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));
1054
+ max = MAX(
1055
+ MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
1056
+ MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
1057
+ min = MIN(
1058
+ MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
1059
+ MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
863
1060
 
864
- const float d = amax / ((1 << 3) - 1);
1061
+ const float magnitude = max >= fabsf(min) ? max : min;
1062
+ const float d = magnitude / -8;
865
1063
  const float id = d ? 1.0/d : 0.0;
866
1064
 
867
1065
  y[i].d = d;
@@ -870,9 +1068,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
870
1068
  const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
871
1069
  const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
872
1070
  const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1071
+ const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
873
1072
 
874
- y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
875
- y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
1073
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
1074
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
876
1075
  }
877
1076
  }
878
1077
  #else
@@ -1045,6 +1244,193 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1045
1244
  #endif
1046
1245
  }
1047
1246
 
1247
+ // reference implementation for deterministic creation of model files
1248
+ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
1249
+ assert(k % QK4_2 == 0);
1250
+
1251
+ const int nb = k / QK4_2;
1252
+
1253
+ for (int i = 0; i < nb; i++) {
1254
+ float amax = 0.0f; // absolute max
1255
+ float max = 0.0f;
1256
+
1257
+ for (int l = 0; l < QK4_2; l++) {
1258
+ const float v = x[i*QK4_2 + l];
1259
+ if (amax < fabsf(v)) {
1260
+ amax = fabsf(v);
1261
+ max = v;
1262
+ }
1263
+ }
1264
+
1265
+ const float d = max / -8;
1266
+
1267
+ const float id = d ? 1.0f/d : 0.0f;
1268
+
1269
+ y[i].d = GGML_FP32_TO_FP16(d);
1270
+
1271
+ for (int l = 0; l < QK4_2; l += 2) {
1272
+ const float v0 = x[i*QK4_2 + l + 0]*id;
1273
+ const float v1 = x[i*QK4_2 + l + 1]*id;
1274
+
1275
+ const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
1276
+ const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
1277
+
1278
+ assert(vi0 < 16);
1279
+ assert(vi1 < 16);
1280
+
1281
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1282
+ }
1283
+ }
1284
+ }
1285
+
1286
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1287
+ assert(k % QK4_2 == 0);
1288
+
1289
+ block_q4_2 * restrict y = vy;
1290
+
1291
+ quantize_row_q4_2_reference(x, y, k);
1292
+ }
1293
+
1294
+ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1295
+ assert(k % QK4_3 == 0);
1296
+ const int nb = k / QK4_3;
1297
+
1298
+ for (int i = 0; i < nb; i++) {
1299
+ float min = FLT_MAX;
1300
+ float max = -FLT_MAX;
1301
+
1302
+ for (int l = 0; l < QK4_3; l++) {
1303
+ const float v = x[i*QK4_3 + l];
1304
+ if (v < min) min = v;
1305
+ if (v > max) max = v;
1306
+ }
1307
+
1308
+ const float d = (max - min) / ((1 << 4) - 1);
1309
+ const float id = d ? 1.0f/d : 0.0f;
1310
+
1311
+ y[i].d = GGML_FP32_TO_FP16(d);
1312
+ y[i].m = GGML_FP32_TO_FP16(min);
1313
+
1314
+ for (int l = 0; l < QK4_3; l += 2) {
1315
+ const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1316
+ const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1317
+
1318
+ const uint8_t vi0 = (int) (v0 + 0.5f);
1319
+ const uint8_t vi1 = (int) (v1 + 0.5f);
1320
+
1321
+ assert(vi0 < 16);
1322
+ assert(vi1 < 16);
1323
+
1324
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1325
+ }
1326
+ }
1327
+ }
1328
+
1329
+ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1330
+ assert(k % QK4_3 == 0);
1331
+
1332
+ block_q4_3 * restrict y = vy;
1333
+
1334
+ quantize_row_q4_3_reference(x, y, k);
1335
+ }
1336
+
1337
+ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
1338
+ assert(k % QK5_0 == 0);
1339
+ const int nb = k / QK5_0;
1340
+
1341
+ for (int i = 0; i < nb; i++) {
1342
+ float amax = 0.0f; // absolute max
1343
+ float max = 0.0f;
1344
+
1345
+ for (int l = 0; l < QK5_0; l++) {
1346
+ const float v = x[i*QK5_0 + l];
1347
+ if (amax < fabsf(v)) {
1348
+ amax = fabsf(v);
1349
+ max = v;
1350
+ }
1351
+ }
1352
+
1353
+ const float d = max / -16;
1354
+ const float id = d ? 1.0f/d : 0.0f;
1355
+
1356
+ y[i].d = GGML_FP32_TO_FP16(d);
1357
+
1358
+ uint32_t qh = 0;
1359
+
1360
+ for (int l = 0; l < QK5_0; l += 2) {
1361
+ const float v0 = x[i*QK5_0 + l + 0]*id;
1362
+ const float v1 = x[i*QK5_0 + l + 1]*id;
1363
+
1364
+ const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1365
+ const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1366
+
1367
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1368
+
1369
+ // get the 5-th bit and store it in qh at the right position
1370
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1371
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1372
+ }
1373
+
1374
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1375
+ }
1376
+ }
1377
+
1378
+ static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1379
+ assert(k % QK5_0 == 0);
1380
+
1381
+ block_q5_0 * restrict y = vy;
1382
+
1383
+ quantize_row_q5_0_reference(x, y, k);
1384
+ }
1385
+
1386
+ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1387
+ assert(k % QK5_1 == 0);
1388
+ const int nb = k / QK5_1;
1389
+
1390
+ for (int i = 0; i < nb; i++) {
1391
+ float min = FLT_MAX;
1392
+ float max = -FLT_MAX;
1393
+
1394
+ for (int l = 0; l < QK5_1; l++) {
1395
+ const float v = x[i*QK5_1 + l];
1396
+ if (v < min) min = v;
1397
+ if (v > max) max = v;
1398
+ }
1399
+
1400
+ const float d = (max - min) / ((1 << 5) - 1);
1401
+ const float id = d ? 1.0f/d : 0.0f;
1402
+
1403
+ y[i].d = GGML_FP32_TO_FP16(d);
1404
+ y[i].m = GGML_FP32_TO_FP16(min);
1405
+
1406
+ uint32_t qh = 0;
1407
+
1408
+ for (int l = 0; l < QK5_1; l += 2) {
1409
+ const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1410
+ const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1411
+
1412
+ const uint32_t vi0 = (int) (v0 + 0.5f);
1413
+ const uint32_t vi1 = (int) (v1 + 0.5f);
1414
+
1415
+ y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1416
+
1417
+ // get the 5-th bit and store it in qh at the right position
1418
+ qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1419
+ qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1420
+ }
1421
+
1422
+ memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1423
+ }
1424
+ }
1425
+
1426
+ static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1427
+ assert(k % QK5_1 == 0);
1428
+
1429
+ block_q5_1 * restrict y = vy;
1430
+
1431
+ quantize_row_q5_1_reference(x, y, k);
1432
+ }
1433
+
1048
1434
  // reference implementation for deterministic creation of model files
1049
1435
  static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
1436
  assert(k % QK8_0 == 0);
@@ -1064,18 +1450,64 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1064
1450
  y[i].d = d;
1065
1451
 
1066
1452
  for (int l = 0; l < QK8_0; ++l) {
1067
- const float v = x[i*QK8_0 + l]*id;
1068
- y[i].qs[l] = roundf(v);
1453
+ const float v0 = x[i*QK8_0 + l]*id;
1454
+
1455
+ y[i].qs[l] = roundf(v0);
1069
1456
  }
1070
1457
  }
1071
1458
  }
1072
1459
 
1073
1460
  static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1074
1461
  assert(k % QK8_0 == 0);
1075
- const int nb = k / QK8_0;
1076
1462
 
1077
1463
  block_q8_0 * restrict y = vy;
1078
1464
 
1465
+ quantize_row_q8_0_reference(x, y, k);
1466
+ }
1467
+
1468
+ // reference implementation for deterministic creation of model files
1469
+ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) {
1470
+ assert(k % QK8_1 == 0);
1471
+ const int nb = k / QK8_1;
1472
+
1473
+ for (int i = 0; i < nb; i++) {
1474
+ float amax = 0.0f; // absolute max
1475
+
1476
+ for (int l = 0; l < QK8_1; l++) {
1477
+ const float v = x[i*QK8_1 + l];
1478
+ amax = MAX(amax, fabsf(v));
1479
+ }
1480
+
1481
+ const float d = amax / ((1 << 7) - 1);
1482
+ const float id = d ? 1.0f/d : 0.0f;
1483
+
1484
+ y[i].d = d;
1485
+
1486
+ int sum0 = 0;
1487
+ int sum1 = 0;
1488
+
1489
+ for (int l = 0; l < QK8_1/2; ++l) {
1490
+ const float v0 = x[i*QK8_1 + l]*id;
1491
+ const float v1 = x[i*QK8_1 + QK8_1/2 + l]*id;
1492
+
1493
+ y[i].qs[ l] = roundf(v0);
1494
+ y[i].qs[QK8_1/2 + l] = roundf(v1);
1495
+
1496
+ sum0 += y[i].qs[ l];
1497
+ sum1 += y[i].qs[QK8_1/2 + l];
1498
+ }
1499
+
1500
+ y[i].s0 = d * sum0;
1501
+ y[i].s1 = d * sum1;
1502
+ }
1503
+ }
1504
+
1505
+ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
1506
+ assert(k % QK8_1 == 0);
1507
+ const int nb = k / QK8_1;
1508
+
1509
+ block_q8_1 * restrict y = vy;
1510
+
1079
1511
  #if defined(__ARM_NEON)
1080
1512
  for (int i = 0; i < nb; i++) {
1081
1513
  float32x4_t srcv [8];
@@ -1096,7 +1528,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1096
1528
 
1097
1529
  y[i].d = d;
1098
1530
 
1099
- for (int l = 0; l < 8; l++) {
1531
+ int32x4_t accv0 = vdupq_n_s32(0);
1532
+ int32x4_t accv1 = vdupq_n_s32(0);
1533
+
1534
+ // low half
1535
+ for (int l = 0; l < 4; l++) {
1100
1536
  const float32x4_t v = vmulq_n_f32(srcv[l], id);
1101
1537
  const int32x4_t vi = vcvtnq_s32_f32(v);
1102
1538
 
@@ -1104,19 +1540,40 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1104
1540
  y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1105
1541
  y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1106
1542
  y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1543
+
1544
+ accv0 = vaddq_s32(accv0, vi);
1107
1545
  }
1108
- }
1109
- #elif defined(__AVX2__) || defined(__AVX__)
1110
- for (int i = 0; i < nb; i++) {
1111
- // Load elements into 4 AVX vectors
1112
- __m256 v0 = _mm256_loadu_ps( x );
1113
- __m256 v1 = _mm256_loadu_ps( x + 8 );
1114
- __m256 v2 = _mm256_loadu_ps( x + 16 );
1115
- __m256 v3 = _mm256_loadu_ps( x + 24 );
1116
- x += 32;
1117
1546
 
1118
- // Compute max(abs(e)) for the block
1119
- const __m256 signBit = _mm256_set1_ps( -0.0f );
1547
+ // high half
1548
+ for (int l = 4; l < 8; l++) {
1549
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1550
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1551
+
1552
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1553
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1554
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1555
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1556
+
1557
+ accv1 = vaddq_s32(accv1, vi);
1558
+ }
1559
+
1560
+ const int32_t sum0 = vaddvq_s32(accv0);
1561
+ const int32_t sum1 = vaddvq_s32(accv1);
1562
+
1563
+ y[i].s0 = d * sum0;
1564
+ y[i].s1 = d * sum1;
1565
+ }
1566
+ #elif defined(__AVX2__) || defined(__AVX__)
1567
+ for (int i = 0; i < nb; i++) {
1568
+ // Load elements into 4 AVX vectors
1569
+ __m256 v0 = _mm256_loadu_ps( x );
1570
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1571
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1572
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1573
+ x += 32;
1574
+
1575
+ // Compute max(abs(e)) for the block
1576
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1120
1577
  __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1121
1578
  maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1122
1579
  maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
@@ -1152,6 +1609,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1152
1609
  __m256i i3 = _mm256_cvtps_epi32( v3 );
1153
1610
 
1154
1611
  #if defined(__AVX2__)
1612
+ // Compute the sum of the quants and set y[i].s
1613
+ //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1614
+ y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1615
+ y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1616
+
1155
1617
  // Convert int32 to int16
1156
1618
  i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1157
1619
  i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1177,6 +1639,12 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1177
1639
  __m128i ni6 = _mm256_castsi256_si128( i3 );
1178
1640
  __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1179
1641
 
1642
+ // Compute the sum of the quants and set y[i].s
1643
+ const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
1644
+ const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1645
+ y[i].s0 = d * hsum_i32_4(s0);
1646
+ y[i].s1 = d * hsum_i32_4(s1);
1647
+
1180
1648
  // Convert int32 to int16
1181
1649
  ni0 = _mm_packs_epi32( ni0, ni1 );
1182
1650
  ni2 = _mm_packs_epi32( ni2, ni3 );
@@ -1192,7 +1660,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1192
1660
  }
1193
1661
  #else
1194
1662
  // scalar
1195
- quantize_row_q8_0_reference(x, y, k);
1663
+ quantize_row_q8_1_reference(x, y, k);
1196
1664
  #endif
1197
1665
  }
1198
1666
 
@@ -1211,7 +1679,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1211
1679
 
1212
1680
  for (int l = 0; l < QK4_0; l += 32) {
1213
1681
  // Load 32x4-bit integers into 32x8-bit integers
1214
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1682
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1215
1683
 
1216
1684
  // Subtract 8 from the integers
1217
1685
  vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1246,7 +1714,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1246
1714
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1247
1715
 
1248
1716
  // Expand 4-bit qs to 8-bit bytes
1249
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1717
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1250
1718
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1251
1719
 
1252
1720
  // Convert to signed 8-bit integers
@@ -1296,7 +1764,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1296
1764
  for (int l = 0; l < QK4_0; l += 2) {
1297
1765
  const uint8_t vi = pp[l/2];
1298
1766
 
1299
- const int8_t vi0 = vi & 0xf;
1767
+ const int8_t vi0 = vi & 0x0F;
1300
1768
  const int8_t vi1 = vi >> 4;
1301
1769
 
1302
1770
  const float v0 = (vi0 - 8)*d;
@@ -1329,7 +1797,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1329
1797
 
1330
1798
  for (int l = 0; l < QK4_1; l += 32) {
1331
1799
  // Load 32x4-bit integers into 32x8-bit integers
1332
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1800
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1333
1801
 
1334
1802
  // Convert to 16-bit int
1335
1803
  const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1362,7 +1830,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1362
1830
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1363
1831
 
1364
1832
  // Expand 4-bit qs to 8-bit bytes
1365
- const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
1833
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
1366
1834
  const uint8x8_t v1 = vshr_n_u8(v8, 4);
1367
1835
 
1368
1836
  // Interleave and combine
@@ -1404,7 +1872,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1404
1872
  for (int l = 0; l < QK4_1; l += 2) {
1405
1873
  const uint8_t vi = pp[l/2];
1406
1874
 
1407
- const int8_t vi0 = vi & 0xf;
1875
+ const int8_t vi0 = vi & 0x0F;
1408
1876
  const int8_t vi1 = vi >> 4;
1409
1877
 
1410
1878
  const float v0 = vi0*d + m;
@@ -1420,8 +1888,162 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1420
1888
  #endif
1421
1889
  }
1422
1890
 
1423
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1891
+ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1892
+ assert(k % QK4_2 == 0);
1893
+ const int nb = k / QK4_2;
1894
+
1895
+ const block_q4_2 * restrict x = vx;
1896
+
1897
+ for (int i = 0; i < nb; i++) {
1898
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1899
+
1900
+ const uint8_t * restrict pp = x[i].qs;
1901
+
1902
+ for (int l = 0; l < QK4_2; l += 2) {
1903
+ const uint8_t vi = pp[l/2];
1904
+
1905
+ const int8_t vi0 = vi & 0x0F;
1906
+ const int8_t vi1 = vi >> 4;
1907
+
1908
+ const float v0 = (vi0 - 8)*d;
1909
+ const float v1 = (vi1 - 8)*d;
1910
+
1911
+ y[i*QK4_2 + l + 0] = v0;
1912
+ y[i*QK4_2 + l + 1] = v1;
1913
+
1914
+ assert(!isnan(y[i*QK4_2 + l + 0]));
1915
+ assert(!isnan(y[i*QK4_2 + l + 1]));
1916
+ }
1917
+ }
1918
+ }
1919
+
1920
+ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1921
+ assert(k % QK4_3 == 0);
1922
+ const int nb = k / QK4_3;
1923
+
1924
+ const block_q4_3 * restrict x = vx;
1925
+
1926
+ for (int i = 0; i < nb; i++) {
1927
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1928
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1929
+
1930
+ const uint8_t * restrict pp = x[i].qs;
1931
+
1932
+ for (int l = 0; l < QK4_3; l += 2) {
1933
+ const uint8_t vi = pp[l/2];
1934
+
1935
+ const int8_t vi0 = vi & 0x0F;
1936
+ const int8_t vi1 = vi >> 4;
1937
+
1938
+ const float v0 = vi0*d + m;
1939
+ const float v1 = vi1*d + m;
1940
+
1941
+ y[i*QK4_3 + l + 0] = v0;
1942
+ y[i*QK4_3 + l + 1] = v1;
1943
+
1944
+ assert(!isnan(y[i*QK4_3 + l + 0]));
1945
+ assert(!isnan(y[i*QK4_3 + l + 1]));
1946
+ }
1947
+ }
1948
+ }
1949
+
1950
+ static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
1951
+ assert(k % QK5_0 == 0);
1952
+ const int nb = k / QK5_0;
1953
+
1954
+ const block_q5_0 * restrict x = vx;
1955
+
1956
+ for (int i = 0; i < nb; i++) {
1957
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1958
+
1959
+ const uint8_t * restrict pp = x[i].qs;
1960
+
1961
+ uint32_t qh;
1962
+ memcpy(&qh, x[i].qh, sizeof(qh));
1963
+
1964
+ for (int l = 0; l < QK5_0; l += 2) {
1965
+ const uint8_t vi = pp[l/2];
1966
+
1967
+ // extract the 5-th bit from qh
1968
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
1969
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
1970
+
1971
+ const int8_t vi0 = (vi & 0x0F) | vh0;
1972
+ const int8_t vi1 = (vi >> 4) | vh1;
1973
+
1974
+ const float v0 = (vi0 - 16)*d;
1975
+ const float v1 = (vi1 - 16)*d;
1976
+
1977
+ y[i*QK5_0 + l + 0] = v0;
1978
+ y[i*QK5_0 + l + 1] = v1;
1979
+
1980
+ assert(!isnan(y[i*QK5_0 + l + 0]));
1981
+ assert(!isnan(y[i*QK5_0 + l + 1]));
1982
+ }
1983
+ }
1984
+ }
1985
+
1986
+ static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
1987
+ assert(k % QK5_1 == 0);
1988
+ const int nb = k / QK5_1;
1989
+
1990
+ const block_q5_1 * restrict x = vx;
1991
+
1992
+ for (int i = 0; i < nb; i++) {
1993
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1994
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1995
+
1996
+ const uint8_t * restrict pp = x[i].qs;
1997
+
1998
+ uint32_t qh;
1999
+ memcpy(&qh, x[i].qh, sizeof(qh));
2000
+
2001
+ for (int l = 0; l < QK5_1; l += 2) {
2002
+ const uint8_t vi = pp[l/2];
2003
+
2004
+ // extract the 5-th bit from qh
2005
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
2006
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
2007
+
2008
+ const uint8_t vi0 = (vi & 0x0F) | vh0;
2009
+ const uint8_t vi1 = (vi >> 4) | vh1;
2010
+
2011
+ const float v0 = vi0*d + m;
2012
+ const float v1 = vi1*d + m;
2013
+
2014
+ y[i*QK5_1 + l + 0] = v0;
2015
+ y[i*QK5_1 + l + 1] = v1;
2016
+
2017
+ assert(!isnan(y[i*QK5_1 + l + 0]));
2018
+ assert(!isnan(y[i*QK5_1 + l + 1]));
2019
+ }
2020
+ }
2021
+ }
2022
+
2023
+ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
2024
+ assert(k % QK8_0 == 0);
2025
+ const int nb = k / QK8_0;
2026
+
2027
+ const block_q8_0 * restrict x = vx;
2028
+
2029
+ for (int i = 0; i < nb; i++) {
2030
+ const float d = x[i].d;
2031
+
2032
+ const int8_t * restrict pp = x[i].qs;
2033
+
2034
+ for (int l = 0; l < QK8_0; ++l) {
2035
+ y[i*QK8_0 + l] = pp[l]*d;
2036
+ }
2037
+ }
2038
+ }
2039
+
1424
2040
  static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2041
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2042
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2043
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2044
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2045
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
2046
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
2047
 
1426
2048
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
2049
  [GGML_TYPE_Q4_0] = {
@@ -1430,15 +2052,64 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1430
2052
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1431
2053
  .quantize_row_q_dot = quantize_row_q8_0,
1432
2054
  .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
2055
+ .vec_dot_type = GGML_TYPE_Q8_0,
1433
2056
  },
1434
2057
  [GGML_TYPE_Q4_1] = {
1435
2058
  .dequantize_row_q = dequantize_row_q4_1,
1436
2059
  .quantize_row_q = quantize_row_q4_1,
1437
2060
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
- .quantize_row_q_dot = quantize_row_q4_1,
1439
- .vec_dot_q = ggml_vec_dot_q4_1,
2061
+ .quantize_row_q_dot = quantize_row_q8_1,
2062
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
2063
+ .vec_dot_type = GGML_TYPE_Q8_1,
2064
+ },
2065
+ [GGML_TYPE_Q4_2] = {
2066
+ .dequantize_row_q = dequantize_row_q4_2,
2067
+ .quantize_row_q = quantize_row_q4_2,
2068
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
2069
+ .quantize_row_q_dot = quantize_row_q8_0,
2070
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
2071
+ .vec_dot_type = GGML_TYPE_Q8_0,
2072
+ },
2073
+ [GGML_TYPE_Q4_3] = {
2074
+ .dequantize_row_q = dequantize_row_q4_3,
2075
+ .quantize_row_q = quantize_row_q4_3,
2076
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
2077
+ .quantize_row_q_dot = quantize_row_q8_1,
2078
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_1,
2079
+ .vec_dot_type = GGML_TYPE_Q8_1,
2080
+ },
2081
+ [GGML_TYPE_Q5_0] = {
2082
+ .dequantize_row_q = dequantize_row_q5_0,
2083
+ .quantize_row_q = quantize_row_q5_0,
2084
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
2085
+ .quantize_row_q_dot = quantize_row_q8_0,
2086
+ .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
2087
+ .vec_dot_type = GGML_TYPE_Q8_0,
2088
+ },
2089
+ [GGML_TYPE_Q5_1] = {
2090
+ .dequantize_row_q = dequantize_row_q5_1,
2091
+ .quantize_row_q = quantize_row_q5_1,
2092
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
2093
+ .quantize_row_q_dot = quantize_row_q8_1,
2094
+ .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
2095
+ .vec_dot_type = GGML_TYPE_Q8_1,
2096
+ },
2097
+ [GGML_TYPE_Q8_0] = {
2098
+ .dequantize_row_q = dequantize_row_q8_0,
2099
+ .quantize_row_q = quantize_row_q8_0,
2100
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
2101
+ .quantize_row_q_dot = quantize_row_q8_0,
2102
+ .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
2103
+ .vec_dot_type = GGML_TYPE_Q8_0,
2104
+ },
2105
+ [GGML_TYPE_Q8_1] = {
2106
+ .dequantize_row_q = NULL, // TODO
2107
+ .quantize_row_q = quantize_row_q8_1,
2108
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
2109
+ .quantize_row_q_dot = quantize_row_q8_1,
2110
+ .vec_dot_q = NULL, // TODO
2111
+ .vec_dot_type = GGML_TYPE_Q8_1,
1440
2112
  },
1441
- // TODO: GGML_TYPE_Q8_0
1442
2113
  };
1443
2114
 
1444
2115
  // For internal test use
@@ -2004,191 +2675,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
2004
2675
  *s = sumf;
2005
2676
  }
2006
2677
 
2007
- #if __AVX512F__ && QK4_0 == 32
2008
- static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
- // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
- // | :. =_ () [] <> () Zz Yy|
2014
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
- // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
- //
2020
- // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
- // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
- // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
- // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
- #ifdef __AVX512VBMI__
2025
- const __m512i byte_perm = _mm512_set_epi8(
2026
- 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
- 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
- 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
- 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
- );
2031
- const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
- // After applying VPERMB, `permuted` looks like this:
2033
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
- // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
- // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
- #else
2043
- const __m512i word_perm = _mm512_set_epi16(
2044
- 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
- 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
- );
2047
- const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
- // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
- // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
- // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
- #endif
2052
-
2053
- // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
- const __mmask32 shift_mask = 0xaaaaaaaa;
2055
- const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
- // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
- // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
- // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
-
2067
- // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
- const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
- return _mm512_and_si512( low_nibble_mask, shifted );
2070
- // The final result looks like this:
2071
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
- // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
- // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
- }
2081
-
2082
- static inline __m512 dot_q4_0_twoblocks_avx512(
2083
- __m512 acc,
2084
- const block_q4_0 * restrict x,
2085
- const block_q4_0 * restrict y,
2086
- int i
2087
- ) {
2088
- // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
- // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
- // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
- // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
- const __mmask8 load_mask = 0x1f;
2093
- const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
- const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
-
2096
- // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
- // blocks_0_float
2101
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
- // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
- // blocks_1_float
2105
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
- // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
- const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
- const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
- // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
- // random data, which might very well underflow. At least on Intel, this leads
2112
- // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
- // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
- // (and ggml can't assume that you do)...
2115
- const __mmask16 scale_mul_mask = 0x21;
2116
- #ifdef __clang__
2117
- // ...however, clang decides to optimize the multiplication mask away:
2118
- // https://godbolt.org/z/P8PqdsfvW
2119
- // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
- __m512i scales;
2121
- __asm__(
2122
- "vmulps %1, %2, %0%{%3%}"
2123
- : "=v" ( scales )
2124
- : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
- );
2126
- #else
2127
- const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
- #endif
2129
- const __m512i scale_perm = _mm512_set_epi32(
2130
- 5, 5, 5, 5, 5, 5, 5, 5,
2131
- 0, 0, 0, 0, 0, 0, 0, 0
2132
- );
2133
- const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
- // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
- // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
-
2141
- const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
- const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
-
2144
- // Now we want to compute dot products of 4-element byte vectors and store them in
2145
- // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
- // +----+----+----+----+
2147
- // ... | 03 | 02 | 01 | 00 |
2148
- // +----+----+----+----+
2149
- // bytes_0
2150
- // +----+----+----+----+
2151
- // ... | D | C | B | A |
2152
- // +----+----+----+----+
2153
- // bytes_1
2154
- // +----+----+----+----+
2155
- // ... | H | G | F | E |
2156
- // +----+----+----+----+
2157
- // final_res_int
2158
- // +----+----+----+----+
2159
- // ... | A*E+B*F+C*G+D*H |
2160
- // +----+----+----+----+
2161
- const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
- const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
-
2164
- #ifdef __AVX512VNNI__
2165
- // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
- // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
- // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
- // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
- // which means we only need 2 instructions.
2170
- const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
- const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
- const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
- const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
- #else
2175
- // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
- // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
- // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
- // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
- const __m512i one = _mm512_set1_epi16( 1 );
2180
- const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
- const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
- const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
- const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
- #endif
2185
-
2186
- // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
- const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
- return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
2189
- }
2190
- #endif
2191
-
2192
2678
  inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
2193
2679
  ggml_float sumf = 0.0;
2194
2680
 
@@ -2225,67 +2711,62 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2225
2711
  *s = sumf;
2226
2712
  }
2227
2713
 
2228
- static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK4_0;
2714
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2715
+ const int nb = n / QK8_0;
2230
2716
 
2231
- assert(n % QK4_0 == 0);
2717
+ assert(n % QK8_0 == 0);
2232
2718
  assert(nb % 2 == 0);
2233
2719
 
2234
2720
  const block_q4_0 * restrict x = vx;
2235
- const block_q4_0 * restrict y = vy;
2236
-
2237
- float sumf = 0.0;
2721
+ const block_q8_0 * restrict y = vy;
2238
2722
 
2239
2723
  #if defined(__ARM_NEON)
2240
- float sum0 = 0.0f;
2241
- float sum1 = 0.0f;
2724
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2725
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2242
2726
 
2243
2727
  for (int i = 0; i < nb; i += 2) {
2244
2728
  const block_q4_0 * restrict x0 = &x[i + 0];
2245
- const block_q4_0 * restrict y0 = &y[i + 0];
2246
2729
  const block_q4_0 * restrict x1 = &x[i + 1];
2247
- const block_q4_0 * restrict y1 = &y[i + 1];
2730
+ const block_q8_0 * restrict y0 = &y[i + 0];
2731
+ const block_q8_0 * restrict y1 = &y[i + 1];
2248
2732
 
2249
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2250
- const int8x16_t s8b = vdupq_n_s8(0x8);
2733
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2734
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2251
2735
 
2252
2736
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2253
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2254
2737
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2255
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2256
2738
 
2257
2739
  // 4-bit -> 8-bit
2258
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
2259
- const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
2740
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2260
2741
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2261
- const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
2262
-
2263
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
2264
- const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
2742
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2265
2743
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2266
- const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
2267
2744
 
2268
2745
  // sub 8
2269
2746
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2270
- const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
2271
2747
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2272
- const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
2273
-
2274
2748
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2275
- const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
2276
2749
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2277
- const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
2750
+
2751
+ // load y
2752
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2753
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2754
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2755
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2756
+
2757
+ // interleave
2758
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2759
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2760
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2761
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2278
2762
 
2279
2763
  #if defined(__ARM_FEATURE_DOTPROD)
2280
2764
  // dot product into int32x4_t
2281
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2282
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2765
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2766
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2283
2767
 
2284
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2285
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2286
-
2287
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2288
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2768
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2769
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2289
2770
  #else
2290
2771
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2291
2772
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2297,125 +2778,41 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2297
2778
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2298
2779
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2299
2780
 
2300
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2301
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2302
-
2303
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2304
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2305
-
2306
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2307
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2781
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2782
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2783
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2784
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2308
2785
 
2309
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2310
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2786
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2787
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2311
2788
  #endif
2312
2789
  }
2313
2790
 
2314
- sumf = sum0 + sum1;
2315
- #elif defined(__AVX512F__)
2791
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2792
+ #elif defined(__AVX2__)
2316
2793
  // Initialize accumulator with zeros
2317
- __m512 acc0 = _mm512_setzero_ps();
2318
- __m512 acc1 = _mm512_setzero_ps();
2319
-
2320
- const int superblock_size = 16;
2321
-
2322
- const int superblock_count = nb / superblock_size;
2794
+ __m256 acc = _mm256_setzero_ps();
2323
2795
 
2324
- for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
2325
- int i = superblock_ix * superblock_size;
2796
+ // Main loop
2797
+ for (int i = 0; i < nb; ++i) {
2798
+ /* Compute combined scale for the block */
2799
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2326
2800
 
2327
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
2335
- }
2801
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2336
2802
 
2337
- // Remainders
2338
- for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
2340
- }
2803
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2804
+ const __m256i off = _mm256_set1_epi8( 8 );
2805
+ bx = _mm256_sub_epi8( bx, off );
2341
2806
 
2342
- // Horizontal sum of all lanes of the accumulator
2343
- sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
2344
- #elif defined(__AVX2__)
2345
- // Initialize accumulator with zeros
2346
- __m256 acc = _mm256_setzero_ps();
2807
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2347
2808
 
2348
- /* Prepare the constants we will need during execution */
2349
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
2350
- const __m256i offset_8 = _mm256_set1_epi16( 8 );
2809
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2351
2810
 
2352
- #define UNROLL_COUNT 8
2353
- // make sure we only unroll multiples of the block count
2354
- assert(nb % UNROLL_COUNT == 0);
2811
+ /* Multiply q with scale and accumulate */
2812
+ acc = _mm256_fmadd_ps( d, q, acc );
2813
+ }
2355
2814
 
2356
- // Main loop
2357
- for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2358
- // This loop will be unrolled by the compiler
2359
- for (int u=0;u<UNROLL_COUNT;u++) {
2360
- /* Compute combined scale for the block */
2361
- const __m256 scale = _mm256_mul_ps(
2362
- _mm256_broadcast_ss( &x[i+u].d ),
2363
- _mm256_broadcast_ss( &y[i+u].d ) );
2364
-
2365
- /* get input from x
2366
- Input: 32 Nibbles (16 bytes) at *x[i+u]
2367
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2368
-
2369
- /* Load 16 bytes from memory */
2370
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2371
- /* Expand bytes into uint16_t values */
2372
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2373
- /* Unpack values into individual bytes */
2374
- __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2375
- const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2376
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2377
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2378
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2379
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2380
-
2381
- /* get input from y
2382
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2383
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2384
-
2385
- /* Load 16 bytes from memory */
2386
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2387
- /* Expand bytes into uint16_t values */
2388
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2389
- /* Unpack values into individual bytes */
2390
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2391
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2392
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2393
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2394
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2395
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2396
-
2397
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2398
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2399
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2400
-
2401
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2402
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2403
-
2404
- /* Convert to vectore of 8 int32_t to 8 floats */
2405
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2406
-
2407
- /* Multiply q with scale and accumulate */
2408
- acc = _mm256_fmadd_ps( scale, q, acc );
2409
- }
2410
- }
2411
-
2412
- // Return horizontal sum of the acc vector
2413
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2414
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2415
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2416
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2417
-
2418
- sumf = _mm_cvtss_f32( res );
2815
+ *s = hsum_float_8(acc);
2419
2816
  #elif defined(__AVX__)
2420
2817
  // Initialize accumulator with zeros
2421
2818
  __m256 acc = _mm256_setzero_ps();
@@ -2428,13 +2825,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2428
2825
  __m128i i32[2];
2429
2826
  for (int j = 0; j < 2; ++j) {
2430
2827
  // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2431
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2432
- __m128i by = bytesFromNibbles( y[i].qs + 8*j );
2828
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2829
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2433
2830
 
2434
2831
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2435
2832
  const __m128i off = _mm_set1_epi8( 8 );
2436
2833
  bx = _mm_sub_epi8( bx, off );
2437
- by = _mm_sub_epi8( by, off );
2438
2834
 
2439
2835
  // Get absolute values of x vectors
2440
2836
  const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2445,516 +2841,833 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2445
2841
  // Perform multiplication and create 16-bit values
2446
2842
  const __m128i dot = _mm_maddubs_epi16(ax, sy);
2447
2843
 
2448
- const __m128i ones = _mm_set1_epi16(1);
2449
- i32[j] = _mm_madd_epi16(ones, dot);
2450
- }
2844
+ const __m128i ones = _mm_set1_epi16(1);
2845
+ i32[j] = _mm_madd_epi16(ones, dot);
2846
+ }
2847
+
2848
+ // Convert int32_t to float
2849
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2850
+ // Apply the scale, and accumulate
2851
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2852
+ }
2853
+
2854
+ *s = hsum_float_8(acc);
2855
+ #else
2856
+ // scalar
2857
+ float sumf = 0.0;
2858
+ for (int i = 0; i < nb; i++) {
2859
+ const float d0 = x[i].d;
2860
+ const float d1 = y[i].d;
2861
+
2862
+ const uint8_t * restrict p0 = x[i].qs;
2863
+ const int8_t * restrict p1 = y[i].qs;
2864
+
2865
+ int sumi = 0;
2866
+ for (int j = 0; j < QK8_0/2; j++) {
2867
+ const uint8_t v0 = p0[j];
2868
+
2869
+ const int i0 = (int8_t) (v0 & 0x0F) - 8;
2870
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2871
+
2872
+ const int i2 = p1[2*j + 0];
2873
+ const int i3 = p1[2*j + 1];
2874
+
2875
+ sumi += i0*i2 + i1*i3;
2876
+ }
2877
+ sumf += d0*d1*sumi;
2878
+ }
2879
+ *s = sumf;
2880
+ #endif
2881
+ }
2882
+
2883
+ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2884
+ const int nb = n / QK8_1;
2885
+
2886
+ assert(n % QK8_1 == 0);
2887
+ assert(nb % 2 == 0);
2888
+
2889
+ const block_q4_1 * restrict x = vx;
2890
+ const block_q8_1 * restrict y = vy;
2891
+
2892
+ // TODO: add AVX / WASM SIMD / etc
2893
+ #if defined(__ARM_NEON)
2894
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2895
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2896
+
2897
+ float summs = 0;
2898
+
2899
+ for (int i = 0; i < nb; i += 2) {
2900
+ const block_q4_1 * restrict x0 = &x[i + 0];
2901
+ const block_q4_1 * restrict x1 = &x[i + 1];
2902
+ const block_q8_1 * restrict y0 = &y[i + 0];
2903
+ const block_q8_1 * restrict y1 = &y[i + 1];
2904
+
2905
+ summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2906
+
2907
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
2908
+
2909
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2910
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2911
+
2912
+ // 4-bit -> 8-bit
2913
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2914
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2915
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2916
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2917
+
2918
+ // interleave
2919
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2920
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2921
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2922
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2923
+
2924
+ // load y
2925
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2926
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2927
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2928
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2929
+
2930
+ #if defined(__ARM_FEATURE_DOTPROD)
2931
+ // dot product into int32x4_t
2932
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2933
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
2934
+
2935
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2936
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2937
+ #else
2938
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2939
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2940
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2941
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2942
+
2943
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2944
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2945
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2946
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2947
+
2948
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2949
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2950
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2951
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2952
+
2953
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2954
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2955
+ #endif
2956
+ }
2957
+
2958
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2959
+ #elif defined(__AVX2__)
2960
+ // Initialize accumulator with zeros
2961
+ __m256 acc = _mm256_setzero_ps();
2962
+
2963
+ float summs = 0;
2964
+
2965
+ // Main loop
2966
+ for (int i = 0; i < nb; ++i) {
2967
+ const float * d0 = &x[i].d;
2968
+ const float * d1 = &y[i].d;
2969
+
2970
+ summs += x[i].m * (y[i].s0 + y[i].s1);
2971
+
2972
+ const __m256 d0v = _mm256_broadcast_ss( d0 );
2973
+ const __m256 d1v = _mm256_broadcast_ss( d1 );
2974
+
2975
+ // Compute combined scales
2976
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2977
+
2978
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2979
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2980
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2981
+
2982
+ const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2983
+
2984
+ // Accumulate d0*d1*x*y
2985
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
2986
+ }
2987
+
2988
+ *s = hsum_float_8(acc) + summs;
2989
+ #else
2990
+ // scalar
2991
+ float sumf = 0.0;
2992
+ for (int i = 0; i < nb; i++) {
2993
+ const float d0 = x[i].d;
2994
+ const float m0 = x[i].m;
2995
+ const float d1 = y[i].d;
2996
+
2997
+ const uint8_t * restrict p0 = x[i].qs;
2998
+ const int8_t * restrict p1 = y[i].qs;
2999
+
3000
+ // TODO: this is very slow ..
3001
+ for (int j = 0; j < QK8_1/2; j++) {
3002
+ const uint8_t v0 = p0[j];
3003
+
3004
+ const float f0 = d0*(v0 & 0x0F) + m0;
3005
+ const float f1 = d0*(v0 >> 4) + m0;
3006
+
3007
+ const float f2 = d1*p1[2*j + 0];
3008
+ const float f3 = d1*p1[2*j + 1];
3009
+
3010
+ sumf += f0*f2 + f1*f3;
3011
+ }
3012
+ }
3013
+ *s = sumf;
3014
+ #endif
3015
+ }
3016
+
3017
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3018
+ const int nb = n / QK8_0;
3019
+
3020
+ assert(n % QK8_0 == 0);
3021
+ assert(nb % 2 == 0);
3022
+ assert(QK8_0 == 2*QK4_2);
3023
+
3024
+ const block_q4_2 * restrict x = vx;
3025
+ const block_q8_0 * restrict y = vy;
3026
+
3027
+ #if defined(__ARM_NEON)
3028
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3029
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3030
+
3031
+ for (int i = 0; i < nb; i += 2) {
3032
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
3033
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
3034
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
3035
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
3036
+
3037
+ const block_q8_0 * restrict y0 = &y[i + 0];
3038
+ const block_q8_0 * restrict y1 = &y[i + 1];
3039
+
3040
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3041
+ const int8x16_t s8b = vdupq_n_s8(0x8);
3042
+
3043
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3044
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
3045
+
3046
+ // 4-bit -> 8-bit
3047
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
3048
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3049
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3050
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
3051
+
3052
+ // sub 8
3053
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
3054
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
3055
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
3056
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3057
+
3058
+ // interleave
3059
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
3060
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
3061
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
3062
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
3063
+
3064
+ // load y
3065
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3066
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3067
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
3068
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3069
+
3070
+ #if defined(__ARM_FEATURE_DOTPROD)
3071
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3072
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
3073
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3074
+
3075
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3076
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
3077
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3078
+ #else
3079
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3080
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3081
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3082
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3083
+
3084
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3085
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3086
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3087
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3088
+
3089
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3090
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3091
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3092
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3093
+
3094
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3095
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3096
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3097
+
3098
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3099
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3100
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3101
+ #endif
3102
+ }
3103
+
3104
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3105
+ #elif defined(__AVX2__)
3106
+ // Initialize accumulator with zeros
3107
+ __m256 acc = _mm256_setzero_ps();
3108
+
3109
+ // Main loop
3110
+ for (int i = 0; i < nb; i++) {
3111
+ /* Compute combined scale for the block */
3112
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3113
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3114
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
3115
+
3116
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3117
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3118
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
3119
+
3120
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3121
+ const __m256i off = _mm256_set1_epi8(8);
3122
+ bx = _mm256_sub_epi8(bx, off);
3123
+
3124
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3125
+
3126
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3127
+
3128
+ /* Multiply q with scale and accumulate */
3129
+ acc = _mm256_fmadd_ps(d, q, acc);
3130
+ }
3131
+
3132
+ *s = hsum_float_8(acc);
3133
+ #else
3134
+ // scalar
3135
+ float sumf = 0.0;
3136
+ for (int i = 0; i < nb; i++) {
3137
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3138
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3139
+ const int8_t * restrict y0 = y[i].qs;
3140
+
3141
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3142
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3143
+
3144
+ int sumi_0 = 0;
3145
+ int sumi_1 = 0;
3146
+
3147
+ for (int j = 0; j < QK8_0/4; j++) {
3148
+ const uint8_t v0 = x0[j];
3149
+ const uint8_t v1 = x1[j];
3150
+
3151
+ const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
3152
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
3153
+
3154
+ const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
3155
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
3156
+
3157
+ const int i2_0 = y0[2*j + 0];
3158
+ const int i3_0 = y0[2*j + 1];
3159
+
3160
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
3161
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
3162
+
3163
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
3164
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
3165
+ }
3166
+
3167
+ sumf += (d0 * y[i].d) * sumi_0;
3168
+ sumf += (d1 * y[i].d) * sumi_1;
3169
+ }
3170
+ *s = sumf;
3171
+ #endif
3172
+ }
3173
+
3174
+ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3175
+ const int nb = n / QK8_1;
3176
+
3177
+ assert(n % QK8_1 == 0);
3178
+ assert(nb % 2 == 0);
3179
+ assert(QK8_1 == 2*QK4_3);
3180
+
3181
+ const block_q4_3 * restrict x = vx;
3182
+ const block_q8_1 * restrict y = vy;
3183
+
3184
+ #if defined(__ARM_NEON)
3185
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3186
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
3187
+
3188
+ float summs0 = 0.0f;
3189
+ float summs1 = 0.0f;
3190
+
3191
+ for (int i = 0; i < nb; ++i) {
3192
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
3193
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
3194
+
3195
+ const block_q8_1 * restrict y0 = &y[i + 0];
3196
+
3197
+ summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
3198
+ summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
3199
+
3200
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3201
+
3202
+ // 4-bit -> 8-bit
3203
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
3204
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3205
+
3206
+ // interleave
3207
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
3208
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3209
+
3210
+ // load y
3211
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
3212
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3213
+
3214
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
3215
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
3216
+
3217
+ #if defined(__ARM_FEATURE_DOTPROD)
3218
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
3219
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
3220
+ #else
3221
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3222
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3223
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3224
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3225
+
3226
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3227
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3228
+
3229
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
3230
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
3231
+ #endif
3232
+ }
3233
+
3234
+ *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
3235
+ #elif defined(__AVX2__)
3236
+ // Initialize accumulator with zeros
3237
+ __m256 acc = _mm256_setzero_ps();
3238
+ float summs = 0.0f;
3239
+
3240
+ // Main loop
3241
+ for (int i = 0; i < nb; i++) {
3242
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
3243
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
3244
+ const __m256 dx = _mm256_set_m128(d1, d0);
3245
+
3246
+ summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
3247
+ + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
3248
+
3249
+ const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
3250
+ const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
3251
+ const __m256i bx = _mm256_set_m128i(bx1, bx0);
3252
+
3253
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3254
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
3255
+
3256
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
3257
+
3258
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
3259
+ }
3260
+
3261
+ *s = hsum_float_8(acc) + summs;
3262
+ #else
3263
+ // scalar
3264
+ float sumf = 0.0;
3265
+ for (int i = 0; i < nb; i++) {
3266
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3267
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3268
+ const int8_t * restrict y0 = y[i].qs;
3269
+
3270
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3271
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3272
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3273
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3274
+
3275
+ int sxy_0 = 0;
3276
+ int sxy_1 = 0;
3277
+
3278
+ for (int j = 0; j < QK8_1/4; j++) {
3279
+ const uint8_t v0 = x0[j];
3280
+ const uint8_t v1 = x1[j];
3281
+
3282
+ const int x0_0 = v0 & 0x0F;
3283
+ const int x1_0 = v0 >> 4;
3284
+
3285
+ const int x0_1 = v1 & 0x0F;
3286
+ const int x1_1 = v1 >> 4;
3287
+
3288
+ const int y0_0 = y0[2*j + 0];
3289
+ const int y1_0 = y0[2*j + 1];
3290
+
3291
+ const int y0_1 = y0[2*(j + QK8_1/4) + 0];
3292
+ const int y1_1 = y0[2*(j + QK8_1/4) + 1];
3293
+
3294
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3295
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3296
+ }
3297
+
3298
+ sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
3299
+ }
3300
+ *s = sumf;
3301
+ #endif
3302
+ }
3303
+
3304
+ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3305
+ const int nb = n / QK8_0;
3306
+
3307
+ assert(n % QK8_0 == 0);
3308
+ assert(nb % 2 == 0);
3309
+ assert(QK8_0 == QK5_0);
3310
+
3311
+ const block_q5_0 * restrict x = vx;
3312
+ const block_q8_0 * restrict y = vy;
2451
3313
 
2452
- // Convert int32_t to float
2453
- __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2454
- // Apply the scale, and accumulate
2455
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2456
- }
3314
+ #if defined(__ARM_NEON)
3315
+ float32x4_t sumv = vdupq_n_f32(0.0f);
2457
3316
 
2458
- // Return horizontal sum of the acc vector
2459
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2460
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2461
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2462
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
3317
+ uint64_t tmp[4];
2463
3318
 
2464
- sumf = _mm_cvtss_f32( res );
2465
- #elif defined(__wasm_simd128__)
2466
- // wasm simd
2467
- float sum0 = 0.0f;
2468
- float sum1 = 0.0f;
3319
+ for (int i = 0; i < nb; ++i) {
3320
+ const block_q5_0 * restrict x0 = &x[i];
3321
+ const block_q8_0 * restrict y0 = &y[i];
2469
3322
 
2470
- for (int i = 0; i < nb; i += 2) {
2471
- const block_q4_0 * restrict x0 = &x[i + 0];
2472
- const block_q4_0 * restrict y0 = &y[i + 0];
2473
- const block_q4_0 * restrict x1 = &x[i + 1];
2474
- const block_q4_0 * restrict y1 = &y[i + 1];
3323
+ const uint8x16_t m4b = vdupq_n_u8(0x0F);
3324
+ const int8x16_t s16b = vdupq_n_s8(0x10);
2475
3325
 
2476
- const v128_t m4b = wasm_u8x16_splat(0xf);
2477
- const v128_t s8b = wasm_i8x16_splat(0x8);
3326
+ // extract the 5th bit
3327
+ uint32_t qh;
3328
+ memcpy(&qh, x0->qh, sizeof(qh));
2478
3329
 
2479
- const v128_t v0_0 = wasm_v128_load(x0->qs);
2480
- const v128_t v0_1 = wasm_v128_load(y0->qs);
2481
- const v128_t v1_0 = wasm_v128_load(x1->qs);
2482
- const v128_t v1_1 = wasm_v128_load(y1->qs);
3330
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3331
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3332
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3333
+ tmp[3] = table_b2b_u[(qh >> 24) ];
2483
3334
 
2484
- // 4-bit -> 8-bit
2485
- const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2486
- const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
3335
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3336
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
2487
3337
 
2488
- const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2489
- const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
3338
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
2490
3339
 
2491
- const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2492
- const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
3340
+ // 4-bit -> 8-bit
3341
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
3342
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
2493
3343
 
2494
- const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2495
- const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
3344
+ // interleave
3345
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3346
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2496
3347
 
2497
- // sub 8
2498
- const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2499
- const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
3348
+ // add high bit and sub 16
3349
+ const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
3350
+ const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
2500
3351
 
2501
- const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2502
- const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
3352
+ // load y
3353
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3354
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
2503
3355
 
2504
- const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2505
- const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
3356
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2506
3357
 
2507
- const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2508
- const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
3358
+ #if defined(__ARM_FEATURE_DOTPROD)
3359
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3360
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3361
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3362
+ #else
3363
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3364
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3365
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3366
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2509
3367
 
2510
- // dot product into int16x8_t
2511
- const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
2512
- const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
3368
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3369
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2513
3370
 
2514
- const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
2515
- const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
3371
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3372
+ #endif
3373
+ }
2516
3374
 
2517
- const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
2518
- const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
3375
+ *s = vaddvq_f32(sumv);
3376
+ #elif defined(__AVX2__)
3377
+ // Initialize accumulator with zeros
3378
+ __m256 acc = _mm256_setzero_ps();
2519
3379
 
2520
- const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
2521
- const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
3380
+ // Main loop
3381
+ for (int i = 0; i < nb; i++) {
3382
+ /* Compute combined scale for the block */
3383
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
2522
3384
 
2523
- const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
2524
- const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
3385
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3386
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3387
+ bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
3388
+ bx = _mm256_or_si256(bx, bxhi);
2525
3389
 
2526
- const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
2527
- const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
3390
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2528
3391
 
2529
- const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
2530
- const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
3392
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2531
3393
 
2532
- sum0 += x0->d * y0->d * (
2533
- wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
2534
- wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
2535
- wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
2536
- wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
2537
- sum1 += x1->d * y1->d * (
2538
- wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
2539
- wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
2540
- wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
2541
- wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
3394
+ /* Multiply q with scale and accumulate */
3395
+ acc = _mm256_fmadd_ps(d, q, acc);
2542
3396
  }
2543
3397
 
2544
- sumf = sum0 + sum1;
3398
+ *s = hsum_float_8(acc);
2545
3399
  #else
2546
3400
  // scalar
3401
+ float sumf = 0.0;
2547
3402
  for (int i = 0; i < nb; i++) {
2548
- const float d0 = x[i].d;
2549
- const float d1 = y[i].d;
3403
+ const uint8_t * restrict x0 = x[i].qs;
3404
+ const int8_t * restrict y0 = y[i].qs;
2550
3405
 
2551
- const uint8_t * restrict p0 = x[i].qs;
2552
- const uint8_t * restrict p1 = y[i].qs;
3406
+ uint32_t qh;
3407
+ memcpy(&qh, x[i].qh, sizeof(qh));
2553
3408
 
2554
- int sumi = 0;
2555
- for (int j = 0; j < QK4_0/2; j++) {
2556
- const uint8_t v0 = p0[j];
2557
- const uint8_t v1 = p1[j];
3409
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3410
+
3411
+ int sxy = 0;
2558
3412
 
2559
- const int i0 = (v0 & 0xf) - 8;
2560
- const int i1 = (v0 >> 4) - 8;
3413
+ for (int j = 0; j < QK8_0/2; j++) {
3414
+ const uint8_t v0 = x0[j];
2561
3415
 
2562
- const int i2 = (v1 & 0xf) - 8;
2563
- const int i3 = (v1 >> 4) - 8;
3416
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3417
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
2564
3418
 
2565
- sumi += i0*i2 + i1*i3;
3419
+ const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
3420
+ const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
3421
+
3422
+ const int y0_0 = y0[2*j + 0];
3423
+ const int y1_0 = y0[2*j + 1];
3424
+
3425
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2566
3426
  }
2567
- sumf += d0 * d1 * sumi;
2568
- }
2569
- #endif
2570
3427
 
3428
+ sumf += (d*sxy)*y[i].d;
3429
+ }
2571
3430
  *s = sumf;
3431
+ #endif
2572
3432
  }
2573
3433
 
2574
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2575
- const int nb = n / QK4_1;
3434
+ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3435
+ const int nb = n / QK8_1;
2576
3436
 
2577
- const block_q4_1 * restrict x = vx;
2578
- const block_q4_1 * restrict y = vy;
2579
-
2580
- float sumf = 0.0;
3437
+ assert(n % QK8_1 == 0);
3438
+ assert(nb % 2 == 0);
3439
+ assert(QK8_1 == QK5_1);
2581
3440
 
2582
- #if defined(__AVX2__)
2583
- // Initialize accumulator with zeros
2584
- __m256 acc = _mm256_setzero_ps();
2585
- // Accumulator for constant offsets
2586
- float acc_offset = 0.0f;
3441
+ const block_q5_1 * restrict x = vx;
3442
+ const block_q8_1 * restrict y = vy;
2587
3443
 
2588
- // Main loop
2589
- for (int i = 0; i < nb; ++i) {
2590
- const float * d0 = &x[i].d;
2591
- const float * d1 = &y[i].d;
3444
+ #if defined(__ARM_NEON)
3445
+ float32x4_t sumv = vdupq_n_f32(0.0f);
2592
3446
 
2593
- const float * m0 = &x[i].m;
2594
- const float * m1 = &y[i].m;
3447
+ float summs = 0.0f;
2595
3448
 
2596
- const __m256 d0v = _mm256_broadcast_ss( d0 );
2597
- const __m256 d1v = _mm256_broadcast_ss( d1 );
2598
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2599
- const __m256 m1v = _mm256_broadcast_ss( m1 );
3449
+ uint64_t tmp[4];
2600
3450
 
2601
- // Compute combined scale for the block
2602
- const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
3451
+ for (int i = 0; i < nb; ++i) {
3452
+ const block_q5_1 * restrict x0 = &x[i];
3453
+ const block_q8_1 * restrict y0 = &y[i];
2603
3454
 
2604
- // Compute cross scales for the block
2605
- const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2606
- const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2607
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
3455
+ summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2608
3456
 
2609
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2610
- __m256i bx = bytesFromNibbles( x[i].qs );
2611
- __m256i by = bytesFromNibbles( y[i].qs );
2612
-
2613
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2614
-
2615
- // Sign-extend first 16 signed bytes into int16_t
2616
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2617
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2618
- // Compute products of int16_t integers, add pairwise
2619
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2620
-
2621
- // Sign-extend last 16 signed bytes into int16_t vectors
2622
- __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2623
- __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2624
- // Accumulate products of int16_t integers
2625
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2626
-
2627
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2628
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2629
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2630
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2631
- __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2632
- __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2633
- __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2634
- __m256 sums = _mm256_cvtepi32_ps( sumsi );
3457
+ // extract the 5th bit
3458
+ uint32_t qh;
3459
+ memcpy(&qh, x0->qh, sizeof(qh));
2635
3460
 
2636
- // Convert int32_t to float
2637
- __m256 p = _mm256_cvtepi32_ps( i32 );
2638
- // Apply the scale, and accumulate
2639
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2640
- acc = _mm256_fmadd_ps( scale_01, p, acc );
2641
- acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2642
- // acc_offset += m0*m1 (for each entry in the block)
2643
- acc_offset += (*m0)*(*m1);
2644
- }
3461
+ tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3462
+ tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3463
+ tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3464
+ tmp[3] = table_b2b_u[(qh >> 24) ];
2645
3465
 
2646
- // Return horizontal sum of the acc vector
2647
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2648
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2649
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2650
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
3466
+ const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
3467
+ const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
2651
3468
 
2652
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2653
- #elif defined(__ARM_NEON)
2654
- float sum00 = 0.0f;
2655
- float sum01 = 0.0f;
2656
- float sum10 = 0.0f;
2657
- float sum11 = 0.0f;
3469
+ const uint8x16_t v0 = vld1q_u8(x0->qs);
2658
3470
 
2659
- for (int i = 0; i < nb; i += 2) {
2660
- const block_q4_1 * restrict x0 = &x[i + 0];
2661
- const block_q4_1 * restrict y0 = &y[i + 0];
2662
- const block_q4_1 * restrict x1 = &x[i + 1];
2663
- const block_q4_1 * restrict y1 = &y[i + 1];
3471
+ // 4-bit -> 8-bit
3472
+ const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
3473
+ const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
2664
3474
 
2665
- const uint8x16_t m4b = vdupq_n_u8(0xf);
3475
+ // interleave
3476
+ const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
3477
+ const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2666
3478
 
2667
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2668
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2669
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2670
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
3479
+ // add
3480
+ const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
3481
+ const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
2671
3482
 
2672
- // 4-bit -> 8-bit
2673
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2674
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2675
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2676
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
3483
+ // load y
3484
+ const int8x16_t v1l = vld1q_s8(y0->qs);
3485
+ const int8x16_t v1h = vld1q_s8(y0->qs + 16);
2677
3486
 
2678
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2679
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2680
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2681
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
3487
+ const float x0d = GGML_FP16_TO_FP32(x0->d);
2682
3488
 
2683
- sum00 += x0->m*y0->m;
2684
- sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
- sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
3489
+ #if defined(__ARM_FEATURE_DOTPROD)
3490
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
3491
+ vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
3492
+ vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
3493
+ #else
3494
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
3495
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
3496
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
3497
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2686
3498
 
2687
- sum00 += x1->m*y1->m;
2688
- sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
- sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
3499
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3500
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2690
3501
 
2691
- #if defined(__ARM_FEATURE_DOTPROD)
2692
- // dot product into int32x4_t
2693
- uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2694
- uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
3502
+ sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
3503
+ #endif
3504
+ }
2695
3505
 
2696
- p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2697
- p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
3506
+ *s = vaddvq_f32(sumv) + summs;
3507
+ #elif defined(__AVX2__)
3508
+ // Initialize accumulator with zeros
3509
+ __m256 acc = _mm256_setzero_ps();
3510
+ float summs = 0.0f;
2698
3511
 
2699
- sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2700
- sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2701
- #else
2702
- const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2703
- const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2704
- const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2705
- const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
3512
+ // Main loop
3513
+ for (int i = 0; i < nb; i++) {
3514
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2706
3515
 
2707
- const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2708
- const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2709
- const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2710
- const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
3516
+ summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2711
3517
 
2712
- const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2713
- const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
3518
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
3519
+ __m256i bxhi = bytes_from_bits_32(x[i].qh);
3520
+ bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
3521
+ bx = _mm256_or_si256(bx, bxhi);
2714
3522
 
2715
- const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2716
- const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
3523
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
3524
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2717
3525
 
2718
- const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2719
- const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
3526
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2720
3527
 
2721
- sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2722
- sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2723
- #endif
3528
+ acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
2724
3529
  }
2725
3530
 
2726
- sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
3531
+ *s = hsum_float_8(acc) + summs;
2727
3532
  #else
2728
- // scalar
3533
+ float sumf = 0.0;
3534
+
2729
3535
  for (int i = 0; i < nb; i++) {
2730
- const float d0 = x[i].d;
2731
- const float d1 = y[i].d;
3536
+ const uint8_t * restrict x0 = x[i].qs;
3537
+ const int8_t * restrict y0 = y[i].qs;
2732
3538
 
2733
- const float m0 = x[i].m;
2734
- const float m1 = y[i].m;
3539
+ uint32_t qh;
3540
+ memcpy(&qh, x[i].qh, sizeof(qh));
2735
3541
 
2736
- const uint8_t * restrict p0 = x[i].qs;
2737
- const uint8_t * restrict p1 = y[i].qs;
3542
+ const float d = GGML_FP16_TO_FP32(x[i].d);
3543
+ const float m = GGML_FP16_TO_FP32(x[i].m);
2738
3544
 
2739
- for (int j = 0; j < QK4_1/2; j++) {
2740
- const uint8_t v0 = p0[j];
2741
- const uint8_t v1 = p1[j];
3545
+ int sxy = 0;
2742
3546
 
2743
- const float f0 = d0*(v0 & 0xf) + m0;
2744
- const float f1 = d0*(v0 >> 4) + m0;
3547
+ for (int j = 0; j < QK8_1/2; j++) {
3548
+ const uint8_t v0 = x0[j];
2745
3549
 
2746
- const float f2 = d1*(v1 & 0xf) + m1;
2747
- const float f3 = d1*(v1 >> 4) + m1;
3550
+ const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
3551
+ const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
2748
3552
 
2749
- sumf += f0*f2 + f1*f3;
3553
+ const int x0_0 = (v0 & 0x0F) | x0_0h;
3554
+ const int x1_0 = (v0 >> 4) | x1_0h;
3555
+
3556
+ const int y0_0 = y0[2*j + 0];
3557
+ const int y1_0 = y0[2*j + 1];
3558
+
3559
+ sxy += x0_0*y0_0 + x1_0*y1_0;
2750
3560
  }
3561
+
3562
+ sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
2751
3563
  }
2752
- #endif
2753
3564
 
2754
3565
  *s = sumf;
3566
+ #endif
2755
3567
  }
2756
3568
 
2757
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3569
+ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
3570
  const int nb = n / QK8_0;
2759
3571
 
2760
3572
  assert(n % QK8_0 == 0);
2761
3573
  assert(nb % 2 == 0);
3574
+ assert(QK8_0 == QK8_0);
2762
3575
 
2763
- const block_q4_0 * restrict x = vx;
3576
+ const block_q8_0 * restrict x = vx;
2764
3577
  const block_q8_0 * restrict y = vy;
2765
3578
 
2766
- float sumf = 0.0;
2767
-
2768
3579
  #if defined(__ARM_NEON)
2769
- float sum0 = 0.0f;
2770
- float sum1 = 0.0f;
3580
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
3581
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2771
3582
 
2772
3583
  for (int i = 0; i < nb; i += 2) {
2773
- const block_q4_0 * restrict x0 = &x[i + 0];
2774
- const block_q4_0 * restrict x1 = &x[i + 1];
3584
+ const block_q8_0 * restrict x0 = &x[i + 0];
3585
+ const block_q8_0 * restrict x1 = &x[i + 1];
2775
3586
  const block_q8_0 * restrict y0 = &y[i + 0];
2776
3587
  const block_q8_0 * restrict y1 = &y[i + 1];
2777
3588
 
2778
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
- const int8x16_t s8b = vdupq_n_s8(0x8);
2780
-
2781
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2783
-
2784
- // 4-bit -> 8-bit
2785
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2786
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2787
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2788
- const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2789
-
2790
- // sub 8
2791
- const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2792
- const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2793
- const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
- const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3589
+ const int8x16_t x0_0 = vld1q_s8(x0->qs);
3590
+ const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3591
+ const int8x16_t x1_0 = vld1q_s8(x1->qs);
3592
+ const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
2795
3593
 
2796
3594
  // load y
2797
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
- const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
- const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2801
-
2802
- // interleave
2803
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
3595
+ const int8x16_t y0_0 = vld1q_s8(y0->qs);
3596
+ const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3597
+ const int8x16_t y1_0 = vld1q_s8(y1->qs);
3598
+ const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
2807
3599
 
2808
3600
  #if defined(__ARM_FEATURE_DOTPROD)
2809
- // dot product into int32x4_t
2810
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
3601
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3602
+ vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3603
+ vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
2812
3604
 
2813
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
3605
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3606
+ vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3607
+ vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
2815
3608
 
2816
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2818
3609
  #else
2819
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2823
-
2824
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
-
2829
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
-
2832
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
-
2835
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
-
2838
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
3610
+ const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3611
+ const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3612
+ const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3613
+ const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3614
+
3615
+ const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3616
+ const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3617
+ const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3618
+ const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3619
+
3620
+ const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3621
+ const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3622
+ const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3623
+ const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3624
+
3625
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3626
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
2840
3627
  #endif
2841
3628
  }
2842
3629
 
2843
- sumf = sum0 + sum1;
3630
+ *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2844
3631
  #elif defined(__AVX2__)
2845
3632
  // Initialize accumulator with zeros
2846
3633
  __m256 acc = _mm256_setzero_ps();
2847
3634
 
2848
3635
  // Main loop
2849
3636
  for (int i = 0; i < nb; ++i) {
2850
- /* Compute combined scale for the block */
3637
+ // Compute combined scale for the block
2851
3638
  const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2852
-
2853
- __m256i bx = bytesFromNibbles(x[i].qs);
2854
-
2855
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
- const __m256i off = _mm256_set1_epi8( 8 );
2857
- bx = _mm256_sub_epi8( bx, off );
2858
-
3639
+ __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
2859
3640
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
3641
 
2861
- // Get absolute values of x vectors
2862
- const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
-
2864
- // Sign the values of the y vectors
2865
- const __m256i sy = _mm256_sign_epi8(by, bx);
3642
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
2866
3643
 
2867
- // Perform multiplication and create 16-bit values
2868
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
-
2870
- const __m256i ones = _mm256_set1_epi16(1);
2871
- __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
-
2873
- /* Convert to vectore of 8 int32_t to 8 floats */
2874
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2875
-
2876
- /* Multiply q with scale and accumulate */
3644
+ // Multiply q with scale and accumulate
2877
3645
  acc = _mm256_fmadd_ps( d, q, acc );
2878
3646
  }
2879
3647
 
2880
- // Return horizontal sum of the acc vector
2881
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2885
-
2886
- sumf = _mm_cvtss_f32( res );
2887
- #elif defined(__AVX__)
2888
- // Initialize accumulator with zeros
2889
- __m256 acc = _mm256_setzero_ps();
2890
-
2891
- // Main loop
2892
- for (int i = 0; i < nb; ++i) {
2893
- // Compute combined scale for the block
2894
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2895
-
2896
- __m128i i32[2];
2897
- for (int j = 0; j < 2; ++j) {
2898
- // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
- __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2901
-
2902
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
- const __m128i off = _mm_set1_epi8( 8 );
2904
- bx = _mm_sub_epi8( bx, off );
2905
-
2906
- // Get absolute values of x vectors
2907
- const __m128i ax = _mm_sign_epi8(bx, bx);
2908
-
2909
- // Sign the values of the y vectors
2910
- const __m128i sy = _mm_sign_epi8(by, bx);
2911
-
2912
- // Perform multiplication and create 16-bit values
2913
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
2914
-
2915
- const __m128i ones = _mm_set1_epi16(1);
2916
- i32[j] = _mm_madd_epi16(ones, dot);
2917
- }
2918
-
2919
- // Convert int32_t to float
2920
- __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
- // Apply the scale, and accumulate
2922
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2923
- }
2924
-
2925
- // Return horizontal sum of the acc vector
2926
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
-
2931
- sumf = _mm_cvtss_f32( res );
3648
+ *s = hsum_float_8(acc);
2932
3649
  #else
2933
3650
  // scalar
2934
- for (int i = 0; i < nb; i++) {
2935
- const float d0 = x[i].d;
2936
- const float d1 = y[i].d;
3651
+ float sumf = 0.0;
2937
3652
 
2938
- const uint8_t * restrict p0 = x[i].qs;
2939
- const int8_t * restrict p1 = y[i].qs;
3653
+ for (int i = 0; i < nb; i++) {
3654
+ const int8_t * restrict x0 = x[i].qs;
3655
+ const int8_t * restrict y0 = y[i].qs;
2940
3656
 
2941
3657
  int sumi = 0;
2942
- for (int j = 0; j < QK8_0/2; j++) {
2943
- const uint8_t v0 = p0[j];
2944
3658
 
2945
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
- const int i1 = (int8_t) (v0 >> 4) - 8;
3659
+ for (int j = 0; j < QK8_0; j++) {
3660
+ const int v0 = x0[j];
3661
+ const int v1 = y0[j];
2947
3662
 
2948
- const int i2 = p1[2*j + 0];
2949
- const int i3 = p1[2*j + 1];
2950
-
2951
- sumi += i0*i2 + i1*i3;
3663
+ sumi += v0*v1;
2952
3664
  }
2953
- sumf += d0*d1*sumi;
3665
+
3666
+ sumf += (x[i].d*y[i].d)*sumi;
2954
3667
  }
2955
- #endif
2956
3668
 
2957
3669
  *s = sumf;
3670
+ #endif
2958
3671
  }
2959
3672
 
2960
3673
  // compute GGML_VEC_DOT_UNROLL dot products at once
@@ -3153,6 +3866,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
3153
3866
  #endif
3154
3867
  }
3155
3868
 
3869
+ inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) {
3870
+ ggml_float sum = 0.0;
3871
+ for (int i = 0; i < n; ++i) {
3872
+ sum += (ggml_float)x[i];
3873
+ }
3874
+ *s = sum;
3875
+ }
3876
+
3156
3877
  inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
3157
3878
  #ifndef GGML_USE_ACCELERATE
3158
3879
  float max = -INFINITY;
@@ -3203,24 +3924,34 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3203
3924
  [GGML_TYPE_F16] = 1,
3204
3925
  [GGML_TYPE_Q4_0] = QK4_0,
3205
3926
  [GGML_TYPE_Q4_1] = QK4_1,
3927
+ [GGML_TYPE_Q4_2] = QK4_2,
3928
+ [GGML_TYPE_Q4_3] = QK4_3,
3929
+ [GGML_TYPE_Q5_0] = QK5_0,
3930
+ [GGML_TYPE_Q5_1] = QK5_1,
3206
3931
  [GGML_TYPE_Q8_0] = QK8_0,
3932
+ [GGML_TYPE_Q8_1] = QK8_1,
3207
3933
  [GGML_TYPE_I8] = 1,
3208
3934
  [GGML_TYPE_I16] = 1,
3209
3935
  [GGML_TYPE_I32] = 1,
3210
3936
  };
3211
- static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
3937
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
3212
3938
 
3213
3939
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3214
3940
  [GGML_TYPE_F32] = sizeof(float),
3215
3941
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
3216
3942
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3217
3943
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3944
+ [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3945
+ [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3946
+ [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
3947
+ [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
3218
3948
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3949
+ [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
3219
3950
  [GGML_TYPE_I8] = sizeof(int8_t),
3220
3951
  [GGML_TYPE_I16] = sizeof(int16_t),
3221
3952
  [GGML_TYPE_I32] = sizeof(int32_t),
3222
3953
  };
3223
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
3954
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
3224
3955
 
3225
3956
 
3226
3957
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3228,12 +3959,34 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3228
3959
  [GGML_TYPE_F16] = "f16",
3229
3960
  [GGML_TYPE_Q4_0] = "q4_0",
3230
3961
  [GGML_TYPE_Q4_1] = "q4_1",
3962
+ [GGML_TYPE_Q4_2] = "q4_2",
3963
+ [GGML_TYPE_Q4_3] = "q4_3",
3964
+ [GGML_TYPE_Q5_0] = "q5_0",
3965
+ [GGML_TYPE_Q5_1] = "q5_1",
3231
3966
  [GGML_TYPE_Q8_0] = "q8_0",
3967
+ [GGML_TYPE_Q8_1] = "q8_1",
3232
3968
  [GGML_TYPE_I8] = "i8",
3233
3969
  [GGML_TYPE_I16] = "i16",
3234
3970
  [GGML_TYPE_I32] = "i32",
3235
3971
  };
3236
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
3972
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
3973
+
3974
+ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3975
+ [GGML_TYPE_F32] = false,
3976
+ [GGML_TYPE_F16] = false,
3977
+ [GGML_TYPE_Q4_0] = true,
3978
+ [GGML_TYPE_Q4_1] = true,
3979
+ [GGML_TYPE_Q4_2] = true,
3980
+ [GGML_TYPE_Q4_3] = true,
3981
+ [GGML_TYPE_Q5_0] = true,
3982
+ [GGML_TYPE_Q5_1] = true,
3983
+ [GGML_TYPE_Q8_0] = true,
3984
+ [GGML_TYPE_Q8_1] = true,
3985
+ [GGML_TYPE_I8] = false,
3986
+ [GGML_TYPE_I16] = false,
3987
+ [GGML_TYPE_I32] = false,
3988
+ };
3989
+ static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
3237
3990
 
3238
3991
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3239
3992
  "NONE",
@@ -3495,6 +4248,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3495
4248
  (t0->ne[3] == t1->ne[3]);
3496
4249
  }
3497
4250
 
4251
+ bool ggml_is_quantized(enum ggml_type type) {
4252
+ return GGML_IS_QUANTIZED[type];
4253
+ }
4254
+
3498
4255
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3499
4256
  return tensor->nb[0] > tensor->nb[1];
3500
4257
  }
@@ -3605,6 +4362,13 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3605
4362
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3606
4363
  }
3607
4364
 
4365
+ // initialize cuBLAS
4366
+ #if defined(GGML_USE_CUBLAS)
4367
+ ggml_init_cublas();
4368
+ #elif defined(GGML_USE_CLBLAST)
4369
+ ggml_cl_init();
4370
+ #endif
4371
+
3608
4372
  is_first_call = false;
3609
4373
  }
3610
4374
 
@@ -5535,7 +6299,6 @@ static void ggml_compute_forward_dup_f16(
5535
6299
  const struct ggml_compute_params * params,
5536
6300
  const struct ggml_tensor * src0,
5537
6301
  struct ggml_tensor * dst) {
5538
- GGML_ASSERT(params->ith == 0);
5539
6302
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5540
6303
 
5541
6304
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5547,6 +6310,11 @@ static void ggml_compute_forward_dup_f16(
5547
6310
  const int64_t ne02 = src0->ne[2];
5548
6311
  const int64_t ne03 = src0->ne[3];
5549
6312
 
6313
+ const int64_t ne0 = dst->ne[0];
6314
+ const int64_t ne1 = dst->ne[1];
6315
+ const int64_t ne2 = dst->ne[2];
6316
+ const int64_t ne3 = dst->ne[3];
6317
+
5550
6318
  const size_t nb00 = src0->nb[0];
5551
6319
  const size_t nb01 = src0->nb[1];
5552
6320
  const size_t nb02 = src0->nb[2];
@@ -5557,19 +6325,40 @@ static void ggml_compute_forward_dup_f16(
5557
6325
  const size_t nb2 = dst->nb[2];
5558
6326
  const size_t nb3 = dst->nb[3];
5559
6327
 
6328
+ const int ith = params->ith; // thread index
6329
+ const int nth = params->nth; // number of threads
6330
+
5560
6331
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5561
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
6332
+ // parallelize by elements
6333
+ const int ne = ggml_nelements(dst);
6334
+ const int dr = (ne + nth - 1) / nth;
6335
+ const int ie0 = dr * ith;
6336
+ const int ie1 = MIN(ie0 + dr, ne);
6337
+
6338
+ memcpy(
6339
+ ((char *) dst->data + ie0*nb0),
6340
+ ((char *) src0->data + ie0*nb00),
6341
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
6342
+
5562
6343
  return;
5563
6344
  }
5564
6345
 
6346
+ // parallelize by rows
6347
+ const int nr = ne01;
6348
+ // number of rows per thread
6349
+ const int dr = (nr + nth - 1) / nth;
6350
+ // row range for this thread
6351
+ const int ir0 = dr * ith;
6352
+ const int ir1 = MIN(ir0 + dr, nr);
6353
+
5565
6354
  if (src0->type == dst->type &&
5566
- src0->ne[0] == dst->ne[0] &&
5567
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6355
+ ne00 == ne0 &&
6356
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5568
6357
  // copy by rows
5569
6358
  const size_t rs = ne00*nb00;
5570
6359
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5571
6360
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5572
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6361
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5573
6362
  memcpy(
5574
6363
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5575
6364
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5583,21 +6372,21 @@ static void ggml_compute_forward_dup_f16(
5583
6372
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
5584
6373
 
5585
6374
  if (ggml_is_contiguous(dst)) {
5586
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
6375
+ if (nb00 == sizeof(ggml_fp16_t)) {
5587
6376
  if (dst->type == GGML_TYPE_F16) {
5588
6377
  size_t id = 0;
5589
- const size_t rs = ne00*nb00;
6378
+ const size_t rs = ne00 * nb00;
6379
+ char * dst_ptr = (char *) dst->data;
5590
6380
 
5591
6381
  for (int i03 = 0; i03 < ne03; i03++) {
5592
6382
  for (int i02 = 0; i02 < ne02; i02++) {
5593
- for (int i01 = 0; i01 < ne01; i01++) {
6383
+ id += rs * ir0;
6384
+ for (int i01 = ir0; i01 < ir1; i01++) {
5594
6385
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5595
- char * dst_ptr = (char *) dst->data + id*rs;
5596
-
5597
- memcpy(dst_ptr, src0_ptr, rs);
5598
-
5599
- id++;
6386
+ memcpy(dst_ptr + id, src0_ptr, rs);
6387
+ id += rs;
5600
6388
  }
6389
+ id += rs * (ne01 - ir1);
5601
6390
  }
5602
6391
  }
5603
6392
  } else if (dst->type == GGML_TYPE_F32) {
@@ -5606,34 +6395,39 @@ static void ggml_compute_forward_dup_f16(
5606
6395
 
5607
6396
  for (int i03 = 0; i03 < ne03; i03++) {
5608
6397
  for (int i02 = 0; i02 < ne02; i02++) {
5609
- for (int i01 = 0; i01 < ne01; i01++) {
6398
+ id += ne00 * ir0;
6399
+ for (int i01 = ir0; i01 < ir1; i01++) {
6400
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5610
6401
  for (int i00 = 0; i00 < ne00; i00++) {
5611
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5612
-
5613
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
6402
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5614
6403
  id++;
5615
6404
  }
5616
6405
  }
6406
+ id += ne00 * (ne01 - ir1);
5617
6407
  }
5618
6408
  }
5619
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
6409
+ } else if (ggml_is_quantized(dst->type)) {
5620
6410
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6411
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6412
+
5621
6413
  size_t id = 0;
5622
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
- float * src0_f32 = (float *) params->wdata;
6414
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6415
+ char * dst_ptr = (char *) dst->data;
5625
6416
 
5626
6417
  for (int i03 = 0; i03 < ne03; i03++) {
5627
6418
  for (int i02 = 0; i02 < ne02; i02++) {
5628
- for (int i01 = 0; i01 < ne01; i01++) {
6419
+ id += rs * ir0;
6420
+ for (int i01 = ir0; i01 < ir1; i01++) {
5629
6421
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
- // convert to f32 and quantize
6422
+
5631
6423
  for (int i00 = 0; i00 < ne00; i00++) {
5632
6424
  src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
6425
  }
6426
+
5634
6427
  quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
- id += dst_row_size;
6428
+ id += rs;
5636
6429
  }
6430
+ id += rs * (ne01 - ir1);
5637
6431
  }
5638
6432
  }
5639
6433
  } else {
@@ -5648,7 +6442,8 @@ static void ggml_compute_forward_dup_f16(
5648
6442
 
5649
6443
  for (int i03 = 0; i03 < ne03; i03++) {
5650
6444
  for (int i02 = 0; i02 < ne02; i02++) {
5651
- for (int i01 = 0; i01 < ne01; i01++) {
6445
+ id += ne00 * ir0;
6446
+ for (int i01 = ir0; i01 < ir1; i01++) {
5652
6447
  for (int i00 = 0; i00 < ne00; i00++) {
5653
6448
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5654
6449
 
@@ -5656,6 +6451,7 @@ static void ggml_compute_forward_dup_f16(
5656
6451
  id++;
5657
6452
  }
5658
6453
  }
6454
+ id += ne00 * (ne01 - ir1);
5659
6455
  }
5660
6456
  }
5661
6457
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5664,7 +6460,8 @@ static void ggml_compute_forward_dup_f16(
5664
6460
 
5665
6461
  for (int i03 = 0; i03 < ne03; i03++) {
5666
6462
  for (int i02 = 0; i02 < ne02; i02++) {
5667
- for (int i01 = 0; i01 < ne01; i01++) {
6463
+ id += ne00 * ir0;
6464
+ for (int i01 = ir0; i01 < ir1; i01++) {
5668
6465
  for (int i00 = 0; i00 < ne00; i00++) {
5669
6466
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5670
6467
 
@@ -5672,6 +6469,7 @@ static void ggml_compute_forward_dup_f16(
5672
6469
  id++;
5673
6470
  }
5674
6471
  }
6472
+ id += ne00 * (ne01 - ir1);
5675
6473
  }
5676
6474
  }
5677
6475
  } else {
@@ -5690,7 +6488,20 @@ static void ggml_compute_forward_dup_f16(
5690
6488
  if (dst->type == GGML_TYPE_F16) {
5691
6489
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5692
6490
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5693
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6491
+ i10 += ne00 * ir0;
6492
+ while (i10 >= ne0) {
6493
+ i10 -= ne0;
6494
+ if (++i11 == ne1) {
6495
+ i11 = 0;
6496
+ if (++i12 == ne2) {
6497
+ i12 = 0;
6498
+ if (++i13 == ne3) {
6499
+ i13 = 0;
6500
+ }
6501
+ }
6502
+ }
6503
+ }
6504
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5694
6505
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5695
6506
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5696
6507
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@@ -5711,25 +6522,51 @@ static void ggml_compute_forward_dup_f16(
5711
6522
  }
5712
6523
  }
5713
6524
  }
6525
+ i10 += ne00 * (ne01 - ir1);
6526
+ while (i10 >= ne0) {
6527
+ i10 -= ne0;
6528
+ if (++i11 == ne1) {
6529
+ i11 = 0;
6530
+ if (++i12 == ne2) {
6531
+ i12 = 0;
6532
+ if (++i13 == ne3) {
6533
+ i13 = 0;
6534
+ }
6535
+ }
6536
+ }
6537
+ }
5714
6538
  }
5715
6539
  }
5716
6540
  } else if (dst->type == GGML_TYPE_F32) {
5717
6541
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5718
6542
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5719
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6543
+ i10 += ne00 * ir0;
6544
+ while (i10 >= ne0) {
6545
+ i10 -= ne0;
6546
+ if (++i11 == ne1) {
6547
+ i11 = 0;
6548
+ if (++i12 == ne2) {
6549
+ i12 = 0;
6550
+ if (++i13 == ne3) {
6551
+ i13 = 0;
6552
+ }
6553
+ }
6554
+ }
6555
+ }
6556
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5720
6557
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5721
6558
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5722
6559
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5723
6560
 
5724
6561
  *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5725
6562
 
5726
- if (++i10 == ne00) {
6563
+ if (++i10 == ne0) {
5727
6564
  i10 = 0;
5728
- if (++i11 == ne01) {
6565
+ if (++i11 == ne1) {
5729
6566
  i11 = 0;
5730
- if (++i12 == ne02) {
6567
+ if (++i12 == ne2) {
5731
6568
  i12 = 0;
5732
- if (++i13 == ne03) {
6569
+ if (++i13 == ne3) {
5733
6570
  i13 = 0;
5734
6571
  }
5735
6572
  }
@@ -5737,6 +6574,19 @@ static void ggml_compute_forward_dup_f16(
5737
6574
  }
5738
6575
  }
5739
6576
  }
6577
+ i10 += ne00 * (ne01 - ir1);
6578
+ while (i10 >= ne0) {
6579
+ i10 -= ne0;
6580
+ if (++i11 == ne1) {
6581
+ i11 = 0;
6582
+ if (++i12 == ne2) {
6583
+ i12 = 0;
6584
+ if (++i13 == ne3) {
6585
+ i13 = 0;
6586
+ }
6587
+ }
6588
+ }
6589
+ }
5740
6590
  }
5741
6591
  }
5742
6592
  } else {
@@ -5748,7 +6598,6 @@ static void ggml_compute_forward_dup_f32(
5748
6598
  const struct ggml_compute_params * params,
5749
6599
  const struct ggml_tensor * src0,
5750
6600
  struct ggml_tensor * dst) {
5751
- GGML_ASSERT(params->ith == 0);
5752
6601
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5753
6602
 
5754
6603
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5760,6 +6609,11 @@ static void ggml_compute_forward_dup_f32(
5760
6609
  const int64_t ne02 = src0->ne[2];
5761
6610
  const int64_t ne03 = src0->ne[3];
5762
6611
 
6612
+ const int64_t ne0 = dst->ne[0];
6613
+ const int64_t ne1 = dst->ne[1];
6614
+ const int64_t ne2 = dst->ne[2];
6615
+ const int64_t ne3 = dst->ne[3];
6616
+
5763
6617
  const size_t nb00 = src0->nb[0];
5764
6618
  const size_t nb01 = src0->nb[1];
5765
6619
  const size_t nb02 = src0->nb[2];
@@ -5770,19 +6624,40 @@ static void ggml_compute_forward_dup_f32(
5770
6624
  const size_t nb2 = dst->nb[2];
5771
6625
  const size_t nb3 = dst->nb[3];
5772
6626
 
6627
+ const int ith = params->ith; // thread index
6628
+ const int nth = params->nth; // number of threads
6629
+
5773
6630
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5774
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
6631
+ // parallelize by elements
6632
+ const int ne = ggml_nelements(dst);
6633
+ const int dr = (ne + nth - 1) / nth;
6634
+ const int ie0 = dr * ith;
6635
+ const int ie1 = MIN(ie0 + dr, ne);
6636
+
6637
+ memcpy(
6638
+ ((char *) dst->data + ie0*nb0),
6639
+ ((char *) src0->data + ie0*nb00),
6640
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
6641
+
5775
6642
  return;
5776
6643
  }
5777
6644
 
6645
+ // parallelize by rows
6646
+ const int nr = ne01;
6647
+ // number of rows per thread
6648
+ const int dr = (nr + nth - 1) / nth;
6649
+ // row range for this thread
6650
+ const int ir0 = dr * ith;
6651
+ const int ir1 = MIN(ir0 + dr, nr);
6652
+
5778
6653
  if (src0->type == dst->type &&
5779
- src0->ne[0] == dst->ne[0] &&
5780
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6654
+ ne00 == ne0 &&
6655
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5781
6656
  // copy by rows
5782
6657
  const size_t rs = ne00*nb00;
5783
6658
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5784
6659
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5785
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6660
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5786
6661
  memcpy(
5787
6662
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5788
6663
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5795,21 +6670,21 @@ static void ggml_compute_forward_dup_f32(
5795
6670
 
5796
6671
  if (ggml_is_contiguous(dst)) {
5797
6672
  // TODO: simplify
5798
- if (src0->nb[0] == sizeof(float)) {
6673
+ if (nb00 == sizeof(float)) {
5799
6674
  if (dst->type == GGML_TYPE_F32) {
5800
6675
  size_t id = 0;
5801
- const size_t rs = ne00*nb00;
6676
+ const size_t rs = ne00 * nb00;
6677
+ char * dst_ptr = (char *) dst->data;
5802
6678
 
5803
6679
  for (int i03 = 0; i03 < ne03; i03++) {
5804
6680
  for (int i02 = 0; i02 < ne02; i02++) {
5805
- for (int i01 = 0; i01 < ne01; i01++) {
6681
+ id += rs * ir0;
6682
+ for (int i01 = ir0; i01 < ir1; i01++) {
5806
6683
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5807
- char * dst_ptr = (char *) dst->data + id*rs;
5808
-
5809
- memcpy(dst_ptr, src0_ptr, rs);
5810
-
5811
- id++;
6684
+ memcpy(dst_ptr + id, src0_ptr, rs);
6685
+ id += rs;
5812
6686
  }
6687
+ id += rs * (ne01 - ir1);
5813
6688
  }
5814
6689
  }
5815
6690
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5818,7 +6693,8 @@ static void ggml_compute_forward_dup_f32(
5818
6693
 
5819
6694
  for (int i03 = 0; i03 < ne03; i03++) {
5820
6695
  for (int i02 = 0; i02 < ne02; i02++) {
5821
- for (int i01 = 0; i01 < ne01; i01++) {
6696
+ id += ne00 * ir0;
6697
+ for (int i01 = ir0; i01 < ir1; i01++) {
5822
6698
  for (int i00 = 0; i00 < ne00; i00++) {
5823
6699
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5824
6700
 
@@ -5826,21 +6702,25 @@ static void ggml_compute_forward_dup_f32(
5826
6702
  id++;
5827
6703
  }
5828
6704
  }
6705
+ id += ne00 * (ne01 - ir1);
5829
6706
  }
5830
6707
  }
5831
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
6708
+ } else if (ggml_is_quantized(dst->type)) {
5832
6709
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6710
+
5833
6711
  size_t id = 0;
5834
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6712
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6713
+ char * dst_ptr = (char *) dst->data;
5836
6714
 
5837
6715
  for (int i03 = 0; i03 < ne03; i03++) {
5838
6716
  for (int i02 = 0; i02 < ne02; i02++) {
5839
- for (int i01 = 0; i01 < ne01; i01++) {
6717
+ id += rs * ir0;
6718
+ for (int i01 = ir0; i01 < ir1; i01++) {
5840
6719
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
6720
  quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
- id += dst_row_size;
6721
+ id += rs;
5843
6722
  }
6723
+ id += rs * (ne01 - ir1);
5844
6724
  }
5845
6725
  }
5846
6726
  } else {
@@ -5855,7 +6735,8 @@ static void ggml_compute_forward_dup_f32(
5855
6735
 
5856
6736
  for (int i03 = 0; i03 < ne03; i03++) {
5857
6737
  for (int i02 = 0; i02 < ne02; i02++) {
5858
- for (int i01 = 0; i01 < ne01; i01++) {
6738
+ id += ne00 * ir0;
6739
+ for (int i01 = ir0; i01 < ir1; i01++) {
5859
6740
  for (int i00 = 0; i00 < ne00; i00++) {
5860
6741
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5861
6742
 
@@ -5863,6 +6744,7 @@ static void ggml_compute_forward_dup_f32(
5863
6744
  id++;
5864
6745
  }
5865
6746
  }
6747
+ id += ne00 * (ne01 - ir1);
5866
6748
  }
5867
6749
  }
5868
6750
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5871,7 +6753,8 @@ static void ggml_compute_forward_dup_f32(
5871
6753
 
5872
6754
  for (int i03 = 0; i03 < ne03; i03++) {
5873
6755
  for (int i02 = 0; i02 < ne02; i02++) {
5874
- for (int i01 = 0; i01 < ne01; i01++) {
6756
+ id += ne00 * ir0;
6757
+ for (int i01 = ir0; i01 < ir1; i01++) {
5875
6758
  for (int i00 = 0; i00 < ne00; i00++) {
5876
6759
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5877
6760
 
@@ -5879,6 +6762,7 @@ static void ggml_compute_forward_dup_f32(
5879
6762
  id++;
5880
6763
  }
5881
6764
  }
6765
+ id += ne00 * (ne01 - ir1);
5882
6766
  }
5883
6767
  }
5884
6768
  } else {
@@ -5890,6 +6774,7 @@ static void ggml_compute_forward_dup_f32(
5890
6774
  }
5891
6775
 
5892
6776
  // dst counters
6777
+
5893
6778
  int64_t i10 = 0;
5894
6779
  int64_t i11 = 0;
5895
6780
  int64_t i12 = 0;
@@ -5898,20 +6783,33 @@ static void ggml_compute_forward_dup_f32(
5898
6783
  if (dst->type == GGML_TYPE_F32) {
5899
6784
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5900
6785
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5901
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6786
+ i10 += ne00 * ir0;
6787
+ while (i10 >= ne0) {
6788
+ i10 -= ne0;
6789
+ if (++i11 == ne1) {
6790
+ i11 = 0;
6791
+ if (++i12 == ne2) {
6792
+ i12 = 0;
6793
+ if (++i13 == ne3) {
6794
+ i13 = 0;
6795
+ }
6796
+ }
6797
+ }
6798
+ }
6799
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5902
6800
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5903
6801
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5904
6802
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5905
6803
 
5906
6804
  memcpy(dst_ptr, src0_ptr, sizeof(float));
5907
6805
 
5908
- if (++i10 == dst->ne[0]) {
6806
+ if (++i10 == ne0) {
5909
6807
  i10 = 0;
5910
- if (++i11 == dst->ne[1]) {
6808
+ if (++i11 == ne1) {
5911
6809
  i11 = 0;
5912
- if (++i12 == dst->ne[2]) {
6810
+ if (++i12 == ne2) {
5913
6811
  i12 = 0;
5914
- if (++i13 == dst->ne[3]) {
6812
+ if (++i13 == ne3) {
5915
6813
  i13 = 0;
5916
6814
  }
5917
6815
  }
@@ -5919,25 +6817,51 @@ static void ggml_compute_forward_dup_f32(
5919
6817
  }
5920
6818
  }
5921
6819
  }
6820
+ i10 += ne00 * (ne01 - ir1);
6821
+ while (i10 >= ne0) {
6822
+ i10 -= ne0;
6823
+ if (++i11 == ne1) {
6824
+ i11 = 0;
6825
+ if (++i12 == ne2) {
6826
+ i12 = 0;
6827
+ if (++i13 == ne3) {
6828
+ i13 = 0;
6829
+ }
6830
+ }
6831
+ }
6832
+ }
5922
6833
  }
5923
6834
  }
5924
6835
  } else if (dst->type == GGML_TYPE_F16) {
5925
6836
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5926
6837
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5927
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6838
+ i10 += ne00 * ir0;
6839
+ while (i10 >= ne0) {
6840
+ i10 -= ne0;
6841
+ if (++i11 == ne1) {
6842
+ i11 = 0;
6843
+ if (++i12 == ne2) {
6844
+ i12 = 0;
6845
+ if (++i13 == ne3) {
6846
+ i13 = 0;
6847
+ }
6848
+ }
6849
+ }
6850
+ }
6851
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5928
6852
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5929
6853
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5930
6854
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5931
6855
 
5932
6856
  *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5933
6857
 
5934
- if (++i10 == dst->ne[0]) {
6858
+ if (++i10 == ne0) {
5935
6859
  i10 = 0;
5936
- if (++i11 == dst->ne[1]) {
6860
+ if (++i11 == ne1) {
5937
6861
  i11 = 0;
5938
- if (++i12 == dst->ne[2]) {
6862
+ if (++i12 == ne2) {
5939
6863
  i12 = 0;
5940
- if (++i13 == dst->ne[3]) {
6864
+ if (++i13 == ne3) {
5941
6865
  i13 = 0;
5942
6866
  }
5943
6867
  }
@@ -5945,6 +6869,19 @@ static void ggml_compute_forward_dup_f32(
5945
6869
  }
5946
6870
  }
5947
6871
  }
6872
+ i10 += ne00 * (ne01 - ir1);
6873
+ while (i10 >= ne0) {
6874
+ i10 -= ne0;
6875
+ if (++i11 == ne1) {
6876
+ i11 = 0;
6877
+ if (++i12 == ne2) {
6878
+ i12 = 0;
6879
+ if (++i13 == ne3) {
6880
+ i13 = 0;
6881
+ }
6882
+ }
6883
+ }
6884
+ }
5948
6885
  }
5949
6886
  }
5950
6887
  } else {
@@ -6191,7 +7128,7 @@ static void ggml_compute_forward_add_q_f32(
6191
7128
  GGML_ASSERT(nb1 <= nb2);
6192
7129
  GGML_ASSERT(nb2 <= nb3);
6193
7130
 
6194
- GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
7131
+ GGML_ASSERT(ggml_is_quantized(src0->type));
6195
7132
  GGML_ASSERT(dst->type == src0->type);
6196
7133
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
7134
 
@@ -6205,7 +7142,7 @@ static void ggml_compute_forward_add_q_f32(
6205
7142
  const int ir0 = dr*ith;
6206
7143
  const int ir1 = MIN(ir0 + dr, nr);
6207
7144
 
6208
- float * wdata = (float*) params->wdata + ne00 * ith;
7145
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6209
7146
 
6210
7147
  for (int ir = ir0; ir < ir1; ++ir) {
6211
7148
  // src0 indices
@@ -6261,6 +7198,11 @@ static void ggml_compute_forward_add(
6261
7198
  } break;
6262
7199
  case GGML_TYPE_Q4_0:
6263
7200
  case GGML_TYPE_Q4_1:
7201
+ case GGML_TYPE_Q4_2:
7202
+ case GGML_TYPE_Q4_3:
7203
+ case GGML_TYPE_Q5_0:
7204
+ case GGML_TYPE_Q5_1:
7205
+ case GGML_TYPE_Q8_0:
6264
7206
  {
6265
7207
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
7208
  } break;
@@ -6518,15 +7460,20 @@ static void ggml_compute_forward_sum_f32(
6518
7460
  const size_t nb02 = src0->nb[2];
6519
7461
  const size_t nb03 = src0->nb[3];
6520
7462
 
7463
+ ggml_float sum = 0;
7464
+ ggml_float row_sum = 0;
7465
+
6521
7466
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6522
7467
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6523
7468
  for (int64_t i01 = 0; i01 < ne01; i01++) {
6524
- ggml_vec_sum_f32(ne00,
6525
- (float *) (dst->data),
7469
+ ggml_vec_sum_ggf(ne00,
7470
+ &row_sum,
6526
7471
  (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
7472
+ sum += row_sum;
6527
7473
  }
6528
7474
  }
6529
7475
  }
7476
+ ((float *) dst->data)[0] = sum;
6530
7477
  }
6531
7478
 
6532
7479
  static void ggml_compute_forward_sum(
@@ -7161,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
7161
8108
 
7162
8109
  // ggml_compute_forward_mul_mat
7163
8110
 
7164
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8111
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7165
8112
  // helper function to determine if it is better to use BLAS or not
7166
8113
  // for large matrices, BLAS is faster
7167
8114
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7186,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
7186
8133
 
7187
8134
  return false;
7188
8135
  }
8136
+
7189
8137
  #endif
7190
8138
 
7191
8139
  static void ggml_compute_forward_mul_mat_f32(
@@ -7201,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
7201
8149
  const int64_t ne02 = src0->ne[2];
7202
8150
  const int64_t ne03 = src0->ne[3];
7203
8151
 
7204
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8152
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7205
8153
  const int64_t ne10 = src1->ne[0];
7206
8154
  #endif
7207
8155
  const int64_t ne11 = src1->ne[1];
@@ -7258,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
7258
8206
  // nb01 >= nb00 - src0 is not transposed
7259
8207
  // compute by src0 rows
7260
8208
 
7261
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8209
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7262
8210
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7263
8211
  if (params->ith != 0) {
7264
8212
  return;
@@ -7272,6 +8220,19 @@ static void ggml_compute_forward_mul_mat_f32(
7272
8220
  return;
7273
8221
  }
7274
8222
 
8223
+ #if defined(GGML_USE_CUBLAS)
8224
+ const float alpha = 1.0f;
8225
+ const float beta = 0.0f;
8226
+ const int x_ne = ne01 * ne10;
8227
+ const int y_ne = ne11 * ne10;
8228
+ const int d_ne = ne11 * ne01;
8229
+
8230
+ size_t x_size, y_size, d_size;
8231
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8232
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8233
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8234
+ #endif
8235
+
7275
8236
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7276
8237
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7277
8238
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -7279,15 +8240,44 @@ static void ggml_compute_forward_mul_mat_f32(
7279
8240
 
7280
8241
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7281
8242
 
8243
+ #if defined(GGML_USE_CUBLAS)
8244
+ // copy data to device
8245
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8246
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8247
+
8248
+ // compute
8249
+ CUBLAS_CHECK(
8250
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8251
+ ne01, ne11, ne10,
8252
+ &alpha, d_X, ne00,
8253
+ d_Y, ne10,
8254
+ &beta, d_D, ne01));
8255
+
8256
+ // copy data to host
8257
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8258
+ #elif defined(GGML_USE_CLBLAST)
7282
8259
  // zT = y * xT
8260
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8261
+ ne11, ne01, ne10,
8262
+ 1.0f, y, ne10,
8263
+ x, ne10,
8264
+ 0.0f, d, ne01,
8265
+ GGML_TYPE_F32);
8266
+ #else
7283
8267
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7284
8268
  ne11, ne01, ne10,
7285
8269
  1.0f, y, ne10,
7286
8270
  x, ne00,
7287
8271
  0.0f, d, ne01);
8272
+ #endif
7288
8273
  }
7289
8274
  }
7290
-
8275
+ #if defined(GGML_USE_CUBLAS)
8276
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8277
+ ggml_cuda_pool_free(d_X, x_size);
8278
+ ggml_cuda_pool_free(d_Y, y_size);
8279
+ ggml_cuda_pool_free(d_D, d_size);
8280
+ #endif
7291
8281
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7292
8282
 
7293
8283
  return;
@@ -7417,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7417
8407
  // nb01 >= nb00 - src0 is not transposed
7418
8408
  // compute by src0 rows
7419
8409
 
7420
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8410
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7421
8411
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7422
8412
  GGML_ASSERT(nb10 == sizeof(float));
7423
8413
 
@@ -7433,10 +8423,35 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7433
8423
  return;
7434
8424
  }
7435
8425
 
7436
- float * const wdata = params->wdata;
8426
+ #if defined(GGML_USE_CUBLAS)
8427
+ ggml_fp16_t * const wdata = params->wdata;
8428
+
8429
+ const float alpha = 1.0f;
8430
+ const float beta = 0.0f;
8431
+ const int x_ne = ne01 * ne10;
8432
+ const int y_ne = ne11 * ne10;
8433
+ const int d_ne = ne11 * ne01;
7437
8434
 
8435
+ size_t x_size, y_size, d_size;
8436
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8437
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8438
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8439
+ #else
8440
+ float * const wdata = params->wdata;
8441
+ #endif
7438
8442
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7439
8443
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8444
+ #if defined(GGML_USE_CUBLAS)
8445
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
8446
+ {
8447
+ size_t id = 0;
8448
+ for (int64_t i01 = 0; i01 < ne11; ++i01) {
8449
+ for (int64_t i00 = 0; i00 < ne10; ++i00) {
8450
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
8451
+ }
8452
+ }
8453
+ }
8454
+ #else
7440
8455
  {
7441
8456
  size_t id = 0;
7442
8457
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7445,7 +8460,44 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7445
8460
  }
7446
8461
  }
7447
8462
  }
8463
+ #endif
8464
+
8465
+ #if defined(GGML_USE_CUBLAS)
8466
+ const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
8467
+ const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
8468
+
8469
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8470
+
8471
+ // copy data to device
8472
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8473
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8474
+
8475
+ // compute
8476
+ CUBLAS_CHECK(
8477
+ cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8478
+ ne01, ne11, ne10,
8479
+ &alpha, d_X, CUDA_R_16F, ne00,
8480
+ d_Y, CUDA_R_16F, ne10,
8481
+ &beta, d_D, CUDA_R_32F, ne01,
8482
+ CUBLAS_COMPUTE_32F,
8483
+ CUBLAS_GEMM_DEFAULT));
8484
+
8485
+ // copy data to host
8486
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8487
+ #elif defined(GGML_USE_CLBLAST)
8488
+ const float * x = wdata;
8489
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7448
8490
 
8491
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8492
+
8493
+ // zT = y * xT
8494
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8495
+ ne11, ne01, ne10,
8496
+ 1.0f, y, ne10,
8497
+ x, ne10,
8498
+ 0.0f, d, ne01,
8499
+ GGML_TYPE_F32);
8500
+ #else
7449
8501
  const float * x = wdata;
7450
8502
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7451
8503
 
@@ -7457,9 +8509,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7457
8509
  1.0f, y, ne10,
7458
8510
  x, ne00,
7459
8511
  0.0f, d, ne01);
8512
+ #endif
7460
8513
  }
7461
8514
  }
7462
8515
 
8516
+ #if defined(GGML_USE_CUBLAS)
8517
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8518
+ ggml_cuda_pool_free(d_X, x_size);
8519
+ ggml_cuda_pool_free(d_Y, y_size);
8520
+ ggml_cuda_pool_free(d_D, d_size);
8521
+ #endif
7463
8522
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7464
8523
 
7465
8524
  return;
@@ -7592,6 +8651,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7592
8651
  const enum ggml_type type = src0->type;
7593
8652
  quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7594
8653
  vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
8654
+ enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
7595
8655
 
7596
8656
  // we don't support permuted src0 or src1
7597
8657
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7611,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7611
8671
  // nb01 >= nb00 - src0 is not transposed
7612
8672
  // compute by src0 rows
7613
8673
 
7614
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8674
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
7615
8675
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7616
8676
  if (params->ith != 0) {
7617
8677
  return;
@@ -7625,11 +8685,66 @@ static void ggml_compute_forward_mul_mat_q_f32(
7625
8685
  return;
7626
8686
  }
7627
8687
 
8688
+ #if defined(GGML_USE_CUBLAS)
8689
+ const float alpha = 1.0f;
8690
+ const float beta = 0.0f;
8691
+ const int x_ne = ne01 * ne10;
8692
+ const int y_ne = ne11 * ne10;
8693
+ const int d_ne = ne11 * ne01;
8694
+
8695
+ size_t x_size, y_size, d_size, q_size;
8696
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8697
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8698
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8699
+ float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8700
+
8701
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8702
+ if (type == GGML_TYPE_Q4_0) {
8703
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8704
+ }
8705
+ else if (type == GGML_TYPE_Q4_1) {
8706
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8707
+ }
8708
+ else if (type == GGML_TYPE_Q4_2) {
8709
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8710
+ }
8711
+ else if (type == GGML_TYPE_Q4_3) {
8712
+ dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8713
+ }
8714
+ else if (type == GGML_TYPE_Q5_0) {
8715
+ dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8716
+ }
8717
+ else if (type == GGML_TYPE_Q5_1) {
8718
+ dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8719
+ }
8720
+ else if (type == GGML_TYPE_Q8_0) {
8721
+ dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8722
+ }
8723
+ else {
8724
+ GGML_ASSERT(false);
8725
+ }
8726
+ #elif !defined(GGML_USE_CLBLAST)
7628
8727
  float * const wdata = params->wdata;
7629
8728
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8729
+ #endif
7630
8730
 
7631
8731
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7632
8732
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8733
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8734
+
8735
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8736
+
8737
+ #if defined(GGML_USE_CUBLAS)
8738
+ // copy and dequantize on device
8739
+ CUDA_CHECK(
8740
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8741
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8742
+
8743
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8744
+ CUDA_CHECK(cudaGetLastError());
8745
+ #elif defined(GGML_USE_CLBLAST)
8746
+ const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
8747
+ #else
7633
8748
  {
7634
8749
  size_t id = 0;
7635
8750
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7637,21 +8752,49 @@ static void ggml_compute_forward_mul_mat_q_f32(
7637
8752
  id += ne00;
7638
8753
  }
7639
8754
  }
7640
-
7641
8755
  const float * x = wdata;
7642
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8756
+ #endif
7643
8757
 
7644
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7645
8758
 
8759
+ #if defined(GGML_USE_CUBLAS)
8760
+ // copy data to device
8761
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
8762
+
8763
+ // compute
8764
+ CUBLAS_CHECK(
8765
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8766
+ ne01, ne11, ne10,
8767
+ &alpha, d_X, ne00,
8768
+ d_Y, ne10,
8769
+ &beta, d_D, ne01));
8770
+
8771
+ // copy data to host
8772
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
8773
+ #elif defined(GGML_USE_CLBLAST)
7646
8774
  // zT = y * xT
8775
+ ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
8776
+ ne11, ne01, ne10,
8777
+ 1.0f, y, ne10,
8778
+ x, ne10,
8779
+ 0.0f, d, ne01,
8780
+ type);
8781
+ #else
7647
8782
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7648
8783
  ne11, ne01, ne10,
7649
8784
  1.0f, y, ne10,
7650
8785
  x, ne00,
7651
8786
  0.0f, d, ne01);
8787
+ #endif
7652
8788
  }
7653
8789
  }
7654
8790
 
8791
+ #if defined(GGML_USE_CUBLAS)
8792
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
8793
+ ggml_cuda_pool_free(d_X, x_size);
8794
+ ggml_cuda_pool_free(d_Y, y_size);
8795
+ ggml_cuda_pool_free(d_D, d_size);
8796
+ ggml_cuda_pool_free(d_Q, q_size);
8797
+ #endif
7655
8798
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7656
8799
 
7657
8800
  return;
@@ -7660,7 +8803,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7660
8803
 
7661
8804
  if (params->type == GGML_TASK_INIT) {
7662
8805
  char * wdata = params->wdata;
7663
- const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8806
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
7664
8807
 
7665
8808
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7666
8809
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -7691,7 +8834,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7691
8834
  const int ir1 = MIN(ir0 + dr, nr);
7692
8835
 
7693
8836
  void * wdata = params->wdata;
7694
- const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
8837
+ const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
7695
8838
 
7696
8839
  for (int ir = ir0; ir < ir1; ++ir) {
7697
8840
  // src0 indices
@@ -7739,7 +8882,12 @@ static void ggml_compute_forward_mul_mat(
7739
8882
  switch (src0->type) {
7740
8883
  case GGML_TYPE_Q4_0:
7741
8884
  case GGML_TYPE_Q4_1:
8885
+ case GGML_TYPE_Q4_2:
8886
+ case GGML_TYPE_Q4_3:
8887
+ case GGML_TYPE_Q5_0:
8888
+ case GGML_TYPE_Q5_1:
7742
8889
  case GGML_TYPE_Q8_0:
8890
+ case GGML_TYPE_Q8_1:
7743
8891
  {
7744
8892
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
7745
8893
  } break;
@@ -7756,34 +8904,6 @@ static void ggml_compute_forward_mul_mat(
7756
8904
  GGML_ASSERT(false);
7757
8905
  } break;
7758
8906
  }
7759
-
7760
- #if 0
7761
- if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
7762
- static int first = 8;
7763
- printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7764
- printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7765
- printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7766
- if (first) {
7767
- --first;
7768
- } else {
7769
- for (int k = 0; k < dst->ne[1]; ++k) {
7770
- for (int j = 0; j < dst->ne[0]/16; ++j) {
7771
- for (int i = 0; i < 16; ++i) {
7772
- printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
7773
- }
7774
- printf("\n");
7775
- }
7776
- printf("\n");
7777
- }
7778
- printf("\n");
7779
- exit(0);
7780
- }
7781
- } else {
7782
- printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7783
- printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7784
- printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7785
- }
7786
- #endif
7787
8907
  }
7788
8908
 
7789
8909
  // ggml_compute_forward_scale
@@ -7994,7 +9114,12 @@ static void ggml_compute_forward_get_rows(
7994
9114
  switch (src0->type) {
7995
9115
  case GGML_TYPE_Q4_0:
7996
9116
  case GGML_TYPE_Q4_1:
9117
+ case GGML_TYPE_Q4_2:
9118
+ case GGML_TYPE_Q4_3:
9119
+ case GGML_TYPE_Q5_0:
9120
+ case GGML_TYPE_Q5_1:
7997
9121
  case GGML_TYPE_Q8_0:
9122
+ case GGML_TYPE_Q8_1:
7998
9123
  {
7999
9124
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
8000
9125
  } break;
@@ -8132,6 +9257,7 @@ static void ggml_compute_forward_soft_max_f32(
8132
9257
 
8133
9258
  uint16_t scvt;
8134
9259
  for (int i = 0; i < nc; i++) {
9260
+ //printf("p[%3d] = %8.4f\n", i, p[i]);
8135
9261
  if (p[i] == -INFINITY) {
8136
9262
  p[i] = 0.0f;
8137
9263
  } else {
@@ -8224,9 +9350,11 @@ static void ggml_compute_forward_rope_f32(
8224
9350
 
8225
9351
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8226
9352
 
9353
+ const bool is_neox = mode & 2;
9354
+
8227
9355
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8228
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8229
- const int p = (mode == 0 ? n_past + i2 : i2);
9356
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
9357
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8230
9358
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8231
9359
  if (ir++ < ir0) continue;
8232
9360
  if (ir > ir1) break;
@@ -8239,14 +9367,25 @@ static void ggml_compute_forward_rope_f32(
8239
9367
 
8240
9368
  theta *= theta_scale;
8241
9369
 
8242
- const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8243
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9370
+ if (!is_neox) {
9371
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9372
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9373
+
9374
+ const float x0 = src[0];
9375
+ const float x1 = src[1];
9376
+
9377
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9378
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
9379
+ } else {
9380
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
9381
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8244
9382
 
8245
- const float x0 = src[0];
8246
- const float x1 = src[1];
9383
+ const float x0 = src[0];
9384
+ const float x1 = src[n_dims/2];
8247
9385
 
8248
- dst_data[0] = x0*cos_theta - x1*sin_theta;
8249
- dst_data[1] = x0*sin_theta + x1*cos_theta;
9386
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9387
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
9388
+ }
8250
9389
  }
8251
9390
  }
8252
9391
  }
@@ -8301,9 +9440,11 @@ static void ggml_compute_forward_rope_f16(
8301
9440
 
8302
9441
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8303
9442
 
9443
+ const bool is_neox = mode & 2;
9444
+
8304
9445
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8305
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8306
- const int p = (mode == 0 ? n_past + i2 : i2);
9446
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
9447
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8307
9448
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8308
9449
  if (ir++ < ir0) continue;
8309
9450
  if (ir > ir1) break;
@@ -8316,14 +9457,25 @@ static void ggml_compute_forward_rope_f16(
8316
9457
 
8317
9458
  theta *= theta_scale;
8318
9459
 
8319
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8320
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9460
+ if (!is_neox) {
9461
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9462
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9463
+
9464
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
9465
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
9466
+
9467
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9468
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9469
+ } else {
9470
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
9471
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8321
9472
 
8322
- const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
- const float x1 = GGML_FP16_TO_FP32(src[1]);
9473
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
9474
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8324
9475
 
8325
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
- dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9476
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9477
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9478
+ }
8327
9479
  }
8328
9480
  }
8329
9481
  }
@@ -10402,11 +11554,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10402
11554
  case GGML_OP_CPY:
10403
11555
  case GGML_OP_DUP:
10404
11556
  {
10405
- node->n_tasks = 1;
11557
+ node->n_tasks = n_threads;
10406
11558
 
10407
11559
  size_t cur = 0;
10408
- if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
11560
+ if (ggml_is_quantized(node->type)) {
11561
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
10410
11562
  }
10411
11563
 
10412
11564
  work_size = MAX(work_size, cur);
@@ -10417,7 +11569,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10417
11569
 
10418
11570
  size_t cur = 0;
10419
11571
 
10420
- if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
11572
+ if (ggml_is_quantized(node->src0->type)) {
10421
11573
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
11574
  }
10423
11575
 
@@ -10466,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10466
11618
  size_t cur = 0;
10467
11619
 
10468
11620
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10469
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
11621
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10470
11622
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10471
11623
  node->n_tasks = 1; // TODO: this actually is doing nothing
10472
11624
  // the threads are still spinning
@@ -10482,15 +11634,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10482
11634
  #endif
10483
11635
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10484
11636
  cur = 0;
10485
- } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
10486
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
11637
+ } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
11638
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
10487
11639
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10488
11640
  node->n_tasks = 1;
10489
11641
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
10490
11642
  } else
10491
11643
  #endif
10492
11644
  {
10493
- cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
11645
+ const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
11646
+ cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
10494
11647
  }
10495
11648
  } else {
10496
11649
  GGML_ASSERT(false);
@@ -10818,9 +11971,9 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
10818
11971
  for (int i = 0; i < cgraph->n_nodes; i++) {
10819
11972
  struct ggml_tensor * node = cgraph->nodes[i];
10820
11973
 
10821
- perf_total_per_op_us[node->op] += node->perf_time_us;
11974
+ perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us);
10822
11975
 
10823
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 ", %" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
11976
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n",
10824
11977
  i,
10825
11978
  node->ne[0], node->ne[1], node->ne[2],
10826
11979
  GGML_OP_LABEL[node->op], node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs,
@@ -10834,13 +11987,17 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
10834
11987
  for (int i = 0; i < cgraph->n_leafs; i++) {
10835
11988
  struct ggml_tensor * node = cgraph->leafs[i];
10836
11989
 
10837
- GGML_PRINT(" - %3d: [ %" PRId64 ", %" PRId64 "] %8s\n",
11990
+ GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
10838
11991
  i,
10839
11992
  node->ne[0], node->ne[1],
10840
11993
  GGML_OP_LABEL[node->op]);
10841
11994
  }
10842
11995
 
10843
11996
  for (int i = 0; i < GGML_OP_COUNT; i++) {
11997
+ if (perf_total_per_op_us[i] == 0) {
11998
+ continue;
11999
+ }
12000
+
10844
12001
  GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", GGML_OP_LABEL[i], (double) perf_total_per_op_us[i] / 1000.0);
10845
12002
  }
10846
12003
 
@@ -11674,7 +12831,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
11674
12831
 
11675
12832
  for (int i = 0; i < nb; i++) {
11676
12833
  for (int l = 0; l < QK4_0; l += 2) {
11677
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12834
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
11678
12835
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11679
12836
 
11680
12837
  hist[vi0]++;
@@ -11697,7 +12854,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11697
12854
 
11698
12855
  for (int i = 0; i < nb; i++) {
11699
12856
  for (int l = 0; l < QK4_1; l += 2) {
11700
- const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12857
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
11701
12858
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11702
12859
 
11703
12860
  hist[vi0]++;
@@ -11709,6 +12866,184 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11709
12866
  return (n/QK4_1*sizeof(block_q4_1));
11710
12867
  }
11711
12868
 
12869
+ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
12870
+ assert(k % QK4_2 == 0);
12871
+ const int nb = k / QK4_2;
12872
+
12873
+ for (int j = 0; j < n; j += k) {
12874
+ block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12875
+
12876
+ quantize_row_q4_2_reference(src + j, y, k);
12877
+
12878
+ for (int i = 0; i < nb; i++) {
12879
+ for (int l = 0; l < QK4_2; l += 2) {
12880
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12881
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12882
+
12883
+ hist[vi0]++;
12884
+ hist[vi1]++;
12885
+ }
12886
+ }
12887
+ }
12888
+
12889
+ return (n/QK4_2*sizeof(block_q4_2));
12890
+ }
12891
+
12892
+ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12893
+ assert(k % QK4_3 == 0);
12894
+ const int nb = k / QK4_3;
12895
+
12896
+ for (int j = 0; j < n; j += k) {
12897
+ block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
12898
+
12899
+ quantize_row_q4_3_reference(src + j, y, k);
12900
+
12901
+ for (int i = 0; i < nb; i++) {
12902
+ for (int l = 0; l < QK4_3; l += 2) {
12903
+ const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
12904
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12905
+
12906
+ hist[vi0]++;
12907
+ hist[vi1]++;
12908
+ }
12909
+ }
12910
+ }
12911
+
12912
+ return (n/QK4_3*sizeof(block_q4_3));
12913
+ }
12914
+
12915
+ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12916
+ assert(k % QK5_0 == 0);
12917
+ const int nb = k / QK5_0;
12918
+
12919
+ for (int j = 0; j < n; j += k) {
12920
+ block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
12921
+
12922
+ quantize_row_q5_0_reference(src + j, y, k);
12923
+
12924
+ for (int i = 0; i < nb; i++) {
12925
+ uint32_t qh;
12926
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12927
+
12928
+ for (int l = 0; l < QK5_0; l += 2) {
12929
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12930
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12931
+
12932
+ // cast to 16 bins
12933
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12934
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12935
+
12936
+ hist[vi0]++;
12937
+ hist[vi1]++;
12938
+ }
12939
+ }
12940
+ }
12941
+
12942
+ return (n/QK5_0*sizeof(block_q5_0));
12943
+ }
12944
+
12945
+ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
12946
+ assert(k % QK5_1 == 0);
12947
+ const int nb = k / QK5_1;
12948
+
12949
+ for (int j = 0; j < n; j += k) {
12950
+ block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
12951
+
12952
+ quantize_row_q5_1_reference(src + j, y, k);
12953
+
12954
+ for (int i = 0; i < nb; i++) {
12955
+ uint32_t qh;
12956
+ memcpy(&qh, &y[i].qh, sizeof(qh));
12957
+
12958
+ for (int l = 0; l < QK5_1; l += 2) {
12959
+ const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
12960
+ const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
12961
+
12962
+ // cast to 16 bins
12963
+ const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
12964
+ const uint8_t vi1 = ((y[i].qs[l/2] >> 4) | vh1) / 2;
12965
+
12966
+ hist[vi0]++;
12967
+ hist[vi1]++;
12968
+ }
12969
+ }
12970
+ }
12971
+
12972
+ return (n/QK5_1*sizeof(block_q5_1));
12973
+ }
12974
+
12975
+ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
12976
+ assert(k % QK8_0 == 0);
12977
+ const int nb = k / QK8_0;
12978
+
12979
+ for (int j = 0; j < n; j += k) {
12980
+ block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0;
12981
+
12982
+ quantize_row_q8_0_reference(src + j, y, k);
12983
+
12984
+ for (int i = 0; i < nb; i++) {
12985
+ for (int l = 0; l < QK8_0; ++l) {
12986
+ const int8_t vi = y[i].qs[l];
12987
+
12988
+ hist[vi/16 + 8]++;
12989
+ }
12990
+ }
12991
+ }
12992
+
12993
+ return (n/QK8_0*sizeof(block_q8_0));
12994
+ }
12995
+
12996
+ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12997
+ size_t result = 0;
12998
+ switch (type) {
12999
+ case GGML_TYPE_Q4_0:
13000
+ {
13001
+ GGML_ASSERT(start % QK4_0 == 0);
13002
+ block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
13003
+ result = ggml_quantize_q4_0(src + start, block, n, n, hist);
13004
+ } break;
13005
+ case GGML_TYPE_Q4_1:
13006
+ {
13007
+ GGML_ASSERT(start % QK4_1 == 0);
13008
+ block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
13009
+ result = ggml_quantize_q4_1(src + start, block, n, n, hist);
13010
+ } break;
13011
+ case GGML_TYPE_Q4_2:
13012
+ {
13013
+ GGML_ASSERT(start % QK4_2 == 0);
13014
+ block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
13015
+ result = ggml_quantize_q4_2(src + start, block, n, n, hist);
13016
+ } break;
13017
+ case GGML_TYPE_Q4_3:
13018
+ {
13019
+ GGML_ASSERT(start % QK4_3 == 0);
13020
+ block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
13021
+ result = ggml_quantize_q4_3(src + start, block, n, n, hist);
13022
+ } break;
13023
+ case GGML_TYPE_Q5_0:
13024
+ {
13025
+ GGML_ASSERT(start % QK5_0 == 0);
13026
+ block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
13027
+ result = ggml_quantize_q5_0(src + start, block, n, n, hist);
13028
+ } break;
13029
+ case GGML_TYPE_Q5_1:
13030
+ {
13031
+ GGML_ASSERT(start % QK5_1 == 0);
13032
+ block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
13033
+ result = ggml_quantize_q5_1(src + start, block, n, n, hist);
13034
+ } break;
13035
+ case GGML_TYPE_Q8_0:
13036
+ {
13037
+ GGML_ASSERT(start % QK8_0 == 0);
13038
+ block_q8_0 * block = (block_q8_0*)dst + start / QK8_0;
13039
+ result = ggml_quantize_q8_0(src + start, block, n, n, hist);
13040
+ } break;
13041
+ default:
13042
+ assert(false);
13043
+ }
13044
+ return result;
13045
+ }
13046
+
11712
13047
  ////////////////////////////////////////////////////////////////////////////////
11713
13048
 
11714
13049
  int ggml_cpu_has_avx(void) {
@@ -11800,13 +13135,33 @@ int ggml_cpu_has_wasm_simd(void) {
11800
13135
  }
11801
13136
 
11802
13137
  int ggml_cpu_has_blas(void) {
11803
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
13138
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
13139
+ return 1;
13140
+ #else
13141
+ return 0;
13142
+ #endif
13143
+ }
13144
+
13145
+ int ggml_cpu_has_cublas(void) {
13146
+ #if defined(GGML_USE_CUBLAS)
11804
13147
  return 1;
11805
13148
  #else
11806
13149
  return 0;
11807
13150
  #endif
11808
13151
  }
11809
13152
 
13153
+ int ggml_cpu_has_clblast(void) {
13154
+ #if defined(GGML_USE_CLBLAST)
13155
+ return 1;
13156
+ #else
13157
+ return 0;
13158
+ #endif
13159
+ }
13160
+
13161
+ int ggml_cpu_has_gpublas(void) {
13162
+ return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
13163
+ }
13164
+
11810
13165
  int ggml_cpu_has_sse3(void) {
11811
13166
  #if defined(__SSE3__)
11812
13167
  return 1;