llama_cpp 0.0.4 → 0.0.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -118,7 +118,16 @@ typedef void* thread_ret_t;
118
118
  #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
119
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
120
  #else
121
- #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
121
+ inline static void* ggml_aligned_malloc(size_t size) {
122
+ void* aligned_memory = NULL;
123
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
124
+ if (result != 0) {
125
+ // Handle allocation failure
126
+ return NULL;
127
+ }
128
+ return aligned_memory;
129
+ }
130
+ #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
122
131
  #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
132
  #endif
124
133
 
@@ -418,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
418
427
  // quantization
419
428
  //
420
429
 
421
- #define QK 32
422
-
423
430
  // AVX routines provided by GH user Const-me
424
431
  // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
425
432
  #if __AVX2__ || __AVX512F__
@@ -531,68 +538,73 @@ inline static float vaddvq_f32(float32x4_t v) {
531
538
  return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
539
  }
533
540
 
534
- inline float vminvq_f32(float32x4_t v) {
541
+ float vminvq_f32(float32x4_t v) {
535
542
  return
536
543
  MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
544
  MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
545
  }
539
546
 
540
- inline float vmaxvq_f32(float32x4_t v) {
547
+ float vmaxvq_f32(float32x4_t v) {
541
548
  return
542
549
  MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
550
  MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
551
  }
545
552
 
546
- inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
553
+ int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
554
  return vget_low_s8(vcombine_s8(a, b));
548
555
  }
549
556
 
550
- inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
557
+ int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
558
  return vget_high_s8(vcombine_s8(a, b));
552
559
  }
553
560
 
554
- inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
561
+ uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
562
  return vget_low_u8(vcombine_u8(a, b));
556
563
  }
557
564
 
558
- inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
565
+ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
566
  return vget_high_u8(vcombine_u8(a, b));
560
567
  }
561
568
 
562
569
  #endif
563
570
  #endif
564
571
 
565
- // method 5
566
- // blocks of QK elements
567
- // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
572
+
573
+ #define QK4_0 32
568
574
  typedef struct {
569
- float d; // delta
570
- uint8_t qs[QK / 2]; // nibbles / quants
575
+ float d; // delta
576
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
571
577
  } block_q4_0;
572
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
578
+ static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
573
579
 
574
- // method 4
575
- // blocks of QK elements
576
- // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
580
+ #define QK4_1 32
577
581
  typedef struct {
578
- float d;
579
- float m;
580
- uint8_t qs[QK / 2]; // nibbles / quants
582
+ float d; // delta
583
+ float m; // min
584
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
581
585
  } block_q4_1;
582
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
586
+ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
587
+
588
+ #define QK8_0 32
589
+ typedef struct {
590
+ float d; // delta
591
+ int8_t qs[QK8_0]; // quants
592
+ } block_q8_0;
593
+ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
594
+
583
595
 
584
596
  // reference implementation for deterministic creation of model files
585
597
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
586
- assert(k % QK == 0);
587
- const int nb = k / QK;
598
+ assert(k % QK4_0 == 0);
599
+ const int nb = k / QK4_0;
588
600
 
589
- uint8_t pp[QK/2];
601
+ uint8_t pp[QK4_0/2];
590
602
 
