llama_cpp 0.0.4 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -118,7 +118,16 @@ typedef void* thread_ret_t;
118
118
  #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
119
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
120
  #else
121
- #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
121
+ inline static void* ggml_aligned_malloc(size_t size) {
122
+ void* aligned_memory = NULL;
123
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
124
+ if (result != 0) {
125
+ // Handle allocation failure
126
+ return NULL;
127
+ }
128
+ return aligned_memory;
129
+ }
130
+ #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
122
131
  #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
132
  #endif
124
133
 
@@ -418,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
418
427
  // quantization
419
428
  //
420
429
 
421
- #define QK 32
422
-
423
430
  // AVX routines provided by GH user Const-me
424
431
  // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
425
432
  #if __AVX2__ || __AVX512F__
@@ -531,68 +538,73 @@ inline static float vaddvq_f32(float32x4_t v) {
531
538
  return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
539
  }
533
540
 
534
- inline float vminvq_f32(float32x4_t v) {
541
+ float vminvq_f32(float32x4_t v) {
535
542
  return
536
543
  MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
544
  MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
545
  }
539
546
 
540
- inline float vmaxvq_f32(float32x4_t v) {
547
+ float vmaxvq_f32(float32x4_t v) {
541
548
  return
542
549
  MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
550
  MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
551
  }
545
552
 
546
- inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
553
+ int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
554
  return vget_low_s8(vcombine_s8(a, b));
548
555
  }
549
556
 
550
- inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
557
+ int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
558
  return vget_high_s8(vcombine_s8(a, b));
552
559
  }
553
560
 
554
- inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
561
+ uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
562
  return vget_low_u8(vcombine_u8(a, b));
556
563
  }
557
564
 
558
- inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
565
+ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
566
  return vget_high_u8(vcombine_u8(a, b));
560
567
  }
561
568
 
562
569
  #endif
563
570
  #endif
564
571
 
565
- // method 5
566
- // blocks of QK elements
567
- // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
572
+
573
+ #define QK4_0 32
568
574
  typedef struct {
569
- float d; // delta
570
- uint8_t qs[QK / 2]; // nibbles / quants
575
+ float d; // delta
576
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
571
577
  } block_q4_0;
572
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
578
+ static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
573
579
 
574
- // method 4
575
- // blocks of QK elements
576
- // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
580
+ #define QK4_1 32
577
581
  typedef struct {
578
- float d;
579
- float m;
580
- uint8_t qs[QK / 2]; // nibbles / quants
582
+ float d; // delta
583
+ float m; // min
584
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
581
585
  } block_q4_1;
582
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
586
+ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
587
+
588
+ #define QK8_0 32
589
+ typedef struct {
590
+ float d; // delta
591
+ int8_t qs[QK8_0]; // quants
592
+ } block_q8_0;
593
+ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
594
+
583
595
 
584
596
  // reference implementation for deterministic creation of model files
585
597
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
586
- assert(k % QK == 0);
587
- const int nb = k / QK;
598
+ assert(k % QK4_0 == 0);
599
+ const int nb = k / QK4_0;
588
600
 
589
- uint8_t pp[QK/2];
601
+ uint8_t pp[QK4_0/2];
590
602
 
