llama_cpp 0.0.4 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@
19
19
  #include <inttypes.h>
20
20
  #include <stdio.h>
21
21
  #include <float.h>
22
+ #include <limits.h>
22
23
 
23
24
  // if C99 - static_assert is noop
24
25
  // ref: https://stackoverflow.com/a/53923785/4039976
@@ -118,7 +119,16 @@ typedef void* thread_ret_t;
118
119
  #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
120
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
121
  #else
121
- #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
122
+ inline static void* ggml_aligned_malloc(size_t size) {
123
+ void* aligned_memory = NULL;
124
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
125
+ if (result != 0) {
126
+ // Handle allocation failure
127
+ return NULL;
128
+ }
129
+ return aligned_memory;
130
+ }
131
+ #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
122
132
  #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
133
  #endif
124
134
 
@@ -133,10 +143,49 @@ typedef void* thread_ret_t;
133
143
  } \
134
144
  } while (0)
135
145
 
136
- #ifdef GGML_USE_ACCELERATE
146
+ #if defined(GGML_USE_ACCELERATE)
137
147
  #include <Accelerate/Accelerate.h>
138
- #elif GGML_USE_OPENBLAS
148
+ #elif defined(GGML_USE_OPENBLAS)
139
149
  #include <cblas.h>
150
+ #elif defined(GGML_USE_CUBLAS)
151
+ #include <cublas_v2.h>
152
+ #include <cuda_runtime.h>
153
+ #include "ggml-cuda.h"
154
+
155
+ #define CUDA_CHECK(err) \
156
+ do { \
157
+ cudaError_t err_ = (err); \
158
+ if (err_ != cudaSuccess) { \
159
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
+ cudaGetErrorString(err_)); \
161
+ exit(1); \
162
+ } \
163
+ } while (0)
164
+
165
+ #define CUBLAS_CHECK(err) \
166
+ do { \
167
+ cublasStatus_t err_ = (err); \
168
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
+ exit(1); \
171
+ } \
172
+ } while (0)
173
+
174
+ static cublasHandle_t cublasH = NULL;
175
+ static cudaStream_t cudaStream = NULL;
176
+ static void init_cublas(void) {
177
+ if (cublasH == NULL) {
178
+ // create cublas handle, bind a stream
179
+ CUBLAS_CHECK(cublasCreate(&cublasH));
180
+
181
+ CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
+
183
+ CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
+
185
+ // configure logging to stdout
186
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
+ }
188
+ }
140
189
  #endif
141
190
 
142
191
  #undef MIN
@@ -418,14 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
418
467
  // quantization
419
468
  //
420
469
 
421
- #define QK 32
470
+ #if __AVX__ || __AVX2__ || __AVX512F__
471
+ // Unpack 16 4-bit fields into 16 bytes
472
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
473
+ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
+ {
475
+ // Load 8 bytes from memory
476
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477
+
478
+ // Expand bytes into uint16_t values
479
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
480
+
481
+ // Unpack values into individual bytes
482
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
483
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
484
+ __m128i low = _mm_and_si128( lowMask, bytes );
485
+ high = _mm_slli_epi16( high, 4 );
486
+ bytes = _mm_or_si128( low, high );
487
+ return bytes;
488
+ }
422
489
 
423
- // AVX routines provided by GH user Const-me
424
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
425
490
  #if __AVX2__ || __AVX512F__
426
491
  // Unpack 32 4-bit fields into 32 bytes
427
492
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
428
- static inline __m256i bytesFromNibbles( const uint8_t* rsi )
493
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
429
494
  {
430
495
  // Load 16 bytes from memory
431
496
  __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -456,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
456
521
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
457
522
  return _mm_packus_epi16( r0, r1 );
458
523
  }
459
- #elif __AVX__
460
- static inline __m128i bytesFromNibbles( const uint8_t* rsi )
461
- {
462
- // Load 8 bytes from memory
463
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
464
-
465
- // Expand bytes into uint16_t values
466
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
467
-
468
- // Unpack values into individual bytes
469
- const __m128i lowMask = _mm_set1_epi8( 0xF );
470
- __m128i high = _mm_andnot_si128( lowMask, bytes );
471
- __m128i low = _mm_and_si128( lowMask, bytes );
472
- high = _mm_slli_epi16( high, 4 );
473
- bytes = _mm_or_si128( low, high );
474
- return bytes;
475
- }
476
-
524
+ #else
477
525
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
478
526
  {
479
527
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -490,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
490
538
  return _mm_packus_epi16( bytes1, bytes2);
491
539
  }
492
540
  #endif
541
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
493
542
 
494
543
  #if __ARM_NEON
495
544
 
@@ -507,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
507
556
  (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508
557
  }
509
558
 