591
603
  for (int i = 0; i < nb; i++) {
592
604
  float amax = 0.0f; // absolute max
593
605
 
594
- for (int l = 0; l < QK; l++) {
595
- const float v = x[i*QK + l];
606
+ for (int l = 0; l < QK4_0; l++) {
607
+ const float v = x[i*QK4_0 + l];
596
608
  amax = MAX(amax, fabsf(v));
597
609
  }
598
610
 
@@ -601,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
601
613
 
602
614
  y[i].d = d;
603
615
 
604
- for (int l = 0; l < QK; l += 2) {
605
- const float v0 = x[i*QK + l + 0]*id;
606
- const float v1 = x[i*QK + l + 1]*id;
616
+ for (int l = 0; l < QK4_0; l += 2) {
617
+ const float v0 = x[i*QK4_0 + l + 0]*id;
618
+ const float v1 = x[i*QK4_0 + l + 1]*id;
607
619
 
608
620
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
609
621
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -619,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
619
631
  }
620
632
 
621
633
  static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
622
- assert(k % QK == 0);
623
- const int nb = k / QK;
634
+ assert(k % QK4_0 == 0);
635
+ const int nb = k / QK4_0;
624
636
 
625
637
  block_q4_0 * restrict y = vy;
626
638
 
@@ -870,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
870
882
  }
871
883
 
872
884
  static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
873
- assert(k % QK == 0);
874
- const int nb = k / QK;
885
+ assert(k % QK4_1 == 0);
886
+ const int nb = k / QK4_1;
875
887
 
876
888
  block_q4_1 * restrict y = vy;
877
889
 
878
- uint8_t pp[QK/2];
890
+ uint8_t pp[QK4_1/2];
879
891
 
880
892
  for (int i = 0; i < nb; i++) {
881
893
  float min = FLT_MAX;
882
894
  float max = -FLT_MAX;
883
895
 
884
- for (int l = 0; l < QK; l++) {
885
- const float v = x[i*QK + l];
896
+ for (int l = 0; l < QK4_1; l++) {
897
+ const float v = x[i*QK4_1 + l];
886
898
  if (v < min) min = v;
887
899
  if (v > max) max = v;
888
900
  }
@@ -893,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
893
905
  y[i].d = d;
894
906
  y[i].m = min;
895
907
 
896
- for (int l = 0; l < QK; l += 2) {
897
- const float v0 = (x[i*QK + l + 0] - min)*id;
898
- const float v1 = (x[i*QK + l + 1] - min)*id;
908
+ for (int l = 0; l < QK4_1; l += 2) {
909
+ const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
910
+ const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
899
911
 
900
912
  const uint8_t vi0 = roundf(v0);
901
913
  const uint8_t vi1 = roundf(v1);
@@ -911,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
911
923
  }
912
924
 
913
925
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
914
- assert(k % QK == 0);
926
+ assert(k % QK4_1 == 0);
915
927
 
916
- const int nb = k / QK;
928
+ const int nb = k / QK4_1;
917
929
 
918
930
  block_q4_1 * restrict y = vy;
919
931
 
@@ -997,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
997
1009
  float32x4_t minv[8];
998
1010
  float32x4_t maxv[8];
999
1011
 
1000
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
1012
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
1001
1013
 
1002
1014
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
1003
1015
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -1033,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1033
1045
  #endif
1034
1046
  }
1035
1047
 
1048
+ // reference implementation for deterministic creation of model files
1049
+ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
+ assert(k % QK8_0 == 0);
1051
+ const int nb = k / QK8_0;
1052
+
1053
+ for (int i = 0; i < nb; i++) {
1054
+ float amax = 0.0f; // absolute max
1055
+
1056
+ for (int l = 0; l < QK8_0; l++) {
1057
+ const float v = x[i*QK8_0 + l];
1058
+ amax = MAX(amax, fabsf(v));
1059
+ }
1060
+
1061
+ const float d = amax / ((1 << 7) - 1);
1062
+ const float id = d ? 1.0f/d : 0.0f;
1063
+
1064
+ y[i].d = d;
1065
+
1066
+ for (int l = 0; l < QK8_0; ++l) {
1067
+ const float v = x[i*QK8_0 + l]*id;
1068
+ y[i].qs[l] = roundf(v);
1069
+ }
1070
+ }
1071
+ }
1072
+
1073
+ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1074
+ assert(k % QK8_0 == 0);
1075
+ const int nb = k / QK8_0;
1076
+
1077
+ block_q8_0 * restrict y = vy;
1078
+
1079
+ #if defined(__ARM_NEON)
1080
+ for (int i = 0; i < nb; i++) {
1081
+ float32x4_t srcv [8];
1082
+ float32x4_t asrcv[8];
1083
+ float32x4_t amaxv[8];
1084
+
1085
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1086
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1087
+
1088
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1089
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1090
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1091
+
1092
+ const float amax = vmaxvq_f32(amaxv[0]);
1093
+
1094
+ const float d = amax / ((1 << 7) - 1);
1095
+ const float id = d ? 1.0f/d : 0.0f;
1096
+
1097
+ y[i].d = d;
1098
+
1099
+ for (int l = 0; l < 8; l++) {
1100
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1101
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1102
+
1103
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1104
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1105
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1106
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1107
+ }
1108
+ }
1109
+ #elif defined(__AVX2__) || defined(__AVX__)
1110
+ for (int i = 0; i < nb; i++) {
1111
+ // Load elements into 4 AVX vectors
1112
+ __m256 v0 = _mm256_loadu_ps( x );
1113
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1114
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1115
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1116
+ x += 32;
1117
+
1118
+ // Compute max(abs(e)) for the block
1119
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1120
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1121
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1122
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1123
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1124
+
1125
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1126
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1127
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1128
+ const float maxScalar = _mm_cvtss_f32( max4 );
1129
+
1130
+ // Quantize these floats
1131
+ const float d = maxScalar / 127.f;
1132
+ y[i].d = d;
1133
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1134
+ const __m256 mul = _mm256_set1_ps( id );
1135
+
1136
+ // Apply the multiplier
1137
+ v0 = _mm256_mul_ps( v0, mul );
1138
+ v1 = _mm256_mul_ps( v1, mul );
1139
+ v2 = _mm256_mul_ps( v2, mul );
1140
+ v3 = _mm256_mul_ps( v3, mul );
1141
+
1142
+ // Round to nearest integer
1143
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1144
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1145
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1146
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1147
+
1148
+ // Convert floats to integers
1149
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1150
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1151
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1152
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1153
+
1154
+ #if defined(__AVX2__)
1155
+ // Convert int32 to int16
1156
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1157
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1158
+ // Convert int16 to int8
1159
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1160
+
1161
+ // We got our precious signed bytes, but the order is now wrong
1162
+ // These AVX2 pack instructions process 16-byte pieces independently
1163
+ // The following instruction is fixing the order
1164
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1165
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1166
+
1167
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1168
+ #else
1169
+ // Since we don't have in AVX some necessary functions,
1170
+ // we split the registers in half and call AVX2 analogs from SSE
1171
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1172
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1173
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1174
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1175
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1176
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1177
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1178
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1179
+
1180
+ // Convert int32 to int16
1181
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1182
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1183
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1184
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1185
+ // Convert int16 to int8
1186
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1187
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1188
+
1189
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1190
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1191
+ #endif
1192
+ }
1193
+ #else
1194
+ // scalar
1195
+ quantize_row_q8_0_reference(x, y, k);
1196
+ #endif
1197
+ }
1198
+
1036
1199
  static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1037
- assert(k % QK == 0);
1038
- const int nb = k / QK;
1200
+ assert(k % QK4_0 == 0);
1201
+ const int nb = k / QK4_0;
1039
1202
 
1040
1203
  const block_q4_0 * restrict x = vx;
1041
1204
 
@@ -1046,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1046
1209
 
1047
1210
  const uint8_t * restrict pp = x[i].qs;
1048
1211
 
1049
- for (int l = 0; l < QK; l += 32) {
1212
+ for (int l = 0; l < QK4_0; l += 32) {
1050
1213
  // Load 32x4-bit integers into 32x8-bit integers
1051
1214
  __m256i vx8 = bytesFromNibbles(pp+l/2);
1052
1215
 
@@ -1068,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1068
1231
  // Scale and store
1069
1232
  for (int j = 0; j < 4; j++) {
1070
1233
  const __m256 result = _mm256_mul_ps(vf[j], d_v);
1071
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1234
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1072
1235
  }
1073
1236
  }
1074
1237
  }
@@ -1078,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1078
1241
 
1079
1242
  const uint8_t * restrict pp = x[i].qs;
1080
1243
 
1081
- for (int l = 0; l < QK; l += 16) {
1244
+ for (int l = 0; l < QK4_0; l += 16) {
1082
1245
  // Load 16x4-bit integers into 8x8-bit integers
1083
1246
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1084
1247
 
@@ -1117,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1117
1280
  const float32x4_t r3 = vmulq_f32(vf_3, vd);
1118
1281
 
1119
1282
  // Store
1120
- vst1q_f32(y + i*QK + l + 0, r0);
1121
- vst1q_f32(y + i*QK + l + 4, r1);
1122
- vst1q_f32(y + i*QK + l + 8, r2);
1123
- vst1q_f32(y + i*QK + l + 12, r3);
1283
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1284
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1285
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1286
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1124
1287
  }
1125
1288
  }
1126
1289
  #else
@@ -1130,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1130
1293
 
1131
1294
  const uint8_t * restrict pp = x[i].qs;
1132
1295
 
1133
- for (int l = 0; l < QK; l += 2) {
1296
+ for (int l = 0; l < QK4_0; l += 2) {
1134
1297
  const uint8_t vi = pp[l/2];
1135
1298
 
1136
1299
  const int8_t vi0 = vi & 0xf;
@@ -1141,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1141
1304
 
1142
1305
  //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
1143
1306
 
1144
- y[i*QK + l + 0] = v0;
1145
- y[i*QK + l + 1] = v1;
1307
+ y[i*QK4_0 + l + 0] = v0;
1308
+ y[i*QK4_0 + l + 1] = v1;
1146
1309
 
1147
- assert(!isnan(y[i*QK + l + 0]));
1148
- assert(!isnan(y[i*QK + l + 1]));
1310
+ assert(!isnan(y[i*QK4_0 + l + 0]));
1311
+ assert(!isnan(y[i*QK4_0 + l + 1]));
1149
1312
  }
1150
1313
  }
1151
1314
  #endif
1152
1315
  }
1153
1316
 
1154
1317
  static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
1155
- assert(k % QK == 0);
1156
- const int nb = k / QK;
1318
+ assert(k % QK4_1 == 0);
1319
+ const int nb = k / QK4_1;
1157
1320
 
1158
1321
  const block_q4_1 * restrict x = vx;
1159
1322
 
@@ -1164,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1164
1327
 
1165
1328
  const uint8_t * restrict pp = x[i].qs;
1166
1329
 
1167
- for (int l = 0; l < QK; l += 32) {
1330
+ for (int l = 0; l < QK4_1; l += 32) {
1168
1331
  // Load 32x4-bit integers into 32x8-bit integers
1169
1332
  __m256i vx8 = bytesFromNibbles(pp+l/2);
1170
1333
 
@@ -1183,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1183
1346
  // Scale, add m and store
1184
1347
  for (int j = 0; j < 4; j++) {
1185
1348
  const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
1186
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1349
+ _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
1187
1350
  }
1188
1351
  }
1189
1352
  }
@@ -1194,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1194
1357
 
1195
1358
  const uint8_t * restrict pp = x[i].qs;
1196
1359
 
1197
- for (int l = 0; l < QK; l += 16) {
1360
+ for (int l = 0; l < QK4_1; l += 16) {
1198
1361
  // Load 16x4-bit integers into 8x8-bit integers
1199
1362
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1200
1363
 
@@ -1225,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1225
1388
  const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1226
1389
 
1227
1390
  // Store
1228
- vst1q_f32(y + i*QK + l + 0, r0);
1229
- vst1q_f32(y + i*QK + l + 4, r1);
1230
- vst1q_f32(y + i*QK + l + 8, r2);
1231
- vst1q_f32(y + i*QK + l + 12, r3);
1391
+ vst1q_f32(y + i*QK4_1 + l + 0, r0);
1392
+ vst1q_f32(y + i*QK4_1 + l + 4, r1);
1393
+ vst1q_f32(y + i*QK4_1 + l + 8, r2);
1394
+ vst1q_f32(y + i*QK4_1 + l + 12, r3);
1232
1395
  }
1233
1396
  }
1234
1397
  #else
@@ -1238,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1238
1401
 
1239
1402
  const uint8_t * restrict pp = x[i].qs;
1240
1403
 
1241
- for (int l = 0; l < QK; l += 2) {
1404
+ for (int l = 0; l < QK4_1; l += 2) {
1242
1405
  const uint8_t vi = pp[l/2];
1243
1406
 
1244
1407
  const int8_t vi0 = vi & 0xf;
@@ -1247,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1247
1410
  const float v0 = vi0*d + m;
1248
1411
  const float v1 = vi1*d + m;
1249
1412
 
1250
- y[i*QK + l + 0] = v0;
1251
- y[i*QK + l + 1] = v1;
1413
+ y[i*QK4_1 + l + 0] = v0;
1414
+ y[i*QK4_1 + l + 1] = v1;
1252
1415
 
1253
- assert(!isnan(y[i*QK + l + 0]));
1254
- assert(!isnan(y[i*QK + l + 1]));
1416
+ assert(!isnan(y[i*QK4_1 + l + 0]));
1417
+ assert(!isnan(y[i*QK4_1 + l + 1]));
1255
1418
  }
1256
1419
  }
1257
1420
  #endif
1258
1421
  }
1259
1422
 
1423
+ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1424
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
+
1426
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
+ [GGML_TYPE_Q4_0] = {
1428
+ .dequantize_row_q = dequantize_row_q4_0,
1429
+ .quantize_row_q = quantize_row_q4_0,
1430
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1431
+ .quantize_row_q_dot = quantize_row_q8_0,
1432
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1433
+ },
1434
+ [GGML_TYPE_Q4_1] = {
1435
+ .dequantize_row_q = dequantize_row_q4_1,
1436
+ .quantize_row_q = quantize_row_q4_1,
1437
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
+ .quantize_row_q_dot = quantize_row_q4_1,
1439
+ .vec_dot_q = ggml_vec_dot_q4_1,
1440
+ },
1441
+ // TODO: GGML_TYPE_Q8_0
1442
+ };
1443
+
1444
+ // For internal test use
1445
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1446
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
1447
+ return quantize_fns[i];
1448
+ }
1449
+
1450
+
1260
1451
  //
1261
1452
  // simd mappings
1262
1453
  //
@@ -1813,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1813
2004
  *s = sumf;
1814
2005
  }
1815
2006
 