591
603
  for (int i = 0; i < nb; i++) {
592
604
  float amax = 0.0f; // absolute max
593
605
 
594
- for (int l = 0; l < QK; l++) {
595
- const float v = x[i*QK + l];
606
+ for (int l = 0; l < QK4_0; l++) {
607
+ const float v = x[i*QK4_0 + l];
596
608
  amax = MAX(amax, fabsf(v));
597
609
  }
598
610
 
@@ -601,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
601
613
 
602
614
  y[i].d = d;
603
615
 
604
- for (int l = 0; l < QK; l += 2) {
605
- const float v0 = x[i*QK + l + 0]*id;
606
- const float v1 = x[i*QK + l + 1]*id;
616
+ for (int l = 0; l < QK4_0; l += 2) {
617
+ const float v0 = x[i*QK4_0 + l + 0]*id;
618
+ const float v1 = x[i*QK4_0 + l + 1]*id;
607
619
 
608
620
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
609
621
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -619,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
619
631
  }
620
632
 
621
633
  static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
622
- assert(k % QK == 0);
623
- const int nb = k / QK;
634
+ assert(k % QK4_0 == 0);
635
+ const int nb = k / QK4_0;
624
636
 
625
637
  block_q4_0 * restrict y = vy;
626
638
 
@@ -870,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
870
882
  }
871
883
 
872
884
  static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
873
- assert(k % QK == 0);
874
- const int nb = k / QK;
885
+ assert(k % QK4_1 == 0);
886
+ const int nb = k / QK4_1;
875
887
 
876
888
  block_q4_1 * restrict y = vy;
877
889
 
878
- uint8_t pp[QK/2];
890
+ uint8_t pp[QK4_1/2];
879
891
 
880
892
  for (int i = 0; i < nb; i++) {
881
893
  float min = FLT_MAX;
882
894
  float max = -FLT_MAX;
883
895
 
884
- for (int l = 0; l < QK; l++) {
885
- const float v = x[i*QK + l];
896
+ for (int l = 0; l < QK4_1; l++) {
897
+ const float v = x[i*QK4_1 + l];
886
898
  if (v < min) min = v;
887
899
  if (v > max) max = v;
888
900
  }
@@ -893,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
893
905
  y[i].d = d;
894
906
  y[i].m = min;
895
907
 
896
- for (int l = 0; l < QK; l += 2) {
897
- const float v0 = (x[i*QK + l + 0] - min)*id;
898
- const float v1 = (x[i*QK + l + 1] - min)*id;
908
+ for (int l = 0; l < QK4_1; l += 2) {
909
+ const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
910
+ const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
899
911
 
900
912
  const uint8_t vi0 = roundf(v0);
901
913
  const uint8_t vi1 = roundf(v1);
@@ -911,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
911
923
  }
912
924
 
913
925
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
914
- assert(k % QK == 0);
926
+ assert(k % QK4_1 == 0);
915
927
 
916
- const int nb = k / QK;
928
+ const int nb = k / QK4_1;
917
929
 
918
930
  block_q4_1 * restrict y = vy;
919
931
 
@@ -997,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
997
1009
  float32x4_t minv[8];
998
1010
  float32x4_t maxv[8];
999
1011
 
1000
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
1012
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
1001
1013
 
1002
1014
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
1003
1015
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -1033,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1033
1045
  #endif
1034
1046
  }
1035
1047
 
1048
+ // reference implementation for deterministic creation of model files
1049
+ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
+ assert(k % QK8_0 == 0);
1051
+ const int nb = k / QK8_0;
1052
+
1053
+ for (int i = 0; i < nb; i++) {
1054
+ float amax = 0.0f; // absolute max
1055
+
1056
+ for (int l = 0; l < QK8_0; l++) {
1057
+ const float v = x[i*QK8_0 + l];
1058
+ amax = MAX(amax, fabsf(v));
1059
+ }
1060
+
1061
+ const float d = amax / ((1 << 7) - 1);
1062
+ const float id = d ? 1.0f/d : 0.0f;
1063
+
1064
+ y[i].d = d;
1065
+
1066
+ for (int l = 0; l < QK8_0; ++l) {
1067
+ const float v = x[i*QK8_0 + l]*id;
1068
+ y[i].qs[l] = roundf(v);
1069
+ }
1070
+ }
1071
+ }
1072
+
1073
+ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1074
+ assert(k % QK8_0 == 0);
1075
+ const int nb = k / QK8_0;
1076
+
1077
+ block_q8_0 * restrict y = vy;
1078
+
1079
+ #if defined(__ARM_NEON)
1080
+ for (int i = 0; i < nb; i++) {
1081
+ float32x4_t srcv [8];
1082
+ float32x4_t asrcv[8];
1083
+ float32x4_t amaxv[8];
1084
+
1085
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1086
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1087
+
1088
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1089
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1090
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1091
+
1092
+ const float amax = vmaxvq_f32(amaxv[0]);
1093
+
1094
+ const float d = amax / ((1 << 7) - 1);
1095
+ const float id = d ? 1.0f/d : 0.0f;
1096
+
1097
+ y[i].d = d;
1098
+
1099
+ for (int l = 0; l < 8; l++) {
1100
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1101
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1102
+
1103
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1104
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1105
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1106
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1107
+ }
1108
+ }
1109
+ #elif defined(__AVX2__) || defined(__AVX__)
1110
+ for (int i = 0; i < nb; i++) {
1111
+ // Load elements into 4 AVX vectors
1112
+ __m256 v0 = _mm256_loadu_ps( x );
1113
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1114
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1115
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1116
+ x += 32;
1117
+
1118
+ // Compute max(abs(e)) for the block
1119
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1120
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1121
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1122
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1123
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1124
+
1125
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1126
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1127
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1128
+ const float maxScalar = _mm_cvtss_f32( max4 );
1129
+
1130
+ // Quantize these floats
1131
+ const float d = maxScalar / 127.f;
1132
+ y[i].d = d;
1133
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1134
+ const __m256 mul = _mm256_set1_ps( id );
1135
+
1136
+ // Apply the multiplier
1137
+ v0 = _mm256_mul_ps( v0, mul );
1138
+ v1 = _mm256_mul_ps( v1, mul );
1139
+ v2 = _mm256_mul_ps( v2, mul );
1140
+ v3 = _mm256_mul_ps( v3, mul );
1141
+
1142
+ // Round to nearest integer
1143
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1144
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1145
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1146
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1147
+
1148
+ // Convert floats to integers
1149
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1150
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1151
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1152
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1153
+
1154
+ #if defined(__AVX2__)
1155
+ // Convert int32 to int16
1156
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1157
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1158
+ // Convert int16 to int8
1159
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1160
+
1161
+ // We got our precious signed bytes, but the order is now wrong
1162
+ // These AVX2 pack instructions process 16-byte pieces independently
1163
+ // The following instruction is fixing the order
1164
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1165
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1166
+
1167
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1168
+ #else
1169
+ // Since we don't have in AVX some necessary functions,
1170
+ // we split the registers in half and call AVX2 analogs from SSE
1171
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1172
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1173
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1174
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1175
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1176
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1177
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1178
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1179
+
1180
+ // Convert int32 to int16
1181
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1182
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1183
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1184
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1185
+ // Convert int16 to int8
1186
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1187
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1188
+
1189
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1190
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1191
+ #endif
1192
+ }
1193
+ #else
1194
+ // scalar
1195
+ quantize_row_q8_0_reference(x, y, k);
1196
+ #endif
1197
+ }
1198
+
1036
1199
  static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1037
- assert(k % QK == 0);
1038
- const int nb = k / QK;
1200
+ assert(k % QK4_0 == 0);
1201
+ const int nb = k / QK4_0;
1039
1202
 
1040
1203
  const block_q4_0 * restrict x = vx;
1041
1204
 
@@ -1046,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1046
1209
 
1047
1210
  const uint8_t * restrict pp = x[i].qs;
1048
1211
 
1049
- for (int l = 0; l < QK; l += 32) {
1212
+ for (int l = 0; l < QK4_0; l += 32) {
1050
1213
  // Load 32x4-bit integers into 32x8-bit integers
1051
1214
  __m256i vx8 = bytesFromNibbles(pp+l/2);
1052
1215
 
@@ -1068,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1068
1231
  // Scale and store
1069
1232
  for (int j = 0; j < 4; j++) {
1070
1233
  const __m256 result = _mm256_mul_ps(vf[j], d_v);
1071
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1234
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1072
1235
  }
1073
1236
  }
1074
1237
  }
@@ -1078,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1078
1241
 
1079
1242
  const uint8_t * restrict pp = x[i].qs;
1080
1243
 
1081
- for (int l = 0; l < QK; l += 16) {
1244
+ for (int l = 0; l < QK4_0; l += 16) {
1082
1245
  // Load 16x4-bit integers into 8x8-bit integers
1083
1246
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1084
1247
 
@@ -1117,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1117
1280
  const float32x4_t r3 = vmulq_f32(vf_3, vd);
1118
1281
 
1119
1282
  // Store
1120
- vst1q_f32(y + i*QK + l + 0, r0);
1121
- vst1q_f32(y + i*QK + l + 4, r1);
1122
- vst1q_f32(y + i*QK + l + 8, r2);
1123
- vst1q_f32(y + i*QK + l + 12, r3);
1283
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1284
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1285
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1286
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1124
1287
  }
1125
1288
  }
1126
1289
  #else
@@ -1130,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1130
1293
 
1131
1294
  const uint8_t * restrict pp = x[i].qs;
1132
1295
 
1133
- for (int l = 0; l < QK; l += 2) {
1296
+ for (int l = 0; l < QK4_0; l += 2) {
1134
1297
  const uint8_t vi = pp[l/2];
1135
1298
 
1136
1299
  const int8_t vi0 = vi & 0xf;
@@ -1141,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1141
1304
 
1142
1305
  //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
1143
1306
 
1144
- y[i*QK + l + 0] = v0;
1145
- y[i*QK + l + 1] = v1;
1307
+ y[i*QK4_0 + l + 0] = v0;
1308
+ y[i*QK4_0 + l + 1] = v1;
1146
1309
 
1147
- assert(!isnan(y[i*QK + l + 0]));
1148
- assert(!isnan(y[i*QK + l + 1]));
1310
+ assert(!isnan(y[i*QK4_0 + l + 0]));
1311
+ assert(!isnan(y[i*QK4_0 + l + 1]));
1149
1312
  }
1150
1313
  }
1151
1314
  #endif
1152
1315
  }
1153
1316
 
1154
1317
  static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
1155
- assert(k % QK == 0);
1156
- const int nb = k / QK;
1318
+ assert(k % QK4_1 == 0);
1319
+ const int nb = k / QK4_1;
1157
1320
 
1158
1321
  const block_q4_1 * restrict x = vx;
1159
1322
 
@@ -1164,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1164
1327
 
1165
1328
  const uint8_t * restrict pp = x[i].qs;
1166
1329
 
1167
- for (int l = 0; l < QK; l += 32) {
1330
+ for (int l = 0; l < QK4_1; l += 32) {
1168
1331
  // Load 32x4-bit integers into 32x8-bit integers
1169
1332
  __m256i vx8 = bytesFromNibbles(pp+l/2);
1170
1333
 
@@ -1183,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1183
1346
  // Scale, add m and store
1184
1347
  for (int j = 0; j < 4; j++) {
1185
1348
  const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
1186
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1349
+ _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
1187
1350
  }
1188
1351
  }
1189
1352
  }
@@ -1194,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1194
1357
 
1195
1358
  const uint8_t * restrict pp = x[i].qs;
1196
1359
 
1197
- for (int l = 0; l < QK; l += 16) {
1360
+ for (int l = 0; l < QK4_1; l += 16) {
1198
1361
  // Load 16x4-bit integers into 8x8-bit integers
1199
1362
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1200
1363
 
@@ -1225,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1225
1388
  const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1226
1389
 
1227
1390
  // Store
1228
- vst1q_f32(y + i*QK + l + 0, r0);
1229
- vst1q_f32(y + i*QK + l + 4, r1);
1230
- vst1q_f32(y + i*QK + l + 8, r2);
1231
- vst1q_f32(y + i*QK + l + 12, r3);
1391
+ vst1q_f32(y + i*QK4_1 + l + 0, r0);
1392
+ vst1q_f32(y + i*QK4_1 + l + 4, r1);
1393
+ vst1q_f32(y + i*QK4_1 + l + 8, r2);
1394
+ vst1q_f32(y + i*QK4_1 + l + 12, r3);
1232
1395
  }
1233
1396
  }
1234
1397
  #else
@@ -1238,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1238
1401
 
1239
1402
  const uint8_t * restrict pp = x[i].qs;
1240
1403
 
1241
- for (int l = 0; l < QK; l += 2) {
1404
+ for (int l = 0; l < QK4_1; l += 2) {
1242
1405
  const uint8_t vi = pp[l/2];
1243
1406
 
1244
1407
  const int8_t vi0 = vi & 0xf;
@@ -1247,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1247
1410
  const float v0 = vi0*d + m;
1248
1411
  const float v1 = vi1*d + m;
1249
1412
 
1250
- y[i*QK + l + 0] = v0;
1251
- y[i*QK + l + 1] = v1;
1413
+ y[i*QK4_1 + l + 0] = v0;
1414
+ y[i*QK4_1 + l + 1] = v1;
1252
1415
 
1253
- assert(!isnan(y[i*QK + l + 0]));
1254
- assert(!isnan(y[i*QK + l + 1]));
1416
+ assert(!isnan(y[i*QK4_1 + l + 0]));
1417
+ assert(!isnan(y[i*QK4_1 + l + 1]));
1255
1418
  }
1256
1419
  }
1257
1420
  #endif
1258
1421
  }
1259
1422
 
1423
+ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1424
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
+
1426
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
+ [GGML_TYPE_Q4_0] = {
1428
+ .dequantize_row_q = dequantize_row_q4_0,
1429
+ .quantize_row_q = quantize_row_q4_0,
1430
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1431
+ .quantize_row_q_dot = quantize_row_q8_0,
1432
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1433
+ },
1434
+ [GGML_TYPE_Q4_1] = {
1435
+ .dequantize_row_q = dequantize_row_q4_1,
1436
+ .quantize_row_q = quantize_row_q4_1,
1437
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
+ .quantize_row_q_dot = quantize_row_q4_1,
1439
+ .vec_dot_q = ggml_vec_dot_q4_1,
1440
+ },
1441
+ // TODO: GGML_TYPE_Q8_0
1442
+ };
1443
+
1444
+ // For internal test use
1445
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1446
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
1447
+ return quantize_fns[i];
1448
+ }
1449
+
1450
+
1260
1451
  //
1261
1452
  // simd mappings
1262
1453
  //
@@ -1813,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1813
2004
  *s = sumf;
1814
2005
  }
1815
2006
 