559
+ inline static int16_t vaddvq_s8(int8x16_t v) {
560
+ return
561
+ (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
562
+ (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
563
+ (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
564
+ (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
565
+ (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
566
+ (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
567
+ (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
568
+ (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
569
+ }
570
+
510
571
  inline static int32_t vaddvq_s16(int16x8_t v) {
511
572
  return
512
573
  (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -531,68 +592,88 @@ inline static float vaddvq_f32(float32x4_t v) {
531
592
  return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
593
  }
533
594
 
534
- inline float vminvq_f32(float32x4_t v) {
595
+ float vminvq_f32(float32x4_t v) {
535
596
  return
536
597
  MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
598
  MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
599
  }
539
600
 
540
- inline float vmaxvq_f32(float32x4_t v) {
601
+ float vmaxvq_f32(float32x4_t v) {
541
602
  return
542
603
  MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
604
  MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
605
  }
545
606
 
546
- inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
607
+ int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
608
  return vget_low_s8(vcombine_s8(a, b));
548
609
  }
549
610
 
550
- inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
611
+ int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
612
  return vget_high_s8(vcombine_s8(a, b));
552
613
  }
553
614
 
554
- inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
615
+ uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
616
  return vget_low_u8(vcombine_u8(a, b));
556
617
  }
557
618
 
558
- inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
619
+ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
620
  return vget_high_u8(vcombine_u8(a, b));
560
621
  }
561
622
 
562
623
  #endif
563
624
  #endif
564
625
 
565
- // method 5
566
- // blocks of QK elements
567
- // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
626
+
627
+ #define QK4_0 32
568
628
  typedef struct {
569
- float d; // delta
570
- uint8_t qs[QK / 2]; // nibbles / quants
629
+ float d; // delta
630
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
571
631
  } block_q4_0;
572
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
632
+ static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
573
633
 
574
- // method 4
575
- // blocks of QK elements
576
- // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
634
+ #define QK4_1 32
577
635
  typedef struct {
578
- float d;
579
- float m;
580
- uint8_t qs[QK / 2]; // nibbles / quants
636
+ float d; // delta
637
+ float m; // min
638
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
581
639
  } block_q4_1;
582
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
640
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
641
+
642
+ #define QK4_2 16
643
+ typedef struct {
644
+ ggml_fp16_t d; // delta
645
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
646
+ } block_q4_2;
647
+ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
648
+
649
+ #define QK4_3 16
650
+ typedef struct {
651
+ ggml_fp16_t d; // delta
652
+ ggml_fp16_t m; // min
653
+ uint8_t qs[QK4_3 / 2]; // nibbles / quants
654
+ } block_q4_3;
655
+ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
656
+
657
+ #define QK8_0 32
658
+ typedef struct {
659
+ float d; // delta
660
+ int8_t qs[QK8_0]; // quants
661
+ } block_q8_0;
662
+ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663
+
583
664
 
584
665
  // reference implementation for deterministic creation of model files
585
666
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
586
- assert(k % QK == 0);
587
- const int nb = k / QK;
667
+ assert(k % QK4_0 == 0);
668
+ const int nb = k / QK4_0;
588
669
 
589
- uint8_t pp[QK/2];
670
+ uint8_t pp[QK4_0/2];
590
671
 
591
672
  for (int i = 0; i < nb; i++) {
592
673
  float amax = 0.0f; // absolute max
593
674
 
594
- for (int l = 0; l < QK; l++) {
595
- const float v = x[i*QK + l];
675
+ for (int l = 0; l < QK4_0; l++) {
676
+ const float v = x[i*QK4_0 + l];
596
677
  amax = MAX(amax, fabsf(v));
597
678
  }
598
679
 
@@ -601,9 +682,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
601
682
 
602
683
  y[i].d = d;
603
684
 
604
- for (int l = 0; l < QK; l += 2) {
605
- const float v0 = x[i*QK + l + 0]*id;
606
- const float v1 = x[i*QK + l + 1]*id;
685
+ for (int l = 0; l < QK4_0; l += 2) {
686
+ const float v0 = x[i*QK4_0 + l + 0]*id;
687
+ const float v1 = x[i*QK4_0 + l + 1]*id;
607
688
 
608
689
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
609
690
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -619,8 +700,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
619
700
  }
620
701
 
621
702
  static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
622
- assert(k % QK == 0);
623
- const int nb = k / QK;
703
+ assert(k % QK4_0 == 0);
704
+ const int nb = k / QK4_0;
624
705
 
625
706
  block_q4_0 * restrict y = vy;
626
707
 
@@ -870,19 +951,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
870
951
  }
871
952
 
872
953
  static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
873
- assert(k % QK == 0);
874
- const int nb = k / QK;
954
+ assert(k % QK4_1 == 0);
955
+ const int nb = k / QK4_1;
875
956
 
876
957
  block_q4_1 * restrict y = vy;
877
958
 
878
- uint8_t pp[QK/2];
959
+ uint8_t pp[QK4_1/2];
879
960
 
880
961
  for (int i = 0; i < nb; i++) {
881
962
  float min = FLT_MAX;
882
963
  float max = -FLT_MAX;
883
964
 
884
- for (int l = 0; l < QK; l++) {
885
- const float v = x[i*QK + l];
965
+ for (int l = 0; l < QK4_1; l++) {
966
+ const float v = x[i*QK4_1 + l];
886
967
  if (v < min) min = v;
887
968
  if (v > max) max = v;
888
969
  }
@@ -893,9 +974,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
893
974
  y[i].d = d;
894
975
  y[i].m = min;
895
976
 
896
- for (int l = 0; l < QK; l += 2) {
897
- const float v0 = (x[i*QK + l + 0] - min)*id;
898
- const float v1 = (x[i*QK + l + 1] - min)*id;
977
+ for (int l = 0; l < QK4_1; l += 2) {
978
+ const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
979
+ const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
899
980
 
900
981
  const uint8_t vi0 = roundf(v0);
901
982
  const uint8_t vi1 = roundf(v1);
@@ -911,9 +992,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
911
992
  }
912
993
 
913
994
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
914
- assert(k % QK == 0);
995
+ assert(k % QK4_1 == 0);
915
996
 
916
- const int nb = k / QK;
997
+ const int nb = k / QK4_1;
917
998
 
918
999
  block_q4_1 * restrict y = vy;
919
1000
 
@@ -997,7 +1078,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
997
1078
  float32x4_t minv[8];
998
1079
  float32x4_t maxv[8];
999
1080
 
1000
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
1081
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
1001
1082
 
1002
1083
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
1003
1084
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -1033,9 +1114,327 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1033
1114
  #endif
1034
1115
  }
1035
1116
 
1117
+ // reference implementation for deterministic creation of model files
1118
+ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
1119
+ assert(k % QK4_2 == 0);
1120
+
1121
+ const int nb = k / QK4_2;
1122
+
1123
+ for (int i = 0; i < nb; i++) {
1124
+ float amax = 0.0f; // absolute max
1125
+
1126
+ for (int l = 0; l < QK4_2; l++) {
1127
+ const float v = x[i*QK4_2 + l];
1128
+ amax = MAX(amax, fabsf(v));
1129
+ }
1130
+
1131
+ const float d = amax / ((1 << 3) - 1);
1132
+
1133
+ const float id = d ? 1.0f/d : 0.0f;
1134
+
1135
+ y[i].d = GGML_FP32_TO_FP16(d);
1136
+
1137
+ for (int l = 0; l < QK4_2; l += 2) {
1138
+ const float v0 = x[i*QK4_2 + l + 0]*id;
1139
+ const float v1 = x[i*QK4_2 + l + 1]*id;
1140
+
1141
+ const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
+ const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1143
+
1144
+ assert(vi0 < 16);
1145
+ assert(vi1 < 16);
1146
+
1147
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1148
+ }
1149
+ }
1150
+ }
1151
+
1152
+ static inline int nearest_int(float fval) {
1153
+ assert(fval <= 4194303.f);
1154
+ float val = fval + 12582912.f;
1155
+ int i; memcpy(&i, &val, sizeof(int));
1156
+ return (i & 0x007fffff) - 0x00400000;
1157
+ }
1158
+
1159
+ static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
+ const float * restrict candidates, int8_t * restrict L) {
1161
+ assert (nmin >= INT8_MIN);
1162
+ assert (nmax <= INT8_MAX);
1163
+ float amax = 0;
1164
+ for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
+ if (!amax) { // all zero
1166
+ for (int i=0; i<n; ++i) L[i] = 0;
1167
+ return 1.f;
1168
+ }
1169
+ float best = 0, bestScale = 0;
1170
+ for (int si=0; si<nCandidates; ++si) {
1171
+ float iscale = candidates[si]/amax;
1172
+ float sumlxP = 0; int suml2P = 0;
1173
+ float sumlxM = 0; int suml2M = 0;
1174
+ for (int i=0; i<n; ++i) {
1175
+ int l = nearest_int(iscale*X[i]);
1176
+ int lp = MAX(nmin, MIN(nmax, +l));
1177
+ int lm = MAX(nmin, MIN(nmax, -l));
1178
+ sumlxP += X[i]*lp; suml2P += lp*lp;
1179
+ sumlxM += X[i]*lm; suml2M += lm*lm;
1180
+ }
1181
+ float sumlxP2 = sumlxP*sumlxP;
1182
+ float sumlxM2 = sumlxM*sumlxM;
1183
+ if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
+ if (sumlxP2 > best*suml2P) {
1185
+ best = sumlxP2/suml2P; bestScale = iscale;
1186
+ }
1187
+ } else {
1188
+ if (sumlxM2 > best*suml2M) {
1189
+ best = sumlxM2/suml2M; bestScale = -iscale;
1190
+ }
1191
+ }
1192
+ }
1193
+ float sumlx = 0; int suml2 = 0;
1194
+ for (int i=0; i<n; ++i) {
1195
+ int l = nearest_int(bestScale*X[i]);
1196
+ l = MAX(nmin, MIN(nmax, l));
1197
+ sumlx += X[i]*l; suml2 += l*l;
1198
+ L[i] = l;
1199
+ }
1200
+ float scale = sumlx/suml2;
1201
+ return scale;
1202
+ }
1203
+
1204
+ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
+ #define CANDIDATE_COUNT 8
1206
+ static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
+ assert(k % QK4_2 == 0);
1208
+
1209
+ int8_t L[QK4_2];
1210
+
1211
+ const int nb = k / QK4_2;
1212
+
1213
+ for (int i = 0; i < nb; i++) {
1214
+ float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
+ y[i].d = GGML_FP32_TO_FP16(scale);
1216
+
1217
+ for (int l = 0; l < QK4_2; l += 2) {
1218
+ const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
+ const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1220
+
1221
+ assert(vi0 < 16);
1222
+ assert(vi1 < 16);
1223
+
1224
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1225
+ }
1226
+
1227
+ x += QK4_2;
1228
+ }
1229
+ }
1230
+
1231
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1232
+ assert(k % QK4_2 == 0);
1233
+
1234
+ block_q4_2 * restrict y = vy;
1235
+
1236
+ //quantize_row_q4_2_reference(x, y, k);
1237
+ // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
+ quantize_row_q4_2_rmse(x, y, k);
1239
+ }
1240
+
1241
+ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1242
+ assert(k % QK4_3 == 0);
1243
+ const int nb = k / QK4_3;
1244
+
1245
+ for (int i = 0; i < nb; i++) {
1246
+ float min = FLT_MAX;
1247
+ float max = -FLT_MAX;
1248
+
1249
+ for (int l = 0; l < QK4_3; l++) {
1250
+ const float v = x[i*QK4_3 + l];
1251
+ if (v < min) min = v;
1252
+ if (v > max) max = v;
1253
+ }
1254
+
1255
+ const float d = (max - min) / ((1 << 4) - 1);
1256
+ const float id = d ? 1.0f/d : 0.0f;
1257
+
1258
+ y[i].d = GGML_FP32_TO_FP16(d);
1259
+ y[i].m = GGML_FP32_TO_FP16(min);
1260
+
1261
+ for (int l = 0; l < QK4_3; l += 2) {
1262
+ const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1263
+ const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1264
+
1265
+ const uint8_t vi0 = (int) (v0 + 0.5f);
1266
+ const uint8_t vi1 = (int) (v1 + 0.5f);
1267
+
1268
+ assert(vi0 < 16);
1269
+ assert(vi1 < 16);
1270
+
1271
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1272
+ }
1273
+ }
1274
+ }
1275
+
1276
+ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1277
+ assert(k % QK4_3 == 0);
1278
+
1279
+ block_q4_3 * restrict y = vy;
1280
+
1281
+ quantize_row_q4_3_reference(x, y, k);
1282
+ }
1283
+
1284
+ // reference implementation for deterministic creation of model files
1285
+ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1286
+ assert(k % QK8_0 == 0);
1287
+ const int nb = k / QK8_0;
1288
+
1289
+ for (int i = 0; i < nb; i++) {
1290
+ float amax = 0.0f; // absolute max
1291
+
1292
+ for (int l = 0; l < QK8_0; l++) {
1293
+ const float v = x[i*QK8_0 + l];
1294
+ amax = MAX(amax, fabsf(v));
1295
+ }
1296
+
1297
+ const float d = amax / ((1 << 7) - 1);
1298
+ const float id = d ? 1.0f/d : 0.0f;
1299
+
1300
+ y[i].d = d;
1301
+
1302
+ for (int l = 0; l < QK8_0; ++l) {
1303
+ const float v = x[i*QK8_0 + l]*id;
1304
+ y[i].qs[l] = roundf(v);
1305
+ }
1306
+ }
1307
+ }
1308
+
1309
+ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1310
+ assert(k % QK8_0 == 0);
1311
+ const int nb = k / QK8_0;
1312
+
1313
+ block_q8_0 * restrict y = vy;
1314
+
1315
+ #if defined(__ARM_NEON)
1316
+ for (int i = 0; i < nb; i++) {
1317
+ float32x4_t srcv [8];
1318
+ float32x4_t asrcv[8];
1319
+ float32x4_t amaxv[8];
1320
+
1321
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1322
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1323
+
1324
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1325
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1326
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1327
+
1328
+ const float amax = vmaxvq_f32(amaxv[0]);
1329
+
1330
+ const float d = amax / ((1 << 7) - 1);
1331
+ const float id = d ? 1.0f/d : 0.0f;
1332
+
1333
+ y[i].d = d;
1334
+
1335
+ for (int l = 0; l < 8; l++) {
1336
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1337
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1338
+
1339
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1340
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1341
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1342
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1343
+ }
1344
+ }
1345
+ #elif defined(__AVX2__) || defined(__AVX__)
1346
+ for (int i = 0; i < nb; i++) {
1347
+ // Load elements into 4 AVX vectors
1348
+ __m256 v0 = _mm256_loadu_ps( x );
1349
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1350
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1351
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1352
+ x += 32;
1353
+
1354
+ // Compute max(abs(e)) for the block
1355
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1356
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1357
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1358
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1359
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1360
+
1361
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1362
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1363
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1364
+ const float maxScalar = _mm_cvtss_f32( max4 );
1365
+
1366
+ // Quantize these floats
1367
+ const float d = maxScalar / 127.f;
1368
+ y[i].d = d;
1369
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1370
+ const __m256 mul = _mm256_set1_ps( id );
1371
+
1372
+ // Apply the multiplier
1373
+ v0 = _mm256_mul_ps( v0, mul );
1374
+ v1 = _mm256_mul_ps( v1, mul );
1375
+ v2 = _mm256_mul_ps( v2, mul );
1376
+ v3 = _mm256_mul_ps( v3, mul );
1377
+
1378
+ // Round to nearest integer
1379
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1380
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1381
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1382
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1383
+
1384
+ // Convert floats to integers
1385
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1386
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1387
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1388
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1389
+
1390
+ #if defined(__AVX2__)
1391
+ // Convert int32 to int16
1392
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1393
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1394
+ // Convert int16 to int8
1395
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1396
+
1397
+ // We got our precious signed bytes, but the order is now wrong
1398
+ // These AVX2 pack instructions process 16-byte pieces independently
1399
+ // The following instruction is fixing the order
1400
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1401
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1402
+
1403
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1404
+ #else
1405
+ // Since we don't have in AVX some necessary functions,
1406
+ // we split the registers in half and call AVX2 analogs from SSE
1407
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1408
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1409
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1410
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1411
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1412
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1413
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1414
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1415
+
1416
+ // Convert int32 to int16
1417
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1418
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1419
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1420
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1421
+ // Convert int16 to int8
1422
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1423
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1424
+
1425
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1426
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1427
+ #endif
1428
+ }
1429
+ #else
1430
+ // scalar
1431
+ quantize_row_q8_0_reference(x, y, k);
1432
+ #endif
1433
+ }
1434
+
1036
1435
  static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1037
- assert(k % QK == 0);
1038
- const int nb = k / QK;
1436
+ assert(k % QK4_0 == 0);
1437
+ const int nb = k / QK4_0;
1039
1438
 
1040
1439
  const block_q4_0 * restrict x = vx;
1041
1440
 
@@ -1046,9 +1445,9 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1046
1445
 
1047
1446
  const uint8_t * restrict pp = x[i].qs;
1048
1447
 
1049
- for (int l = 0; l < QK; l += 32) {
1448
+ for (int l = 0; l < QK4_0; l += 32) {
1050
1449
  // Load 32x4-bit integers into 32x8-bit integers
1051
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1450
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1052
1451
 
1053
1452
  // Subtract 8 from the integers
1054
1453
  vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1068,7 +1467,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1068
1467
  // Scale and store
1069
1468
  for (int j = 0; j < 4; j++) {
1070
1469
  const __m256 result = _mm256_mul_ps(vf[j], d_v);
1071
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1470
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1072
1471
  }
1073
1472
  }
1074
1473
  }
@@ -1078,7 +1477,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1078
1477
 
1079
1478
  const uint8_t * restrict pp = x[i].qs;
1080
1479
 
1081
- for (int l = 0; l < QK; l += 16) {
1480
+ for (int l = 0; l < QK4_0; l += 16) {
1082
1481
  // Load 16x4-bit integers into 8x8-bit integers
1083
1482
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1084
1483
 
@@ -1117,10 +1516,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1117
1516
  const float32x4_t r3 = vmulq_f32(vf_3, vd);
1118
1517
 
1119
1518
  // Store
1120
- vst1q_f32(y + i*QK + l + 0, r0);
1121
- vst1q_f32(y + i*QK + l + 4, r1);
1122
- vst1q_f32(y + i*QK + l + 8, r2);
1123
- vst1q_f32(y + i*QK + l + 12, r3);
1519
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1520
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1521
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1522
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1124
1523
  }
1125
1524
  }
1126
1525
  #else
@@ -1130,7 +1529,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1130
1529
 
1131
1530
  const uint8_t * restrict pp = x[i].qs;
1132
1531
 
1133
- for (int l = 0; l < QK; l += 2) {
1532
+ for (int l = 0; l < QK4_0; l += 2) {
1134
1533
  const uint8_t vi = pp[l/2];
1135
1534
 
1136
1535
  const int8_t vi0 = vi & 0xf;
@@ -1141,19 +1540,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1141
1540
 
1142
1541
  //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
1143
1542
 
1144
- y[i*QK + l + 0] = v0;
1145
- y[i*QK + l + 1] = v1;
1543
+ y[i*QK4_0 + l + 0] = v0;
1544
+ y[i*QK4_0 + l + 1] = v1;
1146
1545
 
1147
- assert(!isnan(y[i*QK + l + 0]));
1148
- assert(!isnan(y[i*QK + l + 1]));
1546
+ assert(!isnan(y[i*QK4_0 + l + 0]));
1547
+ assert(!isnan(y[i*QK4_0 + l + 1]));
1149
1548
  }
1150
1549
  }
1151
1550
  #endif
1152
1551
  }
1153
1552
 
1154
1553
  static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
1155
- assert(k % QK == 0);
1156
- const int nb = k / QK;
1554
+ assert(k % QK4_1 == 0);
1555
+ const int nb = k / QK4_1;
1157
1556
 
1158
1557
  const block_q4_1 * restrict x = vx;
1159
1558
 
@@ -1164,9 +1563,9 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1164
1563
 
1165
1564
  const uint8_t * restrict pp = x[i].qs;
1166
1565
 
1167
- for (int l = 0; l < QK; l += 32) {
1566
+ for (int l = 0; l < QK4_1; l += 32) {
1168
1567
  // Load 32x4-bit integers into 32x8-bit integers
1169
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1568
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1170
1569
 
1171
1570
  // Convert to 16-bit int
1172
1571
  const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1183,7 +1582,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1183
1582
  // Scale, add m and store
1184
1583
  for (int j = 0; j < 4; j++) {
1185
1584
  const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
1186
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1585
+ _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
1187
1586
  }
1188
1587
  }
1189
1588
  }
@@ -1194,7 +1593,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1194
1593
 
1195
1594
  const uint8_t * restrict pp = x[i].qs;
1196
1595
 
1197
- for (int l = 0; l < QK; l += 16) {
1596
+ for (int l = 0; l < QK4_1; l += 16) {
1198
1597
  // Load 16x4-bit integers into 8x8-bit integers
1199
1598
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1200
1599
 
@@ -1225,10 +1624,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1225
1624
  const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1226
1625
 
1227
1626
  // Store
1228
- vst1q_f32(y + i*QK + l + 0, r0);
1229
- vst1q_f32(y + i*QK + l + 4, r1);
1230
- vst1q_f32(y + i*QK + l + 8, r2);
1231
- vst1q_f32(y + i*QK + l + 12, r3);
1627
+ vst1q_f32(y + i*QK4_1 + l + 0, r0);
1628
+ vst1q_f32(y + i*QK4_1 + l + 4, r1);
1629
+ vst1q_f32(y + i*QK4_1 + l + 8, r2);
1630
+ vst1q_f32(y + i*QK4_1 + l + 12, r3);
1232
1631
  }
1233
1632
  }
1234
1633
  #else
@@ -1238,7 +1637,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1238
1637
 
1239
1638
  const uint8_t * restrict pp = x[i].qs;
1240
1639
 
1241
- for (int l = 0; l < QK; l += 2) {
1640
+ for (int l = 0; l < QK4_1; l += 2) {
1242
1641
  const uint8_t vi = pp[l/2];
1243
1642
 
1244
1643
  const int8_t vi0 = vi & 0xf;
@@ -1247,21 +1646,130 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1247
1646
  const float v0 = vi0*d + m;
1248
1647
  const float v1 = vi1*d + m;
1249
1648
 
1250
- y[i*QK + l + 0] = v0;
1251
- y[i*QK + l + 1] = v1;
1649
+ y[i*QK4_1 + l + 0] = v0;
1650
+ y[i*QK4_1 + l + 1] = v1;
1252
1651
 
1253
- assert(!isnan(y[i*QK + l + 0]));
1254
- assert(!isnan(y[i*QK + l + 1]));
1652
+ assert(!isnan(y[i*QK4_1 + l + 0]));
1653
+ assert(!isnan(y[i*QK4_1 + l + 1]));
1255
1654
  }
1256
1655
  }
1257
1656
  #endif
1258
1657
  }
1259
1658
 
1260
- //
1261
- // simd mappings
1262
- //
1659
+ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1660
+ assert(k % QK4_2 == 0);
1661
+ const int nb = k / QK4_2;
1263
1662
 
1264
- // we define a common set of C macros which map to specific intrinsics based on the current architecture
1663
+ const block_q4_2 * restrict x = vx;
1664
+
1665
+ for (int i = 0; i < nb; i++) {
1666
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1667
+
1668
+ const uint8_t * restrict pp = x[i].qs;
1669
+
1670
+ for (int l = 0; l < QK4_2; l += 2) {
1671
+ const uint8_t vi = pp[l/2];
1672
+
1673
+ const int8_t vi0 = vi & 0xf;
1674
+ const int8_t vi1 = vi >> 4;
1675
+
1676
+ const float v0 = (vi0 - 8)*d;
1677
+ const float v1 = (vi1 - 8)*d;
1678
+
1679
+ y[i*QK4_2 + l + 0] = v0;
1680
+ y[i*QK4_2 + l + 1] = v1;
1681
+
1682
+ assert(!isnan(y[i*QK4_2 + l + 0]));
1683
+ assert(!isnan(y[i*QK4_2 + l + 1]));
1684
+ }
1685
+ }
1686
+ }
1687
+
1688
+ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1689
+ assert(k % QK4_3 == 0);
1690
+ const int nb = k / QK4_3;
1691
+
1692
+ const block_q4_3 * restrict x = vx;
1693
+
1694
+ for (int i = 0; i < nb; i++) {
1695
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1696
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1697
+
1698
+ const uint8_t * restrict pp = x[i].qs;
1699
+
1700
+ for (int l = 0; l < QK4_3; l += 2) {
1701
+ const uint8_t vi = pp[l/2];
1702
+
1703
+ const int8_t vi0 = vi & 0xf;
1704
+ const int8_t vi1 = vi >> 4;
1705
+
1706
+ const float v0 = vi0*d + m;
1707
+ const float v1 = vi1*d + m;
1708
+
1709
+ y[i*QK4_3 + l + 0] = v0;
1710
+ y[i*QK4_3 + l + 1] = v1;
1711
+
1712
+ assert(!isnan(y[i*QK4_3 + l + 0]));
1713
+ assert(!isnan(y[i*QK4_3 + l + 1]));
1714
+ }
1715
+ }
1716
+ }
1717
+
1718
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1722
+
1723
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1724
+ [GGML_TYPE_Q4_0] = {
1725
+ .dequantize_row_q = dequantize_row_q4_0,
1726
+ .quantize_row_q = quantize_row_q4_0,
1727
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1728
+ .quantize_row_q_dot = quantize_row_q8_0,
1729
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1730
+ },
1731
+ [GGML_TYPE_Q4_1] = {
1732
+ .dequantize_row_q = dequantize_row_q4_1,
1733
+ .quantize_row_q = quantize_row_q4_1,
1734
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1735
+ .quantize_row_q_dot = quantize_row_q8_0,
1736
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
1737
+ },
1738
+ [GGML_TYPE_Q4_2] = {
1739
+ .dequantize_row_q = dequantize_row_q4_2,
1740
+ .quantize_row_q = quantize_row_q4_2,
1741
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
1742
+ .quantize_row_q_dot = quantize_row_q8_0,
1743
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
1744
+ },
1745
+ [GGML_TYPE_Q4_3] = {
1746
+ .dequantize_row_q = dequantize_row_q4_3,
1747
+ .quantize_row_q = quantize_row_q4_3,
1748
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
1749
+ .quantize_row_q_dot = quantize_row_q8_0,
1750
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
1751
+ },
1752
+ [GGML_TYPE_Q8_0] = {
1753
+ .dequantize_row_q = NULL, // TODO
1754
+ .quantize_row_q = quantize_row_q8_0,
1755
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
+ .quantize_row_q_dot = quantize_row_q8_0,
1757
+ .vec_dot_q = NULL, // TODO
1758
+ },
1759
+ };
1760
+
1761
+ // For internal test use
1762
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1763
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
1764
+ return quantize_fns[i];
1765
+ }
1766
+
1767
+
1768
+ //
1769
+ // simd mappings
1770
+ //
1771
+
1772
+ // we define a common set of C macros which map to specific intrinsics based on the current architecture
1265
1773
  // we then implement the fundamental computation operations below using only these macros
1266
1774
  // adding support for new architectures requires to define the corresponding SIMD macros
1267
1775
  //
@@ -1813,37 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1813
2321
  *s = sumf;
1814
2322
  }
1815
2323
 
1816
- #if __AVX512F__ && QK == 32
1817
- static inline __m512 dot_q4_0_oneblock_avx512(
1818
- __m512 acc,
1819
- const block_q4_0 * restrict x,
1820
- const block_q4_0 * restrict y,
1821
- int i
1822
- ) {
1823
- // Compute combined scale for the block
1824
- __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
1825
-
1826
- __m256i bx = bytesFromNibbles( x[i].qs );
1827
- __m256i by = bytesFromNibbles( y[i].qs );
1828
-
1829
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1830
- const __m256i off = _mm256_set1_epi8( 8 );
1831
- bx = _mm256_sub_epi8( bx, off );
1832
- by = _mm256_sub_epi8( by, off );
1833
-
1834
- // Sign-extend 16 signed bytes into int16_t
1835
- __m512i x32 = _mm512_cvtepi8_epi16( bx );
1836
- __m512i y32 = _mm512_cvtepi8_epi16( by );
1837
- // Compute products of int16_t integers, add pairwise
1838
- __m512i i64 = _mm512_madd_epi16( x32, y32 );
1839
-
1840
- // Convert int32_t to float
1841
- __m512 p = _mm512_cvtepi32_ps( i64 );
1842
- // Apply the scale, and accumulate
1843
- return _mm512_fmadd_ps( d, p, acc );
1844
- }
1845
- #endif
1846
-
1847
2324
  inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
1848
2325
  ggml_float sumf = 0.0;
1849
2326
 
@@ -1880,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
1880
2357
  *s = sumf;
1881
2358
  }
1882
2359
 
1883
- static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1884
- const int nb = n / QK;
2360
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2361
+ const int nb = n / QK8_0;
1885
2362
 
1886
- assert(n % QK == 0);
2363
+ assert(n % QK8_0 == 0);
1887
2364
  assert(nb % 2 == 0);
1888
2365
 
1889
2366
  const block_q4_0 * restrict x = vx;
1890
- const block_q4_0 * restrict y = vy;
2367
+ const block_q8_0 * restrict y = vy;
1891
2368
 
1892
2369
  float sumf = 0.0;
1893
2370
 
1894
2371
  #if defined(__ARM_NEON)
1895
- float sum0 = 0.0f;
1896
- float sum1 = 0.0f;
2372
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
1897
2374
 