1816
- #if __AVX512F__ && QK == 32
1817
- static inline __m512 dot_q4_0_oneblock_avx512(
2007
+ #if __AVX512F__ && QK4_0 == 32
2008
+ static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
+ // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
+ // | :. =_ () [] <> () Zz Yy|
2014
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
+ // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
+ //
2020
+ // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
+ // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
+ // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
+ // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
+ #ifdef __AVX512VBMI__
2025
+ const __m512i byte_perm = _mm512_set_epi8(
2026
+ 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
+ 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
+ 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
+ 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
+ );
2031
+ const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
+ // After applying VPERMB, `permuted` looks like this:
2033
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
+ // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
+ // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
+ #else
2043
+ const __m512i word_perm = _mm512_set_epi16(
2044
+ 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
+ 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
+ );
2047
+ const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
+ // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
+ // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
+ // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
+ #endif
2052
+
2053
+ // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
+ const __mmask32 shift_mask = 0xaaaaaaaa;
2055
+ const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
+ // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
+ // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
+ // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
+
2067
+ // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
+ const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
+ return _mm512_and_si512( low_nibble_mask, shifted );
2070
+ // The final result looks like this:
2071
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
+ // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
+ // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
+ }
2081
+
2082
+ static inline __m512 dot_q4_0_twoblocks_avx512(
1818
2083
  __m512 acc,
1819
2084
  const block_q4_0 * restrict x,
1820
2085
  const block_q4_0 * restrict y,
1821
2086
  int i
1822
2087
  ) {
1823
- // Compute combined scale for the block
1824
- __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
1825
-
1826
- __m256i bx = bytesFromNibbles( x[i].qs );
1827
- __m256i by = bytesFromNibbles( y[i].qs );
1828
-
1829
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1830
- const __m256i off = _mm256_set1_epi8( 8 );
1831
- bx = _mm256_sub_epi8( bx, off );
1832
- by = _mm256_sub_epi8( by, off );
1833
-
1834
- // Sign-extend 16 signed bytes into int16_t
1835
- __m512i x32 = _mm512_cvtepi8_epi16( bx );
1836
- __m512i y32 = _mm512_cvtepi8_epi16( by );
1837
- // Compute products of int16_t integers, add pairwise
1838
- __m512i i64 = _mm512_madd_epi16( x32, y32 );
2088
+ // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
+ // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
+ // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
+ // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
+ const __mmask8 load_mask = 0x1f;
2093
+ const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
+ const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
+
2096
+ // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
+ // blocks_0_float
2101
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
+ // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
+ // blocks_1_float
2105
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
+ // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
+ const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
+ const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
+ // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
+ // random data, which might very well underflow. At least on Intel, this leads
2112
+ // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
+ // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
+ // (and ggml can't assume that you do)...
2115
+ const __mmask16 scale_mul_mask = 0x21;
2116
+ #ifdef __clang__
2117
+ // ...however, clang decides to optimize the multiplication mask away:
2118
+ // https://godbolt.org/z/P8PqdsfvW
2119
+ // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
+ __m512i scales;
2121
+ __asm__(
2122
+ "vmulps %1, %2, %0%{%3%}"
2123
+ : "=v" ( scales )
2124
+ : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
+ );
2126
+ #else
2127
+ const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
+ #endif
2129
+ const __m512i scale_perm = _mm512_set_epi32(
2130
+ 5, 5, 5, 5, 5, 5, 5, 5,
2131
+ 0, 0, 0, 0, 0, 0, 0, 0
2132
+ );
2133
+ const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
+ // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
+ // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
+
2141
+ const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
+ const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
+
2144
+ // Now we want to compute dot products of 4-element byte vectors and store them in
2145
+ // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
+ // +----+----+----+----+
2147
+ // ... | 03 | 02 | 01 | 00 |
2148
+ // +----+----+----+----+
2149
+ // bytes_0
2150
+ // +----+----+----+----+
2151
+ // ... | D | C | B | A |
2152
+ // +----+----+----+----+
2153
+ // bytes_1
2154
+ // +----+----+----+----+
2155
+ // ... | H | G | F | E |
2156
+ // +----+----+----+----+
2157
+ // final_res_int
2158
+ // +----+----+----+----+
2159
+ // ... | A*E+B*F+C*G+D*H |
2160
+ // +----+----+----+----+
2161
+ const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
+ const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
+
2164
+ #ifdef __AVX512VNNI__
2165
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
+ // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
+ // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
+ // which means we only need 2 instructions.
2170
+ const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
+ const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
+ const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
+ const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
+ #else
2175
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
+ const __m512i one = _mm512_set1_epi16( 1 );
2180
+ const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
+ const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
+ const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
+ const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
+ #endif
1839
2185
 
1840
- // Convert int32_t to float
1841
- __m512 p = _mm512_cvtepi32_ps( i64 );
1842
- // Apply the scale, and accumulate
1843
- return _mm512_fmadd_ps( d, p, acc );
2186
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
+ const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
+ return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
1844
2189
  }
1845
2190
  #endif
1846
2191
 
@@ -1881,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
1881
2226
  }
1882
2227
 
1883
2228
  static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1884
- const int nb = n / QK;
2229
+ const int nb = n / QK4_0;
1885
2230
 
1886
- assert(n % QK == 0);
2231
+ assert(n % QK4_0 == 0);
1887
2232
  assert(nb % 2 == 0);
1888
2233
 
1889
2234
  const block_q4_0 * restrict x = vx;
@@ -1972,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2317
  __m512 acc0 = _mm512_setzero_ps();
1973
2318
  __m512 acc1 = _mm512_setzero_ps();
1974
2319
 
1975
- const int superblock_size = 8;
2320
+ const int superblock_size = 16;
2321
+
1976
2322
  const int superblock_count = nb / superblock_size;
1977
2323
 
1978
2324
  for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1979
2325
  int i = superblock_ix * superblock_size;
1980
2326
 
1981
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
1982
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
1983
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
1984
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
1985
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
1986
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
1987
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
1988
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
2327
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
1989
2335
  }
1990
2336
 
1991
2337
  // Remainders
1992
- for (int i = superblock_count * superblock_size; i < nb; ++i) {
1993
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
2338
+ for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
1994
2340
  }
1995
2341
 
1996
2342
  // Horizontal sum of all lanes of the accumulator
@@ -2206,15 +2552,15 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2206
2552
  const uint8_t * restrict p1 = y[i].qs;
2207
2553
 
2208
2554
  int sumi = 0;
2209
- for (int j = 0; j < QK/2; j++) {
2555
+ for (int j = 0; j < QK4_0/2; j++) {
2210
2556
  const uint8_t v0 = p0[j];
2211
2557
  const uint8_t v1 = p1[j];
2212
2558
 
2213
- const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
- const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2559
+ const int i0 = (v0 & 0xf) - 8;
2560
+ const int i1 = (v0 >> 4) - 8;
2215
2561
 
2216
- const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
- const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2562
+ const int i2 = (v1 & 0xf) - 8;
2563
+ const int i3 = (v1 >> 4) - 8;
2218
2564
 
2219
2565
  sumi += i0*i2 + i1*i3;
2220
2566
  }
@@ -2226,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2226
2572
  }
2227
2573
 
2228
2574
  static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK;
2575
+ const int nb = n / QK4_1;
2230
2576
 
2231
2577
  const block_q4_1 * restrict x = vx;
2232
2578
  const block_q4_1 * restrict y = vy;
@@ -2303,7 +2649,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2303
2649
  res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2304
2650
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2305
2651
 
2306
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2652
+ sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2307
2653
  #elif defined(__ARM_NEON)
2308
2654
  float sum00 = 0.0f;
2309
2655
  float sum01 = 0.0f;
@@ -2335,12 +2681,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2335
2681
  const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2336
2682
 
2337
2683
  sum00 += x0->m*y0->m;
2338
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2684
+ sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
+ sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2340
2686
 
2341
2687
  sum00 += x1->m*y1->m;
2342
- sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
- sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2688
+ sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
+ sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2344
2690
 
2345
2691
  #if defined(__ARM_FEATURE_DOTPROD)
2346
2692
  // dot product into int32x4_t
@@ -2377,7 +2723,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2377
2723
  #endif
2378
2724
  }
2379
2725
 
2380
- sumf = QK*sum00 + sum01 + sum10 + sum11;
2726
+ sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2381
2727
  #else
2382
2728
  // scalar
2383
2729
  for (int i = 0; i < nb; i++) {
@@ -2390,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2390
2736
  const uint8_t * restrict p0 = x[i].qs;
2391
2737
  const uint8_t * restrict p1 = y[i].qs;
2392
2738
 
2393
- for (int j = 0; j < QK/2; j++) {
2739
+ for (int j = 0; j < QK4_1/2; j++) {
2394
2740
  const uint8_t v0 = p0[j];
2395
2741
  const uint8_t v1 = p1[j];
2396
2742
 
@@ -2408,78 +2754,281 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2408
2754
  *s = sumf;
2409
2755
  }
2410
2756
 
2411
- // compute GGML_VEC_DOT_UNROLL dot products at once
2412
- // xs - x row stride in bytes
2413
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2414
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2415
-
2416
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2757
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
+ const int nb = n / QK8_0;
2417
2759
 
2418
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2419
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2420
- }
2760
+ assert(n % QK8_0 == 0);
2761
+ assert(nb % 2 == 0);
2421
2762
 
2422
- #if defined(GGML_SIMD)
2423
- const int np = (n & ~(GGML_F16_STEP - 1));
2763
+ const block_q4_0 * restrict x = vx;
2764
+ const block_q8_0 * restrict y = vy;
2424
2765
 
2425
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2766
+ float sumf = 0.0;
2426
2767
 
2427
- GGML_F16_VEC ax[GGML_F16_ARR];
2428
- GGML_F16_VEC ay[GGML_F16_ARR];
2768
+ #if defined(__ARM_NEON)
2769
+ float sum0 = 0.0f;
2770
+ float sum1 = 0.0f;
2429
2771
 
2430
- for (int i = 0; i < np; i += GGML_F16_STEP) {
2431
- for (int j = 0; j < GGML_F16_ARR; j++) {
2432
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2772
+ for (int i = 0; i < nb; i += 2) {
2773
+ const block_q4_0 * restrict x0 = &x[i + 0];
2774
+ const block_q4_0 * restrict x1 = &x[i + 1];
2775
+ const block_q8_0 * restrict y0 = &y[i + 0];
2776
+ const block_q8_0 * restrict y1 = &y[i + 1];
2433
2777
 
2434
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2435
- ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
2778
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2436
2780
 
2437
- sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
2438
- }
2439
- }
2440
- }
2781
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2441
2783
 
2442
- // reduce sum0..sum3 to sum0
2443
- for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2444
- GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
2445
- }
2784
+ // 4-bit -> 8-bit
2785
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2786
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2787
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2788
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2446
2789
 
2447
- // leftovers
2448
- for (int i = np; i < n; ++i) {
2449
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
2450
- sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
2451
- }
2452
- }
2453
- #else
2454
- for (int i = 0; i < n; ++i) {
2455
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
2456
- sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
2457
- }
2458
- }
2459
- #endif
2790
+ // sub 8
2791
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2792
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2793
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2460
2795
 
2461
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2462
- s[i] = sumf[i];
2463
- }
2464
- }
2796
+ // load y
2797
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2465
2801
 
2466
- inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
2467
- #if defined(GGML_SIMD)
2468
- const int np = (n & ~(GGML_F32_STEP - 1));
2802
+ // interleave
2803
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2469
2807
 
2470
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
2808
+ #if defined(__ARM_FEATURE_DOTPROD)
2809
+ // dot product into int32x4_t
2810
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2471
2812
 
2472
- GGML_F32_VEC ax[GGML_F32_ARR];
2473
- GGML_F32_VEC ay[GGML_F32_ARR];
2813
+ p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
+ p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2474
2815
 