1816
- #if __AVX512F__ && QK == 32
1817
- static inline __m512 dot_q4_0_oneblock_avx512(
2007
+ #if __AVX512F__ && QK4_0 == 32
2008
+ static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
+ // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
+ // | :. =_ () [] <> () Zz Yy|
2014
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
+ // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
+ //
2020
+ // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
+ // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
+ // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
+ // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
+ #ifdef __AVX512VBMI__
2025
+ const __m512i byte_perm = _mm512_set_epi8(
2026
+ 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
+ 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
+ 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
+ 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
+ );
2031
+ const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
+ // After applying VPERMB, `permuted` looks like this:
2033
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
+ // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
+ // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
+ #else
2043
+ const __m512i word_perm = _mm512_set_epi16(
2044
+ 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
+ 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
+ );
2047
+ const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
+ // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
+ // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
+ // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
+ #endif
2052
+
2053
+ // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
+ const __mmask32 shift_mask = 0xaaaaaaaa;
2055
+ const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
+ // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
+ // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
+ // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
+
2067
+ // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
+ const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
+ return _mm512_and_si512( low_nibble_mask, shifted );
2070
+ // The final result looks like this:
2071
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
+ // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
+ // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
+ }
2081
+
2082
+ static inline __m512 dot_q4_0_twoblocks_avx512(
1818
2083
  __m512 acc,
1819
2084
  const block_q4_0 * restrict x,
1820
2085
  const block_q4_0 * restrict y,
1821
2086
  int i
1822
2087
  ) {
1823
- // Compute combined scale for the block
1824
- __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
1825
-
1826
- __m256i bx = bytesFromNibbles( x[i].qs );
1827
- __m256i by = bytesFromNibbles( y[i].qs );
1828
-
1829
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1830
- const __m256i off = _mm256_set1_epi8( 8 );
1831
- bx = _mm256_sub_epi8( bx, off );
1832
- by = _mm256_sub_epi8( by, off );
1833
-
1834
- // Sign-extend 16 signed bytes into int16_t
1835
- __m512i x32 = _mm512_cvtepi8_epi16( bx );
1836
- __m512i y32 = _mm512_cvtepi8_epi16( by );
1837
- // Compute products of int16_t integers, add pairwise
1838
- __m512i i64 = _mm512_madd_epi16( x32, y32 );
2088
+ // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
+ // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
+ // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
+ // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
+ const __mmask8 load_mask = 0x1f;
2093
+ const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
+ const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
+
2096
+ // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
+ // blocks_0_float
2101
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
+ // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
+ // blocks_1_float
2105
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
+ // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
+ const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
+ const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
+ // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
+ // random data, which might very well underflow. At least on Intel, this leads
2112
+ // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
+ // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
+ // (and ggml can't assume that you do)...
2115
+ const __mmask16 scale_mul_mask = 0x21;
2116
+ #ifdef __clang__
2117
+ // ...however, clang decides to optimize the multiplication mask away:
2118
+ // https://godbolt.org/z/P8PqdsfvW
2119
+ // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
+ __m512i scales;
2121
+ __asm__(
2122
+ "vmulps %1, %2, %0%{%3%}"
2123
+ : "=v" ( scales )
2124
+ : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
+ );
2126
+ #else
2127
+ const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
+ #endif
2129
+ const __m512i scale_perm = _mm512_set_epi32(
2130
+ 5, 5, 5, 5, 5, 5, 5, 5,
2131
+ 0, 0, 0, 0, 0, 0, 0, 0
2132
+ );
2133
+ const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
+ // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
+ // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
+
2141
+ const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
+ const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
+
2144
+ // Now we want to compute dot products of 4-element byte vectors and store them in
2145
+ // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
+ // +----+----+----+----+
2147
+ // ... | 03 | 02 | 01 | 00 |
2148
+ // +----+----+----+----+
2149
+ // bytes_0
2150
+ // +----+----+----+----+
2151
+ // ... | D | C | B | A |
2152
+ // +----+----+----+----+
2153
+ // bytes_1
2154
+ // +----+----+----+----+
2155
+ // ... | H | G | F | E |
2156
+ // +----+----+----+----+
2157
+ // final_res_int
2158
+ // +----+----+----+----+
2159
+ // ... | A*E+B*F+C*G+D*H |
2160
+ // +----+----+----+----+
2161
+ const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
+ const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
+
2164
+ #ifdef __AVX512VNNI__
2165
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
+ // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
+ // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
+ // which means we only need 2 instructions.
2170
+ const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
+ const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
+ const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
+ const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
+ #else
2175
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
+ const __m512i one = _mm512_set1_epi16( 1 );
2180
+ const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
+ const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
+ const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
+ const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
+ #endif
1839
2185
 
1840
- // Convert int32_t to float
1841
- __m512 p = _mm512_cvtepi32_ps( i64 );
1842
- // Apply the scale, and accumulate
1843
- return _mm512_fmadd_ps( d, p, acc );
2186
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
+ const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
+ return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
1844
2189
  }
1845
2190
  #endif
1846
2191
 
@@ -1881,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
1881
2226
  }
1882
2227
 
1883
2228
  static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1884
- const int nb = n / QK;
2229
+ const int nb = n / QK4_0;
1885
2230
 
1886
- assert(n % QK == 0);
2231
+ assert(n % QK4_0 == 0);
1887
2232
  assert(nb % 2 == 0);
1888
2233
 
1889
2234
  const block_q4_0 * restrict x = vx;
@@ -1972,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2317
  __m512 acc0 = _mm512_setzero_ps();
1973
2318
  __m512 acc1 = _mm512_setzero_ps();
1974
2319
 
1975
- const int superblock_size = 8;
2320
+ const int superblock_size = 16;
2321
+
1976
2322
  const int superblock_count = nb / superblock_size;
1977
2323
 
1978
2324
  for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1979
2325
  int i = superblock_ix * superblock_size;
1980
2326
 
1981
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
1982
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
1983
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
1984
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
1985
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
1986
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
1987
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
1988
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
2327
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
1989
2335
  }
1990
2336
 
1991
2337
  // Remainders
1992
- for (int i = superblock_count * superblock_size; i < nb; ++i) {
1993
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
2338
+ for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
1994
2340
  }
1995
2341
 
1996
2342
  // Horizontal sum of all lanes of the accumulator
@@ -2206,15 +2552,15 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2206
2552
  const uint8_t * restrict p1 = y[i].qs;
2207
2553
 
2208
2554
  int sumi = 0;
2209
- for (int j = 0; j < QK/2; j++) {
2555
+ for (int j = 0; j < QK4_0/2; j++) {
2210
2556
  const uint8_t v0 = p0[j];
2211
2557
  const uint8_t v1 = p1[j];
2212
2558
 
2213
- const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
- const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2559
+ const int i0 = (v0 & 0xf) - 8;
2560
+ const int i1 = (v0 >> 4) - 8;
2215
2561
 
2216
- const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
- const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2562
+ const int i2 = (v1 & 0xf) - 8;
2563
+ const int i3 = (v1 >> 4) - 8;
2218
2564
 
2219
2565
  sumi += i0*i2 + i1*i3;
2220
2566
  }
@@ -2226,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2226
2572
  }
2227
2573
 
2228
2574
  static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK;
2575
+ const int nb = n / QK4_1;
2230
2576
 
2231
2577
  const block_q4_1 * restrict x = vx;
2232
2578
  const block_q4_1 * restrict y = vy;
@@ -2303,7 +2649,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2303
2649
  res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2304
2650
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2305
2651
 
2306
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2652
+ sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2307
2653
  #elif defined(__ARM_NEON)
2308
2654
  float sum00 = 0.0f;
2309
2655
  float sum01 = 0.0f;
@@ -2335,12 +2681,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2335
2681
  const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336
2682
 
2337
2683
  sum00 += x0->m*y0->m;
2338
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2684
+ sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
+ sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2340
2686
 
2341
2687
  sum00 += x1->m*y1->m;
2342
- sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
- sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2688
+ sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
+ sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2344
2690
 
2345
2691
  #if defined(__ARM_FEATURE_DOTPROD)
2346
2692
  // dot product into int32x4_t
@@ -2377,7 +2723,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2377
2723
  #endif
2378
2724
  }
2379
2725
 
2380
- sumf = QK*sum00 + sum01 + sum10 + sum11;
2726
+ sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2381
2727
  #else
2382
2728
  // scalar
2383
2729
  for (int i = 0; i < nb; i++) {
@@ -2390,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2390
2736
  const uint8_t * restrict p0 = x[i].qs;
2391
2737
  const uint8_t * restrict p1 = y[i].qs;
2392
2738
 
2393
- for (int j = 0; j < QK/2; j++) {
2739
+ for (int j = 0; j < QK4_1/2; j++) {
2394
2740
  const uint8_t v0 = p0[j];
2395
2741
  const uint8_t v1 = p1[j];
2396
2742
 
@@ -2408,78 +2754,281 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2408
2754
  *s = sumf;
2409
2755
  }
2410
2756
 
2411
- // compute GGML_VEC_DOT_UNROLL dot products at once
2412
- // xs - x row stride in bytes
2413
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2414
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2415
-
2416
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2757
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
+ const int nb = n / QK8_0;
2417
2759
 
2418
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2419
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2420
- }
2760
+ assert(n % QK8_0 == 0);
2761
+ assert(nb % 2 == 0);
2421
2762
 
2422
- #if defined(GGML_SIMD)
2423
- const int np = (n & ~(GGML_F16_STEP - 1));
2763
+ const block_q4_0 * restrict x = vx;
2764
+ const block_q8_0 * restrict y = vy;
2424
2765
 
2425
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2766
+ float sumf = 0.0;
2426
2767
 
2427
- GGML_F16_VEC ax[GGML_F16_ARR];
2428
- GGML_F16_VEC ay[GGML_F16_ARR];
2768
+ #if defined(__ARM_NEON)
2769
+ float sum0 = 0.0f;
2770
+ float sum1 = 0.0f;
2429
2771
 
2430
- for (int i = 0; i < np; i += GGML_F16_STEP) {
2431
- for (int j = 0; j < GGML_F16_ARR; j++) {
2432
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2772
+ for (int i = 0; i < nb; i += 2) {
2773
+ const block_q4_0 * restrict x0 = &x[i + 0];
2774
+ const block_q4_0 * restrict x1 = &x[i + 1];
2775
+ const block_q8_0 * restrict y0 = &y[i + 0];
2776
+ const block_q8_0 * restrict y1 = &y[i + 1];
2433
2777
 
2434
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2435
- ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
2778
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2436
2780
 
2437
- sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
2438
- }
2439
- }
2440
- }
2781
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2441
2783
 
2442
- // reduce sum0..sum3 to sum0
2443
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2444
- GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
2445
- }
2784
+ // 4-bit -> 8-bit
2785
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2786
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2787
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2788
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2446
2789
 
2447
- // leftovers
2448
- for (int i = np; i < n; ++i) {
2449
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
2450
- sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
2451
- }
2452
- }
2453
- #else
2454
- for (int i = 0; i < n; ++i) {
2455
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
2456
- sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
2457
- }
2458
- }
2459
- #endif
2790
+ // sub 8
2791
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2792
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2793
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2460
2795
 
2461
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2462
- s[i] = sumf[i];
2463
- }
2464
- }
2796
+ // load y
2797
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2465
2801
 
2466
- inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
2467
- #if defined(GGML_SIMD)
2468
- const int np = (n & ~(GGML_F32_STEP - 1));
2802
+ // interleave
2803
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2469
2807
 
2470
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
2808
+ #if defined(__ARM_FEATURE_DOTPROD)
2809
+ // dot product into int32x4_t
2810
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2471
2812
 
2472
- GGML_F32_VEC ax[GGML_F32_ARR];
2473
- GGML_F32_VEC ay[GGML_F32_ARR];
2813
+ p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
+ p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2474
2815
 