1898
2375
  for (int i = 0; i < nb; i += 2) {
1899
2376
  const block_q4_0 * restrict x0 = &x[i + 0];
1900
- const block_q4_0 * restrict y0 = &y[i + 0];
1901
2377
  const block_q4_0 * restrict x1 = &x[i + 1];
1902
- const block_q4_0 * restrict y1 = &y[i + 1];
2378
+ const block_q8_0 * restrict y0 = &y[i + 0];
2379
+ const block_q8_0 * restrict y1 = &y[i + 1];
1903
2380
 
1904
- const uint8x16_t m4b = vdupq_n_u8(0xf);
1905
- const int8x16_t s8b = vdupq_n_s8(0x8);
2381
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2382
+ const int8x16_t s8b = vdupq_n_s8(0x8);
1906
2383
 
1907
2384
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
1908
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
1909
2385
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
1910
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
1911
2386
 
1912
2387
  // 4-bit -> 8-bit
1913
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1914
- const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
2388
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
1915
2389
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1916
- const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1917
-
1918
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1919
- const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
2390
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
1920
2391
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1921
- const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1922
2392
 
1923
2393
  // sub 8
1924
2394
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1925
- const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1926
2395
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1927
- const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1928
-
1929
2396
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1930
- const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1931
2397
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1932
- const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
2398
+
2399
+ // load y
2400
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2401
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2402
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2403
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2404
+
2405
+ // interleave
2406
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2407
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2408
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2409
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
1933
2410
 
1934
2411
  #if defined(__ARM_FEATURE_DOTPROD)
1935
2412
  // dot product into int32x4_t
1936
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1937
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2413
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
1938
2415
 
1939
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1940
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1941
-
1942
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2416
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2417
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
1944
2418
  #else
1945
2419
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1946
2420
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -1952,115 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1952
2426
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1953
2427
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1954
2428
 
1955
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
1956
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
1957
-
1958
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
1959
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2429
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2430
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2431
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2432
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
1960
2433
 
1961
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1962
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1963
-
1964
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2434
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2435
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
1966
2436
  #endif
1967
2437
  }
1968
2438
 
1969
- sumf = sum0 + sum1;
1970
- #elif defined(__AVX512F__)
2439
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
+ #elif defined(__AVX2__)
1971
2441
  // Initialize accumulator with zeros
1972
- __m512 acc0 = _mm512_setzero_ps();
1973
- __m512 acc1 = _mm512_setzero_ps();
2442
+ __m256 acc = _mm256_setzero_ps();
1974
2443
 
1975
- const int superblock_size = 8;
1976
- const int superblock_count = nb / superblock_size;
2444
+ // Main loop
2445
+ for (int i = 0; i < nb; ++i) {
2446
+ /* Compute combined scale for the block */
2447
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
1977
2448
 
1978
- for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1979
- int i = superblock_ix * superblock_size;
2449
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
1980
2450
 
1981
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
1982
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
1983
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
1984
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
1985
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
1986
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
1987
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
1988
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
1989
- }
2451
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2452
+ const __m256i off = _mm256_set1_epi8( 8 );
2453
+ bx = _mm256_sub_epi8( bx, off );
1990
2454
 
1991
- // Remainders
1992
- for (int i = superblock_count * superblock_size; i < nb; ++i) {
1993
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
1994
- }
2455
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1995
2456
 
1996
- // Horizontal sum of all lanes of the accumulator
1997
- sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
1998
- #elif defined(__AVX2__)
1999
- // Initialize accumulator with zeros
2000
- __m256 acc = _mm256_setzero_ps();
2457
+ // Get absolute values of x vectors
2458
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2001
2459
 
2002
- /* Prepare the constants we will need during execution */
2003
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
2004
- const __m256i offset_8 = _mm256_set1_epi16( 8 );
2460
+ // Sign the values of the y vectors
2461
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2005
2462
 
2006
- #define UNROLL_COUNT 8
2007
- // make sure we only unroll multiples of the block count
2008
- assert(nb % UNROLL_COUNT == 0);
2463
+ // Perform multiplication and create 16-bit values
2464
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2009
2465
 
2010
- // Main loop
2011
- for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2012
- // This loop will be unrolled by the compiler
2013
- for (int u=0;u<UNROLL_COUNT;u++) {
2014
- /* Compute combined scale for the block */
2015
- const __m256 scale = _mm256_mul_ps(
2016
- _mm256_broadcast_ss( &x[i+u].d ),
2017
- _mm256_broadcast_ss( &y[i+u].d ) );
2018
-
2019
- /* get input from x
2020
- Input: 32 Nibbles (16 bytes) at *x[i+u]
2021
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2022
-
2023
- /* Load 16 bytes from memory */
2024
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2025
- /* Expand bytes into uint16_t values */
2026
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2027
- /* Unpack values into individual bytes */
2028
- __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2029
- const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2030
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2031
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2032
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2033
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2034
-
2035
- /* get input from y
2036
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2037
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2038
-
2039
- /* Load 16 bytes from memory */
2040
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2041
- /* Expand bytes into uint16_t values */
2042
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2043
- /* Unpack values into individual bytes */
2044
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2045
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2046
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2047
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2048
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2049
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2050
-
2051
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2052
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2053
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2054
-
2055
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2056
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2057
-
2058
- /* Convert to vectore of 8 int32_t to 8 floats */
2059
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2060
-
2061
- /* Multiply q with scale and accumulate */
2062
- acc = _mm256_fmadd_ps( scale, q, acc );
2063
- }
2466
+ const __m256i ones = _mm256_set1_epi16(1);
2467
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2468
+
2469
+ /* Convert to vectore of 8 int32_t to 8 floats */
2470
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2471
+
2472
+ /* Multiply q with scale and accumulate */
2473
+ acc = _mm256_fmadd_ps( d, q, acc );
2064
2474
  }
2065
2475
 
2066
2476
  // Return horizontal sum of the acc vector
@@ -2082,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2082
2492
  __m128i i32[2];
2083
2493
  for (int j = 0; j < 2; ++j) {
2084
2494
  // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2085
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2086
- __m128i by = bytesFromNibbles( y[i].qs + 8*j );
2495
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2496
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2087
2497
 
2088
2498
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2089
2499
  const __m128i off = _mm_set1_epi8( 8 );
2090
2500
  bx = _mm_sub_epi8( bx, off );
2091
- by = _mm_sub_epi8( by, off );
2092
2501
 
2093
2502
  // Get absolute values of x vectors
2094
2503
  const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2116,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2116
2525
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2117
2526
 
2118
2527
  sumf = _mm_cvtss_f32( res );
2119
- #elif defined(__wasm_simd128__)
2120
- // wasm simd
2121
- float sum0 = 0.0f;
2122
- float sum1 = 0.0f;
2123
-
2124
- for (int i = 0; i < nb; i += 2) {
2125
- const block_q4_0 * restrict x0 = &x[i + 0];
2126
- const block_q4_0 * restrict y0 = &y[i + 0];
2127
- const block_q4_0 * restrict x1 = &x[i + 1];
2128
- const block_q4_0 * restrict y1 = &y[i + 1];
2129
-
2130
- const v128_t m4b = wasm_u8x16_splat(0xf);
2131
- const v128_t s8b = wasm_i8x16_splat(0x8);
2132
-
2133
- const v128_t v0_0 = wasm_v128_load(x0->qs);
2134
- const v128_t v0_1 = wasm_v128_load(y0->qs);
2135
- const v128_t v1_0 = wasm_v128_load(x1->qs);
2136
- const v128_t v1_1 = wasm_v128_load(y1->qs);
2137
-
2138
- // 4-bit -> 8-bit
2139
- const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2140
- const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
2141
-
2142
- const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2143
- const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
2144
-
2145
- const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2146
- const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
2147
-
2148
- const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2149
- const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
2150
-
2151
- // sub 8
2152
- const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2153
- const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
2154
-
2155
- const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2156
- const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
2157
-
2158
- const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2159
- const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
2160
-
2161
- const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2162
- const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
2163
-
2164
- // dot product into int16x8_t
2165
- const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
2166
- const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
2167
-
2168
- const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
2169
- const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
2170
-
2171
- const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
2172
- const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
2173
-
2174
- const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
2175
- const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
2176
-
2177
- const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
2178
- const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
2179
-
2180
- const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
2181
- const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
2182
-
2183
- const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
2184
- const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
2185
-
2186
- sum0 += x0->d * y0->d * (
2187
- wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
2188
- wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
2189
- wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
2190
- wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
2191
- sum1 += x1->d * y1->d * (
2192
- wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
2193
- wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
2194
- wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
2195
- wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
2196
- }
2197
-
2198
- sumf = sum0 + sum1;
2199
2528
  #else
2200
2529
  // scalar
2201
2530
  for (int i = 0; i < nb; i++) {
@@ -2203,98 +2532,159 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2203
2532
  const float d1 = y[i].d;
2204
2533
 
2205
2534
  const uint8_t * restrict p0 = x[i].qs;
2206
- const uint8_t * restrict p1 = y[i].qs;
2535
+ const int8_t * restrict p1 = y[i].qs;
2207
2536
 
2208
2537
  int sumi = 0;
2209
- for (int j = 0; j < QK/2; j++) {
2538
+ for (int j = 0; j < QK8_0/2; j++) {
2210
2539
  const uint8_t v0 = p0[j];
2211
- const uint8_t v1 = p1[j];
2212
2540
 
2213
- const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
- const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2541
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2215
2543
 
2216
- const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
- const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2544
+ const int i2 = p1[2*j + 0];
2545
+ const int i3 = p1[2*j + 1];
2218
2546
 
2219
2547
  sumi += i0*i2 + i1*i3;
2220
2548
  }
2221
- sumf += d0 * d1 * sumi;
2549
+ sumf += d0*d1*sumi;
2222
2550
  }
2223
2551
  #endif
2224
2552
 
2225
2553
  *s = sumf;
2226
2554
  }
2227
2555
 
2228
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK;
2556
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
+ const int nb = n / QK8_0;
2558
+
2559
+ assert(n % QK8_0 == 0);
2560
+ assert(nb % 2 == 0);
2230
2561
 
2231
2562
  const block_q4_1 * restrict x = vx;
2232
- const block_q4_1 * restrict y = vy;
2563
+ const block_q8_0 * restrict y = vy;
2233
2564
 
2234
2565
  float sumf = 0.0;
2235
2566
 
2236
- #if defined(__AVX2__)
2567
+ // TODO: add AVX / WASM SIMD / etc
2568
+ #if defined(__ARM_NEON)
2569
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2571
+
2572
+ for (int i = 0; i < nb; i += 2) {
2573
+ const block_q4_1 * restrict x0 = &x[i + 0];
2574
+ const block_q4_1 * restrict x1 = &x[i + 1];
2575
+ const block_q8_0 * restrict y0 = &y[i + 0];
2576
+ const block_q8_0 * restrict y1 = &y[i + 1];
2577
+
2578
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2579
+
2580
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2581
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2582
+
2583
+ // 4-bit -> 8-bit
2584
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2585
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2586
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
+
2589
+ // load y
2590
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2594
+
2595
+ // interleave
2596
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2600
+
2601
+ const int16x8_t s0i = vaddq_s16(
2602
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
+
2605
+ const int16x8_t s1i = vaddq_s16(
2606
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
+
2609
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
+
2612
+ #if defined(__ARM_FEATURE_DOTPROD)
2613
+ // dot product into int32x4_t
2614
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2616
+
2617
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2619
+ #else
2620
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2624
+
2625
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2629
+
2630
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2632
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2633
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2634
+
2635
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2636
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2637
+ #endif
2638
+ }
2639
+
2640
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2641
+ #elif defined(__AVX2__)
2237
2642
  // Initialize accumulator with zeros
2238
2643
  __m256 acc = _mm256_setzero_ps();
2239
- // Accumulator for constant offsets
2240
- float acc_offset = 0.0f;
2241
2644
 
2242
2645
  // Main loop
2243
2646
  for (int i = 0; i < nb; ++i) {
2244
2647
  const float * d0 = &x[i].d;
2245
2648
  const float * d1 = &y[i].d;
2246
-
2247
2649
  const float * m0 = &x[i].m;
2248
- const float * m1 = &y[i].m;
2249
2650
 
2250
2651
  const __m256 d0v = _mm256_broadcast_ss( d0 );
2251
2652
  const __m256 d1v = _mm256_broadcast_ss( d1 );
2252
2653
  const __m256 m0v = _mm256_broadcast_ss( m0 );
2253
- const __m256 m1v = _mm256_broadcast_ss( m1 );
2254
2654
 
2255
- // Compute combined scale for the block
2256
- const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
2257
-
2258
- // Compute cross scales for the block
2259
- const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2260
- const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2261
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
2655
+ // Compute combined scales
2656
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
+ const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2262
2658
 
2263
2659
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2264
- __m256i bx = bytesFromNibbles( x[i].qs );
2265
- __m256i by = bytesFromNibbles( y[i].qs );
2266
-
2267
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2268
-
2269
- // Sign-extend first 16 signed bytes into int16_t
2270
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2271
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2272
- // Compute products of int16_t integers, add pairwise
2273
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2274
-
2275
- // Sign-extend last 16 signed bytes into int16_t vectors
2276
- __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2277
- __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2278
- // Accumulate products of int16_t integers
2279
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2280
-
2281
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2282
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2283
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2284
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2285
- __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2286
- __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2287
- __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2288
- __m256 sums = _mm256_cvtepi32_ps( sumsi );
2660
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2289
2662
 
2290
- // Convert int32_t to float
2291
- __m256 p = _mm256_cvtepi32_ps( i32 );
2292
- // Apply the scale, and accumulate
2293
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2294
- acc = _mm256_fmadd_ps( scale_01, p, acc );
2295
- acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2296
- // acc_offset += m0*m1 (for each entry in the block)
2297
- acc_offset += (*m0)*(*m1);
2663
+ // Get absolute values of x vectors
2664
+ const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
+
2666
+ // Sign the values of the y vectors
2667
+ const __m256i sy = _mm256_sign_epi8( by, bx );
2668
+
2669
+ // Perform multiplication and create 16-bit values
2670
+ const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
+ const __m256i ones = _mm256_set1_epi16( 1 );
2672
+ const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
+
2674
+ // Convert to vector of 8 int32_t to 8 floats
2675
+ const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2676
+
2677
+ // Accumulate d0*d1*x*y
2678
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
+
2680
+ // Compute sum of y values
2681
+ const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
+ const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
+ const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
+ const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
+
2686
+ // Accumulate d1*m0*y
2687
+ acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2298
2688
  }
2299
2689
 
2300
2690
  // Return horizontal sum of the acc vector
@@ -2303,131 +2693,379 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2303
2693
  res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2304
2694
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2305
2695
 
2306
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2307
- #elif defined(__ARM_NEON)
2308
- float sum00 = 0.0f;
2309
- float sum01 = 0.0f;
2310
- float sum10 = 0.0f;
2311
- float sum11 = 0.0f;
2696
+ sumf = _mm_cvtss_f32( res );
2697
+ #else
2698
+ // scalar
2699
+ for (int i = 0; i < nb; i++) {
2700
+ const float d0 = x[i].d;
2701
+ const float m0 = x[i].m;
2702
+ const float d1 = y[i].d;
2703
+
2704
+ const uint8_t * restrict p0 = x[i].qs;
2705
+ const int8_t * restrict p1 = y[i].qs;
2706
+
2707
+ // TODO: this is very slow ..
2708
+ for (int j = 0; j < QK8_0/2; j++) {
2709
+ const uint8_t v0 = p0[j];
2710
+
2711
+ const float f0 = d0*(v0 & 0xf) + m0;
2712
+ const float f1 = d0*(v0 >> 4) + m0;
2713
+
2714
+ const float f2 = d1*p1[2*j + 0];
2715
+ const float f3 = d1*p1[2*j + 1];
2716
+
2717
+ sumf += f0*f2 + f1*f3;
2718
+ }
2719
+ }
2720
+ #endif
2721
+
2722
+ *s = sumf;
2723
+ }
2724
+
2725
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2726
+ const int nb = n / QK8_0;
2727
+
2728
+ assert(n % QK8_0 == 0);
2729
+ assert(nb % 2 == 0);
2730
+ assert(QK8_0 == 2*QK4_2);
2731
+
2732
+ const block_q4_2 * restrict x = vx;
2733
+ const block_q8_0 * restrict y = vy;
2734
+
2735
+ float sumf = 0.0;
2736
+
2737
+ #if defined(__ARM_NEON)
2738
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2312
2740
 
2313
2741
  for (int i = 0; i < nb; i += 2) {
2314
- const block_q4_1 * restrict x0 = &x[i + 0];
2315
- const block_q4_1 * restrict y0 = &y[i + 0];
2316
- const block_q4_1 * restrict x1 = &x[i + 1];
2317
- const block_q4_1 * restrict y1 = &y[i + 1];
2742
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
2743
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
2744
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
2745
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
2318
2746
 
2319
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2747
+ const block_q8_0 * restrict y0 = &y[i + 0];
2748
+ const block_q8_0 * restrict y1 = &y[i + 1];
2320
2749
 
2321
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2322
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2750
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2751
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2752
+
2753
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2754
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2325
2755
 
2326
2756
  // 4-bit -> 8-bit
2327
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2328
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2329
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2330
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2757
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2758
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2759
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2760
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2331
2761
 
2332
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2762
+ // sub 8
2763
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2764
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2765
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2766
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2336
2767
 
2337
- sum00 += x0->m*y0->m;
2338
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2768
+ // interleave
2769
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2770
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2771
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2772
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2340
2773
 
2341
- sum00 += x1->m*y1->m;
2342
- sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
- sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2774
+ // load y
2775
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2776
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2777
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2778
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2344
2779
 
2345
2780
  #if defined(__ARM_FEATURE_DOTPROD)
2346
- // dot product into int32x4_t
2347
- uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2348
- uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2781
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2782
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2349
2784
 
2350
- p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2351
- p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2352
-
2353
- sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2354
- sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2785
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2355
2788
  #else
2356
- const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2357
- const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2358
- const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2359
- const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2789
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2793
+
2794
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2798
+
2799
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2803
+
2804
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2807
+
2808
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2811
+ #endif
2812
+ }
2813
+
2814
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2815
+ #elif defined(__AVX2__)
2816
+ // Initialize accumulator with zeros
2817
+ __m256 acc = _mm256_setzero_ps();
2818
+
2819
+ // Main loop
2820
+ for (int i = 0; i < nb; i++) {
2821
+ /* Compute combined scale for the block */
2822
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2360
2825
 
2361
- const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362
- const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363
- const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364
- const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2826
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
2365
2829
 
2366
- const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367
- const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2830
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2831
+ const __m256i off = _mm256_set1_epi8(8);
2832
+ bx = _mm256_sub_epi8(bx, off);
2368
2833
 
2369
- const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370
- const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2834
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2371
2835
 
2372
- const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373
- const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2836
+ // Get absolute values of x vectors
2837
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2838
+ // Sign the values of the y vectors
2839
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2840
+ // Perform multiplication and create 16-bit values
2841
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2374
2842
 
2375
- sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376
- sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377
- #endif
2843
+ const __m256i ones = _mm256_set1_epi16(1);
2844
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2845
+
2846
+ /* Convert to vectore of 8 int32_t to 8 floats */
2847
+ __m256 q = _mm256_cvtepi32_ps(xy_q);
2848
+
2849
+ /* Multiply q with scale and accumulate */
2850
+ acc = _mm256_fmadd_ps(d, q, acc);
2378
2851
  }
2379
2852
 
2380
- sumf = QK*sum00 + sum01 + sum10 + sum11;
2853
+ // Return horizontal sum of the acc vector
2854
+ __m128 res = _mm256_extractf128_ps(acc, 1);
2855
+ res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858
+
2859
+ sumf = _mm_cvtss_f32(res);
2381
2860
  #else
2382
2861
  // scalar
2383
2862
  for (int i = 0; i < nb; i++) {
2384
- const float d0 = x[i].d;
2385
- const float d1 = y[i].d;
2863
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
2865
+ const int8_t * restrict y0 = y[i].qs;
2386
2866
 
2387
- const float m0 = x[i].m;
2388
- const float m1 = y[i].m;
2867
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
2389
2869
 
2390
- const uint8_t * restrict p0 = x[i].qs;
2391
- const uint8_t * restrict p1 = y[i].qs;
2870
+ int sumi_0 = 0;
2871
+ int sumi_1 = 0;
2392
2872
 
2393
- for (int j = 0; j < QK/2; j++) {
2394
- const uint8_t v0 = p0[j];
2395
- const uint8_t v1 = p1[j];
2873
+ for (int j = 0; j < QK8_0/4; j++) {
2874
+ const uint8_t v0 = x0[j];
2875
+ const uint8_t v1 = x1[j];
2396
2876
 
2397
- const float f0 = d0*(v0 & 0xf) + m0;
2398
- const float f1 = d0*(v0 >> 4) + m0;
2877
+ const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
2399
2879
 
2400
- const float f2 = d1*(v1 & 0xf) + m1;
2401
- const float f3 = d1*(v1 >> 4) + m1;
2880
+ const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
2402
2882
 
2403
- sumf += f0*f2 + f1*f3;
2883
+ const int i2_0 = y0[2*j + 0];
2884
+ const int i3_0 = y0[2*j + 1];
2885
+
2886
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
2888
+
2889
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
2404
2891
  }
2892
+
2893
+ sumf += (d0 * y[i].d) * sumi_0;
2894
+ sumf += (d1 * y[i].d) * sumi_1;
2405
2895
  }
2406
2896
  #endif
2407
2897
 
2408
2898
  *s = sumf;
2409
2899
  }
2410
2900
 
2411
- // compute GGML_VEC_DOT_UNROLL dot products at once
2412
- // xs - x row stride in bytes
2413
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2414
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2901
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
+ const int nb = n / QK8_0;
2415
2903
 
2416
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2904
+ assert(n % QK8_0 == 0);
2905
+ assert(nb % 2 == 0);
2906
+ assert(QK8_0 == 2*QK4_2);
2417
2907
 
2418
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2419
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2420
- }
2908
+ const block_q4_3 * restrict x = vx;
2909
+ const block_q8_0 * restrict y = vy;
2421
2910
 
2422
- #if defined(GGML_SIMD)
2423
- const int np = (n & ~(GGML_F16_STEP - 1));
2911
+ float sumf = 0.0;
2424
2912
 
2425
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2913
+ #if defined(__ARM_NEON)
2914
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2426
2916
 
2427
- GGML_F16_VEC ax[GGML_F16_ARR];
2428
- GGML_F16_VEC ay[GGML_F16_ARR];
2917
+ for (int i = 0; i < nb; i += 2) {
2918
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
+ const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
+ const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2429
2922
 
2430
- for (int i = 0; i < np; i += GGML_F16_STEP) {
2923
+ const block_q8_0 * restrict y0 = &y[i + 0];
2924
+ const block_q8_0 * restrict y1 = &y[i + 1];
2925
+
2926
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
+
2928
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
+ const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
+ const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
+
2933
+ const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
+ const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
+ const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
+ const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
+
2938
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
+
2941
+ // 4-bit -> 8-bit
2942
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2946
+
2947
+ // interleave
2948
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2952
+
2953
+ // load y
2954
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2958
+
2959
+ const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
+ const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2961
+
2962
+ const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
+ const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2964
+
2965
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
2969
+
2970
+ #if defined(__ARM_FEATURE_DOTPROD)
2971
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
+ #else
2976
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
+
2981
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2985
+
2986
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2990
+
2991
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
2995
+ #endif
2996
+ }
2997
+
2998
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2999
+ #else
3000
+ // scalar
3001
+ for (int i = 0; i < nb; i++) {
3002
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
+ const int8_t * restrict y0 = y[i].qs;
3005
+
3006
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3010
+
3011
+ int sy_0 = 0;
3012
+ int sy_1 = 0;
3013
+
3014
+ int sxy_0 = 0;
3015
+ int sxy_1 = 0;
3016
+
3017
+ for (int j = 0; j < QK8_0/4; j++) {
3018
+ const uint8_t v0 = x0[j];
3019
+ const uint8_t v1 = x1[j];
3020
+
3021
+ const int x0_0 = v0 & 0xf;
3022
+ const int x1_0 = v0 >> 4;
3023
+
3024
+ const int x0_1 = v1 & 0xf;
3025
+ const int x1_1 = v1 >> 4;
3026
+
3027
+ const int y0_0 = y0[2*j + 0];
3028
+ const int y1_0 = y0[2*j + 1];
3029
+
3030
+ const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
+ const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3032
+
3033
+ sy_0 += y0_0 + y1_0;
3034
+ sy_1 += y0_1 + y1_1;
3035
+
3036
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3038
+ }
3039
+
3040
+ sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
+ sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
3042
+ }
3043
+ #endif
3044
+
3045
+ *s = sumf;
3046
+ }
3047
+
3048
+
3049
+ // compute GGML_VEC_DOT_UNROLL dot products at once
3050
+ // xs - x row stride in bytes
3051
+ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
3052
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
3053
+
3054
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
3055
+
3056
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
3057
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
3058
+ }
3059
+
3060
+ #if defined(GGML_SIMD)
3061
+ const int np = (n & ~(GGML_F16_STEP - 1));
3062
+
3063
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
3064
+
3065
+ GGML_F16_VEC ax[GGML_F16_ARR];
3066
+ GGML_F16_VEC ay[GGML_F16_ARR];
3067
+
3068
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
2431
3069
  for (int j = 0; j < GGML_F16_ARR; j++) {
2432
3070
  ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2433
3071
 
@@ -2652,24 +3290,30 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2652
3290
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2653
3291
  [GGML_TYPE_F32] = 1,
2654
3292
  [GGML_TYPE_F16] = 1,
2655
- [GGML_TYPE_Q4_0] = QK,
2656
- [GGML_TYPE_Q4_1] = QK,
3293
+ [GGML_TYPE_Q4_0] = QK4_0,
3294
+ [GGML_TYPE_Q4_1] = QK4_1,
3295
+ [GGML_TYPE_Q4_2] = QK4_2,
3296
+ [GGML_TYPE_Q4_3] = QK4_3,
3297
+ [GGML_TYPE_Q8_0] = QK8_0,
2657
3298
  [GGML_TYPE_I8] = 1,
2658
3299
  [GGML_TYPE_I16] = 1,
2659
3300
  [GGML_TYPE_I32] = 1,
2660
3301
  };
2661
- static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
3302
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
2662
3303
 
2663
3304
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2664
3305
  [GGML_TYPE_F32] = sizeof(float),
2665
3306
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
3307
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
3308
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
+ [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
+ [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3311
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
2668
3312
  [GGML_TYPE_I8] = sizeof(int8_t),
2669
3313
  [GGML_TYPE_I16] = sizeof(int16_t),
2670
3314
  [GGML_TYPE_I32] = sizeof(int32_t),
2671
3315
  };
2672
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
3316
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
2673
3317
 
2674
3318
 
2675
3319
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2677,11 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2677
3321
  [GGML_TYPE_F16] = "f16",
2678
3322
  [GGML_TYPE_Q4_0] = "q4_0",
2679
3323
  [GGML_TYPE_Q4_1] = "q4_1",
3324
+ [GGML_TYPE_Q4_2] = "q4_2",
3325
+ [GGML_TYPE_Q4_3] = "q4_3",
3326
+ [GGML_TYPE_Q8_0] = "q8_0",
2680
3327
  [GGML_TYPE_I8] = "i8",
2681
3328
  [GGML_TYPE_I16] = "i16",
2682
3329
  [GGML_TYPE_I32] = "i32",
2683
3330
  };
2684
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
3331
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
3332
+
3333
+ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
+ [GGML_TYPE_F32] = false,
3335
+ [GGML_TYPE_F16] = false,
3336
+ [GGML_TYPE_Q4_0] = true,
3337
+ [GGML_TYPE_Q4_1] = true,
3338
+ [GGML_TYPE_Q4_2] = true,
3339
+ [GGML_TYPE_Q4_3] = true,
3340
+ [GGML_TYPE_Q8_0] = true,
3341
+ [GGML_TYPE_I8] = false,
3342
+ [GGML_TYPE_I16] = false,
3343
+ [GGML_TYPE_I32] = false,
3344
+ };
3345
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
2685
3346
 
2686
3347
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2687
3348
  "NONE",
@@ -2943,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
2943
3604
  (t0->ne[3] == t1->ne[3]);
2944
3605
  }