2475
- for (int i = 0; i < np; i += GGML_F32_STEP) {
2476
- for (int j = 0; j < GGML_F32_ARR; j++) {
2477
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
2478
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
2479
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
2816
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2818
+ #else
2819
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2480
2823
 
2481
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
2482
- }
2824
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
+
2829
+ const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
+ const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
+
2832
+ const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
+ const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
+
2835
+ const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
+ const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
+
2838
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2840
+ #endif
2841
+ }
2842
+
2843
+ sumf = sum0 + sum1;
2844
+ #elif defined(__AVX2__)
2845
+ // Initialize accumulator with zeros
2846
+ __m256 acc = _mm256_setzero_ps();
2847
+
2848
+ // Main loop
2849
+ for (int i = 0; i < nb; ++i) {
2850
+ /* Compute combined scale for the block */
2851
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2852
+
2853
+ __m256i bx = bytesFromNibbles(x[i].qs);
2854
+
2855
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
+ const __m256i off = _mm256_set1_epi8( 8 );
2857
+ bx = _mm256_sub_epi8( bx, off );
2858
+
2859
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
+
2861
+ // Get absolute values of x vectors
2862
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
+
2864
+ // Sign the values of the y vectors
2865
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2866
+
2867
+ // Perform multiplication and create 16-bit values
2868
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
+
2870
+ const __m256i ones = _mm256_set1_epi16(1);
2871
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
+
2873
+ /* Convert to vectore of 8 int32_t to 8 floats */
2874
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2875
+
2876
+ /* Multiply q with scale and accumulate */
2877
+ acc = _mm256_fmadd_ps( d, q, acc );
2878
+ }
2879
+
2880
+ // Return horizontal sum of the acc vector
2881
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2885
+
2886
+ sumf = _mm_cvtss_f32( res );
2887
+ #elif defined(__AVX__)
2888
+ // Initialize accumulator with zeros
2889
+ __m256 acc = _mm256_setzero_ps();
2890
+
2891
+ // Main loop
2892
+ for (int i = 0; i < nb; ++i) {
2893
+ // Compute combined scale for the block
2894
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2895
+
2896
+ __m128i i32[2];
2897
+ for (int j = 0; j < 2; ++j) {
2898
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2901
+
2902
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
+ const __m128i off = _mm_set1_epi8( 8 );
2904
+ bx = _mm_sub_epi8( bx, off );
2905
+
2906
+ // Get absolute values of x vectors
2907
+ const __m128i ax = _mm_sign_epi8(bx, bx);
2908
+
2909
+ // Sign the values of the y vectors
2910
+ const __m128i sy = _mm_sign_epi8(by, bx);
2911
+
2912
+ // Perform multiplication and create 16-bit values
2913
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
2914
+
2915
+ const __m128i ones = _mm_set1_epi16(1);
2916
+ i32[j] = _mm_madd_epi16(ones, dot);
2917
+ }
2918
+
2919
+ // Convert int32_t to float
2920
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
+ // Apply the scale, and accumulate
2922
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2923
+ }
2924
+
2925
+ // Return horizontal sum of the acc vector
2926
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
+
2931
+ sumf = _mm_cvtss_f32( res );
2932
+ #else
2933
+ // scalar
2934
+ for (int i = 0; i < nb; i++) {
2935
+ const float d0 = x[i].d;
2936
+ const float d1 = y[i].d;
2937
+
2938
+ const uint8_t * restrict p0 = x[i].qs;
2939
+ const int8_t * restrict p1 = y[i].qs;
2940
+
2941
+ int sumi = 0;
2942
+ for (int j = 0; j < QK8_0/2; j++) {
2943
+ const uint8_t v0 = p0[j];
2944
+
2945
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2947
+
2948
+ const int i2 = p1[2*j + 0];
2949
+ const int i3 = p1[2*j + 1];
2950
+
2951
+ sumi += i0*i2 + i1*i3;
2952
+ }
2953
+ sumf += d0*d1*sumi;
2954
+ }
2955
+ #endif
2956
+
2957
+ *s = sumf;
2958
+ }
2959
+
2960
+ // compute GGML_VEC_DOT_UNROLL dot products at once
2961
+ // xs - x row stride in bytes
2962
+ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2963
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2964
+
2965
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2966
+
2967
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2968
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2969
+ }
2970
+
2971
+ #if defined(GGML_SIMD)
2972
+ const int np = (n & ~(GGML_F16_STEP - 1));
2973
+
2974
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2975
+
2976
+ GGML_F16_VEC ax[GGML_F16_ARR];
2977
+ GGML_F16_VEC ay[GGML_F16_ARR];
2978
+
2979
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
2980
+ for (int j = 0; j < GGML_F16_ARR; j++) {
2981
+ ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2982
+
2983
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2984
+ ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
2985
+
2986
+ sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
2987
+ }
2988
+ }
2989
+ }
2990
+
2991
+ // reduce sum0..sum3 to sum0
2992
+ for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
2993
+ GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
2994
+ }
2995
+
2996
+ // leftovers
2997
+ for (int i = np; i < n; ++i) {
2998
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
2999
+ sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
3000
+ }
3001
+ }
3002
+ #else
3003
+ for (int i = 0; i < n; ++i) {
3004
+ for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
3005
+ sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
3006
+ }
3007
+ }
3008
+ #endif
3009
+
3010
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
3011
+ s[i] = sumf[i];
3012
+ }
3013
+ }
3014
+
3015
+ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
3016
+ #if defined(GGML_SIMD)
3017
+ const int np = (n & ~(GGML_F32_STEP - 1));
3018
+
3019
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
3020
+
3021
+ GGML_F32_VEC ax[GGML_F32_ARR];
3022
+ GGML_F32_VEC ay[GGML_F32_ARR];
3023
+
3024
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
3025
+ for (int j = 0; j < GGML_F32_ARR; j++) {
3026
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
3027
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
3028
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
3029
+
3030
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
3031
+ }
2483
3032
  }
2484
3033
 
2485
3034
  // leftovers
@@ -2652,24 +3201,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2652
3201
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2653
3202
  [GGML_TYPE_F32] = 1,
2654
3203
  [GGML_TYPE_F16] = 1,
2655
- [GGML_TYPE_Q4_0] = QK,
2656
- [GGML_TYPE_Q4_1] = QK,
3204
+ [GGML_TYPE_Q4_0] = QK4_0,
3205
+ [GGML_TYPE_Q4_1] = QK4_1,
3206
+ [GGML_TYPE_Q8_0] = QK8_0,
2657
3207
  [GGML_TYPE_I8] = 1,
2658
3208
  [GGML_TYPE_I16] = 1,
2659
3209
  [GGML_TYPE_I32] = 1,
2660
3210
  };
2661
- static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
3211
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
2662
3212
 
2663
3213
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2664
3214
  [GGML_TYPE_F32] = sizeof(float),
2665
3215
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
3216
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
3217
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3218
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
2668
3219
  [GGML_TYPE_I8] = sizeof(int8_t),
2669
3220
  [GGML_TYPE_I16] = sizeof(int16_t),
2670
3221
  [GGML_TYPE_I32] = sizeof(int32_t),
2671
3222
  };
2672
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
3223
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
2673
3224
 
2674
3225
 
2675
3226
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2677,11 +3228,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2677
3228
  [GGML_TYPE_F16] = "f16",
2678
3229
  [GGML_TYPE_Q4_0] = "q4_0",
2679
3230
  [GGML_TYPE_Q4_1] = "q4_1",
3231
+ [GGML_TYPE_Q8_0] = "q8_0",
2680
3232
  [GGML_TYPE_I8] = "i8",
2681
3233
  [GGML_TYPE_I16] = "i16",
2682
3234
  [GGML_TYPE_I32] = "i32",
2683
3235
  };
2684
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
3236
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
2685
3237
 
2686
3238
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2687
3239
  "NONE",
@@ -3354,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3354
3906
  char * const data = tensor->data;
3355
3907
 
3356
3908
  switch (tensor->type) {
3357
- case GGML_TYPE_Q4_0:
3358
- {
3359
- GGML_ASSERT(false);
3360
- } break;
3361
- case GGML_TYPE_Q4_1:
3362
- {
3363
- GGML_ASSERT(false);
3364
- } break;
3365
3909
  case GGML_TYPE_I8:
3366
3910
  {
3367
3911
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3397,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3397
3941
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3398
3942
  }
3399
3943
  } break;
3400
- case GGML_TYPE_COUNT:
3944
+ default:
3401
3945
  {
3402
3946
  GGML_ASSERT(false);
3403
3947
  } break;
@@ -3414,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3414
3958
  char * const data = tensor->data;
3415
3959
 
3416
3960
  switch (tensor->type) {
3417
- case GGML_TYPE_Q4_0:
3418
- {
3419
- GGML_ASSERT(false);
3420
- } break;
3421
- case GGML_TYPE_Q4_1:
3422
- {
3423
- GGML_ASSERT(false);
3424
- } break;
3425
3961
  case GGML_TYPE_I8:
3426
3962
  {
3427
3963
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3457,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3457
3993
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3458
3994
  }
3459
3995
  } break;
3460
- case GGML_TYPE_COUNT:
3996
+ default:
3461
3997
  {
3462
3998
  GGML_ASSERT(false);
3463
3999
  } break;
@@ -3468,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3468
4004
 