2475
- for (int i = 0; i < np; i += GGML_F32_STEP) {
2476
- for (int j = 0; j < GGML_F32_ARR; j++) {
2477
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
2478
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
2479
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
2816
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2818
+ #else
2819
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2480
2823
 
2481
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
2482
- }
2824
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
+
2829
+ const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
+ const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
+
2832
+ const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
+ const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
+
2835
+ const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
+ const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
+
2838
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2840
+ #endif
2841
+ }
2842
+
2843
+ sumf = sum0 + sum1;
2844
+ #elif defined(__AVX2__)
2845
+ // Initialize accumulator with zeros
2846
+ __m256 acc = _mm256_setzero_ps();
2847
+
2848
+ // Main loop
2849
+ for (int i = 0; i < nb; ++i) {
2850
+ /* Compute combined scale for the block */
2851
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2852
+
2853
+ __m256i bx = bytesFromNibbles(x[i].qs);
2854
+
2855
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
+ const __m256i off = _mm256_set1_epi8( 8 );
2857
+ bx = _mm256_sub_epi8( bx, off );
2858
+
2859
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
+
2861
+ // Get absolute values of x vectors
2862
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
+
2864
+ // Sign the values of the y vectors
2865
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2866
+
2867
+ // Perform multiplication and create 16-bit values
2868
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
+
2870
+ const __m256i ones = _mm256_set1_epi16(1);
2871
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
+
2873
+ /* Convert to vectore of 8 int32_t to 8 floats */
2874
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2875
+
2876
+ /* Multiply q with scale and accumulate */
2877
+ acc = _mm256_fmadd_ps( d, q, acc );
2878
+ }
2879
+
2880
+ // Return horizontal sum of the acc vector
2881
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2885
+
2886
+ sumf = _mm_cvtss_f32( res );
2887
+ #elif defined(__AVX__)
2888
+ // Initialize accumulator with zeros
2889
+ __m256 acc = _mm256_setzero_ps();
2890
+
2891
+ // Main loop
2892
+ for (int i = 0; i < nb; ++i) {
2893
+ // Compute combined scale for the block
2894
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2895
+
2896
+ __m128i i32[2];
2897
+ for (int j = 0; j < 2; ++j) {
2898
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2901
+
2902
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
+ const __m128i off = _mm_set1_epi8( 8 );
2904
+ bx = _mm_sub_epi8( bx, off );
2905
+
2906
+ // Get absolute values of x vectors
2907
+ const __m128i ax = _mm_sign_epi8(bx, bx);
2908
+
2909
+ // Sign the values of the y vectors
2910
+ const __m128i sy = _mm_sign_epi8(by, bx);
2911
+
2912
+ // Perform multiplication and create 16-bit values
2913
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
2914
+
2915
+ const __m128i ones = _mm_set1_epi16(1);
2916
+ i32[j] = _mm_madd_epi16(ones, dot);
2917
+ }
2918
+
2919
+ // Convert int32_t to float
2920
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
+ // Apply the scale, and accumulate
2922
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2923
+ }
2924
+
2925
+ // Return horizontal sum of the acc vector
2926
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
+
2931
+ sumf = _mm_cvtss_f32( res );
2932
+ #else
2933
+ // scalar
2934
+ for (int i = 0; i < nb; i++) {
2935
+ const float d0 = x[i].d;
2936
+ const float d1 = y[i].d;
2937
+
2938
+ const uint8_t * restrict p0 = x[i].qs;
2939
+ const int8_t * restrict p1 = y[i].qs;
2940
+
2941
+ int sumi = 0;
2942
+ for (int j = 0; j < QK8_0/2; j++) {
2943
+ const uint8_t v0 = p0[j];
2944
+
2945
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2947
+
2948
+ const int i2 = p1[2*j + 0];
2949
+ const int i3 = p1[2*j + 1];
2950
+
2951
+ sumi += i0*i2 + i1*i3;
2952
+ }
2953
+ sumf += d0*d1*sumi;
2954
+ }
2955
+ #endif
2956
+
2957
+ *s = sumf;
2958
+ }
2959
+
2960
+ // compute GGML_VEC_DOT_UNROLL dot products at once
2961
+ // xs - x row stride in bytes
2962
+ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2963
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2964
+
2965
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2966
+
2967
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2968
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2969
+ }
2970
+
2971
+ #if defined(GGML_SIMD)
2972
+ const int np = (n & ~(GGML_F16_STEP - 1));
2973
+
2974
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2975
+
2976
+ GGML_F16_VEC ax[GGML_F16_ARR];
2977
+ GGML_F16_VEC ay[GGML_F16_ARR];
2978
+
2979
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
2980
+ for (int j = 0; j < GGML_F16_ARR; j++) {
2981
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2982
+
2983
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2984
+ ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
2985
+
2986
+ sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
2987
+ }
2988
+ }
2989
+ }
2990
+
2991
+ // reduce sum0..sum3 to sum0
2992
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2993
+ GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
2994
+ }
2995
+
2996
+ // leftovers
2997
+ for (int i = np; i < n; ++i) {
2998
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
2999
+ sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
3000
+ }
3001
+ }
3002
+ #else
3003
+ for (int i = 0; i < n; ++i) {
3004
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
3005
+ sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
3006
+ }
3007
+ }
3008
+ #endif
3009
+
3010
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
3011
+ s[i] = sumf[i];
3012
+ }
3013
+ }
3014
+
3015
+ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
3016
+ #if defined(GGML_SIMD)
3017
+ const int np = (n & ~(GGML_F32_STEP - 1));
3018
+
3019
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
3020
+
3021
+ GGML_F32_VEC ax[GGML_F32_ARR];
3022
+ GGML_F32_VEC ay[GGML_F32_ARR];
3023
+
3024
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
3025
+ for (int j = 0; j < GGML_F32_ARR; j++) {
3026
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
3027
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3028
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
3029
+
3030
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
3031
+ }
2483
3032
  }
2484
3033
 
2485
3034
  // leftovers
@@ -2652,24 +3201,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2652
3201
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2653
3202
  [GGML_TYPE_F32] = 1,
2654
3203
  [GGML_TYPE_F16] = 1,
2655
- [GGML_TYPE_Q4_0] = QK,
2656
- [GGML_TYPE_Q4_1] = QK,
3204
+ [GGML_TYPE_Q4_0] = QK4_0,
3205
+ [GGML_TYPE_Q4_1] = QK4_1,
3206
+ [GGML_TYPE_Q8_0] = QK8_0,
2657
3207
  [GGML_TYPE_I8] = 1,
2658
3208
  [GGML_TYPE_I16] = 1,
2659
3209
  [GGML_TYPE_I32] = 1,
2660
3210
  };
2661
- static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
3211
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
2662
3212
 
2663
3213
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2664
3214
  [GGML_TYPE_F32] = sizeof(float),
2665
3215
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
3216
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
3217
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3218
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
2668
3219
  [GGML_TYPE_I8] = sizeof(int8_t),
2669
3220
  [GGML_TYPE_I16] = sizeof(int16_t),
2670
3221
  [GGML_TYPE_I32] = sizeof(int32_t),
2671
3222
  };
2672
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
3223
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
2673
3224
 
2674
3225
 
2675
3226
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2677,11 +3228,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2677
3228
  [GGML_TYPE_F16] = "f16",
2678
3229
  [GGML_TYPE_Q4_0] = "q4_0",
2679
3230
  [GGML_TYPE_Q4_1] = "q4_1",
3231
+ [GGML_TYPE_Q8_0] = "q8_0",
2680
3232
  [GGML_TYPE_I8] = "i8",
2681
3233
  [GGML_TYPE_I16] = "i16",
2682
3234
  [GGML_TYPE_I32] = "i32",
2683
3235
  };
2684
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
3236
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
2685
3237
 
2686
3238
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2687
3239
  "NONE",
@@ -3354,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3354
3906
  char * const data = tensor->data;
3355
3907
 
3356
3908
  switch (tensor->type) {
3357
- case GGML_TYPE_Q4_0:
3358
- {
3359
- GGML_ASSERT(false);
3360
- } break;
3361
- case GGML_TYPE_Q4_1:
3362
- {
3363
- GGML_ASSERT(false);
3364
- } break;
3365
3909
  case GGML_TYPE_I8:
3366
3910
  {
3367
3911
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3397,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3397
3941
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3398
3942
  }
3399
3943
  } break;
3400
- case GGML_TYPE_COUNT:
3944
+ default:
3401
3945
  {
3402
3946
  GGML_ASSERT(false);
3403
3947
  } break;
@@ -3414,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3414
3958
  char * const data = tensor->data;
3415
3959
 
3416
3960
  switch (tensor->type) {
3417
- case GGML_TYPE_Q4_0:
3418
- {
3419
- GGML_ASSERT(false);
3420
- } break;
3421
- case GGML_TYPE_Q4_1:
3422
- {
3423
- GGML_ASSERT(false);
3424
- } break;
3425
3961
  case GGML_TYPE_I8:
3426
3962
  {
3427
3963
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3457,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3457
3993
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3458
3994
  }
3459
3995
  } break;
3460
- case GGML_TYPE_COUNT:
3996
+ default:
3461
3997
  {
3462
3998
  GGML_ASSERT(false);
3463
3999
  } break;
@@ -3468,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3468
4004
 