2945
3606
 
3607
+ bool ggml_is_quantized(enum ggml_type type) {
3608
+ return GGML_IS_QUANTIZED[type];
3609
+ }
3610
+
2946
3611
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
2947
3612
  return tensor->nb[0] > tensor->nb[1];
2948
3613
  }
@@ -3053,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3053
3718
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3054
3719
  }
3055
3720
 
3721
+ // initialize cuBLAS
3722
+ #if defined(GGML_USE_CUBLAS)
3723
+ init_cublas();
3724
+ #endif
3725
+
3056
3726
  is_first_call = false;
3057
3727
  }
3058
3728
 
@@ -3354,14 +4024,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3354
4024
  char * const data = tensor->data;
3355
4025
 
3356
4026
  switch (tensor->type) {
3357
- case GGML_TYPE_Q4_0:
3358
- {
3359
- GGML_ASSERT(false);
3360
- } break;
3361
- case GGML_TYPE_Q4_1:
3362
- {
3363
- GGML_ASSERT(false);
3364
- } break;
3365
4027
  case GGML_TYPE_I8:
3366
4028
  {
3367
4029
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3397,7 +4059,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3397
4059
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3398
4060
  }
3399
4061
  } break;
3400
- case GGML_TYPE_COUNT:
4062
+ default:
3401
4063
  {
3402
4064
  GGML_ASSERT(false);
3403
4065
  } break;
@@ -3414,14 +4076,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3414
4076
  char * const data = tensor->data;
3415
4077
 
3416
4078
  switch (tensor->type) {
3417
- case GGML_TYPE_Q4_0:
3418
- {
3419
- GGML_ASSERT(false);
3420
- } break;
3421
- case GGML_TYPE_Q4_1:
3422
- {
3423
- GGML_ASSERT(false);
3424
- } break;
3425
4079
  case GGML_TYPE_I8:
3426
4080
  {
3427
4081
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3457,7 +4111,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3457
4111
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3458
4112
  }
3459
4113
  } break;
3460
- case GGML_TYPE_COUNT:
4114
+ default:
3461
4115
  {
3462
4116
  GGML_ASSERT(false);
3463
4117
  } break;
@@ -3468,14 +4122,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3468
4122
 