3469
4005
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3470
4006
  switch (tensor->type) {
3471
- case GGML_TYPE_Q4_0:
3472
- {
3473
- GGML_ASSERT(false);
3474
- } break;
3475
- case GGML_TYPE_Q4_1:
3476
- {
3477
- GGML_ASSERT(false);
3478
- } break;
3479
4007
  case GGML_TYPE_I8:
3480
4008
  {
3481
4009
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3501,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3501
4029
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3502
4030
  return ((float *)(tensor->data))[i];
3503
4031
  } break;
3504
- case GGML_TYPE_COUNT:
4032
+ default:
3505
4033
  {
3506
4034
  GGML_ASSERT(false);
3507
4035
  } break;
@@ -3512,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3512
4040
 
3513
4041
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3514
4042
  switch (tensor->type) {
3515
- case GGML_TYPE_Q4_0:
3516
- {
3517
- GGML_ASSERT(false);
3518
- } break;
3519
- case GGML_TYPE_Q4_1:
3520
- {
3521
- GGML_ASSERT(false);
3522
- } break;
3523
4043
  case GGML_TYPE_I8:
3524
4044
  {
3525
4045
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3545,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3545
4065
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3546
4066
  ((float *)(tensor->data))[i] = value;
3547
4067
  } break;
3548
- case GGML_TYPE_COUNT:
4068
+ default:
3549
4069
  {
3550
4070
  GGML_ASSERT(false);
3551
4071
  } break;
@@ -3554,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3554
4074
 
3555
4075
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3556
4076
  switch (tensor->type) {
3557
- case GGML_TYPE_Q4_0:
3558
- {
3559
- GGML_ASSERT(false);
3560
- } break;
3561
- case GGML_TYPE_Q4_1:
3562
- {
3563
- GGML_ASSERT(false);
3564
- } break;
3565
4077
  case GGML_TYPE_I8:
3566
4078
  {
3567
4079
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3587,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3587
4099
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3588
4100
  return ((float *)(tensor->data))[i];
3589
4101
  } break;
3590
- case GGML_TYPE_COUNT:
4102
+ default:
3591
4103
  {
3592
4104
  GGML_ASSERT(false);
3593
4105
  } break;
@@ -3598,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3598
4110
 
3599
4111
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3600
4112
  switch (tensor->type) {
3601
- case GGML_TYPE_Q4_0:
3602
- {
3603
- GGML_ASSERT(false);
3604
- } break;
3605
- case GGML_TYPE_Q4_1:
3606
- {
3607
- GGML_ASSERT(false);
3608
- } break;
3609
4113
  case GGML_TYPE_I8:
3610
4114
  {
3611
4115
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3631,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3631
4135
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3632
4136
  ((float *)(tensor->data))[i] = value;
3633
4137
  } break;
3634
- case GGML_TYPE_COUNT:
4138
+ default:
3635
4139
  {
3636
4140
  GGML_ASSERT(false);
3637
4141
  } break;
@@ -5112,6 +5616,26 @@ static void ggml_compute_forward_dup_f16(
5112
5616
  }
5113
5617
  }
5114
5618
  }
5619
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5620
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5621
+ size_t id = 0;
5622
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
+ float * src0_f32 = (float *) params->wdata;
5625
+
5626
+ for (int i03 = 0; i03 < ne03; i03++) {
5627
+ for (int i02 = 0; i02 < ne02; i02++) {
5628
+ for (int i01 = 0; i01 < ne01; i01++) {
5629
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
+ // convert to f32 and quantize
5631
+ for (int i00 = 0; i00 < ne00; i00++) {
5632
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
+ }
5634
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
+ id += dst_row_size;
5636
+ }
5637
+ }
5638
+ }
5115
5639
  } else {
5116
5640
  GGML_ASSERT(false); // TODO: implement
5117
5641
  }
@@ -5304,6 +5828,21 @@ static void ggml_compute_forward_dup_f32(
5304
5828
  }
5305
5829
  }
5306
5830
  }
5831
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5832
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5833
+ size_t id = 0;
5834
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5836
+
5837
+ for (int i03 = 0; i03 < ne03; i03++) {
5838
+ for (int i02 = 0; i02 < ne02; i02++) {
5839
+ for (int i01 = 0; i01 < ne01; i01++) {
5840
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
+ id += dst_row_size;
5843
+ }
5844
+ }
5845
+ }
5307
5846
  } else {
5308
5847
  GGML_ASSERT(false); // TODO: implement
5309
5848
  }
@@ -5426,12 +5965,7 @@ static void ggml_compute_forward_dup(
5426
5965
  {
5427
5966
  ggml_compute_forward_dup_f32(params, src0, dst);
5428
5967
  } break;
5429
- case GGML_TYPE_Q4_0:
5430
- case GGML_TYPE_Q4_1:
5431
- case GGML_TYPE_I8:
5432
- case GGML_TYPE_I16:
5433
- case GGML_TYPE_I32:
5434
- case GGML_TYPE_COUNT:
5968
+ default:
5435
5969
  {
5436
5970
  GGML_ASSERT(false);
5437
5971
  } break;
@@ -5463,37 +5997,243 @@ static void ggml_compute_forward_add_f32(
5463
5997
  const size_t nb10 = src1->nb[0];
5464
5998
  const size_t nb11 = src1->nb[1];
5465
5999
 
5466
- const size_t nb0 = dst->nb[0];
5467
- const size_t nb1 = dst->nb[1];
6000
+ const size_t nb0 = dst->nb[0];
6001
+ const size_t nb1 = dst->nb[1];
6002
+
6003
+ GGML_ASSERT( nb0 == sizeof(float));
6004
+ GGML_ASSERT(nb00 == sizeof(float));
6005
+
6006
+ if (nb10 == sizeof(float)) {
6007
+ for (int j = ith; j < n; j += nth) {
6008
+ #ifdef GGML_USE_ACCELERATE
6009
+ vDSP_vadd(
6010
+ (float *) ((char *) src0->data + j*nb01), 1,
6011
+ (float *) ((char *) src1->data + j*nb11), 1,
6012
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
6013
+ #else
6014
+ ggml_vec_add_f32(nc,
6015
+ (float *) ((char *) dst->data + j*nb1),
6016
+ (float *) ((char *) src0->data + j*nb01),
6017
+ (float *) ((char *) src1->data + j*nb11));
6018
+ #endif
6019
+ }
6020
+ } else {
6021
+ // src1 is not contiguous
6022
+ for (int j = ith; j < n; j += nth) {
6023
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
6024
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
6025
+ for (int i = 0; i < nc; i++) {
6026
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6027
+
6028
+ dst_ptr[i] = src0_ptr[i] + *src1_ptr;
6029
+ }
6030
+ }
6031
+ }
6032
+ }
6033
+
6034
+ static void ggml_compute_forward_add_f16_f32(
6035
+ const struct ggml_compute_params * params,
6036
+ const struct ggml_tensor * src0,
6037
+ const struct ggml_tensor * src1,
6038
+ struct ggml_tensor * dst) {
6039
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6040
+
6041
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6042
+ return;
6043
+ }
6044
+
6045
+ const int ith = params->ith;
6046
+ const int nth = params->nth;
6047
+
6048
+ const int n = ggml_nrows(src0);
6049
+ const int nc = src0->ne[0];
6050
+
6051
+ const size_t nb00 = src0->nb[0];
6052
+ const size_t nb01 = src0->nb[1];
6053
+
6054
+ const size_t nb10 = src1->nb[0];
6055
+ const size_t nb11 = src1->nb[1];
6056
+
6057
+ const size_t nb0 = dst->nb[0];
6058
+ const size_t nb1 = dst->nb[1];
6059
+
6060
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6061
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6062
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6063
+
6064
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6065
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6066
+
6067
+ if (nb10 == sizeof(float)) {
6068
+ for (int j = ith; j < n; j += nth) {
6069
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6070
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6071
+ for (int i = 0; i < nc; i++) {
6072
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6073
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
6074
+ }
6075
+ }
6076
+ }
6077
+ else {
6078
+ // src1 is not contiguous
6079
+ GGML_ASSERT(false);
6080
+ }
6081
+ }
6082
+
6083
+ static void ggml_compute_forward_add_f16_f16(
6084
+ const struct ggml_compute_params * params,
6085
+ const struct ggml_tensor * src0,
6086
+ const struct ggml_tensor * src1,
6087
+ struct ggml_tensor * dst) {
6088
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6089
+
6090
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6091
+ return;
6092
+ }
6093
+
6094
+ const int ith = params->ith;
6095
+ const int nth = params->nth;
6096
+
6097
+ const int n = ggml_nrows(src0);
6098
+ const int nc = src0->ne[0];
6099
+
6100
+ const size_t nb00 = src0->nb[0];
6101
+ const size_t nb01 = src0->nb[1];
6102
+
6103
+ const size_t nb10 = src1->nb[0];
6104
+ const size_t nb11 = src1->nb[1];
6105
+
6106
+ const size_t nb0 = dst->nb[0];
6107
+ const size_t nb1 = dst->nb[1];
6108
+
6109
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6110
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
6111
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6112
+
6113
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6114
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6115
+
6116
+ if (nb10 == sizeof(ggml_fp16_t)) {
6117
+ for (int j = ith; j < n; j += nth) {
6118
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6119
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6120
+ for (int i = 0; i < nc; i++) {
6121
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
6122
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
6123
+ }
6124
+ }
6125
+ }
6126
+ else {
6127
+ // src1 is not contiguous
6128
+ GGML_ASSERT(false);
6129
+ }
6130
+ }
6131
+
6132
+ static void ggml_compute_forward_add_q_f32(
6133
+ const struct ggml_compute_params * params,
6134
+ const struct ggml_tensor * src0,
6135
+ const struct ggml_tensor * src1,
6136
+ struct ggml_tensor * dst) {
6137
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6138
+
6139
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6140
+ return;
6141
+ }
6142
+
6143
+ const int64_t ne00 = src0->ne[0];
6144
+ const int64_t ne01 = src0->ne[1];
6145
+ const int64_t ne02 = src0->ne[2];
6146
+ const int64_t ne03 = src0->ne[3];
6147
+
6148
+ //const int64_t ne10 = src1->ne[0];
6149
+ //const int64_t ne11 = src1->ne[1];
6150
+ const int64_t ne12 = src1->ne[2];
6151
+ const int64_t ne13 = src1->ne[3];
6152
+
6153
+ //const int64_t ne0 = dst->ne[0];
6154
+ //const int64_t ne1 = dst->ne[1];
6155
+ const int64_t ne2 = dst->ne[2];
6156
+ const int64_t ne3 = dst->ne[3];
6157
+
6158
+ const int nb00 = src0->nb[0];
6159
+ const int nb01 = src0->nb[1];
6160
+ const int nb02 = src0->nb[2];
6161
+ const int nb03 = src0->nb[3];
6162
+
6163
+ const int nb10 = src1->nb[0];
6164
+ const int nb11 = src1->nb[1];
6165
+ const int nb12 = src1->nb[2];
6166
+ const int nb13 = src1->nb[3];
6167
+
6168
+ const int nb0 = dst->nb[0];
6169
+ const int nb1 = dst->nb[1];
6170
+ const int nb2 = dst->nb[2];
6171
+ const int nb3 = dst->nb[3];
6172
+
6173
+ const int ith = params->ith;
6174
+ const int nth = params->nth;
6175
+
6176
+ GGML_ASSERT(ne02 == ne12);
6177
+ GGML_ASSERT(ne03 == ne13);
6178
+ GGML_ASSERT(ne2 == ne12);
6179
+ GGML_ASSERT(ne3 == ne13);
6180
+
6181
+ const enum ggml_type type = src0->type;
6182
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6183
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6184
+
6185
+ // we don't support permuted src0 or src1
6186
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
6187
+ GGML_ASSERT(nb10 == sizeof(float));
6188
+
6189
+ // dst cannot be transposed or permuted
6190
+ GGML_ASSERT(nb0 <= nb1);
6191
+ GGML_ASSERT(nb1 <= nb2);
6192
+ GGML_ASSERT(nb2 <= nb3);
6193
+
6194
+ GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
6195
+ GGML_ASSERT(dst->type == src0->type);
6196
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
+
6198
+ // total rows in src0
6199
+ const int nr = ne01*ne02*ne03;
6200
+
6201
+ // rows per thread
6202
+ const int dr = (nr + nth - 1)/nth;
6203
+
6204
+ // row range for this thread
6205
+ const int ir0 = dr*ith;
6206
+ const int ir1 = MIN(ir0 + dr, nr);
6207
+
6208
+ float * wdata = (float*) params->wdata + ne00 * ith;
6209
+
6210
+ for (int ir = ir0; ir < ir1; ++ir) {
6211
+ // src0 indices
6212
+ const int i03 = ir/(ne02*ne01);
6213
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
6214
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6215
+
6216
+ // src1 and dst are same shape as src0 => same indices
6217
+ const int i13 = i03;
6218
+ const int i12 = i02;
6219
+ const int i11 = i01;
5468
6220
 
5469
- GGML_ASSERT( nb0 == sizeof(float));
5470
- GGML_ASSERT(nb00 == sizeof(float));
6221
+ const int i3 = i03;
6222
+ const int i2 = i02;
6223
+ const int i1 = i01;
5471
6224
 
5472
- if (nb10 == sizeof(float)) {
5473
- for (int j = ith; j < n; j += nth) {
5474
- #ifdef GGML_USE_ACCELERATE
5475
- vDSP_vadd(
5476
- (float *) ((char *) src0->data + j*nb01), 1,
5477
- (float *) ((char *) src1->data + j*nb11), 1,
5478
- (float *) ((char *) dst->data + j*nb1), 1, nc);
5479
- #else
5480
- ggml_vec_add_f32(nc,
5481
- (float *) ((char *) dst->data + j*nb1),
5482
- (float *) ((char *) src0->data + j*nb01),
5483
- (float *) ((char *) src1->data + j*nb11));
5484
- #endif
5485
- }
5486
- } else {
5487
- // src1 is not contiguous
5488
- for (int j = ith; j < n; j += nth) {
5489
- float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
5490
- float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5491
- for (int i = 0; i < nc; i++) {
5492
- float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6225
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
6226
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
6227
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
5493
6228
 
5494
- dst_ptr[i] = src0_ptr[i] + *src1_ptr;
5495
- }
5496
- }
6229
+ assert(ne00 % 32 == 0);
6230
+
6231
+ // unquantize row from src0 to temp buffer
6232
+ dequantize_row_q(src0_row, wdata, ne00);
6233
+ // add src1
6234
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
6235
+ // quantize row to dst
6236
+ quantize_row_q(wdata, dst_row, ne00);
5497
6237
  }
5498
6238
  }