3469
4005
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3470
4006
  switch (tensor->type) {
3471
- case GGML_TYPE_Q4_0:
3472
- {
3473
- GGML_ASSERT(false);
3474
- } break;
3475
- case GGML_TYPE_Q4_1:
3476
- {
3477
- GGML_ASSERT(false);
3478
- } break;
3479
4007
  case GGML_TYPE_I8:
3480
4008
  {
3481
4009
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3501,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3501
4029
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3502
4030
  return ((float *)(tensor->data))[i];
3503
4031
  } break;
3504
- case GGML_TYPE_COUNT:
4032
+ default:
3505
4033
  {
3506
4034
  GGML_ASSERT(false);
3507
4035
  } break;
@@ -3512,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3512
4040
 
3513
4041
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3514
4042
  switch (tensor->type) {
3515
- case GGML_TYPE_Q4_0:
3516
- {
3517
- GGML_ASSERT(false);
3518
- } break;
3519
- case GGML_TYPE_Q4_1:
3520
- {
3521
- GGML_ASSERT(false);
3522
- } break;
3523
4043
  case GGML_TYPE_I8:
3524
4044
  {
3525
4045
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3545,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3545
4065
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3546
4066
  ((float *)(tensor->data))[i] = value;
3547
4067
  } break;
3548
- case GGML_TYPE_COUNT:
4068
+ default:
3549
4069
  {
3550
4070
  GGML_ASSERT(false);
3551
4071
  } break;
@@ -3554,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3554
4074
 
3555
4075
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3556
4076
  switch (tensor->type) {
3557
- case GGML_TYPE_Q4_0:
3558
- {
3559
- GGML_ASSERT(false);
3560
- } break;
3561
- case GGML_TYPE_Q4_1:
3562
- {
3563
- GGML_ASSERT(false);
3564
- } break;
3565
4077
  case GGML_TYPE_I8:
3566
4078
  {
3567
4079
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3587,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3587
4099
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3588
4100
  return ((float *)(tensor->data))[i];
3589
4101
  } break;
3590
- case GGML_TYPE_COUNT:
4102
+ default:
3591
4103
  {
3592
4104
  GGML_ASSERT(false);
3593
4105
  } break;
@@ -3598,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3598
4110
 
3599
4111
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3600
4112
  switch (tensor->type) {
3601
- case GGML_TYPE_Q4_0:
3602
- {
3603
- GGML_ASSERT(false);
3604
- } break;
3605
- case GGML_TYPE_Q4_1:
3606
- {
3607
- GGML_ASSERT(false);
3608
- } break;
3609
4113
  case GGML_TYPE_I8:
3610
4114
  {
3611
4115
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3631,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3631
4135
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3632
4136
  ((float *)(tensor->data))[i] = value;
3633
4137
  } break;
3634
- case GGML_TYPE_COUNT:
4138
+ default:
3635
4139
  {
3636
4140
  GGML_ASSERT(false);
3637
4141
  } break;
@@ -5112,6 +5616,26 @@ static void ggml_compute_forward_dup_f16(
5112
5616
  }
5113
5617
  }
5114
5618
  }
5619
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5620
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5621
+ size_t id = 0;
5622
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
+ float * src0_f32 = (float *) params->wdata;
5625
+
5626
+ for (int i03 = 0; i03 < ne03; i03++) {
5627
+ for (int i02 = 0; i02 < ne02; i02++) {
5628
+ for (int i01 = 0; i01 < ne01; i01++) {
5629
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
+ // convert to f32 and quantize
5631
+ for (int i00 = 0; i00 < ne00; i00++) {
5632
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
+ }
5634
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
+ id += dst_row_size;
5636
+ }
5637
+ }
5638
+ }
5115
5639
  } else {
5116
5640
  GGML_ASSERT(false); // TODO: implement
5117
5641
  }
@@ -5304,6 +5828,21 @@ static void ggml_compute_forward_dup_f32(
5304
5828
  }
5305
5829
  }
5306
5830
  }
5831
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5832
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5833
+ size_t id = 0;
5834
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5836
+
5837
+ for (int i03 = 0; i03 < ne03; i03++) {
5838
+ for (int i02 = 0; i02 < ne02; i02++) {
5839
+ for (int i01 = 0; i01 < ne01; i01++) {
5840
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
+ id += dst_row_size;
5843
+ }
5844
+ }
5845
+ }
5307
5846
  } else {
5308
5847
  GGML_ASSERT(false); // TODO: implement
5309
5848
  }
@@ -5426,12 +5965,7 @@ static void ggml_compute_forward_dup(
5426
5965
  {
5427
5966
  ggml_compute_forward_dup_f32(params, src0, dst);
5428
5967
  } break;
5429
- case GGML_TYPE_Q4_0:
5430
- case GGML_TYPE_Q4_1:
5431
- case GGML_TYPE_I8:
5432
- case GGML_TYPE_I16:
5433
- case GGML_TYPE_I32:
5434
- case GGML_TYPE_COUNT:
5968
+ default:
5435
5969
  {
5436
5970
  GGML_ASSERT(false);
5437
5971
  } break;
@@ -5463,37 +5997,243 @@ static void ggml_compute_forward_add_f32(
5463
5997
  const size_t nb10 = src1->nb[0];
5464
5998
  const size_t nb11 = src1->nb[1];
5465
5999
 
5466
- const size_t nb0 = dst->nb[0];
5467
- const size_t nb1 = dst->nb[1];
6000
+ const size_t nb0 = dst->nb[0];
6001
+ const size_t nb1 = dst->nb[1];
6002
+
6003
+ GGML_ASSERT( nb0 == sizeof(float));
6004
+ GGML_ASSERT(nb00 == sizeof(float));
6005
+
6006
+ if (nb10 == sizeof(float)) {
6007
+ for (int j = ith; j < n; j += nth) {
6008
+ #ifdef GGML_USE_ACCELERATE
6009
+ vDSP_vadd(
6010
+ (float *) ((char *) src0->data + j*nb01), 1,
6011
+ (float *) ((char *) src1->data + j*nb11), 1,
6012
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
6013
+ #else
6014
+ ggml_vec_add_f32(nc,
6015
+ (float *) ((char *) dst->data + j*nb1),
6016
+ (float *) ((char *) src0->data + j*nb01),
6017
+ (float *) ((char *) src1->data + j*nb11));
6018
+ #endif
6019
+ }
6020
+ } else {
6021
+ // src1 is not contiguous
6022
+ for (int j = ith; j < n; j += nth) {
6023
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
6024
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
6025
+ for (int i = 0; i < nc; i++) {
6026
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6027
+
6028
+ dst_ptr[i] = src0_ptr[i] + *src1_ptr;
6029
+ }
6030
+ }
6031
+ }
6032
+ }
6033
+
6034
+ static void ggml_compute_forward_add_f16_f32(
6035
+ const struct ggml_compute_params * params,
6036
+ const struct ggml_tensor * src0,
6037
+ const struct ggml_tensor * src1,
6038
+ struct ggml_tensor * dst) {
6039
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6040
+
6041
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6042
+ return;
6043
+ }
6044
+
6045
+ const int ith = params->ith;
6046
+ const int nth = params->nth;
6047
+
6048
+ const int n = ggml_nrows(src0);
6049
+ const int nc = src0->ne[0];
6050
+
6051
+ const size_t nb00 = src0->nb[0];
6052
+ const size_t nb01 = src0->nb[1];
6053
+
6054
+ const size_t nb10 = src1->nb[0];
6055
+ const size_t nb11 = src1->nb[1];
6056
+
6057
+ const size_t nb0 = dst->nb[0];
6058
+ const size_t nb1 = dst->nb[1];
6059
+
6060
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6061
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6062
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6063
+
6064
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6065
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6066
+
6067
+ if (nb10 == sizeof(float)) {
6068
+ for (int j = ith; j < n; j += nth) {
6069
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6070
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6071
+ for (int i = 0; i < nc; i++) {
6072
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6073
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
6074
+ }
6075
+ }
6076
+ }
6077
+ else {
6078
+ // src1 is not contiguous
6079
+ GGML_ASSERT(false);
6080
+ }
6081
+ }
6082
+
6083
+ static void ggml_compute_forward_add_f16_f16(
6084
+ const struct ggml_compute_params * params,
6085
+ const struct ggml_tensor * src0,
6086
+ const struct ggml_tensor * src1,
6087
+ struct ggml_tensor * dst) {
6088
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6089
+
6090
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6091
+ return;
6092
+ }
6093
+
6094
+ const int ith = params->ith;
6095
+ const int nth = params->nth;
6096
+
6097
+ const int n = ggml_nrows(src0);
6098
+ const int nc = src0->ne[0];
6099
+
6100
+ const size_t nb00 = src0->nb[0];
6101
+ const size_t nb01 = src0->nb[1];
6102
+
6103
+ const size_t nb10 = src1->nb[0];
6104
+ const size_t nb11 = src1->nb[1];
6105
+
6106
+ const size_t nb0 = dst->nb[0];
6107
+ const size_t nb1 = dst->nb[1];
6108
+
6109
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6110
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
6111
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6112
+
6113
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6114
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6115
+
6116
+ if (nb10 == sizeof(ggml_fp16_t)) {
6117
+ for (int j = ith; j < n; j += nth) {
6118
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6119
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6120
+ for (int i = 0; i < nc; i++) {
6121
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
6122
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
6123
+ }
6124
+ }
6125
+ }
6126
+ else {
6127
+ // src1 is not contiguous
6128
+ GGML_ASSERT(false);
6129
+ }
6130
+ }
6131
+
6132
+ static void ggml_compute_forward_add_q_f32(
6133
+ const struct ggml_compute_params * params,
6134
+ const struct ggml_tensor * src0,
6135
+ const struct ggml_tensor * src1,
6136
+ struct ggml_tensor * dst) {
6137
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6138
+
6139
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6140
+ return;
6141
+ }
6142
+
6143
+ const int64_t ne00 = src0->ne[0];
6144
+ const int64_t ne01 = src0->ne[1];
6145
+ const int64_t ne02 = src0->ne[2];
6146
+ const int64_t ne03 = src0->ne[3];
6147
+
6148
+ //const int64_t ne10 = src1->ne[0];
6149
+ //const int64_t ne11 = src1->ne[1];
6150
+ const int64_t ne12 = src1->ne[2];
6151
+ const int64_t ne13 = src1->ne[3];
6152
+
6153
+ //const int64_t ne0 = dst->ne[0];
6154
+ //const int64_t ne1 = dst->ne[1];
6155
+ const int64_t ne2 = dst->ne[2];
6156
+ const int64_t ne3 = dst->ne[3];
6157
+
6158
+ const int nb00 = src0->nb[0];
6159
+ const int nb01 = src0->nb[1];
6160
+ const int nb02 = src0->nb[2];
6161
+ const int nb03 = src0->nb[3];
6162
+
6163
+ const int nb10 = src1->nb[0];
6164
+ const int nb11 = src1->nb[1];
6165
+ const int nb12 = src1->nb[2];
6166
+ const int nb13 = src1->nb[3];
6167
+
6168
+ const int nb0 = dst->nb[0];
6169
+ const int nb1 = dst->nb[1];
6170
+ const int nb2 = dst->nb[2];
6171
+ const int nb3 = dst->nb[3];
6172
+
6173
+ const int ith = params->ith;
6174
+ const int nth = params->nth;
6175
+
6176
+ GGML_ASSERT(ne02 == ne12);
6177
+ GGML_ASSERT(ne03 == ne13);
6178
+ GGML_ASSERT(ne2 == ne12);
6179
+ GGML_ASSERT(ne3 == ne13);
6180
+
6181
+ const enum ggml_type type = src0->type;
6182
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6183
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6184
+
6185
+ // we don't support permuted src0 or src1
6186
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
6187
+ GGML_ASSERT(nb10 == sizeof(float));
6188
+
6189
+ // dst cannot be transposed or permuted
6190
+ GGML_ASSERT(nb0 <= nb1);
6191
+ GGML_ASSERT(nb1 <= nb2);
6192
+ GGML_ASSERT(nb2 <= nb3);
6193
+
6194
+ GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
6195
+ GGML_ASSERT(dst->type == src0->type);
6196
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
+
6198
+ // total rows in src0
6199
+ const int nr = ne01*ne02*ne03;
6200
+
6201
+ // rows per thread
6202
+ const int dr = (nr + nth - 1)/nth;
6203
+
6204
+ // row range for this thread
6205
+ const int ir0 = dr*ith;
6206
+ const int ir1 = MIN(ir0 + dr, nr);
6207
+
6208
+ float * wdata = (float*) params->wdata + ne00 * ith;
6209
+
6210
+ for (int ir = ir0; ir < ir1; ++ir) {
6211
+ // src0 indices
6212
+ const int i03 = ir/(ne02*ne01);
6213
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
6214
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6215
+
6216
+ // src1 and dst are same shape as src0 => same indices
6217
+ const int i13 = i03;
6218
+ const int i12 = i02;
6219
+ const int i11 = i01;
5468
6220
 
5469
- GGML_ASSERT( nb0 == sizeof(float));
5470
- GGML_ASSERT(nb00 == sizeof(float));
6221
+ const int i3 = i03;
6222
+ const int i2 = i02;
6223
+ const int i1 = i01;
5471
6224
 
5472
- if (nb10 == sizeof(float)) {
5473
- for (int j = ith; j < n; j += nth) {
5474
- #ifdef GGML_USE_ACCELERATE
5475
- vDSP_vadd(
5476
- (float *) ((char *) src0->data + j*nb01), 1,
5477
- (float *) ((char *) src1->data + j*nb11), 1,
5478
- (float *) ((char *) dst->data + j*nb1), 1, nc);
5479
- #else
5480
- ggml_vec_add_f32(nc,
5481
- (float *) ((char *) dst->data + j*nb1),
5482
- (float *) ((char *) src0->data + j*nb01),
5483
- (float *) ((char *) src1->data + j*nb11));
5484
- #endif
5485
- }
5486
- } else {
5487
- // src1 is not contiguous
5488
- for (int j = ith; j < n; j += nth) {
5489
- float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
5490
- float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5491
- for (int i = 0; i < nc; i++) {
5492
- float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6225
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
6226
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
6227
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
5493
6228
 
5494
- dst_ptr[i] = src0_ptr[i] + *src1_ptr;
5495
- }
5496
- }
6229
+ assert(ne00 % 32 == 0);
6230
+
6231
+ // unquantize row from src0 to temp buffer
6232
+ dequantize_row_q(src0_row, wdata, ne00);
6233
+ // add src1
6234
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
6235
+ // quantize row to dst
6236
+ quantize_row_q(wdata, dst_row, ne00);
5497
6237
  }
5498
6238
  }