3469
4123
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3470
4124
  switch (tensor->type) {
3471
- case GGML_TYPE_Q4_0:
3472
- {
3473
- GGML_ASSERT(false);
3474
- } break;
3475
- case GGML_TYPE_Q4_1:
3476
- {
3477
- GGML_ASSERT(false);
3478
- } break;
3479
4125
  case GGML_TYPE_I8:
3480
4126
  {
3481
4127
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3501,7 +4147,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3501
4147
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3502
4148
  return ((float *)(tensor->data))[i];
3503
4149
  } break;
3504
- case GGML_TYPE_COUNT:
4150
+ default:
3505
4151
  {
3506
4152
  GGML_ASSERT(false);
3507
4153
  } break;
@@ -3512,14 +4158,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3512
4158
 
3513
4159
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3514
4160
  switch (tensor->type) {
3515
- case GGML_TYPE_Q4_0:
3516
- {
3517
- GGML_ASSERT(false);
3518
- } break;
3519
- case GGML_TYPE_Q4_1:
3520
- {
3521
- GGML_ASSERT(false);
3522
- } break;
3523
4161
  case GGML_TYPE_I8:
3524
4162
  {
3525
4163
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3545,7 +4183,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3545
4183
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3546
4184
  ((float *)(tensor->data))[i] = value;
3547
4185
  } break;
3548
- case GGML_TYPE_COUNT:
4186
+ default:
3549
4187
  {
3550
4188
  GGML_ASSERT(false);
3551
4189
  } break;
@@ -3554,14 +4192,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3554
4192
 
3555
4193
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3556
4194
  switch (tensor->type) {
3557
- case GGML_TYPE_Q4_0:
3558
- {
3559
- GGML_ASSERT(false);
3560
- } break;
3561
- case GGML_TYPE_Q4_1:
3562
- {
3563
- GGML_ASSERT(false);
3564
- } break;
3565
4195
  case GGML_TYPE_I8:
3566
4196
  {
3567
4197
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3587,7 +4217,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3587
4217
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3588
4218
  return ((float *)(tensor->data))[i];
3589
4219
  } break;
3590
- case GGML_TYPE_COUNT:
4220
+ default:
3591
4221
  {
3592
4222
  GGML_ASSERT(false);
3593
4223
  } break;
@@ -3598,14 +4228,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3598
4228
 
3599
4229
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3600
4230
  switch (tensor->type) {
3601
- case GGML_TYPE_Q4_0:
3602
- {
3603
- GGML_ASSERT(false);
3604
- } break;
3605
- case GGML_TYPE_Q4_1:
3606
- {
3607
- GGML_ASSERT(false);
3608
- } break;
3609
4231
  case GGML_TYPE_I8:
3610
4232
  {
3611
4233
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3631,7 +4253,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3631
4253
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3632
4254
  ((float *)(tensor->data))[i] = value;
3633
4255
  } break;
3634
- case GGML_TYPE_COUNT:
4256
+ default:
3635
4257
  {
3636
4258
  GGML_ASSERT(false);
3637
4259
  } break;
@@ -5031,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
5031
5653
  const struct ggml_compute_params * params,
5032
5654
  const struct ggml_tensor * src0,
5033
5655
  struct ggml_tensor * dst) {
5034
- GGML_ASSERT(params->ith == 0);
5035
5656
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5036
5657
 
5037
5658
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5043,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
5043
5664
  const int64_t ne02 = src0->ne[2];
5044
5665
  const int64_t ne03 = src0->ne[3];
5045
5666
 
5667
+ const int64_t ne0 = dst->ne[0];
5668
+ const int64_t ne1 = dst->ne[1];
5669
+ const int64_t ne2 = dst->ne[2];
5670
+ const int64_t ne3 = dst->ne[3];
5671
+
5046
5672
  const size_t nb00 = src0->nb[0];
5047
5673
  const size_t nb01 = src0->nb[1];
5048
5674
  const size_t nb02 = src0->nb[2];
@@ -5053,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
5053
5679
  const size_t nb2 = dst->nb[2];
5054
5680
  const size_t nb3 = dst->nb[3];
5055
5681
 
5682
+ const int ith = params->ith; // thread index
5683
+ const int nth = params->nth; // number of threads
5684
+
5056
5685
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5057
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5686
+ // parallelize by elements
5687
+ const int ne = ggml_nelements(dst);
5688
+ const int dr = (ne + nth - 1) / nth;
5689
+ const int ie0 = dr * ith;
5690
+ const int ie1 = MIN(ie0 + dr, ne);
5691
+
5692
+ memcpy(
5693
+ ((char *) dst->data + ie0*nb0),
5694
+ ((char *) src0->data + ie0*nb00),
5695
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5696
+
5058
5697
  return;
5059
5698
  }
5060
5699
 
5700
+ // parallelize by rows
5701
+ const int nr = ne01;
5702
+ // number of rows per thread
5703
+ const int dr = (nr + nth - 1) / nth;
5704
+ // row range for this thread
5705
+ const int ir0 = dr * ith;
5706
+ const int ir1 = MIN(ir0 + dr, nr);
5707
+
5061
5708
  if (src0->type == dst->type &&
5062
- src0->ne[0] == dst->ne[0] &&
5063
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5709
+ ne00 == ne0 &&
5710
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5064
5711
  // copy by rows
5065
5712
  const size_t rs = ne00*nb00;
5066
5713
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5067
5714
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5068
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5715
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5069
5716
  memcpy(
5070
5717
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5071
5718
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5079,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
5079
5726
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
5080
5727
 
5081
5728
  if (ggml_is_contiguous(dst)) {
5082
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5729
+ if (nb00 == sizeof(ggml_fp16_t)) {
5083
5730
  if (dst->type == GGML_TYPE_F16) {
5084
5731
  size_t id = 0;
5085
- const size_t rs = ne00*nb00;
5732
+ const size_t rs = ne00 * nb00;
5733
+ char * dst_ptr = (char *) dst->data;
5086
5734
 
5087
5735
  for (int i03 = 0; i03 < ne03; i03++) {
5088
5736
  for (int i02 = 0; i02 < ne02; i02++) {
5089
- for (int i01 = 0; i01 < ne01; i01++) {
5737
+ id += rs * ir0;
5738
+ for (int i01 = ir0; i01 < ir1; i01++) {
5090
5739
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5091
- char * dst_ptr = (char *) dst->data + id*rs;
5092
-
5093
- memcpy(dst_ptr, src0_ptr, rs);
5094
-
5095
- id++;
5740
+ memcpy(dst_ptr + id, src0_ptr, rs);
5741
+ id += rs;
5096
5742
  }
5743
+ id += rs * (ne01 - ir1);
5097
5744
  }
5098
5745
  }
5099
5746
  } else if (dst->type == GGML_TYPE_F32) {
@@ -5102,14 +5749,39 @@ static void ggml_compute_forward_dup_f16(
5102
5749
 
5103
5750
  for (int i03 = 0; i03 < ne03; i03++) {
5104
5751
  for (int i02 = 0; i02 < ne02; i02++) {
5105
- for (int i01 = 0; i01 < ne01; i01++) {
5752
+ id += ne00 * ir0;
5753
+ for (int i01 = ir0; i01 < ir1; i01++) {
5754
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5106
5755
  for (int i00 = 0; i00 < ne00; i00++) {
5107
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5108
-
5109
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5756
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5110
5757
  id++;
5111
5758
  }
5112
5759
  }
5760
+ id += ne00 * (ne01 - ir1);
5761
+ }
5762
+ }
5763
+ } else if (ggml_is_quantized(dst->type)) {
5764
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5765
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5766
+
5767
+ size_t id = 0;
5768
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5769
+ char * dst_ptr = (char *) dst->data;
5770
+
5771
+ for (int i03 = 0; i03 < ne03; i03++) {
5772
+ for (int i02 = 0; i02 < ne02; i02++) {
5773
+ id += rs * ir0;
5774
+ for (int i01 = ir0; i01 < ir1; i01++) {
5775
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5776
+
5777
+ for (int i00 = 0; i00 < ne00; i00++) {
5778
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5779
+ }
5780
+
5781
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
5782
+ id += rs;
5783
+ }
5784
+ id += rs * (ne01 - ir1);
5113
5785
  }
5114
5786
  }
5115
5787
  } else {
@@ -5124,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
5124
5796
 
5125
5797
  for (int i03 = 0; i03 < ne03; i03++) {
5126
5798
  for (int i02 = 0; i02 < ne02; i02++) {
5127
- for (int i01 = 0; i01 < ne01; i01++) {
5799
+ id += ne00 * ir0;
5800
+ for (int i01 = ir0; i01 < ir1; i01++) {
5128
5801
  for (int i00 = 0; i00 < ne00; i00++) {
5129
5802
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5130
5803
 
@@ -5132,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
5132
5805
  id++;
5133
5806
  }
5134
5807
  }
5808
+ id += ne00 * (ne01 - ir1);
5135
5809
  }
5136
5810
  }
5137
5811
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5140,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
5140
5814
 
5141
5815
  for (int i03 = 0; i03 < ne03; i03++) {
5142
5816
  for (int i02 = 0; i02 < ne02; i02++) {
5143
- for (int i01 = 0; i01 < ne01; i01++) {
5817
+ id += ne00 * ir0;
5818
+ for (int i01 = ir0; i01 < ir1; i01++) {
5144
5819
  for (int i00 = 0; i00 < ne00; i00++) {
5145
5820
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5146
5821
 
@@ -5148,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
5148
5823
  id++;
5149
5824
  }
5150
5825
  }
5826
+ id += ne00 * (ne01 - ir1);
5151
5827
  }
5152
5828
  }
5153
5829
  } else {
@@ -5166,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
5166
5842
  if (dst->type == GGML_TYPE_F16) {
5167
5843
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5168
5844
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5169
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5845
+ i10 += ne00 * ir0;
5846
+ while (i10 >= ne0) {
5847
+ i10 -= ne0;
5848
+ if (++i11 == ne1) {
5849
+ i11 = 0;
5850
+ if (++i12 == ne2) {
5851
+ i12 = 0;
5852
+ if (++i13 == ne3) {
5853
+ i13 = 0;
5854
+ }
5855
+ }
5856
+ }
5857
+ }
5858
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5170
5859
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5171
5860
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5172
5861
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@@ -5187,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
5187
5876
  }
5188
5877
  }
5189
5878
  }
5879
+ i10 += ne00 * (ne01 - ir1);
5880
+ while (i10 >= ne0) {
5881
+ i10 -= ne0;
5882
+ if (++i11 == ne1) {
5883
+ i11 = 0;
5884
+ if (++i12 == ne2) {
5885
+ i12 = 0;
5886
+ if (++i13 == ne3) {
5887
+ i13 = 0;
5888
+ }
5889
+ }
5890
+ }
5891
+ }
5190
5892
  }
5191
5893
  }
5192
5894
  } else if (dst->type == GGML_TYPE_F32) {
5193
5895
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5194
5896
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5195
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5897
+ i10 += ne00 * ir0;
5898
+ while (i10 >= ne0) {
5899
+ i10 -= ne0;
5900
+ if (++i11 == ne1) {
5901
+ i11 = 0;
5902
+ if (++i12 == ne2) {
5903
+ i12 = 0;
5904
+ if (++i13 == ne3) {
5905
+ i13 = 0;
5906
+ }
5907
+ }
5908
+ }
5909
+ }
5910
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5196
5911
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5197
5912
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5198
5913
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5199
5914
 
5200
5915
  *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5201
5916
 
5202
- if (++i10 == ne00) {
5917
+ if (++i10 == ne0) {
5203
5918
  i10 = 0;
5204
- if (++i11 == ne01) {
5919
+ if (++i11 == ne1) {
5205
5920
  i11 = 0;
5206
- if (++i12 == ne02) {
5921
+ if (++i12 == ne2) {
5207
5922
  i12 = 0;
5208
- if (++i13 == ne03) {
5923
+ if (++i13 == ne3) {
5209
5924
  i13 = 0;
5210
5925
  }
5211
5926
  }
@@ -5213,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
5213
5928
  }
5214
5929
  }
5215
5930
  }
5931
+ i10 += ne00 * (ne01 - ir1);
5932
+ while (i10 >= ne0) {
5933
+ i10 -= ne0;
5934
+ if (++i11 == ne1) {
5935
+ i11 = 0;
5936
+ if (++i12 == ne2) {
5937
+ i12 = 0;
5938
+ if (++i13 == ne3) {
5939
+ i13 = 0;
5940
+ }
5941
+ }
5942
+ }
5943
+ }
5216
5944
  }
5217
5945
  }
5218
5946
  } else {
@@ -5224,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
5224
5952
  const struct ggml_compute_params * params,
5225
5953
  const struct ggml_tensor * src0,
5226
5954
  struct ggml_tensor * dst) {
5227
- GGML_ASSERT(params->ith == 0);
5228
5955
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5229
5956
 
5230
5957
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5236,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
5236
5963
  const int64_t ne02 = src0->ne[2];
5237
5964
  const int64_t ne03 = src0->ne[3];
5238
5965
 
5966
+ const int64_t ne0 = dst->ne[0];
5967
+ const int64_t ne1 = dst->ne[1];
5968
+ const int64_t ne2 = dst->ne[2];
5969
+ const int64_t ne3 = dst->ne[3];
5970
+
5239
5971
  const size_t nb00 = src0->nb[0];
5240
5972
  const size_t nb01 = src0->nb[1];
5241
5973
  const size_t nb02 = src0->nb[2];
@@ -5246,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
5246
5978
  const size_t nb2 = dst->nb[2];
5247
5979
  const size_t nb3 = dst->nb[3];
5248
5980
 
5981
+ const int ith = params->ith; // thread index
5982
+ const int nth = params->nth; // number of threads
5983
+
5249
5984
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5250
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5985
+ // parallelize by elements
5986
+ const int ne = ggml_nelements(dst);
5987
+ const int dr = (ne + nth - 1) / nth;
5988
+ const int ie0 = dr * ith;
5989
+ const int ie1 = MIN(ie0 + dr, ne);
5990
+
5991
+ memcpy(
5992
+ ((char *) dst->data + ie0*nb0),
5993
+ ((char *) src0->data + ie0*nb00),
5994
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5995
+
5251
5996
  return;
5252
5997
  }
5253
5998
 
5999
+ // parallelize by rows
6000
+ const int nr = ne01;
6001
+ // number of rows per thread
6002
+ const int dr = (nr + nth - 1) / nth;
6003
+ // row range for this thread
6004
+ const int ir0 = dr * ith;
6005
+ const int ir1 = MIN(ir0 + dr, nr);
6006
+
5254
6007
  if (src0->type == dst->type &&
5255
- src0->ne[0] == dst->ne[0] &&
5256
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6008
+ ne00 == ne0 &&
6009
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5257
6010
  // copy by rows
5258
6011
  const size_t rs = ne00*nb00;
5259
6012
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5260
6013
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5261
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6014
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5262
6015
  memcpy(
5263
6016
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5264
6017
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5271,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
5271
6024
 
5272
6025
  if (ggml_is_contiguous(dst)) {
5273
6026
  // TODO: simplify
5274
- if (src0->nb[0] == sizeof(float)) {
6027
+ if (nb00 == sizeof(float)) {
5275
6028
  if (dst->type == GGML_TYPE_F32) {
5276
6029
  size_t id = 0;
5277
- const size_t rs = ne00*nb00;
6030
+ const size_t rs = ne00 * nb00;
6031
+ char * dst_ptr = (char *) dst->data;
5278
6032
 
5279
6033
  for (int i03 = 0; i03 < ne03; i03++) {
5280
6034
  for (int i02 = 0; i02 < ne02; i02++) {
5281
- for (int i01 = 0; i01 < ne01; i01++) {
6035
+ id += rs * ir0;
6036
+ for (int i01 = ir0; i01 < ir1; i01++) {
5282
6037
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5283
- char * dst_ptr = (char *) dst->data + id*rs;
5284
-
5285
- memcpy(dst_ptr, src0_ptr, rs);
5286
-
5287
- id++;
6038
+ memcpy(dst_ptr + id, src0_ptr, rs);
6039
+ id += rs;
5288
6040
  }
6041
+ id += rs * (ne01 - ir1);
5289
6042
  }
5290
6043
  }
5291
6044
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5294,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
5294
6047
 
5295
6048
  for (int i03 = 0; i03 < ne03; i03++) {
5296
6049
  for (int i02 = 0; i02 < ne02; i02++) {
5297
- for (int i01 = 0; i01 < ne01; i01++) {
6050
+ id += ne00 * ir0;
6051
+ for (int i01 = ir0; i01 < ir1; i01++) {
5298
6052
  for (int i00 = 0; i00 < ne00; i00++) {
5299
6053
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5300
6054
 
@@ -5302,6 +6056,25 @@ static void ggml_compute_forward_dup_f32(
5302
6056
  id++;
5303
6057
  }
5304
6058
  }
6059
+ id += ne00 * (ne01 - ir1);
6060
+ }
6061
+ }
6062
+ } else if (ggml_is_quantized(dst->type)) {
6063
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6064
+
6065
+ size_t id = 0;
6066
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6067
+ char * dst_ptr = (char *) dst->data;
6068
+
6069
+ for (int i03 = 0; i03 < ne03; i03++) {
6070
+ for (int i02 = 0; i02 < ne02; i02++) {
6071
+ id += rs * ir0;
6072
+ for (int i01 = ir0; i01 < ir1; i01++) {
6073
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
6074
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
6075
+ id += rs;
6076
+ }
6077
+ id += rs * (ne01 - ir1);
5305
6078
  }
5306
6079
  }
5307
6080
  } else {
@@ -5316,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
5316
6089
 
5317
6090
  for (int i03 = 0; i03 < ne03; i03++) {
5318
6091
  for (int i02 = 0; i02 < ne02; i02++) {
5319
- for (int i01 = 0; i01 < ne01; i01++) {
6092
+ id += ne00 * ir0;
6093
+ for (int i01 = ir0; i01 < ir1; i01++) {
5320
6094
  for (int i00 = 0; i00 < ne00; i00++) {
5321
6095
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5322
6096
 
@@ -5324,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
5324
6098
  id++;
5325
6099
  }
5326
6100
  }
6101
+ id += ne00 * (ne01 - ir1);
5327
6102
  }
5328
6103
  }
5329
6104
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5332,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
5332
6107
 
5333
6108
  for (int i03 = 0; i03 < ne03; i03++) {
5334
6109
  for (int i02 = 0; i02 < ne02; i02++) {
5335
- for (int i01 = 0; i01 < ne01; i01++) {
6110
+ id += ne00 * ir0;
6111
+ for (int i01 = ir0; i01 < ir1; i01++) {
5336
6112
  for (int i00 = 0; i00 < ne00; i00++) {
5337
6113
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5338
6114
 
@@ -5340,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
5340
6116
  id++;
5341
6117
  }
5342
6118
  }
6119
+ id += ne00 * (ne01 - ir1);
5343
6120
  }
5344
6121
  }
5345
6122
  } else {
@@ -5351,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
5351
6128
  }
5352
6129
 
5353
6130
  // dst counters
6131
+
5354
6132
  int64_t i10 = 0;
5355
6133
  int64_t i11 = 0;
5356
6134
  int64_t i12 = 0;
@@ -5359,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
5359
6137
  if (dst->type == GGML_TYPE_F32) {
5360
6138
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5361
6139
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5362
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6140
+ i10 += ne00 * ir0;
6141
+ while (i10 >= ne0) {
6142
+ i10 -= ne0;
6143
+ if (++i11 == ne1) {
6144
+ i11 = 0;
6145
+ if (++i12 == ne2) {
6146
+ i12 = 0;
6147
+ if (++i13 == ne3) {
6148
+ i13 = 0;
6149
+ }
6150
+ }
6151
+ }
6152
+ }
6153
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5363
6154
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5364
6155
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5365
6156
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5366
6157
 
5367
6158
  memcpy(dst_ptr, src0_ptr, sizeof(float));
5368
6159
 
5369
- if (++i10 == dst->ne[0]) {
6160
+ if (++i10 == ne0) {
5370
6161
  i10 = 0;
5371
- if (++i11 == dst->ne[1]) {
6162
+ if (++i11 == ne1) {
5372
6163
  i11 = 0;
5373
- if (++i12 == dst->ne[2]) {
6164
+ if (++i12 == ne2) {
5374
6165
  i12 = 0;
5375
- if (++i13 == dst->ne[3]) {
6166
+ if (++i13 == ne3) {
5376
6167
  i13 = 0;
5377
6168
  }
5378
6169
  }
@@ -5380,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
5380
6171
  }
5381
6172
  }
5382
6173
  }
6174
+ i10 += ne00 * (ne01 - ir1);
6175
+ while (i10 >= ne0) {
6176
+ i10 -= ne0;
6177
+ if (++i11 == ne1) {
6178
+ i11 = 0;
6179
+ if (++i12 == ne2) {
6180
+ i12 = 0;
6181
+ if (++i13 == ne3) {
6182
+ i13 = 0;
6183
+ }
6184
+ }
6185
+ }
6186
+ }
5383
6187
  }
5384
6188
  }
5385
6189
  } else if (dst->type == GGML_TYPE_F16) {
5386
6190
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5387
6191
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5388
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6192
+ i10 += ne00 * ir0;
6193
+ while (i10 >= ne0) {
6194
+ i10 -= ne0;
6195
+ if (++i11 == ne1) {
6196
+ i11 = 0;
6197
+ if (++i12 == ne2) {
6198
+ i12 = 0;
6199
+ if (++i13 == ne3) {
6200
+ i13 = 0;
6201
+ }
6202
+ }
6203
+ }
6204
+ }
6205
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5389
6206
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5390
6207
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5391
6208
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5392
6209
 
5393
6210
  *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5394
6211
 
5395
- if (++i10 == dst->ne[0]) {
6212
+ if (++i10 == ne0) {
5396
6213
  i10 = 0;
5397
- if (++i11 == dst->ne[1]) {
6214
+ if (++i11 == ne1) {
5398
6215
  i11 = 0;
5399
- if (++i12 == dst->ne[2]) {
6216
+ if (++i12 == ne2) {
5400
6217
  i12 = 0;
5401
- if (++i13 == dst->ne[3]) {
6218
+ if (++i13 == ne3) {
5402
6219
  i13 = 0;
5403
6220
  }
5404
6221
  }
@@ -5406,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
5406
6223
  }
5407
6224
  }
5408
6225
  }
6226
+ i10 += ne00 * (ne01 - ir1);
6227
+ while (i10 >= ne0) {
6228
+ i10 -= ne0;
6229
+ if (++i11 == ne1) {
6230
+ i11 = 0;
6231
+ if (++i12 == ne2) {
6232
+ i12 = 0;
6233
+ if (++i13 == ne3) {
6234
+ i13 = 0;
6235
+ }
6236
+ }
6237
+ }
6238
+ }
5409
6239
  }
5410
6240
  }
5411
6241
  } else {
@@ -5426,12 +6256,7 @@ static void ggml_compute_forward_dup(
5426
6256
  {
5427
6257
  ggml_compute_forward_dup_f32(params, src0, dst);
5428
6258
  } break;
5429
- case GGML_TYPE_Q4_0:
5430
- case GGML_TYPE_Q4_1:
5431
- case GGML_TYPE_I8:
5432
- case GGML_TYPE_I16:
5433
- case GGML_TYPE_I32:
5434
- case GGML_TYPE_COUNT:
6259
+ default:
5435
6260
  {
5436
6261
  GGML_ASSERT(false);
5437
6262
  } break;
@@ -5497,6 +6322,212 @@ static void ggml_compute_forward_add_f32(
5497
6322
  }
5498
6323
  }
5499
6324
 
6325
+ static void ggml_compute_forward_add_f16_f32(
6326
+ const struct ggml_compute_params * params,
6327
+ const struct ggml_tensor * src0,
6328
+ const struct ggml_tensor * src1,
6329
+ struct ggml_tensor * dst) {
6330
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6331
+
6332
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6333
+ return;
6334
+ }
6335
+
6336
+ const int ith = params->ith;
6337
+ const int nth = params->nth;
6338
+
6339
+ const int n = ggml_nrows(src0);
6340
+ const int nc = src0->ne[0];
6341
+
6342
+ const size_t nb00 = src0->nb[0];
6343
+ const size_t nb01 = src0->nb[1];
6344
+
6345
+ const size_t nb10 = src1->nb[0];
6346
+ const size_t nb11 = src1->nb[1];
6347
+
6348
+ const size_t nb0 = dst->nb[0];
6349
+ const size_t nb1 = dst->nb[1];
6350
+
6351
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6352
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6353
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6354
+
6355
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6356
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6357
+
6358
+ if (nb10 == sizeof(float)) {
6359
+ for (int j = ith; j < n; j += nth) {
6360
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6361
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6362
+ for (int i = 0; i < nc; i++) {
6363
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6364
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
6365
+ }
6366
+ }
6367
+ }
6368
+ else {
6369
+ // src1 is not contiguous
6370
+ GGML_ASSERT(false);
6371
+ }
6372
+ }
6373
+
6374
+ static void ggml_compute_forward_add_f16_f16(
6375
+ const struct ggml_compute_params * params,
6376
+ const struct ggml_tensor * src0,
6377
+ const struct ggml_tensor * src1,
6378
+ struct ggml_tensor * dst) {
6379
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6380
+
6381
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6382
+ return;
6383
+ }
6384
+
6385
+ const int ith = params->ith;
6386
+ const int nth = params->nth;
6387
+
6388
+ const int n = ggml_nrows(src0);
6389
+ const int nc = src0->ne[0];
6390
+
6391
+ const size_t nb00 = src0->nb[0];
6392
+ const size_t nb01 = src0->nb[1];
6393
+
6394
+ const size_t nb10 = src1->nb[0];
6395
+ const size_t nb11 = src1->nb[1];
6396
+
6397
+ const size_t nb0 = dst->nb[0];
6398
+ const size_t nb1 = dst->nb[1];
6399
+
6400
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6401
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
6402
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6403
+
6404
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6405
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6406
+
6407
+ if (nb10 == sizeof(ggml_fp16_t)) {
6408
+ for (int j = ith; j < n; j += nth) {
6409
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6410
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6411
+ for (int i = 0; i < nc; i++) {
6412
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
6413
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
6414
+ }
6415
+ }
6416
+ }
6417
+ else {
6418
+ // src1 is not contiguous
6419
+ GGML_ASSERT(false);
6420
+ }
6421
+ }
6422
+
6423
+ static void ggml_compute_forward_add_q_f32(
6424
+ const struct ggml_compute_params * params,
6425
+ const struct ggml_tensor * src0,
6426
+ const struct ggml_tensor * src1,
6427
+ struct ggml_tensor * dst) {
6428
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6429
+
6430
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6431
+ return;
6432
+ }
6433
+
6434
+ const int64_t ne00 = src0->ne[0];
6435
+ const int64_t ne01 = src0->ne[1];
6436
+ const int64_t ne02 = src0->ne[2];
6437
+ const int64_t ne03 = src0->ne[3];
6438
+
6439
+ //const int64_t ne10 = src1->ne[0];
6440
+ //const int64_t ne11 = src1->ne[1];
6441
+ const int64_t ne12 = src1->ne[2];
6442
+ const int64_t ne13 = src1->ne[3];
6443
+
6444
+ //const int64_t ne0 = dst->ne[0];
6445
+ //const int64_t ne1 = dst->ne[1];
6446
+ const int64_t ne2 = dst->ne[2];
6447
+ const int64_t ne3 = dst->ne[3];
6448
+
6449
+ const int nb00 = src0->nb[0];
6450
+ const int nb01 = src0->nb[1];
6451
+ const int nb02 = src0->nb[2];
6452
+ const int nb03 = src0->nb[3];
6453
+
6454
+ const int nb10 = src1->nb[0];
6455
+ const int nb11 = src1->nb[1];
6456
+ const int nb12 = src1->nb[2];
6457
+ const int nb13 = src1->nb[3];
6458
+
6459
+ const int nb0 = dst->nb[0];
6460
+ const int nb1 = dst->nb[1];
6461
+ const int nb2 = dst->nb[2];
6462
+ const int nb3 = dst->nb[3];
6463
+
6464
+ const int ith = params->ith;
6465
+ const int nth = params->nth;
6466
+
6467
+ GGML_ASSERT(ne02 == ne12);
6468
+ GGML_ASSERT(ne03 == ne13);
6469
+ GGML_ASSERT(ne2 == ne12);
6470
+ GGML_ASSERT(ne3 == ne13);
6471
+
6472
+ const enum ggml_type type = src0->type;
6473
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6474
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6475
+
6476
+ // we don't support permuted src0 or src1
6477
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
6478
+ GGML_ASSERT(nb10 == sizeof(float));
6479
+
6480
+ // dst cannot be transposed or permuted
6481
+ GGML_ASSERT(nb0 <= nb1);
6482
+ GGML_ASSERT(nb1 <= nb2);
6483
+ GGML_ASSERT(nb2 <= nb3);
6484
+
6485
+ GGML_ASSERT(ggml_is_quantized(src0->type));
6486
+ GGML_ASSERT(dst->type == src0->type);
6487
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6488
+
6489
+ // total rows in src0
6490
+ const int nr = ne01*ne02*ne03;
6491
+
6492
+ // rows per thread
6493
+ const int dr = (nr + nth - 1)/nth;
6494
+
6495
+ // row range for this thread
6496
+ const int ir0 = dr*ith;
6497
+ const int ir1 = MIN(ir0 + dr, nr);
6498
+
6499
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6500
+
6501
+ for (int ir = ir0; ir < ir1; ++ir) {
6502
+ // src0 indices
6503
+ const int i03 = ir/(ne02*ne01);
6504
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
6505
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6506
+
6507
+ // src1 and dst are same shape as src0 => same indices
6508
+ const int i13 = i03;
6509
+ const int i12 = i02;
6510
+ const int i11 = i01;
6511
+
6512
+ const int i3 = i03;
6513
+ const int i2 = i02;
6514
+ const int i1 = i01;
6515
+
6516
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
6517
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
6518
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
6519
+
6520
+ assert(ne00 % 32 == 0);
6521
+
6522
+ // unquantize row from src0 to temp buffer
6523
+ dequantize_row_q(src0_row, wdata, ne00);
6524
+ // add src1
6525
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
6526
+ // quantize row to dst
6527
+ quantize_row_q(wdata, dst_row, ne00);
6528
+ }
6529
+ }
6530
+
5500
6531
  static void ggml_compute_forward_add(
5501
6532
  const struct ggml_compute_params * params,
5502
6533
  const struct ggml_tensor * src0,
@@ -5507,13 +6538,26 @@ static void ggml_compute_forward_add(
5507
6538
  {
5508
6539
  ggml_compute_forward_add_f32(params, src0, src1, dst);
5509
6540
  } break;
6541
+ case GGML_TYPE_F16:
6542
+ {
6543
+ if (src1->type == GGML_TYPE_F16) {
6544
+ ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6545
+ }
6546
+ else if (src1->type == GGML_TYPE_F32) {
6547
+ ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6548
+ }
6549
+ else {
6550
+ GGML_ASSERT(false);
6551
+ }
6552
+ } break;
5510
6553
  case GGML_TYPE_Q4_0:
5511
6554
  case GGML_TYPE_Q4_1:
5512
- case GGML_TYPE_I8:
5513
- case GGML_TYPE_I16:
5514
- case GGML_TYPE_I32:
5515
- case GGML_TYPE_F16:
5516
- case GGML_TYPE_COUNT:
6555
+ case GGML_TYPE_Q4_2:
6556
+ case GGML_TYPE_Q4_3:
6557
+ {
6558
+ ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6559
+ } break;
6560
+ default:
5517
6561
  {
5518
6562
  GGML_ASSERT(false);
5519
6563
  } break;
@@ -5559,13 +6603,7 @@ static void ggml_compute_forward_sub(
5559
6603
  {
5560
6604
  ggml_compute_forward_sub_f32(params, src0, src1, dst);
5561
6605
  } break;
5562
- case GGML_TYPE_Q4_0:
5563
- case GGML_TYPE_Q4_1:
5564
- case GGML_TYPE_I8:
5565
- case GGML_TYPE_I16:
5566
- case GGML_TYPE_I32:
5567
- case GGML_TYPE_F16:
5568
- case GGML_TYPE_COUNT:
6606
+ default:
5569
6607
  {
5570
6608
  GGML_ASSERT(false);
5571
6609
  } break;
@@ -5611,13 +6649,7 @@ static void ggml_compute_forward_mul(
5611
6649
  {
5612
6650
  ggml_compute_forward_mul_f32(params, src0, src1, dst);
5613
6651
  } break;
5614
- case GGML_TYPE_Q4_0:
5615
- case GGML_TYPE_Q4_1:
5616
- case GGML_TYPE_I8:
5617
- case GGML_TYPE_I16:
5618
- case GGML_TYPE_I32:
5619
- case GGML_TYPE_F16:
5620
- case GGML_TYPE_COUNT:
6652
+ default:
5621
6653
  {
5622
6654
  GGML_ASSERT(false);
5623
6655
  } break;
@@ -5663,13 +6695,7 @@ static void ggml_compute_forward_div(
5663
6695
  {
5664
6696
  ggml_compute_forward_div_f32(params, src0, src1, dst);
5665
6697
  } break;
5666
- case GGML_TYPE_Q4_0:
5667
- case GGML_TYPE_Q4_1:
5668
- case GGML_TYPE_I8:
5669
- case GGML_TYPE_I16:
5670
- case GGML_TYPE_I32:
5671
- case GGML_TYPE_F16:
5672
- case GGML_TYPE_COUNT:
6698
+ default:
5673
6699
  {
5674
6700
  GGML_ASSERT(false);
5675
6701
  } break;
@@ -5711,13 +6737,7 @@ static void ggml_compute_forward_sqr(
5711
6737
  {
5712
6738
  ggml_compute_forward_sqr_f32(params, src0, dst);
5713
6739
  } break;
5714
- case GGML_TYPE_Q4_0:
5715
- case GGML_TYPE_Q4_1:
5716
- case GGML_TYPE_I8:
5717
- case GGML_TYPE_I16:
5718
- case GGML_TYPE_I32:
5719
- case GGML_TYPE_F16:
5720
- case GGML_TYPE_COUNT:
6740
+ default:
5721
6741
  {
5722
6742
  GGML_ASSERT(false);
5723
6743
  } break;
@@ -5759,13 +6779,7 @@ static void ggml_compute_forward_sqrt(
5759
6779
  {
5760
6780
  ggml_compute_forward_sqrt_f32(params, src0, dst);
5761
6781
  } break;
5762
- case GGML_TYPE_Q4_0:
5763
- case GGML_TYPE_Q4_1:
5764
- case GGML_TYPE_I8:
5765
- case GGML_TYPE_I16:
5766
- case GGML_TYPE_I32:
5767
- case GGML_TYPE_F16:
5768
- case GGML_TYPE_COUNT:
6782
+ default:
5769
6783
  {
5770
6784
  GGML_ASSERT(false);
5771
6785
  } break;
@@ -5817,13 +6831,7 @@ static void ggml_compute_forward_sum(
5817
6831
  {
5818
6832
  ggml_compute_forward_sum_f32(params, src0, dst);
5819
6833
  } break;
5820
- case GGML_TYPE_Q4_0:
5821
- case GGML_TYPE_Q4_1:
5822
- case GGML_TYPE_I8:
5823
- case GGML_TYPE_I16:
5824
- case GGML_TYPE_I32:
5825
- case GGML_TYPE_F16:
5826
- case GGML_TYPE_COUNT:
6834
+ default:
5827
6835
  {
5828
6836
  GGML_ASSERT(false);
5829
6837
  } break;
@@ -5894,13 +6902,7 @@ static void ggml_compute_forward_mean(
5894
6902
  {
5895
6903
  ggml_compute_forward_mean_f32(params, src0, dst);
5896
6904
  } break;
5897
- case GGML_TYPE_Q4_0:
5898
- case GGML_TYPE_Q4_1:
5899
- case GGML_TYPE_I8:
5900
- case GGML_TYPE_I16:
5901
- case GGML_TYPE_I32:
5902
- case GGML_TYPE_F16:
5903
- case GGML_TYPE_COUNT:
6905
+ default:
5904
6906
  {
5905
6907
  GGML_ASSERT(false);
5906
6908
  } break;
@@ -5958,13 +6960,7 @@ static void ggml_compute_forward_repeat(
5958
6960
  {
5959
6961
  ggml_compute_forward_repeat_f32(params, src0, dst);
5960
6962
  } break;
5961
- case GGML_TYPE_Q4_0:
5962
- case GGML_TYPE_Q4_1:
5963
- case GGML_TYPE_I8:
5964
- case GGML_TYPE_I16:
5965
- case GGML_TYPE_I32:
5966
- case GGML_TYPE_F16:
5967
- case GGML_TYPE_COUNT:
6963
+ default:
5968
6964
  {
5969
6965
  GGML_ASSERT(false);
5970
6966
  } break;
@@ -6006,13 +7002,7 @@ static void ggml_compute_forward_abs(
6006
7002
  {
6007
7003
  ggml_compute_forward_abs_f32(params, src0, dst);
6008
7004
  } break;
6009
- case GGML_TYPE_Q4_0:
6010
- case GGML_TYPE_Q4_1:
6011
- case GGML_TYPE_I8:
6012
- case GGML_TYPE_I16:
6013
- case GGML_TYPE_I32:
6014
- case GGML_TYPE_F16:
6015
- case GGML_TYPE_COUNT:
7005
+ default:
6016
7006
  {
6017
7007
  GGML_ASSERT(false);
6018
7008
  } break;
@@ -6054,13 +7044,7 @@ static void ggml_compute_forward_sgn(
6054
7044
  {
6055
7045
  ggml_compute_forward_sgn_f32(params, src0, dst);
6056
7046
  } break;
6057
- case GGML_TYPE_Q4_0:
6058
- case GGML_TYPE_Q4_1:
6059
- case GGML_TYPE_I8:
6060
- case GGML_TYPE_I16:
6061
- case GGML_TYPE_I32:
6062
- case GGML_TYPE_F16:
6063
- case GGML_TYPE_COUNT:
7047
+ default:
6064
7048
  {
6065
7049
  GGML_ASSERT(false);
6066
7050
  } break;
@@ -6102,13 +7086,7 @@ static void ggml_compute_forward_neg(
6102
7086
  {
6103
7087
  ggml_compute_forward_neg_f32(params, src0, dst);
6104
7088
  } break;
6105
- case GGML_TYPE_Q4_0:
6106
- case GGML_TYPE_Q4_1:
6107
- case GGML_TYPE_I8:
6108
- case GGML_TYPE_I16:
6109
- case GGML_TYPE_I32:
6110
- case GGML_TYPE_F16:
6111
- case GGML_TYPE_COUNT:
7089
+ default:
6112
7090
  {
6113
7091
  GGML_ASSERT(false);
6114
7092
  } break;
@@ -6150,13 +7128,7 @@ static void ggml_compute_forward_step(
6150
7128
  {
6151
7129
  ggml_compute_forward_step_f32(params, src0, dst);
6152
7130
  } break;
6153
- case GGML_TYPE_Q4_0:
6154
- case GGML_TYPE_Q4_1:
6155
- case GGML_TYPE_I8:
6156
- case GGML_TYPE_I16:
6157
- case GGML_TYPE_I32:
6158
- case GGML_TYPE_F16:
6159
- case GGML_TYPE_COUNT:
7131
+ default:
6160
7132
  {
6161
7133
  GGML_ASSERT(false);
6162
7134
  } break;
@@ -6193,18 +7165,12 @@ static void ggml_compute_forward_relu(
6193
7165
  const struct ggml_compute_params * params,
6194
7166
  const struct ggml_tensor * src0,
6195
7167
  struct ggml_tensor * dst) {
6196
- switch (src0->type) {
6197
- case GGML_TYPE_F32:
6198
- {
6199
- ggml_compute_forward_relu_f32(params, src0, dst);
6200
- } break;
6201
- case GGML_TYPE_Q4_0:
6202
- case GGML_TYPE_Q4_1:
6203
- case GGML_TYPE_I8:
6204
- case GGML_TYPE_I16:
6205
- case GGML_TYPE_I32:
6206
- case GGML_TYPE_F16:
6207
- case GGML_TYPE_COUNT:
7168
+ switch (src0->type) {
7169
+ case GGML_TYPE_F32:
7170
+ {
7171
+ ggml_compute_forward_relu_f32(params, src0, dst);
7172
+ } break;
7173
+ default:
6208
7174
  {
6209
7175
  GGML_ASSERT(false);
6210
7176
  } break;
@@ -6263,13 +7229,7 @@ static void ggml_compute_forward_gelu(
6263
7229
  {
6264
7230
  ggml_compute_forward_gelu_f32(params, src0, dst);
6265
7231
  } break;
6266
- case GGML_TYPE_Q4_0:
6267
- case GGML_TYPE_Q4_1:
6268
- case GGML_TYPE_I8:
6269
- case GGML_TYPE_I16:
6270
- case GGML_TYPE_I32:
6271
- case GGML_TYPE_F16:
6272
- case GGML_TYPE_COUNT:
7232
+ default:
6273
7233
  {
6274
7234
  GGML_ASSERT(false);
6275
7235
  } break;
@@ -6330,13 +7290,7 @@ static void ggml_compute_forward_silu(
6330
7290
  {
6331
7291
  ggml_compute_forward_silu_f32(params, src0, dst);
6332
7292
  } break;
6333
- case GGML_TYPE_Q4_0:
6334
- case GGML_TYPE_Q4_1:
6335
- case GGML_TYPE_I8:
6336
- case GGML_TYPE_I16:
6337
- case GGML_TYPE_I32:
6338
- case GGML_TYPE_F16:
6339
- case GGML_TYPE_COUNT:
7293
+ default:
6340
7294
  {
6341
7295
  GGML_ASSERT(false);
6342
7296
  } break;
@@ -6416,13 +7370,7 @@ static void ggml_compute_forward_norm(
6416
7370
  {
6417
7371
  ggml_compute_forward_norm_f32(params, src0, dst);
6418
7372
  } break;
6419
- case GGML_TYPE_Q4_0:
6420
- case GGML_TYPE_Q4_1:
6421
- case GGML_TYPE_I8:
6422
- case GGML_TYPE_I16:
6423
- case GGML_TYPE_I32:
6424
- case GGML_TYPE_F16:
6425
- case GGML_TYPE_COUNT:
7373
+ default:
6426
7374
  {
6427
7375
  GGML_ASSERT(false);
6428
7376
  } break;
@@ -6496,13 +7444,7 @@ static void ggml_compute_forward_rms_norm(
6496
7444
  {
6497
7445
  ggml_compute_forward_rms_norm_f32(params, src0, dst);
6498
7446
  } break;
6499
- case GGML_TYPE_Q4_0:
6500
- case GGML_TYPE_Q4_1:
6501
- case GGML_TYPE_I8:
6502
- case GGML_TYPE_I16:
6503
- case GGML_TYPE_I32:
6504
- case GGML_TYPE_F16:
6505
- case GGML_TYPE_COUNT:
7447
+ default:
6506
7448
  {
6507
7449
  GGML_ASSERT(false);
6508
7450
  } break;
@@ -6512,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
6512
7454
 
6513
7455
  // ggml_compute_forward_mul_mat
6514
7456
 
6515
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7457
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6516
7458
  // helper function to determine if it is better to use BLAS or not
6517
7459
  // for large matrices, BLAS is faster
6518
7460
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -6552,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
6552
7494
  const int64_t ne02 = src0->ne[2];
6553
7495
  const int64_t ne03 = src0->ne[3];
6554
7496
 
6555
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7497
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6556
7498
  const int64_t ne10 = src1->ne[0];
6557
7499
  #endif
6558
7500
  const int64_t ne11 = src1->ne[1];
@@ -6609,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
6609
7551
  // nb01 >= nb00 - src0 is not transposed
6610
7552
  // compute by src0 rows
6611
7553
 
6612
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7554
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6613
7555
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
6614
7556
  if (params->ith != 0) {
6615
7557
  return;
@@ -6623,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
6623
7565
  return;
6624
7566
  }
6625
7567
 
7568
+ #if defined(GGML_USE_CUBLAS)
7569
+ float *d_X = NULL;
7570
+ float *d_Y = NULL;
7571
+ float *d_D = NULL;
7572
+ const float alpha = 1.0f;
7573
+ const float beta = 0.0f;
7574
+ const int x_ne = ne01 * ne10;
7575
+ const int y_ne = ne11 * ne10;
7576
+ const int d_ne = ne11 * ne01;
7577
+
7578
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7581
+ #endif
7582
+
6626
7583
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6627
7584
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6628
7585
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -6630,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
6630
7587
 
6631
7588
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
6632
7589
 
7590
+ #if defined(GGML_USE_CUBLAS)
7591
+ // copy data to device
7592
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7594
+
7595
+ // compute
7596
+ CUBLAS_CHECK(
7597
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
+ ne01, ne11, ne10,
7599
+ &alpha, d_X, ne00,
7600
+ d_Y, ne10,
7601
+ &beta, d_D, ne01));
7602
+
7603
+ // copy data to host
7604
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
+ #else
6633
7606
  // zT = y * xT
6634
7607
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6635
7608
  ne11, ne01, ne10,
6636
7609
  1.0f, y, ne10,
6637
7610
  x, ne00,
6638
7611
  0.0f, d, ne01);
7612
+ #endif
6639
7613
  }
6640
7614
  }
6641
-
7615
+ #if defined(GGML_USE_CUBLAS)
7616
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
+ CUDA_CHECK(cudaFree(d_X));
7618
+ CUDA_CHECK(cudaFree(d_Y));
7619
+ CUDA_CHECK(cudaFree(d_D));
7620
+ #endif
6642
7621
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
6643
7622
 
6644
7623
  return;
@@ -6768,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6768
7747
  // nb01 >= nb00 - src0 is not transposed
6769
7748
  // compute by src0 rows
6770
7749
 
6771
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7750
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6772
7751
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
6773
7752
  GGML_ASSERT(nb10 == sizeof(float));
6774
7753
 
@@ -6784,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6784
7763
  return;
6785
7764
  }
6786
7765
 
6787
- float * const wdata = params->wdata;
7766
+ #if defined(GGML_USE_CUBLAS)
7767
+ ggml_fp16_t * const wdata = params->wdata;
6788
7768
 
7769
+ float *d_X = NULL;
7770
+ float *d_Y = NULL;
7771
+ float *d_D = NULL;
7772
+ const float alpha = 1.0f;
7773
+ const float beta = 0.0f;
7774
+ const int x_ne = ne01 * ne10;
7775
+ const int y_ne = ne11 * ne10;
7776
+ const int d_ne = ne11 * ne01;
7777
+
7778
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7781
+ #else
7782
+ float * const wdata = params->wdata;
7783
+ #endif
6789
7784
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6790
7785
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7786
+ #if defined(GGML_USE_CUBLAS)
7787
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7788
+ {
7789
+ size_t id = 0;
7790
+ for (int64_t i01 = 0; i01 < ne11; ++i01) {
7791
+ for (int64_t i00 = 0; i00 < ne10; ++i00) {
7792
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
7793
+ }
7794
+ }
7795
+ }
7796
+ #else
6791
7797
  {
6792
7798
  size_t id = 0;
6793
7799
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -6796,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6796
7802
  }
6797
7803
  }
6798
7804
  }
7805
+ #endif
6799
7806
 
7807
+ #if defined(GGML_USE_CUBLAS)
7808
+ const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
7809
+ const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
7810
+
7811
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7812
+
7813
+ // copy data to device
7814
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7816
+
7817
+ // compute
7818
+ CUBLAS_CHECK(
7819
+ cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
+ ne01, ne11, ne10,
7821
+ &alpha, d_X, CUDA_R_16F, ne00,
7822
+ d_Y, CUDA_R_16F, ne10,
7823
+ &beta, d_D, CUDA_R_32F, ne01,
7824
+ CUBLAS_COMPUTE_32F,
7825
+ CUBLAS_GEMM_DEFAULT));
7826
+
7827
+ // copy data to host
7828
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7829
+ #else
6800
7830
  const float * x = wdata;
6801
7831
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
6802
7832
 
@@ -6808,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6808
7838
  1.0f, y, ne10,
6809
7839
  x, ne00,
6810
7840
  0.0f, d, ne01);
7841
+ #endif
6811
7842
  }
6812
7843
  }
6813
7844
 
7845
+ #if defined(GGML_USE_CUBLAS)
7846
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
+ CUDA_CHECK(cudaFree(d_X));
7848
+ CUDA_CHECK(cudaFree(d_Y));
7849
+ CUDA_CHECK(cudaFree(d_D));
7850
+ #endif
6814
7851
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
6815
7852
 
6816
7853
  return;
@@ -6894,27 +7931,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6894
7931
  //}
6895
7932
  }
6896
7933
 
6897
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6898
- [GGML_TYPE_Q4_0] = {
6899
- .dequantize_row_q = dequantize_row_q4_0,
6900
- .quantize_row_q = quantize_row_q4_0,
6901
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
- .vec_dot_q = ggml_vec_dot_q4_0,
6903
- },
6904
- [GGML_TYPE_Q4_1] = {
6905
- .dequantize_row_q = dequantize_row_q4_1,
6906
- .quantize_row_q = quantize_row_q4_1,
6907
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
- .vec_dot_q = ggml_vec_dot_q4_1,
6909
- },
6910
- };
6911
-
6912
- // For internal test use
6913
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
- GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
- return quantize_fns[i];
6916
- }
6917
-
6918
7934
  static void ggml_compute_forward_mul_mat_q_f32(
6919
7935
  const struct ggml_compute_params * params,
6920
7936
  const struct ggml_tensor * src0,
@@ -6962,8 +7978,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
6962
7978
  GGML_ASSERT(ne3 == ne13);
6963
7979
 
6964
7980
  const enum ggml_type type = src0->type;
6965
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6966
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
7981
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7982
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
6967
7983
 
6968
7984
  // we don't support permuted src0 or src1
6969
7985
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -6983,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6983
7999
  // nb01 >= nb00 - src0 is not transposed
6984
8000
  // compute by src0 rows
6985
8001
 
6986
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8002
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6987
8003
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
6988
8004
  if (params->ith != 0) {
6989
8005
  return;
@@ -6997,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
6997
8013
  return;
6998
8014
  }
6999
8015
 
8016
+ #if defined(GGML_USE_CUBLAS)
8017
+ float *d_X = NULL;
8018
+ float *d_Y = NULL;
8019
+ float *d_D = NULL;
8020
+ float *d_Q = NULL;
8021
+ const float alpha = 1.0f;
8022
+ const float beta = 0.0f;
8023
+ const int x_ne = ne01 * ne10;
8024
+ const int y_ne = ne11 * ne10;
8025
+ const int d_ne = ne11 * ne01;
8026
+
8027
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
+ CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8031
+
8032
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
+ if (type == GGML_TYPE_Q4_0) {
8034
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8035
+ }
8036
+ else if (type == GGML_TYPE_Q4_1) {
8037
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8038
+ }
8039
+ else if (type == GGML_TYPE_Q4_2) {
8040
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
+ }
8042
+ else {
8043
+ GGML_ASSERT(false);
8044
+ }
8045
+ #else
7000
8046
  float * const wdata = params->wdata;
7001
8047
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
+ #endif
7002
8049
 
7003
8050
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7004
8051
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8052
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8053
+
8054
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8055
+
8056
+ #if defined(GGML_USE_CUBLAS)
8057
+ // copy and dequantize on device
8058
+ CUDA_CHECK(
8059
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8061
+
8062
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8063
+ CUDA_CHECK(cudaGetLastError());
8064
+ #else
7005
8065
  {
7006
8066
  size_t id = 0;
7007
8067
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7009,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
7009
8069
  id += ne00;
7010
8070
  }
7011
8071
  }
7012
-
7013
8072
  const float * x = wdata;
7014
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8073
+ #endif
7015
8074
 
7016
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7017
8075
 
8076
+ #if defined(GGML_USE_CUBLAS)
8077
+ // copy data to device
8078
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8079
+
8080
+ // compute
8081
+ CUBLAS_CHECK(
8082
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
+ ne01, ne11, ne10,
8084
+ &alpha, d_X, ne00,
8085
+ d_Y, ne10,
8086
+ &beta, d_D, ne01));
8087
+
8088
+ // copy data to host
8089
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
+ #else
7018
8091
  // zT = y * xT
7019
8092
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7020
8093
  ne11, ne01, ne10,
7021
8094
  1.0f, y, ne10,
7022
8095
  x, ne00,
7023
8096
  0.0f, d, ne01);
8097
+ #endif
7024
8098
  }
7025
8099
  }
7026
8100
 
8101
+ #if defined(GGML_USE_CUBLAS)
8102
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
+ CUDA_CHECK(cudaFree(d_X));
8104
+ CUDA_CHECK(cudaFree(d_Y));
8105
+ CUDA_CHECK(cudaFree(d_D));
8106
+ CUDA_CHECK(cudaFree(d_Q));
8107
+ #endif
7027
8108
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7028
8109
 
7029
8110
  return;
@@ -7032,12 +8113,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
7032
8113
 
7033
8114
  if (params->type == GGML_TASK_INIT) {
7034
8115
  char * wdata = params->wdata;
7035
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
8116
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7036
8117
 
7037
8118
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7038
8119
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7039
8120
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
7040
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
8121
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7041
8122
  wdata += row_size;
7042
8123
  }
7043
8124
  }
@@ -7063,7 +8144,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7063
8144
  const int ir1 = MIN(ir0 + dr, nr);
7064
8145
 
7065
8146
  void * wdata = params->wdata;
7066
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
8147
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7067
8148
 
7068
8149
  for (int ir = ir0; ir < ir1; ++ir) {
7069
8150
  // src0 indices
@@ -7111,6 +8192,9 @@ static void ggml_compute_forward_mul_mat(
7111
8192
  switch (src0->type) {
7112
8193
  case GGML_TYPE_Q4_0:
7113
8194
  case GGML_TYPE_Q4_1:
8195
+ case GGML_TYPE_Q4_2:
8196
+ case GGML_TYPE_Q4_3:
8197
+ case GGML_TYPE_Q8_0:
7114
8198
  {
7115
8199
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
7116
8200
  } break;
@@ -7122,42 +8206,11 @@ static void ggml_compute_forward_mul_mat(
7122
8206
  {
7123
8207
  ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
7124
8208
  } break;
7125
- case GGML_TYPE_I8:
7126
- case GGML_TYPE_I16:
7127
- case GGML_TYPE_I32:
7128
- case GGML_TYPE_COUNT:
8209
+ default:
7129
8210
  {
7130
8211
  GGML_ASSERT(false);
7131
8212
  } break;
7132
8213
  }
7133
-
7134
- #if 0
7135
- if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
7136
- static int first = 8;
7137
- printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7138
- printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7139
- printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7140
- if (first) {
7141
- --first;
7142
- } else {
7143
- for (int k = 0; k < dst->ne[1]; ++k) {
7144
- for (int j = 0; j < dst->ne[0]/16; ++j) {
7145
- for (int i = 0; i < 16; ++i) {
7146
- printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
7147
- }
7148
- printf("\n");
7149
- }
7150
- printf("\n");
7151
- }
7152
- printf("\n");
7153
- exit(0);
7154
- }
7155
- } else {
7156
- printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7157
- printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7158
- printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7159
- }
7160
- #endif
7161
8214
  }
7162
8215
 
7163
8216
  // ggml_compute_forward_scale
@@ -7207,13 +8260,7 @@ static void ggml_compute_forward_scale(
7207
8260
  {
7208
8261
  ggml_compute_forward_scale_f32(params, src0, src1, dst);
7209
8262
  } break;
7210
- case GGML_TYPE_Q4_0:
7211
- case GGML_TYPE_Q4_1:
7212
- case GGML_TYPE_I8:
7213
- case GGML_TYPE_I16:
7214
- case GGML_TYPE_I32:
7215
- case GGML_TYPE_F16:
7216
- case GGML_TYPE_COUNT:
8263
+ default:
7217
8264
  {
7218
8265
  GGML_ASSERT(false);
7219
8266
  } break;
@@ -7374,6 +8421,9 @@ static void ggml_compute_forward_get_rows(
7374
8421
  switch (src0->type) {
7375
8422
  case GGML_TYPE_Q4_0:
7376
8423
  case GGML_TYPE_Q4_1:
8424
+ case GGML_TYPE_Q4_2:
8425
+ case GGML_TYPE_Q4_3:
8426
+ case GGML_TYPE_Q8_0:
7377
8427
  {
7378
8428
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
7379
8429
  } break;
@@ -7385,10 +8435,7 @@ static void ggml_compute_forward_get_rows(
7385
8435
  {
7386
8436
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
7387
8437
  } break;
7388
- case GGML_TYPE_I8:
7389
- case GGML_TYPE_I16:
7390
- case GGML_TYPE_I32:
7391
- case GGML_TYPE_COUNT:
8438
+ default:
7392
8439
  {
7393
8440
  GGML_ASSERT(false);
7394
8441
  } break;
@@ -7461,13 +8508,7 @@ static void ggml_compute_forward_diag_mask_inf(
7461
8508
  {
7462
8509
  ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
7463
8510
  } break;
7464
- case GGML_TYPE_Q4_0:
7465
- case GGML_TYPE_Q4_1:
7466
- case GGML_TYPE_I8:
7467
- case GGML_TYPE_I16:
7468
- case GGML_TYPE_I32:
7469
- case GGML_TYPE_F16:
7470
- case GGML_TYPE_COUNT:
8511
+ default:
7471
8512
  {
7472
8513
  GGML_ASSERT(false);
7473
8514
  } break;
@@ -7555,13 +8596,7 @@ static void ggml_compute_forward_soft_max(
7555
8596
  {
7556
8597
  ggml_compute_forward_soft_max_f32(params, src0, dst);
7557
8598
  } break;
7558
- case GGML_TYPE_Q4_0:
7559
- case GGML_TYPE_Q4_1:
7560
- case GGML_TYPE_I8:
7561
- case GGML_TYPE_I16:
7562
- case GGML_TYPE_I32:
7563
- case GGML_TYPE_F16:
7564
- case GGML_TYPE_COUNT:
8599
+ default:
7565
8600
  {
7566
8601
  GGML_ASSERT(false);
7567
8602
  } break;
@@ -7618,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
7618
8653
 
7619
8654
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
7620
8655
 
8656
+ const bool is_neox = mode & 2;
8657
+
7621
8658
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7622
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7623
- const int p = (mode == 0 ? n_past + i2 : i2);
8659
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8660
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
7624
8661
  for (int64_t i1 = 0; i1 < ne1; i1++) {
7625
8662
  if (ir++ < ir0) continue;
7626
8663
  if (ir > ir1) break;
@@ -7633,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
7633
8670
 
7634
8671
  theta *= theta_scale;
7635
8672
 
7636
- const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7637
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8673
+ if (!is_neox) {
8674
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8675
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8676
+
8677
+ const float x0 = src[0];
8678
+ const float x1 = src[1];
7638
8679
 
7639
- const float x0 = src[0];
7640
- const float x1 = src[1];
8680
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8681
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
8682
+ } else {
8683
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8684
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
7641
8685
 
7642
- dst_data[0] = x0*cos_theta - x1*sin_theta;
7643
- dst_data[1] = x0*sin_theta + x1*cos_theta;
8686
+ const float x0 = src[0];
8687
+ const float x1 = src[n_dims/2];
8688
+
8689
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8690
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
8691
+ }
7644
8692
  }
7645
8693
  }
7646
8694
  }
@@ -7695,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
7695
8743
 
7696
8744
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
7697
8745
 
8746
+ const bool is_neox = mode & 2;
8747
+
7698
8748
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7699
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7700
- const int p = (mode == 0 ? n_past + i2 : i2);
8749
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8750
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
7701
8751
  for (int64_t i1 = 0; i1 < ne1; i1++) {
7702
8752
  if (ir++ < ir0) continue;
7703
8753
  if (ir > ir1) break;
@@ -7710,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
7710
8760
 
7711
8761
  theta *= theta_scale;
7712
8762
 
7713
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7714
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8763
+ if (!is_neox) {
8764
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8765
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8766
+
8767
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8768
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
7715
8769
 
7716
- const float x0 = ggml_fp16_to_fp32(src[0]);
7717
- const float x1 = ggml_fp16_to_fp32(src[1]);
8770
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8771
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8772
+ } else {
8773
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8774
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
7718
8775
 
7719
- dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
7720
- dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
8776
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8777
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8778
+
8779
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8780
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8781
+ }
7721
8782
  }
7722
8783
  }
7723
8784
  }
@@ -7738,12 +8799,7 @@ static void ggml_compute_forward_rope(
7738
8799
  {
7739
8800
  ggml_compute_forward_rope_f32(params, src0, src1, dst);
7740
8801
  } break;
7741
- case GGML_TYPE_Q4_0:
7742
- case GGML_TYPE_Q4_1:
7743
- case GGML_TYPE_I8:
7744
- case GGML_TYPE_I16:
7745
- case GGML_TYPE_I32:
7746
- case GGML_TYPE_COUNT:
8802
+ default:
7747
8803
  {
7748
8804
  GGML_ASSERT(false);
7749
8805
  } break;
@@ -8006,12 +9062,7 @@ static void ggml_compute_forward_conv_1d_1s(
8006
9062
  {
8007
9063
  ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
8008
9064
  } break;
8009
- case GGML_TYPE_Q4_0:
8010
- case GGML_TYPE_Q4_1:
8011
- case GGML_TYPE_I8:
8012
- case GGML_TYPE_I16:
8013
- case GGML_TYPE_I32:
8014
- case GGML_TYPE_COUNT:
9065
+ default:
8015
9066
  {
8016
9067
  GGML_ASSERT(false);
8017
9068
  } break;
@@ -8274,12 +9325,7 @@ static void ggml_compute_forward_conv_1d_2s(
8274
9325
  {
8275
9326
  ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
8276
9327
  } break;
8277
- case GGML_TYPE_Q4_0:
8278
- case GGML_TYPE_Q4_1:
8279
- case GGML_TYPE_I8:
8280
- case GGML_TYPE_I16:
8281
- case GGML_TYPE_I32:
8282
- case GGML_TYPE_COUNT:
9328
+ default:
8283
9329
  {
8284
9330
  GGML_ASSERT(false);
8285
9331
  } break;
@@ -8759,12 +9805,7 @@ static void ggml_compute_forward_flash_attn(
8759
9805
  {
8760
9806
  ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
8761
9807
  } break;
8762
- case GGML_TYPE_Q4_0:
8763
- case GGML_TYPE_Q4_1:
8764
- case GGML_TYPE_I8:
8765
- case GGML_TYPE_I16:
8766
- case GGML_TYPE_I32:
8767
- case GGML_TYPE_COUNT:
9808
+ default:
8768
9809
  {
8769
9810
  GGML_ASSERT(false);
8770
9811
  } break;
@@ -8970,12 +10011,7 @@ static void ggml_compute_forward_flash_ff(
8970
10011
  {
8971
10012
  GGML_ASSERT(false); // TODO
8972
10013
  } break;
8973
- case GGML_TYPE_Q4_0:
8974
- case GGML_TYPE_Q4_1:
8975
- case GGML_TYPE_I8:
8976
- case GGML_TYPE_I16:
8977
- case GGML_TYPE_I32:
8978
- case GGML_TYPE_COUNT:
10014
+ default:
8979
10015
  {
8980
10016
  GGML_ASSERT(false);
8981
10017
  } break;
@@ -9019,13 +10055,7 @@ static void ggml_compute_forward_map_unary(
9019
10055
  {
9020
10056
  ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
10057
  } break;
9022
- case GGML_TYPE_Q4_0:
9023
- case GGML_TYPE_Q4_1:
9024
- case GGML_TYPE_I8:
9025
- case GGML_TYPE_I16:
9026
- case GGML_TYPE_I32:
9027
- case GGML_TYPE_F16:
9028
- case GGML_TYPE_COUNT:
10058
+ default:
9029
10059
  {
9030
10060
  GGML_ASSERT(false);
9031
10061
  } break;
@@ -9074,13 +10104,7 @@ static void ggml_compute_forward_map_binary(
9074
10104
  {
9075
10105
  ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
10106
  } break;
9077
- case GGML_TYPE_Q4_0:
9078
- case GGML_TYPE_Q4_1:
9079
- case GGML_TYPE_I8:
9080
- case GGML_TYPE_I16:
9081
- case GGML_TYPE_I32:
9082
- case GGML_TYPE_F16:
9083
- case GGML_TYPE_COUNT:
10107
+ default:
9084
10108
  {
9085
10109
  GGML_ASSERT(false);
9086
10110
  } break;
@@ -9830,13 +10854,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9830
10854
  struct ggml_tensor * node = cgraph->nodes[i];
9831
10855
 
9832
10856
  switch (node->op) {
10857
+ case GGML_OP_CPY:
9833
10858
  case GGML_OP_DUP:
9834
10859
  {
9835
- node->n_tasks = 1;
10860
+ node->n_tasks = n_threads;
10861
+
10862
+ size_t cur = 0;
10863
+ if (ggml_is_quantized(node->type)) {
10864
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
10865
+ }
10866
+
10867
+ work_size = MAX(work_size, cur);
9836
10868
  } break;
9837
10869
  case GGML_OP_ADD:
9838
10870
  {
9839
10871
  node->n_tasks = n_threads;
10872
+
10873
+ size_t cur = 0;
10874
+
10875
+ if (ggml_is_quantized(node->src0->type)) {
10876
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10877
+ }
10878
+
10879
+ work_size = MAX(work_size, cur);
9840
10880
  } break;
9841
10881
  case GGML_OP_SUB:
9842
10882
  case GGML_OP_MUL:
@@ -9881,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9881
10921
  size_t cur = 0;
9882
10922
 
9883
10923
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
9884
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10924
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
9885
10925
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
9886
10926
  node->n_tasks = 1; // TODO: this actually is doing nothing
9887
10927
  // the threads are still spinning
@@ -9897,15 +10937,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9897
10937
  #endif
9898
10938
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
9899
10939
  cur = 0;
9900
- } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
9901
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10940
+ } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
9902
10942
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
9903
10943
  node->n_tasks = 1;
9904
10944
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
9905
10945
  } else
9906
10946
  #endif
9907
10947
  {
9908
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
10948
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
9909
10949
  }
9910
10950
  } else {
9911
10951
  GGML_ASSERT(false);
@@ -9917,7 +10957,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9917
10957
  {
9918
10958
  node->n_tasks = n_threads;
9919
10959
  } break;
9920
- case GGML_OP_CPY:
9921
10960
  case GGML_OP_CONT:
9922
10961
  case GGML_OP_RESHAPE:
9923
10962
  case GGML_OP_VIEW:
@@ -11080,16 +12119,16 @@ enum ggml_opt_result ggml_opt(
11080
12119
  ////////////////////////////////////////////////////////////////////////////////
11081
12120
 
11082
12121
  size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
11083
- assert(k % QK == 0);
11084
- const int nb = k / QK;
12122
+ assert(k % QK4_0 == 0);
12123
+ const int nb = k / QK4_0;
11085
12124
 
11086
12125
  for (int j = 0; j < n; j += k) {
11087
- block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
12126
+ block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
11088
12127
 
11089
12128
  quantize_row_q4_0_reference(src + j, y, k);
11090
12129
 
11091
12130
  for (int i = 0; i < nb; i++) {
11092
- for (int l = 0; l < QK; l += 2) {
12131
+ for (int l = 0; l < QK4_0; l += 2) {
11093
12132
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11094
12133
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11095
12134
 
@@ -11099,20 +12138,67 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
11099
12138
  }
11100
12139
  }
11101
12140
 
11102
- return (n/QK*sizeof(block_q4_0));
12141
+ return (n/QK4_0*sizeof(block_q4_0));
11103
12142
  }
11104
12143
 
11105
12144
  size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
11106
- assert(k % QK == 0);
11107
- const int nb = k / QK;
12145
+ assert(k % QK4_1 == 0);
12146
+ const int nb = k / QK4_1;
11108
12147
 
11109
12148
  for (int j = 0; j < n; j += k) {
11110
- block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
12149
+ block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
11111
12150
 
11112
12151
  quantize_row_q4_1_reference(src + j, y, k);
11113
12152
 
11114
12153
  for (int i = 0; i < nb; i++) {
11115
- for (int l = 0; l < QK; l += 2) {
12154
+ for (int l = 0; l < QK4_1; l += 2) {
12155
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12156
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12157
+
12158
+ hist[vi0]++;
12159
+ hist[vi1]++;
12160
+ }
12161
+ }
12162
+ }
12163
+
12164
+ return (n/QK4_1*sizeof(block_q4_1));
12165
+ }
12166
+
12167
+ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
12168
+ assert(k % QK4_2 == 0);
12169
+ const int nb = k / QK4_2;
12170
+
12171
+ for (int j = 0; j < n; j += k) {
12172
+ block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
+
12174
+ //quantize_row_q4_2_reference(src + j, y, k);
12175
+ quantize_row_q4_2_rmse(src + j, y, k);
12176
+
12177
+ for (int i = 0; i < nb; i++) {
12178
+ for (int l = 0; l < QK4_2; l += 2) {
12179
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12180
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
+
12182
+ hist[vi0]++;
12183
+ hist[vi1]++;
12184
+ }
12185
+ }
12186
+ }
12187
+
12188
+ return (n/QK4_2*sizeof(block_q4_2));
12189
+ }
12190
+
12191
+ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12192
+ assert(k % QK4_3 == 0);
12193
+ const int nb = k / QK4_3;
12194
+
12195
+ for (int j = 0; j < n; j += k) {
12196
+ block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
12197
+
12198
+ quantize_row_q4_3_reference(src + j, y, k);
12199
+
12200
+ for (int i = 0; i < nb; i++) {
12201
+ for (int l = 0; l < QK4_3; l += 2) {
11116
12202
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11117
12203
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11118
12204
 
@@ -11122,7 +12208,40 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11122
12208
  }
11123
12209
  }
11124
12210
 
11125
- return (n/QK*sizeof(block_q4_1));
12211
+ return (n/QK4_3*sizeof(block_q4_3));
12212
+ }
12213
+
12214
+ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12215
+ size_t result = 0;
12216
+ switch (type) {
12217
+ case GGML_TYPE_Q4_0:
12218
+ {
12219
+ GGML_ASSERT(start % QK4_0 == 0);
12220
+ block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
12221
+ result = ggml_quantize_q4_0(src + start, block, n, n, hist);
12222
+ } break;
12223
+ case GGML_TYPE_Q4_1:
12224
+ {
12225
+ GGML_ASSERT(start % QK4_1 == 0);
12226
+ block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
12227
+ result = ggml_quantize_q4_1(src + start, block, n, n, hist);
12228
+ } break;
12229
+ case GGML_TYPE_Q4_2:
12230
+ {
12231
+ GGML_ASSERT(start % QK4_2 == 0);
12232
+ block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
12233
+ result = ggml_quantize_q4_2(src + start, block, n, n, hist);
12234
+ } break;
12235
+ case GGML_TYPE_Q4_3:
12236
+ {
12237
+ GGML_ASSERT(start % QK4_3 == 0);
12238
+ block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
+ result = ggml_quantize_q4_3(src + start, block, n, n, hist);
12240
+ } break;
12241
+ default:
12242
+ assert(false);
12243
+ }
12244
+ return result;
11126
12245
  }
11127
12246
 
11128
12247
  ////////////////////////////////////////////////////////////////////////////////
@@ -11151,6 +12270,22 @@ int ggml_cpu_has_avx512(void) {
11151
12270
  #endif
11152
12271
  }
11153
12272
 
12273
+ int ggml_cpu_has_avx512_vbmi(void) {
12274
+ #if defined(__AVX512VBMI__)
12275
+ return 1;
12276
+ #else
12277
+ return 0;
12278
+ #endif
12279
+ }
12280
+
12281
+ int ggml_cpu_has_avx512_vnni(void) {
12282
+ #if defined(__AVX512VNNI__)
12283
+ return 1;
12284
+ #else
12285
+ return 0;
12286
+ #endif
12287
+ }
12288
+
11154
12289
  int ggml_cpu_has_fma(void) {
11155
12290
  #if defined(__FMA__)
11156
12291
  return 1;
@@ -11200,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
11200
12335
  }
11201
12336
 
11202
12337
  int ggml_cpu_has_blas(void) {
11203
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12338
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
12339
+ return 1;
12340
+ #else
12341
+ return 0;
12342
+ #endif
12343
+ }
12344
+
12345
+ int ggml_cpu_has_cublas(void) {
12346
+ #if defined(GGML_USE_CUBLAS)
11204
12347
  return 1;
11205
12348
  #else
11206
12349
  return 0;