5499
6239
 
@@ -5507,13 +6247,24 @@ static void ggml_compute_forward_add(
5507
6247
  {
5508
6248
  ggml_compute_forward_add_f32(params, src0, src1, dst);
5509
6249
  } break;
6250
+ case GGML_TYPE_F16:
6251
+ {
6252
+ if (src1->type == GGML_TYPE_F16) {
6253
+ ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6254
+ }
6255
+ else if (src1->type == GGML_TYPE_F32) {
6256
+ ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6257
+ }
6258
+ else {
6259
+ GGML_ASSERT(false);
6260
+ }
6261
+ } break;
5510
6262
  case GGML_TYPE_Q4_0:
5511
6263
  case GGML_TYPE_Q4_1:
5512
- case GGML_TYPE_I8:
5513
- case GGML_TYPE_I16:
5514
- case GGML_TYPE_I32:
5515
- case GGML_TYPE_F16:
5516
- case GGML_TYPE_COUNT:
6264
+ {
6265
+ ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
+ } break;
6267
+ default:
5517
6268
  {
5518
6269
  GGML_ASSERT(false);
5519
6270
  } break;
@@ -5559,13 +6310,7 @@ static void ggml_compute_forward_sub(
5559
6310
  {
5560
6311
  ggml_compute_forward_sub_f32(params, src0, src1, dst);
5561
6312
  } break;
5562
- case GGML_TYPE_Q4_0:
5563
- case GGML_TYPE_Q4_1:
5564
- case GGML_TYPE_I8:
5565
- case GGML_TYPE_I16:
5566
- case GGML_TYPE_I32:
5567
- case GGML_TYPE_F16:
5568
- case GGML_TYPE_COUNT:
6313
+ default:
5569
6314
  {
5570
6315
  GGML_ASSERT(false);
5571
6316
  } break;
@@ -5611,13 +6356,7 @@ static void ggml_compute_forward_mul(
5611
6356
  {
5612
6357
  ggml_compute_forward_mul_f32(params, src0, src1, dst);
5613
6358
  } break;
5614
- case GGML_TYPE_Q4_0:
5615
- case GGML_TYPE_Q4_1:
5616
- case GGML_TYPE_I8:
5617
- case GGML_TYPE_I16:
5618
- case GGML_TYPE_I32:
5619
- case GGML_TYPE_F16:
5620
- case GGML_TYPE_COUNT:
6359
+ default:
5621
6360
  {
5622
6361
  GGML_ASSERT(false);
5623
6362
  } break;
@@ -5663,13 +6402,7 @@ static void ggml_compute_forward_div(
5663
6402
  {
5664
6403
  ggml_compute_forward_div_f32(params, src0, src1, dst);
5665
6404
  } break;
5666
- case GGML_TYPE_Q4_0:
5667
- case GGML_TYPE_Q4_1:
5668
- case GGML_TYPE_I8:
5669
- case GGML_TYPE_I16:
5670
- case GGML_TYPE_I32:
5671
- case GGML_TYPE_F16:
5672
- case GGML_TYPE_COUNT:
6405
+ default:
5673
6406
  {
5674
6407
  GGML_ASSERT(false);
5675
6408
  } break;
@@ -5711,13 +6444,7 @@ static void ggml_compute_forward_sqr(
5711
6444
  {
5712
6445
  ggml_compute_forward_sqr_f32(params, src0, dst);
5713
6446
  } break;
5714
- case GGML_TYPE_Q4_0:
5715
- case GGML_TYPE_Q4_1:
5716
- case GGML_TYPE_I8:
5717
- case GGML_TYPE_I16:
5718
- case GGML_TYPE_I32:
5719
- case GGML_TYPE_F16:
5720
- case GGML_TYPE_COUNT:
6447
+ default:
5721
6448
  {
5722
6449
  GGML_ASSERT(false);
5723
6450
  } break;
@@ -5759,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
5759
6486
  {
5760
6487
  ggml_compute_forward_sqrt_f32(params, src0, dst);
5761
6488
  } break;
5762
- case GGML_TYPE_Q4_0:
5763
- case GGML_TYPE_Q4_1:
5764
- case GGML_TYPE_I8:
5765
- case GGML_TYPE_I16:
5766
- case GGML_TYPE_I32:
5767
- case GGML_TYPE_F16:
5768
- case GGML_TYPE_COUNT:
6489
+ default:
5769
6490
  {
5770
6491
  GGML_ASSERT(false);
5771
6492
  } break;
@@ -5817,13 +6538,7 @@ static void ggml_compute_forward_sum(
5817
6538
  {
5818
6539
  ggml_compute_forward_sum_f32(params, src0, dst);
5819
6540
  } break;
5820
- case GGML_TYPE_Q4_0:
5821
- case GGML_TYPE_Q4_1:
5822
- case GGML_TYPE_I8:
5823
- case GGML_TYPE_I16:
5824
- case GGML_TYPE_I32:
5825
- case GGML_TYPE_F16:
5826
- case GGML_TYPE_COUNT:
6541
+ default:
5827
6542
  {
5828
6543
  GGML_ASSERT(false);
5829
6544
  } break;
@@ -5894,13 +6609,7 @@ static void ggml_compute_forward_mean(
5894
6609
  {
5895
6610
  ggml_compute_forward_mean_f32(params, src0, dst);
5896
6611
  } break;
5897
- case GGML_TYPE_Q4_0:
5898
- case GGML_TYPE_Q4_1:
5899
- case GGML_TYPE_I8:
5900
- case GGML_TYPE_I16:
5901
- case GGML_TYPE_I32:
5902
- case GGML_TYPE_F16:
5903
- case GGML_TYPE_COUNT:
6612
+ default:
5904
6613
  {
5905
6614
  GGML_ASSERT(false);
5906
6615
  } break;
@@ -5958,13 +6667,7 @@ static void ggml_compute_forward_repeat(
5958
6667
  {
5959
6668
  ggml_compute_forward_repeat_f32(params, src0, dst);
5960
6669
  } break;
5961
- case GGML_TYPE_Q4_0:
5962
- case GGML_TYPE_Q4_1:
5963
- case GGML_TYPE_I8:
5964
- case GGML_TYPE_I16:
5965
- case GGML_TYPE_I32:
5966
- case GGML_TYPE_F16:
5967
- case GGML_TYPE_COUNT:
6670
+ default:
5968
6671
  {
5969
6672
  GGML_ASSERT(false);
5970
6673
  } break;
@@ -6006,13 +6709,7 @@ static void ggml_compute_forward_abs(
6006
6709
  {
6007
6710
  ggml_compute_forward_abs_f32(params, src0, dst);
6008
6711
  } break;
6009
- case GGML_TYPE_Q4_0:
6010
- case GGML_TYPE_Q4_1:
6011
- case GGML_TYPE_I8:
6012
- case GGML_TYPE_I16:
6013
- case GGML_TYPE_I32:
6014
- case GGML_TYPE_F16:
6015
- case GGML_TYPE_COUNT:
6712
+ default:
6016
6713
  {
6017
6714
  GGML_ASSERT(false);
6018
6715
  } break;
@@ -6054,13 +6751,7 @@ static void ggml_compute_forward_sgn(
6054
6751
  {
6055
6752
  ggml_compute_forward_sgn_f32(params, src0, dst);
6056
6753
  } break;
6057
- case GGML_TYPE_Q4_0:
6058
- case GGML_TYPE_Q4_1:
6059
- case GGML_TYPE_I8:
6060
- case GGML_TYPE_I16:
6061
- case GGML_TYPE_I32:
6062
- case GGML_TYPE_F16:
6063
- case GGML_TYPE_COUNT:
6754
+ default:
6064
6755
  {
6065
6756
  GGML_ASSERT(false);
6066
6757
  } break;
@@ -6102,13 +6793,7 @@ static void ggml_compute_forward_neg(
6102
6793
  {
6103
6794
  ggml_compute_forward_neg_f32(params, src0, dst);
6104
6795
  } break;
6105
- case GGML_TYPE_Q4_0:
6106
- case GGML_TYPE_Q4_1:
6107
- case GGML_TYPE_I8:
6108
- case GGML_TYPE_I16:
6109
- case GGML_TYPE_I32:
6110
- case GGML_TYPE_F16:
6111
- case GGML_TYPE_COUNT:
6796
+ default:
6112
6797
  {
6113
6798
  GGML_ASSERT(false);
6114
6799
  } break;
@@ -6150,13 +6835,7 @@ static void ggml_compute_forward_step(
6150
6835
  {
6151
6836
  ggml_compute_forward_step_f32(params, src0, dst);
6152
6837
  } break;
6153
- case GGML_TYPE_Q4_0:
6154
- case GGML_TYPE_Q4_1:
6155
- case GGML_TYPE_I8:
6156
- case GGML_TYPE_I16:
6157
- case GGML_TYPE_I32:
6158
- case GGML_TYPE_F16:
6159
- case GGML_TYPE_COUNT:
6838
+ default:
6160
6839
  {
6161
6840
  GGML_ASSERT(false);
6162
6841
  } break;
@@ -6198,13 +6877,7 @@ static void ggml_compute_forward_relu(
6198
6877
  {
6199
6878
  ggml_compute_forward_relu_f32(params, src0, dst);
6200
6879
  } break;
6201
- case GGML_TYPE_Q4_0:
6202
- case GGML_TYPE_Q4_1:
6203
- case GGML_TYPE_I8:
6204
- case GGML_TYPE_I16:
6205
- case GGML_TYPE_I32:
6206
- case GGML_TYPE_F16:
6207
- case GGML_TYPE_COUNT:
6880
+ default:
6208
6881
  {
6209
6882
  GGML_ASSERT(false);
6210
6883
  } break;
@@ -6263,13 +6936,7 @@ static void ggml_compute_forward_gelu(
6263
6936
  {
6264
6937
  ggml_compute_forward_gelu_f32(params, src0, dst);
6265
6938
  } break;
6266
- case GGML_TYPE_Q4_0:
6267
- case GGML_TYPE_Q4_1:
6268
- case GGML_TYPE_I8:
6269
- case GGML_TYPE_I16:
6270
- case GGML_TYPE_I32:
6271
- case GGML_TYPE_F16:
6272
- case GGML_TYPE_COUNT:
6939
+ default:
6273
6940
  {
6274
6941
  GGML_ASSERT(false);
6275
6942
  } break;
@@ -6330,13 +6997,7 @@ static void ggml_compute_forward_silu(
6330
6997
  {
6331
6998
  ggml_compute_forward_silu_f32(params, src0, dst);
6332
6999
  } break;
6333
- case GGML_TYPE_Q4_0:
6334
- case GGML_TYPE_Q4_1:
6335
- case GGML_TYPE_I8:
6336
- case GGML_TYPE_I16:
6337
- case GGML_TYPE_I32:
6338
- case GGML_TYPE_F16:
6339
- case GGML_TYPE_COUNT:
7000
+ default:
6340
7001
  {
6341
7002
  GGML_ASSERT(false);
6342
7003
  } break;
@@ -6416,13 +7077,7 @@ static void ggml_compute_forward_norm(
6416
7077
  {
6417
7078
  ggml_compute_forward_norm_f32(params, src0, dst);
6418
7079
  } break;
6419
- case GGML_TYPE_Q4_0:
6420
- case GGML_TYPE_Q4_1:
6421
- case GGML_TYPE_I8:
6422
- case GGML_TYPE_I16:
6423
- case GGML_TYPE_I32:
6424
- case GGML_TYPE_F16:
6425
- case GGML_TYPE_COUNT:
7080
+ default:
6426
7081
  {
6427
7082
  GGML_ASSERT(false);
6428
7083
  } break;
@@ -6496,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
6496
7151
  {
6497
7152
  ggml_compute_forward_rms_norm_f32(params, src0, dst);
6498
7153
  } break;
6499
- case GGML_TYPE_Q4_0:
6500
- case GGML_TYPE_Q4_1:
6501
- case GGML_TYPE_I8:
6502
- case GGML_TYPE_I16:
6503
- case GGML_TYPE_I32:
6504
- case GGML_TYPE_F16:
6505
- case GGML_TYPE_COUNT:
7154
+ default:
6506
7155
  {
6507
7156
  GGML_ASSERT(false);
6508
7157
  } break;
@@ -6894,27 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6894
7543
  //}
6895
7544
  }
6896
7545
 
6897
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6898
- [GGML_TYPE_Q4_0] = {
6899
- .dequantize_row_q = dequantize_row_q4_0,
6900
- .quantize_row_q = quantize_row_q4_0,
6901
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
- .vec_dot_q = ggml_vec_dot_q4_0,
6903
- },
6904
- [GGML_TYPE_Q4_1] = {
6905
- .dequantize_row_q = dequantize_row_q4_1,
6906
- .quantize_row_q = quantize_row_q4_1,
6907
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
- .vec_dot_q = ggml_vec_dot_q4_1,
6909
- },
6910
- };
6911
-
6912
- // For internal test use
6913
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
- GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
- return quantize_fns[i];
6916
- }
6917
-
6918
7546
  static void ggml_compute_forward_mul_mat_q_f32(
6919
7547
  const struct ggml_compute_params * params,
6920
7548
  const struct ggml_tensor * src0,
@@ -6962,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
6962
7590
  GGML_ASSERT(ne3 == ne13);
6963
7591
 
6964
7592
  const enum ggml_type type = src0->type;
6965
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6966
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
7593
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7594
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
6967
7595
 
6968
7596
  // we don't support permuted src0 or src1
6969
7597
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7032,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
7032
7660
 
7033
7661
  if (params->type == GGML_TASK_INIT) {
7034
7662
  char * wdata = params->wdata;
7035
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7663
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7036
7664
 
7037
7665
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7038
7666
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7039
7667
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
7040
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7668
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7041
7669
  wdata += row_size;
7042
7670
  }
7043
7671
  }
@@ -7063,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7063
7691
  const int ir1 = MIN(ir0 + dr, nr);
7064
7692
 
7065
7693
  void * wdata = params->wdata;
7066
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7694
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7067
7695
 
7068
7696
  for (int ir = ir0; ir < ir1; ++ir) {
7069
7697
  // src0 indices
@@ -7111,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
7111
7739
  switch (src0->type) {
7112
7740
  case GGML_TYPE_Q4_0:
7113
7741
  case GGML_TYPE_Q4_1:
7742
+ case GGML_TYPE_Q8_0:
7114
7743
  {
7115
7744
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
7116
7745
  } break;
@@ -7122,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
7122
7751
  {
7123
7752
  ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
7124
7753
  } break;
7125
- case GGML_TYPE_I8:
7126
- case GGML_TYPE_I16:
7127
- case GGML_TYPE_I32:
7128
- case GGML_TYPE_COUNT:
7754
+ default:
7129
7755
  {
7130
7756
  GGML_ASSERT(false);
7131
7757
  } break;
@@ -7207,13 +7833,7 @@ static void ggml_compute_forward_scale(
7207
7833
  {
7208
7834
  ggml_compute_forward_scale_f32(params, src0, src1, dst);
7209
7835
  } break;
7210
- case GGML_TYPE_Q4_0:
7211
- case GGML_TYPE_Q4_1:
7212
- case GGML_TYPE_I8:
7213
- case GGML_TYPE_I16:
7214
- case GGML_TYPE_I32:
7215
- case GGML_TYPE_F16:
7216
- case GGML_TYPE_COUNT:
7836
+ default:
7217
7837
  {
7218
7838
  GGML_ASSERT(false);
7219
7839
  } break;
@@ -7374,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
7374
7994
  switch (src0->type) {
7375
7995
  case GGML_TYPE_Q4_0:
7376
7996
  case GGML_TYPE_Q4_1:
7997
+ case GGML_TYPE_Q8_0:
7377
7998
  {
7378
7999
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
7379
8000
  } break;
@@ -7385,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
7385
8006
  {
7386
8007
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
7387
8008
  } break;
7388
- case GGML_TYPE_I8:
7389
- case GGML_TYPE_I16:
7390
- case GGML_TYPE_I32:
7391
- case GGML_TYPE_COUNT:
8009
+ default:
7392
8010
  {
7393
8011
  GGML_ASSERT(false);
7394
8012
  } break;
@@ -7461,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
7461
8079
  {
7462
8080
  ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
7463
8081
  } break;
7464
- case GGML_TYPE_Q4_0:
7465
- case GGML_TYPE_Q4_1:
7466
- case GGML_TYPE_I8:
7467
- case GGML_TYPE_I16:
7468
- case GGML_TYPE_I32:
7469
- case GGML_TYPE_F16:
7470
- case GGML_TYPE_COUNT:
8082
+ default:
7471
8083
  {
7472
8084
  GGML_ASSERT(false);
7473
8085
  } break;
@@ -7555,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
7555
8167
  {
7556
8168
  ggml_compute_forward_soft_max_f32(params, src0, dst);
7557
8169
  } break;
7558
- case GGML_TYPE_Q4_0:
7559
- case GGML_TYPE_Q4_1:
7560
- case GGML_TYPE_I8:
7561
- case GGML_TYPE_I16:
7562
- case GGML_TYPE_I32:
7563
- case GGML_TYPE_F16:
7564
- case GGML_TYPE_COUNT:
8170
+ default:
7565
8171
  {
7566
8172
  GGML_ASSERT(false);
7567
8173
  } break;
@@ -7713,11 +8319,11 @@ static void ggml_compute_forward_rope_f16(
7713
8319
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7714
8320
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7715
8321
 
7716
- const float x0 = ggml_fp16_to_fp32(src[0]);
7717
- const float x1 = ggml_fp16_to_fp32(src[1]);
8322
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
7718
8324
 
7719
- dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
7720
- dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
8325
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
7721
8327
  }
7722
8328
  }
7723
8329
  }
@@ -7738,12 +8344,7 @@ static void ggml_compute_forward_rope(
7738
8344
  {
7739
8345
  ggml_compute_forward_rope_f32(params, src0, src1, dst);
7740
8346
  } break;
7741
- case GGML_TYPE_Q4_0:
7742
- case GGML_TYPE_Q4_1:
7743
- case GGML_TYPE_I8:
7744
- case GGML_TYPE_I16:
7745
- case GGML_TYPE_I32:
7746
- case GGML_TYPE_COUNT:
8347
+ default:
7747
8348
  {
7748
8349
  GGML_ASSERT(false);
7749
8350
  } break;
@@ -8006,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
8006
8607
  {
8007
8608
  ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
8008
8609
  } break;
8009
- case GGML_TYPE_Q4_0:
8010
- case GGML_TYPE_Q4_1:
8011
- case GGML_TYPE_I8:
8012
- case GGML_TYPE_I16:
8013
- case GGML_TYPE_I32:
8014
- case GGML_TYPE_COUNT:
8610
+ default:
8015
8611
  {
8016
8612
  GGML_ASSERT(false);
8017
8613
  } break;
@@ -8274,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
8274
8870
  {
8275
8871
  ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
8276
8872
  } break;
8277
- case GGML_TYPE_Q4_0:
8278
- case GGML_TYPE_Q4_1:
8279
- case GGML_TYPE_I8:
8280
- case GGML_TYPE_I16:
8281
- case GGML_TYPE_I32:
8282
- case GGML_TYPE_COUNT:
8873
+ default:
8283
8874
  {
8284
8875
  GGML_ASSERT(false);
8285
8876
  } break;
@@ -8759,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
8759
9350
  {
8760
9351
  ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
8761
9352
  } break;
8762
- case GGML_TYPE_Q4_0:
8763
- case GGML_TYPE_Q4_1:
8764
- case GGML_TYPE_I8:
8765
- case GGML_TYPE_I16:
8766
- case GGML_TYPE_I32:
8767
- case GGML_TYPE_COUNT:
9353
+ default:
8768
9354
  {
8769
9355
  GGML_ASSERT(false);
8770
9356
  } break;
@@ -8970,12 +9556,7 @@ static void ggml_compute_forward_flash_ff(
8970
9556
  {
8971
9557
  GGML_ASSERT(false); // TODO
8972
9558
  } break;
8973
- case GGML_TYPE_Q4_0:
8974
- case GGML_TYPE_Q4_1:
8975
- case GGML_TYPE_I8:
8976
- case GGML_TYPE_I16:
8977
- case GGML_TYPE_I32:
8978
- case GGML_TYPE_COUNT:
9559
+ default:
8979
9560
  {
8980
9561
  GGML_ASSERT(false);
8981
9562
  } break;
@@ -9019,13 +9600,7 @@ static void ggml_compute_forward_map_unary(
9019
9600
  {
9020
9601
  ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
9602
  } break;
9022
- case GGML_TYPE_Q4_0:
9023
- case GGML_TYPE_Q4_1:
9024
- case GGML_TYPE_I8:
9025
- case GGML_TYPE_I16:
9026
- case GGML_TYPE_I32:
9027
- case GGML_TYPE_F16:
9028
- case GGML_TYPE_COUNT:
9603
+ default:
9029
9604
  {
9030
9605
  GGML_ASSERT(false);
9031
9606
  } break;
@@ -9074,13 +9649,7 @@ static void ggml_compute_forward_map_binary(
9074
9649
  {
9075
9650
  ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
9651
  } break;
9077
- case GGML_TYPE_Q4_0:
9078
- case GGML_TYPE_Q4_1:
9079
- case GGML_TYPE_I8:
9080
- case GGML_TYPE_I16:
9081
- case GGML_TYPE_I32:
9082
- case GGML_TYPE_F16:
9083
- case GGML_TYPE_COUNT:
9652
+ default:
9084
9653
  {
9085
9654
  GGML_ASSERT(false);
9086
9655
  } break;
@@ -9830,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9830
10399
  struct ggml_tensor * node = cgraph->nodes[i];
9831
10400
 
9832
10401
  switch (node->op) {
10402
+ case GGML_OP_CPY:
9833
10403
  case GGML_OP_DUP:
9834
10404
  {
9835
10405
  node->n_tasks = 1;
10406
+
10407
+ size_t cur = 0;
10408
+ if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
10410
+ }
10411
+
10412
+ work_size = MAX(work_size, cur);
9836
10413
  } break;
9837
10414
  case GGML_OP_ADD:
9838
10415
  {
9839
10416
  node->n_tasks = n_threads;
10417
+
10418
+ size_t cur = 0;
10419
+
10420
+ if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
10421
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
+ }
10423
+
10424
+ work_size = MAX(work_size, cur);
9840
10425
  } break;
9841
10426
  case GGML_OP_SUB:
9842
10427
  case GGML_OP_MUL:
@@ -9905,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9905
10490
  } else
9906
10491
  #endif
9907
10492
  {
9908
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
10493
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
9909
10494
  }
9910
10495
  } else {
9911
10496
  GGML_ASSERT(false);
@@ -9917,7 +10502,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9917
10502
  {
9918
10503
  node->n_tasks = n_threads;
9919
10504
  } break;
9920
- case GGML_OP_CPY:
9921
10505
  case GGML_OP_CONT:
9922
10506
  case GGML_OP_RESHAPE:
9923
10507
  case GGML_OP_VIEW:
@@ -11080,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
11080
11664
  ////////////////////////////////////////////////////////////////////////////////
11081
11665
 
11082
11666
  size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
11083
- assert(k % QK == 0);
11084
- const int nb = k / QK;
11667
+ assert(k % QK4_0 == 0);
11668
+ const int nb = k / QK4_0;
11085
11669
 
11086
11670
  for (int j = 0; j < n; j += k) {
11087
- block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
11671
+ block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
11088
11672
 
11089
11673
  quantize_row_q4_0_reference(src + j, y, k);
11090
11674
 
11091
11675
  for (int i = 0; i < nb; i++) {
11092
- for (int l = 0; l < QK; l += 2) {
11676
+ for (int l = 0; l < QK4_0; l += 2) {
11093
11677
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11094
11678
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11095
11679
 
@@ -11099,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
11099
11683
  }
11100
11684
  }
11101
11685
 
11102
- return (n/QK*sizeof(block_q4_0));
11686
+ return (n/QK4_0*sizeof(block_q4_0));
11103
11687
  }
11104
11688
 
11105
11689
  size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
11106
- assert(k % QK == 0);
11107
- const int nb = k / QK;
11690
+ assert(k % QK4_1 == 0);
11691
+ const int nb = k / QK4_1;
11108
11692
 
11109
11693
  for (int j = 0; j < n; j += k) {
11110
- block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
11694
+ block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
11111
11695
 
11112
11696
  quantize_row_q4_1_reference(src + j, y, k);
11113
11697
 
11114
11698
  for (int i = 0; i < nb; i++) {
11115
- for (int l = 0; l < QK; l += 2) {
11699
+ for (int l = 0; l < QK4_1; l += 2) {
11116
11700
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11117
11701
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11118
11702
 
@@ -11122,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11122
11706
  }
11123
11707
  }
11124
11708
 
11125
- return (n/QK*sizeof(block_q4_1));
11709
+ return (n/QK4_1*sizeof(block_q4_1));
11126
11710
  }
11127
11711
 
11128
11712
  ////////////////////////////////////////////////////////////////////////////////
@@ -11151,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
11151
11735
  #endif
11152
11736
  }
11153
11737
 
11738
+ int ggml_cpu_has_avx512_vbmi(void) {
11739
+ #if defined(__AVX512VBMI__)
11740
+ return 1;
11741
+ #else
11742
+ return 0;
11743
+ #endif
11744
+ }
11745
+
11746
+ int ggml_cpu_has_avx512_vnni(void) {
11747
+ #if defined(__AVX512VNNI__)
11748
+ return 1;
11749
+ #else
11750
+ return 0;
11751
+ #endif
11752
+ }
11753
+
11154
11754
  int ggml_cpu_has_fma(void) {
11155
11755
  #if defined(__FMA__)
11156
11756
  return 1;