5499
6239
 
@@ -5507,13 +6247,24 @@ static void ggml_compute_forward_add(
5507
6247
  {
5508
6248
  ggml_compute_forward_add_f32(params, src0, src1, dst);
5509
6249
  } break;
6250
+ case GGML_TYPE_F16:
6251
+ {
6252
+ if (src1->type == GGML_TYPE_F16) {
6253
+ ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6254
+ }
6255
+ else if (src1->type == GGML_TYPE_F32) {
6256
+ ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6257
+ }
6258
+ else {
6259
+ GGML_ASSERT(false);
6260
+ }
6261
+ } break;
5510
6262
  case GGML_TYPE_Q4_0:
5511
6263
  case GGML_TYPE_Q4_1:
5512
- case GGML_TYPE_I8:
5513
- case GGML_TYPE_I16:
5514
- case GGML_TYPE_I32:
5515
- case GGML_TYPE_F16:
5516
- case GGML_TYPE_COUNT:
6264
+ {
6265
+ ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
+ } break;
6267
+ default:
5517
6268
  {
5518
6269
  GGML_ASSERT(false);
5519
6270
  } break;
@@ -5559,13 +6310,7 @@ static void ggml_compute_forward_sub(
5559
6310
  {
5560
6311
  ggml_compute_forward_sub_f32(params, src0, src1, dst);
5561
6312
  } break;
5562
- case GGML_TYPE_Q4_0:
5563
- case GGML_TYPE_Q4_1:
5564
- case GGML_TYPE_I8:
5565
- case GGML_TYPE_I16:
5566
- case GGML_TYPE_I32:
5567
- case GGML_TYPE_F16:
5568
- case GGML_TYPE_COUNT:
6313
+ default:
5569
6314
  {
5570
6315
  GGML_ASSERT(false);
5571
6316
  } break;
@@ -5611,13 +6356,7 @@ static void ggml_compute_forward_mul(
5611
6356
  {
5612
6357
  ggml_compute_forward_mul_f32(params, src0, src1, dst);
5613
6358
  } break;
5614
- case GGML_TYPE_Q4_0:
5615
- case GGML_TYPE_Q4_1:
5616
- case GGML_TYPE_I8:
5617
- case GGML_TYPE_I16:
5618
- case GGML_TYPE_I32:
5619
- case GGML_TYPE_F16:
5620
- case GGML_TYPE_COUNT:
6359
+ default:
5621
6360
  {
5622
6361
  GGML_ASSERT(false);
5623
6362
  } break;
@@ -5663,13 +6402,7 @@ static void ggml_compute_forward_div(
5663
6402
  {
5664
6403
  ggml_compute_forward_div_f32(params, src0, src1, dst);
5665
6404
  } break;
5666
- case GGML_TYPE_Q4_0:
5667
- case GGML_TYPE_Q4_1:
5668
- case GGML_TYPE_I8:
5669
- case GGML_TYPE_I16:
5670
- case GGML_TYPE_I32:
5671
- case GGML_TYPE_F16:
5672
- case GGML_TYPE_COUNT:
6405
+ default:
5673
6406
  {
5674
6407
  GGML_ASSERT(false);
5675
6408
  } break;
@@ -5711,13 +6444,7 @@ static void ggml_compute_forward_sqr(
5711
6444
  {
5712
6445
  ggml_compute_forward_sqr_f32(params, src0, dst);
5713
6446
  } break;
5714
- case GGML_TYPE_Q4_0:
5715
- case GGML_TYPE_Q4_1:
5716
- case GGML_TYPE_I8:
5717
- case GGML_TYPE_I16:
5718
- case GGML_TYPE_I32:
5719
- case GGML_TYPE_F16:
5720
- case GGML_TYPE_COUNT:
6447
+ default:
5721
6448
  {
5722
6449
  GGML_ASSERT(false);
5723
6450
  } break;
@@ -5759,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
5759
6486
  {
5760
6487
  ggml_compute_forward_sqrt_f32(params, src0, dst);
5761
6488
  } break;
5762
- case GGML_TYPE_Q4_0:
5763
- case GGML_TYPE_Q4_1:
5764
- case GGML_TYPE_I8:
5765
- case GGML_TYPE_I16:
5766
- case GGML_TYPE_I32:
5767
- case GGML_TYPE_F16:
5768
- case GGML_TYPE_COUNT:
6489
+ default:
5769
6490
  {
5770
6491
  GGML_ASSERT(false);
5771
6492
  } break;
@@ -5817,13 +6538,7 @@ static void ggml_compute_forward_sum(
5817
6538
  {
5818
6539
  ggml_compute_forward_sum_f32(params, src0, dst);
5819
6540
  } break;
5820
- case GGML_TYPE_Q4_0:
5821
- case GGML_TYPE_Q4_1:
5822
- case GGML_TYPE_I8:
5823
- case GGML_TYPE_I16:
5824
- case GGML_TYPE_I32:
5825
- case GGML_TYPE_F16:
5826
- case GGML_TYPE_COUNT:
6541
+ default:
5827
6542
  {
5828
6543
  GGML_ASSERT(false);
5829
6544
  } break;
@@ -5894,13 +6609,7 @@ static void ggml_compute_forward_mean(
5894
6609
  {
5895
6610
  ggml_compute_forward_mean_f32(params, src0, dst);
5896
6611
  } break;
5897
- case GGML_TYPE_Q4_0:
5898
- case GGML_TYPE_Q4_1:
5899
- case GGML_TYPE_I8:
5900
- case GGML_TYPE_I16:
5901
- case GGML_TYPE_I32:
5902
- case GGML_TYPE_F16:
5903
- case GGML_TYPE_COUNT:
6612
+ default:
5904
6613
  {
5905
6614
  GGML_ASSERT(false);
5906
6615
  } break;
@@ -5958,13 +6667,7 @@ static void ggml_compute_forward_repeat(
5958
6667
  {
5959
6668
  ggml_compute_forward_repeat_f32(params, src0, dst);
5960
6669
  } break;
5961
- case GGML_TYPE_Q4_0:
5962
- case GGML_TYPE_Q4_1:
5963
- case GGML_TYPE_I8:
5964
- case GGML_TYPE_I16:
5965
- case GGML_TYPE_I32:
5966
- case GGML_TYPE_F16:
5967
- case GGML_TYPE_COUNT:
6670
+ default:
5968
6671
  {
5969
6672
  GGML_ASSERT(false);
5970
6673
  } break;
@@ -6006,13 +6709,7 @@ static void ggml_compute_forward_abs(
6006
6709
  {
6007
6710
  ggml_compute_forward_abs_f32(params, src0, dst);
6008
6711
  } break;
6009
- case GGML_TYPE_Q4_0:
6010
- case GGML_TYPE_Q4_1:
6011
- case GGML_TYPE_I8:
6012
- case GGML_TYPE_I16:
6013
- case GGML_TYPE_I32:
6014
- case GGML_TYPE_F16:
6015
- case GGML_TYPE_COUNT:
6712
+ default:
6016
6713
  {
6017
6714
  GGML_ASSERT(false);
6018
6715
  } break;
@@ -6054,13 +6751,7 @@ static void ggml_compute_forward_sgn(
6054
6751
  {
6055
6752
  ggml_compute_forward_sgn_f32(params, src0, dst);
6056
6753
  } break;
6057
- case GGML_TYPE_Q4_0:
6058
- case GGML_TYPE_Q4_1:
6059
- case GGML_TYPE_I8:
6060
- case GGML_TYPE_I16:
6061
- case GGML_TYPE_I32:
6062
- case GGML_TYPE_F16:
6063
- case GGML_TYPE_COUNT:
6754
+ default:
6064
6755
  {
6065
6756
  GGML_ASSERT(false);
6066
6757
  } break;
@@ -6102,13 +6793,7 @@ static void ggml_compute_forward_neg(
6102
6793
  {
6103
6794
  ggml_compute_forward_neg_f32(params, src0, dst);
6104
6795
  } break;
6105
- case GGML_TYPE_Q4_0:
6106
- case GGML_TYPE_Q4_1:
6107
- case GGML_TYPE_I8:
6108
- case GGML_TYPE_I16:
6109
- case GGML_TYPE_I32:
6110
- case GGML_TYPE_F16:
6111
- case GGML_TYPE_COUNT:
6796
+ default:
6112
6797
  {
6113
6798
  GGML_ASSERT(false);
6114
6799
  } break;
@@ -6150,13 +6835,7 @@ static void ggml_compute_forward_step(
6150
6835
  {
6151
6836
  ggml_compute_forward_step_f32(params, src0, dst);
6152
6837
  } break;
6153
- case GGML_TYPE_Q4_0:
6154
- case GGML_TYPE_Q4_1:
6155
- case GGML_TYPE_I8:
6156
- case GGML_TYPE_I16:
6157
- case GGML_TYPE_I32:
6158
- case GGML_TYPE_F16:
6159
- case GGML_TYPE_COUNT:
6838
+ default:
6160
6839
  {
6161
6840
  GGML_ASSERT(false);
6162
6841
  } break;
@@ -6198,13 +6877,7 @@ static void ggml_compute_forward_relu(
6198
6877
  {
6199
6878
  ggml_compute_forward_relu_f32(params, src0, dst);
6200
6879
  } break;
6201
- case GGML_TYPE_Q4_0:
6202
- case GGML_TYPE_Q4_1:
6203
- case GGML_TYPE_I8:
6204
- case GGML_TYPE_I16:
6205
- case GGML_TYPE_I32:
6206
- case GGML_TYPE_F16:
6207
- case GGML_TYPE_COUNT:
6880
+ default:
6208
6881
  {
6209
6882
  GGML_ASSERT(false);
6210
6883
  } break;
@@ -6263,13 +6936,7 @@ static void ggml_compute_forward_gelu(
6263
6936
  {
6264
6937
  ggml_compute_forward_gelu_f32(params, src0, dst);
6265
6938
  } break;
6266
- case GGML_TYPE_Q4_0:
6267
- case GGML_TYPE_Q4_1:
6268
- case GGML_TYPE_I8:
6269
- case GGML_TYPE_I16:
6270
- case GGML_TYPE_I32:
6271
- case GGML_TYPE_F16:
6272
- case GGML_TYPE_COUNT:
6939
+ default:
6273
6940
  {
6274
6941
  GGML_ASSERT(false);
6275
6942
  } break;
@@ -6330,13 +6997,7 @@ static void ggml_compute_forward_silu(
6330
6997
  {
6331
6998
  ggml_compute_forward_silu_f32(params, src0, dst);
6332
6999
  } break;
6333
- case GGML_TYPE_Q4_0:
6334
- case GGML_TYPE_Q4_1:
6335
- case GGML_TYPE_I8:
6336
- case GGML_TYPE_I16:
6337
- case GGML_TYPE_I32:
6338
- case GGML_TYPE_F16:
6339
- case GGML_TYPE_COUNT:
7000
+ default:
6340
7001
  {
6341
7002
  GGML_ASSERT(false);
6342
7003
  } break;
@@ -6416,13 +7077,7 @@ static void ggml_compute_forward_norm(
6416
7077
  {
6417
7078
  ggml_compute_forward_norm_f32(params, src0, dst);
6418
7079
  } break;
6419
- case GGML_TYPE_Q4_0:
6420
- case GGML_TYPE_Q4_1:
6421
- case GGML_TYPE_I8:
6422
- case GGML_TYPE_I16:
6423
- case GGML_TYPE_I32:
6424
- case GGML_TYPE_F16:
6425
- case GGML_TYPE_COUNT:
7080
+ default:
6426
7081
  {
6427
7082
  GGML_ASSERT(false);
6428
7083
  } break;
@@ -6496,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
6496
7151
  {
6497
7152
  ggml_compute_forward_rms_norm_f32(params, src0, dst);
6498
7153
  } break;
6499
- case GGML_TYPE_Q4_0:
6500
- case GGML_TYPE_Q4_1:
6501
- case GGML_TYPE_I8:
6502
- case GGML_TYPE_I16:
6503
- case GGML_TYPE_I32:
6504
- case GGML_TYPE_F16:
6505
- case GGML_TYPE_COUNT:
7154
+ default:
6506
7155
  {
6507
7156
  GGML_ASSERT(false);
6508
7157
  } break;
@@ -6894,27 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6894
7543
  //}
6895
7544
  }
6896
7545
 
6897
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6898
- [GGML_TYPE_Q4_0] = {
6899
- .dequantize_row_q = dequantize_row_q4_0,
6900
- .quantize_row_q = quantize_row_q4_0,
6901
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
- .vec_dot_q = ggml_vec_dot_q4_0,
6903
- },
6904
- [GGML_TYPE_Q4_1] = {
6905
- .dequantize_row_q = dequantize_row_q4_1,
6906
- .quantize_row_q = quantize_row_q4_1,
6907
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
- .vec_dot_q = ggml_vec_dot_q4_1,
6909
- },
6910
- };
6911
-
6912
- // For internal test use
6913
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
- GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
- return quantize_fns[i];
6916
- }
6917
-
6918
7546
  static void ggml_compute_forward_mul_mat_q_f32(
6919
7547
  const struct ggml_compute_params * params,
6920
7548
  const struct ggml_tensor * src0,
@@ -6962,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
6962
7590
  GGML_ASSERT(ne3 == ne13);
6963
7591
 
6964
7592
  const enum ggml_type type = src0->type;
6965
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6966
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
7593
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7594
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
6967
7595
 
6968
7596
  // we don't support permuted src0 or src1
6969
7597
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7032,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
7032
7660
 
7033
7661
  if (params->type == GGML_TASK_INIT) {
7034
7662
  char * wdata = params->wdata;
7035
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7663
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7036
7664
 
7037
7665
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7038
7666
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7039
7667
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
7040
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7668
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7041
7669
  wdata += row_size;
7042
7670
  }
7043
7671
  }
@@ -7063,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7063
7691
  const int ir1 = MIN(ir0 + dr, nr);
7064
7692
 
7065
7693
  void * wdata = params->wdata;
7066
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7694
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7067
7695
 
7068
7696
  for (int ir = ir0; ir < ir1; ++ir) {
7069
7697
  // src0 indices
@@ -7111,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
7111
7739
  switch (src0->type) {
7112
7740
  case GGML_TYPE_Q4_0:
7113
7741
  case GGML_TYPE_Q4_1:
7742
+ case GGML_TYPE_Q8_0:
7114
7743
  {
7115
7744
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
7116
7745
  } break;
@@ -7122,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
7122
7751
  {
7123
7752
  ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
7124
7753
  } break;
7125
- case GGML_TYPE_I8:
7126
- case GGML_TYPE_I16:
7127
- case GGML_TYPE_I32:
7128
- case GGML_TYPE_COUNT:
7754
+ default:
7129
7755
  {
7130
7756
  GGML_ASSERT(false);
7131
7757
  } break;
@@ -7207,13 +7833,7 @@ static void ggml_compute_forward_scale(
7207
7833
  {
7208
7834
  ggml_compute_forward_scale_f32(params, src0, src1, dst);
7209
7835
  } break;
7210
- case GGML_TYPE_Q4_0:
7211
- case GGML_TYPE_Q4_1:
7212
- case GGML_TYPE_I8:
7213
- case GGML_TYPE_I16:
7214
- case GGML_TYPE_I32:
7215
- case GGML_TYPE_F16:
7216
- case GGML_TYPE_COUNT:
7836
+ default:
7217
7837
  {
7218
7838
  GGML_ASSERT(false);
7219
7839
  } break;
@@ -7374,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
7374
7994
  switch (src0->type) {
7375
7995
  case GGML_TYPE_Q4_0:
7376
7996
  case GGML_TYPE_Q4_1:
7997
+ case GGML_TYPE_Q8_0:
7377
7998
  {
7378
7999
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
7379
8000
  } break;
@@ -7385,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
7385
8006
  {
7386
8007
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
7387
8008
  } break;
7388
- case GGML_TYPE_I8:
7389
- case GGML_TYPE_I16:
7390
- case GGML_TYPE_I32:
7391
- case GGML_TYPE_COUNT:
8009
+ default:
7392
8010
  {
7393
8011
  GGML_ASSERT(false);
7394
8012
  } break;
@@ -7461,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
7461
8079
  {
7462
8080
  ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
7463
8081
  } break;
7464
- case GGML_TYPE_Q4_0:
7465
- case GGML_TYPE_Q4_1:
7466
- case GGML_TYPE_I8:
7467
- case GGML_TYPE_I16:
7468
- case GGML_TYPE_I32:
7469
- case GGML_TYPE_F16:
7470
- case GGML_TYPE_COUNT:
8082
+ default:
7471
8083
  {
7472
8084
  GGML_ASSERT(false);
7473
8085
  } break;
@@ -7555,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
7555
8167
  {
7556
8168
  ggml_compute_forward_soft_max_f32(params, src0, dst);
7557
8169
  } break;
7558
- case GGML_TYPE_Q4_0:
7559
- case GGML_TYPE_Q4_1:
7560
- case GGML_TYPE_I8:
7561
- case GGML_TYPE_I16:
7562
- case GGML_TYPE_I32:
7563
- case GGML_TYPE_F16:
7564
- case GGML_TYPE_COUNT:
8170
+ default:
7565
8171
  {
7566
8172
  GGML_ASSERT(false);
7567
8173
  } break;
@@ -7713,11 +8319,11 @@ static void ggml_compute_forward_rope_f16(
7713
8319
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7714
8320
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7715
8321
 
7716
- const float x0 = ggml_fp16_to_fp32(src[0]);
7717
- const float x1 = ggml_fp16_to_fp32(src[1]);
8322
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
7718
8324
 
7719
- dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
7720
- dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
8325
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
7721
8327
  }
7722
8328
  }
7723
8329
  }
@@ -7738,12 +8344,7 @@ static void ggml_compute_forward_rope(
7738
8344
  {
7739
8345
  ggml_compute_forward_rope_f32(params, src0, src1, dst);
7740
8346
  } break;
7741
- case GGML_TYPE_Q4_0:
7742
- case GGML_TYPE_Q4_1:
7743
- case GGML_TYPE_I8:
7744
- case GGML_TYPE_I16:
7745
- case GGML_TYPE_I32:
7746
- case GGML_TYPE_COUNT:
8347
+ default:
7747
8348
  {
7748
8349
  GGML_ASSERT(false);
7749
8350
  } break;
@@ -8006,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
8006
8607
  {
8007
8608
  ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
8008
8609
  } break;
8009
- case GGML_TYPE_Q4_0:
8010
- case GGML_TYPE_Q4_1:
8011
- case GGML_TYPE_I8:
8012
- case GGML_TYPE_I16:
8013
- case GGML_TYPE_I32:
8014
- case GGML_TYPE_COUNT:
8610
+ default:
8015
8611
  {
8016
8612
  GGML_ASSERT(false);
8017
8613
  } break;
@@ -8274,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
8274
8870
  {
8275
8871
  ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
8276
8872
  } break;
8277
- case GGML_TYPE_Q4_0:
8278
- case GGML_TYPE_Q4_1:
8279
- case GGML_TYPE_I8:
8280
- case GGML_TYPE_I16:
8281
- case GGML_TYPE_I32:
8282
- case GGML_TYPE_COUNT:
8873
+ default:
8283
8874
  {
8284
8875
  GGML_ASSERT(false);
8285
8876
  } break;
@@ -8759,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
8759
9350
  {
8760
9351
  ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
8761
9352
  } break;
8762
- case GGML_TYPE_Q4_0:
8763
- case GGML_TYPE_Q4_1:
8764
- case GGML_TYPE_I8:
8765
- case GGML_TYPE_I16:
8766
- case GGML_TYPE_I32:
8767
- case GGML_TYPE_COUNT:
9353
+ default:
8768
9354
  {
8769
9355
  GGML_ASSERT(false);
8770
9356
  } break;
@@ -8970,12 +9556,7 @@ static void ggml_compute_forward_flash_ff(
8970
9556
  {
8971
9557
  GGML_ASSERT(false); // TODO
8972
9558
  } break;
8973
- case GGML_TYPE_Q4_0:
8974
- case GGML_TYPE_Q4_1:
8975
- case GGML_TYPE_I8:
8976
- case GGML_TYPE_I16:
8977
- case GGML_TYPE_I32:
8978
- case GGML_TYPE_COUNT:
9559
+ default:
8979
9560
  {
8980
9561
  GGML_ASSERT(false);
8981
9562
  } break;
@@ -9019,13 +9600,7 @@ static void ggml_compute_forward_map_unary(
9019
9600
  {
9020
9601
  ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
9602
  } break;
9022
- case GGML_TYPE_Q4_0:
9023
- case GGML_TYPE_Q4_1:
9024
- case GGML_TYPE_I8:
9025
- case GGML_TYPE_I16:
9026
- case GGML_TYPE_I32:
9027
- case GGML_TYPE_F16:
9028
- case GGML_TYPE_COUNT:
9603
+ default:
9029
9604
  {
9030
9605
  GGML_ASSERT(false);
9031
9606
  } break;
@@ -9074,13 +9649,7 @@ static void ggml_compute_forward_map_binary(
9074
9649
  {
9075
9650
  ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
9651
  } break;
9077
- case GGML_TYPE_Q4_0:
9078
- case GGML_TYPE_Q4_1:
9079
- case GGML_TYPE_I8:
9080
- case GGML_TYPE_I16:
9081
- case GGML_TYPE_I32:
9082
- case GGML_TYPE_F16:
9083
- case GGML_TYPE_COUNT:
9652
+ default:
9084
9653
  {
9085
9654
  GGML_ASSERT(false);
9086
9655
  } break;
@@ -9830,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9830
10399
  struct ggml_tensor * node = cgraph->nodes[i];
9831
10400
 
9832
10401
  switch (node->op) {
10402
+ case GGML_OP_CPY:
9833
10403
  case GGML_OP_DUP:
9834
10404
  {
9835
10405
  node->n_tasks = 1;
10406
+
10407
+ size_t cur = 0;
10408
+ if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
10410
+ }
10411
+
10412
+ work_size = MAX(work_size, cur);
9836
10413
  } break;
9837
10414
  case GGML_OP_ADD:
9838
10415
  {
9839
10416
  node->n_tasks = n_threads;
10417
+
10418
+ size_t cur = 0;
10419
+
10420
+ if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
10421
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
+ }
10423
+
10424
+ work_size = MAX(work_size, cur);
9840
10425
  } break;
9841
10426
  case GGML_OP_SUB:
9842
10427
  case GGML_OP_MUL:
@@ -9905,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9905
10490
  } else
9906
10491
  #endif
9907
10492
  {
9908
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
10493
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
9909
10494
  }
9910
10495
  } else {
9911
10496
  GGML_ASSERT(false);
@@ -9917,7 +10502,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9917
10502
  {
9918
10503
  node->n_tasks = n_threads;
9919
10504
  } break;
9920
- case GGML_OP_CPY:
9921
10505
  case GGML_OP_CONT:
9922
10506
  case GGML_OP_RESHAPE:
9923
10507
  case GGML_OP_VIEW:
@@ -11080,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
11080
11664
  ////////////////////////////////////////////////////////////////////////////////
11081
11665
 
11082
11666
  size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
11083
- assert(k % QK == 0);
11084
- const int nb = k / QK;
11667
+ assert(k % QK4_0 == 0);
11668
+ const int nb = k / QK4_0;
11085
11669
 
11086
11670
  for (int j = 0; j < n; j += k) {
11087
- block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
11671
+ block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
11088
11672
 
11089
11673
  quantize_row_q4_0_reference(src + j, y, k);
11090
11674
 
11091
11675
  for (int i = 0; i < nb; i++) {
11092
- for (int l = 0; l < QK; l += 2) {
11676
+ for (int l = 0; l < QK4_0; l += 2) {
11093
11677
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11094
11678
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11095
11679
 
@@ -11099,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
11099
11683
  }
11100
11684
  }
11101
11685
 
11102
- return (n/QK*sizeof(block_q4_0));
11686
+ return (n/QK4_0*sizeof(block_q4_0));
11103
11687
  }
11104
11688
 
11105
11689
  size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
11106
- assert(k % QK == 0);
11107
- const int nb = k / QK;
11690
+ assert(k % QK4_1 == 0);
11691
+ const int nb = k / QK4_1;
11108
11692
 
11109
11693
  for (int j = 0; j < n; j += k) {
11110
- block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
11694
+ block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
11111
11695
 
11112
11696
  quantize_row_q4_1_reference(src + j, y, k);
11113
11697
 
11114
11698
  for (int i = 0; i < nb; i++) {
11115
- for (int l = 0; l < QK; l += 2) {
11699
+ for (int l = 0; l < QK4_1; l += 2) {
11116
11700
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11117
11701
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11118
11702
 
@@ -11122,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11122
11706
  }
11123
11707
  }
11124
11708
 
11125
- return (n/QK*sizeof(block_q4_1));
11709
+ return (n/QK4_1*sizeof(block_q4_1));
11126
11710
  }
11127
11711
 
11128
11712
  ////////////////////////////////////////////////////////////////////////////////
@@ -11151,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
11151
11735
  #endif
11152
11736
  }
11153
11737
 
11738
+ int ggml_cpu_has_avx512_vbmi(void) {
11739
+ #if defined(__AVX512VBMI__)
11740
+ return 1;
11741
+ #else
11742
+ return 0;
11743
+ #endif
11744
+ }
11745
+
11746
+ int ggml_cpu_has_avx512_vnni(void) {
11747
+ #if defined(__AVX512VNNI__)
11748
+ return 1;
11749
+ #else
11750
+ return 0;
11751
+ #endif
11752
+ }
11753
+
11154
11754
  int ggml_cpu_has_fma(void) {
11155
11755
  #if defined(__FMA__)
11156
11756
  return 1;