llama_cpp 0.0.5 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@
19
19
  #include <inttypes.h>
20
20
  #include <stdio.h>
21
21
  #include <float.h>
22
+ #include <limits.h>
22
23
 
23
24
  // if C99 - static_assert is noop
24
25
  // ref: https://stackoverflow.com/a/53923785/4039976
@@ -142,10 +143,49 @@ inline static void* ggml_aligned_malloc(size_t size) {
142
143
  } \
143
144
  } while (0)
144
145
 
145
- #ifdef GGML_USE_ACCELERATE
146
+ #if defined(GGML_USE_ACCELERATE)
146
147
  #include <Accelerate/Accelerate.h>
147
- #elif GGML_USE_OPENBLAS
148
+ #elif defined(GGML_USE_OPENBLAS)
148
149
  #include <cblas.h>
150
+ #elif defined(GGML_USE_CUBLAS)
151
+ #include <cublas_v2.h>
152
+ #include <cuda_runtime.h>
153
+ #include "ggml-cuda.h"
154
+
155
+ #define CUDA_CHECK(err) \
156
+ do { \
157
+ cudaError_t err_ = (err); \
158
+ if (err_ != cudaSuccess) { \
159
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
+ cudaGetErrorString(err_)); \
161
+ exit(1); \
162
+ } \
163
+ } while (0)
164
+
165
+ #define CUBLAS_CHECK(err) \
166
+ do { \
167
+ cublasStatus_t err_ = (err); \
168
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
+ exit(1); \
171
+ } \
172
+ } while (0)
173
+
174
+ static cublasHandle_t cublasH = NULL;
175
+ static cudaStream_t cudaStream = NULL;
176
+ static void init_cublas(void) {
177
+ if (cublasH == NULL) {
178
+ // create cublas handle, bind a stream
179
+ CUBLAS_CHECK(cublasCreate(&cublasH));
180
+
181
+ CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
+
183
+ CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
+
185
+ // configure logging to stdout
186
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
+ }
188
+ }
149
189
  #endif
150
190
 
151
191
  #undef MIN
@@ -427,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427
467
  // quantization
428
468
  //
429
469
 
430
- // AVX routines provided by GH user Const-me
431
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
470
+ #if __AVX__ || __AVX2__ || __AVX512F__
471
+ // Unpack 16 4-bit fields into 16 bytes
472
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
473
+ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
+ {
475
+ // Load 8 bytes from memory
476
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477
+
478
+ // Expand bytes into uint16_t values
479
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
480
+
481
+ // Unpack values into individual bytes
482
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
483
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
484
+ __m128i low = _mm_and_si128( lowMask, bytes );
485
+ high = _mm_slli_epi16( high, 4 );
486
+ bytes = _mm_or_si128( low, high );
487
+ return bytes;
488
+ }
489
+
432
490
  #if __AVX2__ || __AVX512F__
433
491
  // Unpack 32 4-bit fields into 32 bytes
434
492
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435
- static inline __m256i bytesFromNibbles( const uint8_t* rsi )
493
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
436
494
  {
437
495
  // Load 16 bytes from memory
438
496
  __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -463,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
463
521
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
464
522
  return _mm_packus_epi16( r0, r1 );
465
523
  }
466
- #elif __AVX__
467
- static inline __m128i bytesFromNibbles( const uint8_t* rsi )
468
- {
469
- // Load 8 bytes from memory
470
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
471
-
472
- // Expand bytes into uint16_t values
473
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
474
-
475
- // Unpack values into individual bytes
476
- const __m128i lowMask = _mm_set1_epi8( 0xF );
477
- __m128i high = _mm_andnot_si128( lowMask, bytes );
478
- __m128i low = _mm_and_si128( lowMask, bytes );
479
- high = _mm_slli_epi16( high, 4 );
480
- bytes = _mm_or_si128( low, high );
481
- return bytes;
482
- }
483
-
524
+ #else
484
525
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
485
526
  {
486
527
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -497,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
538
  return _mm_packus_epi16( bytes1, bytes2);
498
539
  }
499
540
  #endif
541
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
500
542
 
501
543
  #if __ARM_NEON
502
544
 
@@ -514,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
514
556
  (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
515
557
  }
516
558
 
559
+ inline static int16_t vaddvq_s8(int8x16_t v) {
560
+ return
561
+ (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
562
+ (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
563
+ (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
564
+ (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
565
+ (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
566
+ (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
567
+ (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
568
+ (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
569
+ }
570
+
517
571
  inline static int32_t vaddvq_s16(int16x8_t v) {
518
572
  return
519
573
  (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -583,7 +637,22 @@ typedef struct {
583
637
  float m; // min
584
638
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
585
639
  } block_q4_1;
586
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
640
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
641
+
642
+ #define QK4_2 16
643
+ typedef struct {
644
+ ggml_fp16_t d; // delta
645
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
646
+ } block_q4_2;
647
+ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
648
+
649
+ #define QK4_3 16
650
+ typedef struct {
651
+ ggml_fp16_t d; // delta
652
+ ggml_fp16_t m; // min
653
+ uint8_t qs[QK4_3 / 2]; // nibbles / quants
654
+ } block_q4_3;
655
+ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
587
656
 
588
657
  #define QK8_0 32
589
658
  typedef struct {
@@ -1045,6 +1114,173 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1045
1114
  #endif
1046
1115
  }
1047
1116
 
1117
+ // reference implementation for deterministic creation of model files
1118
+ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
1119
+ assert(k % QK4_2 == 0);
1120
+
1121
+ const int nb = k / QK4_2;
1122
+
1123
+ for (int i = 0; i < nb; i++) {
1124
+ float amax = 0.0f; // absolute max
1125
+
1126
+ for (int l = 0; l < QK4_2; l++) {
1127
+ const float v = x[i*QK4_2 + l];
1128
+ amax = MAX(amax, fabsf(v));
1129
+ }
1130
+
1131
+ const float d = amax / ((1 << 3) - 1);
1132
+
1133
+ const float id = d ? 1.0f/d : 0.0f;
1134
+
1135
+ y[i].d = GGML_FP32_TO_FP16(d);
1136
+
1137
+ for (int l = 0; l < QK4_2; l += 2) {
1138
+ const float v0 = x[i*QK4_2 + l + 0]*id;
1139
+ const float v1 = x[i*QK4_2 + l + 1]*id;
1140
+
1141
+ const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
+ const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1143
+
1144
+ assert(vi0 < 16);
1145
+ assert(vi1 < 16);
1146
+
1147
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1148
+ }
1149
+ }
1150
+ }
1151
+
1152
+ static inline int nearest_int(float fval) {
1153
+ assert(fval <= 4194303.f);
1154
+ float val = fval + 12582912.f;
1155
+ int i; memcpy(&i, &val, sizeof(int));
1156
+ return (i & 0x007fffff) - 0x00400000;
1157
+ }
1158
+
1159
+ static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
+ const float * restrict candidates, int8_t * restrict L) {
1161
+ assert (nmin >= INT8_MIN);
1162
+ assert (nmax <= INT8_MAX);
1163
+ float amax = 0;
1164
+ for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
+ if (!amax) { // all zero
1166
+ for (int i=0; i<n; ++i) L[i] = 0;
1167
+ return 1.f;
1168
+ }
1169
+ float best = 0, bestScale = 0;
1170
+ for (int si=0; si<nCandidates; ++si) {
1171
+ float iscale = candidates[si]/amax;
1172
+ float sumlxP = 0; int suml2P = 0;
1173
+ float sumlxM = 0; int suml2M = 0;
1174
+ for (int i=0; i<n; ++i) {
1175
+ int l = nearest_int(iscale*X[i]);
1176
+ int lp = MAX(nmin, MIN(nmax, +l));
1177
+ int lm = MAX(nmin, MIN(nmax, -l));
1178
+ sumlxP += X[i]*lp; suml2P += lp*lp;
1179
+ sumlxM += X[i]*lm; suml2M += lm*lm;
1180
+ }
1181
+ float sumlxP2 = sumlxP*sumlxP;
1182
+ float sumlxM2 = sumlxM*sumlxM;
1183
+ if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
+ if (sumlxP2 > best*suml2P) {
1185
+ best = sumlxP2/suml2P; bestScale = iscale;
1186
+ }
1187
+ } else {
1188
+ if (sumlxM2 > best*suml2M) {
1189
+ best = sumlxM2/suml2M; bestScale = -iscale;
1190
+ }
1191
+ }
1192
+ }
1193
+ float sumlx = 0; int suml2 = 0;
1194
+ for (int i=0; i<n; ++i) {
1195
+ int l = nearest_int(bestScale*X[i]);
1196
+ l = MAX(nmin, MIN(nmax, l));
1197
+ sumlx += X[i]*l; suml2 += l*l;
1198
+ L[i] = l;
1199
+ }
1200
+ float scale = sumlx/suml2;
1201
+ return scale;
1202
+ }
1203
+
1204
+ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
+ #define CANDIDATE_COUNT 8
1206
+ static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
+ assert(k % QK4_2 == 0);
1208
+
1209
+ int8_t L[QK4_2];
1210
+
1211
+ const int nb = k / QK4_2;
1212
+
1213
+ for (int i = 0; i < nb; i++) {
1214
+ float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
+ y[i].d = GGML_FP32_TO_FP16(scale);
1216
+
1217
+ for (int l = 0; l < QK4_2; l += 2) {
1218
+ const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
+ const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1220
+
1221
+ assert(vi0 < 16);
1222
+ assert(vi1 < 16);
1223
+
1224
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1225
+ }
1226
+
1227
+ x += QK4_2;
1228
+ }
1229
+ }
1230
+
1231
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1232
+ assert(k % QK4_2 == 0);
1233
+
1234
+ block_q4_2 * restrict y = vy;
1235
+
1236
+ //quantize_row_q4_2_reference(x, y, k);
1237
+ // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
+ quantize_row_q4_2_rmse(x, y, k);
1239
+ }
1240
+
1241
+ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1242
+ assert(k % QK4_3 == 0);
1243
+ const int nb = k / QK4_3;
1244
+
1245
+ for (int i = 0; i < nb; i++) {
1246
+ float min = FLT_MAX;
1247
+ float max = -FLT_MAX;
1248
+
1249
+ for (int l = 0; l < QK4_3; l++) {
1250
+ const float v = x[i*QK4_3 + l];
1251
+ if (v < min) min = v;
1252
+ if (v > max) max = v;
1253
+ }
1254
+
1255
+ const float d = (max - min) / ((1 << 4) - 1);
1256
+ const float id = d ? 1.0f/d : 0.0f;
1257
+
1258
+ y[i].d = GGML_FP32_TO_FP16(d);
1259
+ y[i].m = GGML_FP32_TO_FP16(min);
1260
+
1261
+ for (int l = 0; l < QK4_3; l += 2) {
1262
+ const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1263
+ const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1264
+
1265
+ const uint8_t vi0 = (int) (v0 + 0.5f);
1266
+ const uint8_t vi1 = (int) (v1 + 0.5f);
1267
+
1268
+ assert(vi0 < 16);
1269
+ assert(vi1 < 16);
1270
+
1271
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1272
+ }
1273
+ }
1274
+ }
1275
+
1276
+ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1277
+ assert(k % QK4_3 == 0);
1278
+
1279
+ block_q4_3 * restrict y = vy;
1280
+
1281
+ quantize_row_q4_3_reference(x, y, k);
1282
+ }
1283
+
1048
1284
  // reference implementation for deterministic creation of model files
1049
1285
  static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
1286
  assert(k % QK8_0 == 0);
@@ -1064,7 +1300,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1064
1300
  y[i].d = d;
1065
1301
 
1066
1302
  for (int l = 0; l < QK8_0; ++l) {
1067
- const float v = x[i*QK8_0 + l]*id;
1303
+ const float v = x[i*QK8_0 + l]*id;
1068
1304
  y[i].qs[l] = roundf(v);
1069
1305
  }
1070
1306
  }
@@ -1211,7 +1447,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1211
1447
 
1212
1448
  for (int l = 0; l < QK4_0; l += 32) {
1213
1449
  // Load 32x4-bit integers into 32x8-bit integers
1214
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1450
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1215
1451
 
1216
1452
  // Subtract 8 from the integers
1217
1453
  vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1329,7 +1565,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1329
1565
 
1330
1566
  for (int l = 0; l < QK4_1; l += 32) {
1331
1567
  // Load 32x4-bit integers into 32x8-bit integers
1332
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1568
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1333
1569
 
1334
1570
  // Convert to 16-bit int
1335
1571
  const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1420,8 +1656,69 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1420
1656
  #endif
1421
1657
  }
1422
1658
 
1423
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1659
+ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1660
+ assert(k % QK4_2 == 0);
1661
+ const int nb = k / QK4_2;
1662
+
1663
+ const block_q4_2 * restrict x = vx;
1664
+
1665
+ for (int i = 0; i < nb; i++) {
1666
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1667
+
1668
+ const uint8_t * restrict pp = x[i].qs;
1669
+
1670
+ for (int l = 0; l < QK4_2; l += 2) {
1671
+ const uint8_t vi = pp[l/2];
1672
+
1673
+ const int8_t vi0 = vi & 0xf;
1674
+ const int8_t vi1 = vi >> 4;
1675
+
1676
+ const float v0 = (vi0 - 8)*d;
1677
+ const float v1 = (vi1 - 8)*d;
1678
+
1679
+ y[i*QK4_2 + l + 0] = v0;
1680
+ y[i*QK4_2 + l + 1] = v1;
1681
+
1682
+ assert(!isnan(y[i*QK4_2 + l + 0]));
1683
+ assert(!isnan(y[i*QK4_2 + l + 1]));
1684
+ }
1685
+ }
1686
+ }
1687
+
1688
+ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1689
+ assert(k % QK4_3 == 0);
1690
+ const int nb = k / QK4_3;
1691
+
1692
+ const block_q4_3 * restrict x = vx;
1693
+
1694
+ for (int i = 0; i < nb; i++) {
1695
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1696
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1697
+
1698
+ const uint8_t * restrict pp = x[i].qs;
1699
+
1700
+ for (int l = 0; l < QK4_3; l += 2) {
1701
+ const uint8_t vi = pp[l/2];
1702
+
1703
+ const int8_t vi0 = vi & 0xf;
1704
+ const int8_t vi1 = vi >> 4;
1705
+
1706
+ const float v0 = vi0*d + m;
1707
+ const float v1 = vi1*d + m;
1708
+
1709
+ y[i*QK4_3 + l + 0] = v0;
1710
+ y[i*QK4_3 + l + 1] = v1;
1711
+
1712
+ assert(!isnan(y[i*QK4_3 + l + 0]));
1713
+ assert(!isnan(y[i*QK4_3 + l + 1]));
1714
+ }
1715
+ }
1716
+ }
1717
+
1424
1718
  static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
1722
 
1426
1723
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
1724
  [GGML_TYPE_Q4_0] = {
@@ -1435,10 +1732,30 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1435
1732
  .dequantize_row_q = dequantize_row_q4_1,
1436
1733
  .quantize_row_q = quantize_row_q4_1,
1437
1734
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
- .quantize_row_q_dot = quantize_row_q4_1,
1439
- .vec_dot_q = ggml_vec_dot_q4_1,
1735
+ .quantize_row_q_dot = quantize_row_q8_0,
1736
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
1737
+ },
1738
+ [GGML_TYPE_Q4_2] = {
1739
+ .dequantize_row_q = dequantize_row_q4_2,
1740
+ .quantize_row_q = quantize_row_q4_2,
1741
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
1742
+ .quantize_row_q_dot = quantize_row_q8_0,
1743
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
1744
+ },
1745
+ [GGML_TYPE_Q4_3] = {
1746
+ .dequantize_row_q = dequantize_row_q4_3,
1747
+ .quantize_row_q = quantize_row_q4_3,
1748
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
1749
+ .quantize_row_q_dot = quantize_row_q8_0,
1750
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
1751
+ },
1752
+ [GGML_TYPE_Q8_0] = {
1753
+ .dequantize_row_q = NULL, // TODO
1754
+ .quantize_row_q = quantize_row_q8_0,
1755
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
+ .quantize_row_q_dot = quantize_row_q8_0,
1757
+ .vec_dot_q = NULL, // TODO
1440
1758
  },
1441
- // TODO: GGML_TYPE_Q8_0
1442
1759
  };
1443
1760
 
1444
1761
  // For internal test use
@@ -2004,191 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
2004
2321
  *s = sumf;
2005
2322
  }
2006
2323
 
2007
- #if __AVX512F__ && QK4_0 == 32
2008
- static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
- // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
- // | :. =_ () [] <> () Zz Yy|
2014
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
- // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
- //
2020
- // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
- // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
- // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
- // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
- #ifdef __AVX512VBMI__
2025
- const __m512i byte_perm = _mm512_set_epi8(
2026
- 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
- 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
- 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
- 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
- );
2031
- const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
- // After applying VPERMB, `permuted` looks like this:
2033
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
- // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
- // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
- #else
2043
- const __m512i word_perm = _mm512_set_epi16(
2044
- 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
- 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
- );
2047
- const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
- // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
- // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
- // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
- #endif
2052
-
2053
- // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
- const __mmask32 shift_mask = 0xaaaaaaaa;
2055
- const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
- // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
- // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
- // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
-
2067
- // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
- const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
- return _mm512_and_si512( low_nibble_mask, shifted );
2070
- // The final result looks like this:
2071
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
- // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
- // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
- }
2081
-
2082
- static inline __m512 dot_q4_0_twoblocks_avx512(
2083
- __m512 acc,
2084
- const block_q4_0 * restrict x,
2085
- const block_q4_0 * restrict y,
2086
- int i
2087
- ) {
2088
- // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
- // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
- // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
- // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
- const __mmask8 load_mask = 0x1f;
2093
- const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
- const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
-
2096
- // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
- // blocks_0_float
2101
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
- // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
- // blocks_1_float
2105
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
- // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
- const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
- const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
- // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
- // random data, which might very well underflow. At least on Intel, this leads
2112
- // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
- // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
- // (and ggml can't assume that you do)...
2115
- const __mmask16 scale_mul_mask = 0x21;
2116
- #ifdef __clang__
2117
- // ...however, clang decides to optimize the multiplication mask away:
2118
- // https://godbolt.org/z/P8PqdsfvW
2119
- // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
- __m512i scales;
2121
- __asm__(
2122
- "vmulps %1, %2, %0%{%3%}"
2123
- : "=v" ( scales )
2124
- : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
- );
2126
- #else
2127
- const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
- #endif
2129
- const __m512i scale_perm = _mm512_set_epi32(
2130
- 5, 5, 5, 5, 5, 5, 5, 5,
2131
- 0, 0, 0, 0, 0, 0, 0, 0
2132
- );
2133
- const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
- // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
- // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
-
2141
- const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
- const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
-
2144
- // Now we want to compute dot products of 4-element byte vectors and store them in
2145
- // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
- // +----+----+----+----+
2147
- // ... | 03 | 02 | 01 | 00 |
2148
- // +----+----+----+----+
2149
- // bytes_0
2150
- // +----+----+----+----+
2151
- // ... | D | C | B | A |
2152
- // +----+----+----+----+
2153
- // bytes_1
2154
- // +----+----+----+----+
2155
- // ... | H | G | F | E |
2156
- // +----+----+----+----+
2157
- // final_res_int
2158
- // +----+----+----+----+
2159
- // ... | A*E+B*F+C*G+D*H |
2160
- // +----+----+----+----+
2161
- const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
- const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
-
2164
- #ifdef __AVX512VNNI__
2165
- // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
- // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
- // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
- // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
- // which means we only need 2 instructions.
2170
- const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
- const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
- const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
- const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
- #else
2175
- // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
- // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
- // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
- // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
- const __m512i one = _mm512_set1_epi16( 1 );
2180
- const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
- const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
- const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
- const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
- #endif
2185
-
2186
- // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
- const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
- return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
2189
- }
2190
- #endif
2191
-
2192
2324
  inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
2193
2325
  ggml_float sumf = 0.0;
2194
2326
 
@@ -2225,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2225
2357
  *s = sumf;
2226
2358
  }
2227
2359
 
2228
- static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK4_0;
2360
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2361
+ const int nb = n / QK8_0;
2230
2362
 
2231
- assert(n % QK4_0 == 0);
2363
+ assert(n % QK8_0 == 0);
2232
2364
  assert(nb % 2 == 0);
2233
2365
 
2234
2366
  const block_q4_0 * restrict x = vx;
2235
- const block_q4_0 * restrict y = vy;
2367
+ const block_q8_0 * restrict y = vy;
2236
2368
 
2237
2369
  float sumf = 0.0;
2238
2370
 
2239
2371
  #if defined(__ARM_NEON)
2240
- float sum0 = 0.0f;
2241
- float sum1 = 0.0f;
2372
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2242
2374
 
2243
2375
  for (int i = 0; i < nb; i += 2) {
2244
2376
  const block_q4_0 * restrict x0 = &x[i + 0];
2245
- const block_q4_0 * restrict y0 = &y[i + 0];
2246
2377
  const block_q4_0 * restrict x1 = &x[i + 1];
2247
- const block_q4_0 * restrict y1 = &y[i + 1];
2378
+ const block_q8_0 * restrict y0 = &y[i + 0];
2379
+ const block_q8_0 * restrict y1 = &y[i + 1];
2248
2380
 
2249
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2250
- const int8x16_t s8b = vdupq_n_s8(0x8);
2381
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2382
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2251
2383
 
2252
2384
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2253
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2254
2385
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2255
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2256
2386
 
2257
2387
  // 4-bit -> 8-bit
2258
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
2259
- const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
2388
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2260
2389
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2261
- const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
2262
-
2263
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
2264
- const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
2390
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2265
2391
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2266
- const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
2267
2392
 
2268
2393
  // sub 8
2269
2394
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2270
- const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
2271
2395
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2272
- const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
2273
-
2274
2396
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2275
- const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
2276
2397
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2277
- const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
2398
+
2399
+ // load y
2400
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2401
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2402
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2403
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2404
+
2405
+ // interleave
2406
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2407
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2408
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2409
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2278
2410
 
2279
2411
  #if defined(__ARM_FEATURE_DOTPROD)
2280
2412
  // dot product into int32x4_t
2281
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2282
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2413
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2283
2415
 
2284
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2285
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2286
-
2287
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2288
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2416
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2417
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2289
2418
  #else
2290
2419
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2291
2420
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2297,116 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2297
2426
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2298
2427
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2299
2428
 
2300
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2301
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2302
-
2303
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2304
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2429
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2430
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2431
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2432
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2305
2433
 
2306
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2307
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2308
-
2309
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2310
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2434
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2435
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2311
2436
  #endif
2312
2437
  }
2313
2438
 
2314
- sumf = sum0 + sum1;
2315
- #elif defined(__AVX512F__)
2439
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
+ #elif defined(__AVX2__)
2316
2441
  // Initialize accumulator with zeros
2317
- __m512 acc0 = _mm512_setzero_ps();
2318
- __m512 acc1 = _mm512_setzero_ps();
2442
+ __m256 acc = _mm256_setzero_ps();
2319
2443
 
2320
- const int superblock_size = 16;
2444
+ // Main loop
2445
+ for (int i = 0; i < nb; ++i) {
2446
+ /* Compute combined scale for the block */
2447
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2321
2448
 
2322
- const int superblock_count = nb / superblock_size;
2449
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2323
2450
 
2324
- for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
2325
- int i = superblock_ix * superblock_size;
2451
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2452
+ const __m256i off = _mm256_set1_epi8( 8 );
2453
+ bx = _mm256_sub_epi8( bx, off );
2326
2454
 
2327
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
2335
- }
2455
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2336
2456
 
2337
- // Remainders
2338
- for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
2340
- }
2457
+ // Get absolute values of x vectors
2458
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2341
2459
 
2342
- // Horizontal sum of all lanes of the accumulator
2343
- sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
2344
- #elif defined(__AVX2__)
2345
- // Initialize accumulator with zeros
2346
- __m256 acc = _mm256_setzero_ps();
2460
+ // Sign the values of the y vectors
2461
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2347
2462
 
2348
- /* Prepare the constants we will need during execution */
2349
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
2350
- const __m256i offset_8 = _mm256_set1_epi16( 8 );
2463
+ // Perform multiplication and create 16-bit values
2464
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2351
2465
 
2352
- #define UNROLL_COUNT 8
2353
- // make sure we only unroll multiples of the block count
2354
- assert(nb % UNROLL_COUNT == 0);
2466
+ const __m256i ones = _mm256_set1_epi16(1);
2467
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2355
2468
 
2356
- // Main loop
2357
- for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2358
- // This loop will be unrolled by the compiler
2359
- for (int u=0;u<UNROLL_COUNT;u++) {
2360
- /* Compute combined scale for the block */
2361
- const __m256 scale = _mm256_mul_ps(
2362
- _mm256_broadcast_ss( &x[i+u].d ),
2363
- _mm256_broadcast_ss( &y[i+u].d ) );
2364
-
2365
- /* get input from x
2366
- Input: 32 Nibbles (16 bytes) at *x[i+u]
2367
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2368
-
2369
- /* Load 16 bytes from memory */
2370
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2371
- /* Expand bytes into uint16_t values */
2372
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2373
- /* Unpack values into individual bytes */
2374
- __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2375
- const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2376
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2377
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2378
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2379
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2380
-
2381
- /* get input from y
2382
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2383
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2384
-
2385
- /* Load 16 bytes from memory */
2386
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2387
- /* Expand bytes into uint16_t values */
2388
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2389
- /* Unpack values into individual bytes */
2390
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2391
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2392
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2393
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2394
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2395
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2396
-
2397
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2398
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2399
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2400
-
2401
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2402
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2403
-
2404
- /* Convert to vectore of 8 int32_t to 8 floats */
2405
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2406
-
2407
- /* Multiply q with scale and accumulate */
2408
- acc = _mm256_fmadd_ps( scale, q, acc );
2409
- }
2469
+ /* Convert to vectore of 8 int32_t to 8 floats */
2470
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2471
+
2472
+ /* Multiply q with scale and accumulate */
2473
+ acc = _mm256_fmadd_ps( d, q, acc );
2410
2474
  }
2411
2475
 
2412
2476
  // Return horizontal sum of the acc vector
@@ -2428,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2428
2492
  __m128i i32[2];
2429
2493
  for (int j = 0; j < 2; ++j) {
2430
2494
  // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2431
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2432
- __m128i by = bytesFromNibbles( y[i].qs + 8*j );
2495
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2496
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2433
2497
 
2434
2498
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2435
2499
  const __m128i off = _mm_set1_epi8( 8 );
2436
2500
  bx = _mm_sub_epi8( bx, off );
2437
- by = _mm_sub_epi8( by, off );
2438
2501
 
2439
2502
  // Get absolute values of x vectors
2440
2503
  const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2462,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2462
2525
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2463
2526
 
2464
2527
  sumf = _mm_cvtss_f32( res );
2465
- #elif defined(__wasm_simd128__)
2466
- // wasm simd
2467
- float sum0 = 0.0f;
2468
- float sum1 = 0.0f;
2469
-
2470
- for (int i = 0; i < nb; i += 2) {
2471
- const block_q4_0 * restrict x0 = &x[i + 0];
2472
- const block_q4_0 * restrict y0 = &y[i + 0];
2473
- const block_q4_0 * restrict x1 = &x[i + 1];
2474
- const block_q4_0 * restrict y1 = &y[i + 1];
2475
-
2476
- const v128_t m4b = wasm_u8x16_splat(0xf);
2477
- const v128_t s8b = wasm_i8x16_splat(0x8);
2478
-
2479
- const v128_t v0_0 = wasm_v128_load(x0->qs);
2480
- const v128_t v0_1 = wasm_v128_load(y0->qs);
2481
- const v128_t v1_0 = wasm_v128_load(x1->qs);
2482
- const v128_t v1_1 = wasm_v128_load(y1->qs);
2483
-
2484
- // 4-bit -> 8-bit
2485
- const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2486
- const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
2487
-
2488
- const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2489
- const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
2490
-
2491
- const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2492
- const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
2493
-
2494
- const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2495
- const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
2496
-
2497
- // sub 8
2498
- const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2499
- const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
2500
-
2501
- const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2502
- const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
2503
-
2504
- const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2505
- const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
2506
-
2507
- const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2508
- const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
2509
-
2510
- // dot product into int16x8_t
2511
- const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
2512
- const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
2513
-
2514
- const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
2515
- const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
2516
-
2517
- const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
2518
- const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
2519
-
2520
- const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
2521
- const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
2522
-
2523
- const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
2524
- const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
2525
-
2526
- const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
2527
- const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
2528
-
2529
- const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
2530
- const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
2531
-
2532
- sum0 += x0->d * y0->d * (
2533
- wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
2534
- wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
2535
- wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
2536
- wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
2537
- sum1 += x1->d * y1->d * (
2538
- wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
2539
- wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
2540
- wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
2541
- wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
2542
- }
2543
-
2544
- sumf = sum0 + sum1;
2545
2528
  #else
2546
2529
  // scalar
2547
2530
  for (int i = 0; i < nb; i++) {
@@ -2549,202 +2532,187 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2549
2532
  const float d1 = y[i].d;
2550
2533
 
2551
2534
  const uint8_t * restrict p0 = x[i].qs;
2552
- const uint8_t * restrict p1 = y[i].qs;
2535
+ const int8_t * restrict p1 = y[i].qs;
2553
2536
 
2554
2537
  int sumi = 0;
2555
- for (int j = 0; j < QK4_0/2; j++) {
2538
+ for (int j = 0; j < QK8_0/2; j++) {
2556
2539
  const uint8_t v0 = p0[j];
2557
- const uint8_t v1 = p1[j];
2558
2540
 
2559
- const int i0 = (v0 & 0xf) - 8;
2560
- const int i1 = (v0 >> 4) - 8;
2541
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2561
2543
 
2562
- const int i2 = (v1 & 0xf) - 8;
2563
- const int i3 = (v1 >> 4) - 8;
2544
+ const int i2 = p1[2*j + 0];
2545
+ const int i3 = p1[2*j + 1];
2564
2546
 
2565
2547
  sumi += i0*i2 + i1*i3;
2566
2548
  }
2567
- sumf += d0 * d1 * sumi;
2549
+ sumf += d0*d1*sumi;
2568
2550
  }
2569
2551
  #endif
2570
2552
 
2571
2553
  *s = sumf;
2572
2554
  }
2573
2555
 
2574
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2575
- const int nb = n / QK4_1;
2556
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
+ const int nb = n / QK8_0;
2558
+
2559
+ assert(n % QK8_0 == 0);
2560
+ assert(nb % 2 == 0);
2576
2561
 
2577
2562
  const block_q4_1 * restrict x = vx;
2578
- const block_q4_1 * restrict y = vy;
2563
+ const block_q8_0 * restrict y = vy;
2579
2564
 
2580
2565
  float sumf = 0.0;
2581
2566
 
2582
- #if defined(__AVX2__)
2583
- // Initialize accumulator with zeros
2584
- __m256 acc = _mm256_setzero_ps();
2585
- // Accumulator for constant offsets
2586
- float acc_offset = 0.0f;
2587
-
2588
- // Main loop
2589
- for (int i = 0; i < nb; ++i) {
2590
- const float * d0 = &x[i].d;
2591
- const float * d1 = &y[i].d;
2592
-
2593
- const float * m0 = &x[i].m;
2594
- const float * m1 = &y[i].m;
2595
-
2596
- const __m256 d0v = _mm256_broadcast_ss( d0 );
2597
- const __m256 d1v = _mm256_broadcast_ss( d1 );
2598
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2599
- const __m256 m1v = _mm256_broadcast_ss( m1 );
2600
-
2601
- // Compute combined scale for the block
2602
- const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
2603
-
2604
- // Compute cross scales for the block
2605
- const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2606
- const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2607
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
2608
-
2609
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2610
- __m256i bx = bytesFromNibbles( x[i].qs );
2611
- __m256i by = bytesFromNibbles( y[i].qs );
2612
-
2613
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2614
-
2615
- // Sign-extend first 16 signed bytes into int16_t
2616
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2617
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2618
- // Compute products of int16_t integers, add pairwise
2619
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2620
-
2621
- // Sign-extend last 16 signed bytes into int16_t vectors
2622
- __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2623
- __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2624
- // Accumulate products of int16_t integers
2625
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2626
-
2627
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2628
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2629
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2630
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2631
- __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2632
- __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2633
- __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2634
- __m256 sums = _mm256_cvtepi32_ps( sumsi );
2635
-
2636
- // Convert int32_t to float
2637
- __m256 p = _mm256_cvtepi32_ps( i32 );
2638
- // Apply the scale, and accumulate
2639
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2640
- acc = _mm256_fmadd_ps( scale_01, p, acc );
2641
- acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2642
- // acc_offset += m0*m1 (for each entry in the block)
2643
- acc_offset += (*m0)*(*m1);
2644
- }
2645
-
2646
- // Return horizontal sum of the acc vector
2647
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2648
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2649
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2650
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2651
-
2652
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2653
- #elif defined(__ARM_NEON)
2654
- float sum00 = 0.0f;
2655
- float sum01 = 0.0f;
2656
- float sum10 = 0.0f;
2657
- float sum11 = 0.0f;
2567
+ // TODO: add AVX / WASM SIMD / etc
2568
+ #if defined(__ARM_NEON)
2569
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2658
2571
 
2659
2572
  for (int i = 0; i < nb; i += 2) {
2660
2573
  const block_q4_1 * restrict x0 = &x[i + 0];
2661
- const block_q4_1 * restrict y0 = &y[i + 0];
2662
2574
  const block_q4_1 * restrict x1 = &x[i + 1];
2663
- const block_q4_1 * restrict y1 = &y[i + 1];
2575
+ const block_q8_0 * restrict y0 = &y[i + 0];
2576
+ const block_q8_0 * restrict y1 = &y[i + 1];
2664
2577
 
2665
2578
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2666
2579
 
2667
2580
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2668
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2669
2581
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2670
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2671
2582
 
2672
2583
  // 4-bit -> 8-bit
2673
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2674
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2675
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2676
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2584
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2585
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2586
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
+
2589
+ // load y
2590
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2677
2594
 
2678
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2679
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2680
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2681
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2595
+ // interleave
2596
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2682
2600
 
2683
- sum00 += x0->m*y0->m;
2684
- sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
- sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2601
+ const int16x8_t s0i = vaddq_s16(
2602
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2686
2604
 
2687
- sum00 += x1->m*y1->m;
2688
- sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
- sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2605
+ const int16x8_t s1i = vaddq_s16(
2606
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
+
2609
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2690
2611
 
2691
2612
  #if defined(__ARM_FEATURE_DOTPROD)
2692
2613
  // dot product into int32x4_t
2693
- uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2694
- uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2695
-
2696
- p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2697
- p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2614
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2698
2616
 
2699
- sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2700
- sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2617
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2701
2619
  #else
2702
- const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2703
- const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2704
- const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2705
- const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2620
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2624
+
2625
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2629
+
2630
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2632
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2633
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2634
+
2635
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2636
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2637
+ #endif
2638
+ }
2706
2639
 
2707
- const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2708
- const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2709
- const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2710
- const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2640
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2641
+ #elif defined(__AVX2__)
2642
+ // Initialize accumulator with zeros
2643
+ __m256 acc = _mm256_setzero_ps();
2711
2644
 
2712
- const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2713
- const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2645
+ // Main loop
2646
+ for (int i = 0; i < nb; ++i) {
2647
+ const float * d0 = &x[i].d;
2648
+ const float * d1 = &y[i].d;
2649
+ const float * m0 = &x[i].m;
2650
+
2651
+ const __m256 d0v = _mm256_broadcast_ss( d0 );
2652
+ const __m256 d1v = _mm256_broadcast_ss( d1 );
2653
+ const __m256 m0v = _mm256_broadcast_ss( m0 );
2714
2654
 
2715
- const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2716
- const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2655
+ // Compute combined scales
2656
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
+ const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2717
2658
 
2718
- const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2719
- const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2659
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2720
2662
 
2721
- sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2722
- sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2723
- #endif
2663
+ // Get absolute values of x vectors
2664
+ const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
+
2666
+ // Sign the values of the y vectors
2667
+ const __m256i sy = _mm256_sign_epi8( by, bx );
2668
+
2669
+ // Perform multiplication and create 16-bit values
2670
+ const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
+ const __m256i ones = _mm256_set1_epi16( 1 );
2672
+ const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
+
2674
+ // Convert to vector of 8 int32_t to 8 floats
2675
+ const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2676
+
2677
+ // Accumulate d0*d1*x*y
2678
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
+
2680
+ // Compute sum of y values
2681
+ const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
+ const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
+ const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
+ const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
+
2686
+ // Accumulate d1*m0*y
2687
+ acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2724
2688
  }
2725
2689
 
2726
- sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2690
+ // Return horizontal sum of the acc vector
2691
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2692
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695
+
2696
+ sumf = _mm_cvtss_f32( res );
2727
2697
  #else
2728
2698
  // scalar
2729
2699
  for (int i = 0; i < nb; i++) {
2730
2700
  const float d0 = x[i].d;
2731
- const float d1 = y[i].d;
2732
-
2733
2701
  const float m0 = x[i].m;
2734
- const float m1 = y[i].m;
2702
+ const float d1 = y[i].d;
2735
2703
 
2736
2704
  const uint8_t * restrict p0 = x[i].qs;
2737
- const uint8_t * restrict p1 = y[i].qs;
2705
+ const int8_t * restrict p1 = y[i].qs;
2738
2706
 
2739
- for (int j = 0; j < QK4_1/2; j++) {
2707
+ // TODO: this is very slow ..
2708
+ for (int j = 0; j < QK8_0/2; j++) {
2740
2709
  const uint8_t v0 = p0[j];
2741
- const uint8_t v1 = p1[j];
2742
2710
 
2743
2711
  const float f0 = d0*(v0 & 0xf) + m0;
2744
2712
  const float f1 = d0*(v0 >> 4) + m0;
2745
2713
 
2746
- const float f2 = d1*(v1 & 0xf) + m1;
2747
- const float f3 = d1*(v1 >> 4) + m1;
2714
+ const float f2 = d1*p1[2*j + 0];
2715
+ const float f3 = d1*p1[2*j + 1];
2748
2716
 
2749
2717
  sumf += f0*f2 + f1*f3;
2750
2718
  }
@@ -2754,32 +2722,36 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2754
2722
  *s = sumf;
2755
2723
  }
2756
2724
 
2757
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2725
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
2726
  const int nb = n / QK8_0;
2759
2727
 
2760
2728
  assert(n % QK8_0 == 0);
2761
2729
  assert(nb % 2 == 0);
2730
+ assert(QK8_0 == 2*QK4_2);
2762
2731
 
2763
- const block_q4_0 * restrict x = vx;
2732
+ const block_q4_2 * restrict x = vx;
2764
2733
  const block_q8_0 * restrict y = vy;
2765
2734
 
2766
2735
  float sumf = 0.0;
2767
2736
 
2768
2737
  #if defined(__ARM_NEON)
2769
- float sum0 = 0.0f;
2770
- float sum1 = 0.0f;
2738
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2771
2740
 
2772
2741
  for (int i = 0; i < nb; i += 2) {
2773
- const block_q4_0 * restrict x0 = &x[i + 0];
2774
- const block_q4_0 * restrict x1 = &x[i + 1];
2742
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
2743
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
2744
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
2745
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
2746
+
2775
2747
  const block_q8_0 * restrict y0 = &y[i + 0];
2776
2748
  const block_q8_0 * restrict y1 = &y[i + 1];
2777
2749
 
2778
2750
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
2751
  const int8x16_t s8b = vdupq_n_s8(0x8);
2780
2752
 
2781
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2753
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2754
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2783
2755
 
2784
2756
  // 4-bit -> 8-bit
2785
2757
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
@@ -2793,77 +2765,78 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2793
2765
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
2766
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2795
2767
 
2768
+ // interleave
2769
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2770
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2771
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2772
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2773
+
2796
2774
  // load y
2797
2775
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
2776
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
2777
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
2778
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2801
2779
 
2802
- // interleave
2803
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2807
-
2808
2780
  #if defined(__ARM_FEATURE_DOTPROD)
2809
- // dot product into int32x4_t
2810
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2812
-
2813
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2781
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2782
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2815
2784
 
2816
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2785
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2818
2788
  #else
2819
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2823
-
2824
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
-
2829
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
-
2832
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
-
2835
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
-
2838
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2789
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2793
+
2794
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2798
+
2799
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2803
+
2804
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2807
+
2808
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2840
2811
  #endif
2841
2812
  }
2842
2813
 
2843
- sumf = sum0 + sum1;
2814
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2844
2815
  #elif defined(__AVX2__)
2845
2816
  // Initialize accumulator with zeros
2846
2817
  __m256 acc = _mm256_setzero_ps();
2847
2818
 
2848
2819
  // Main loop
2849
- for (int i = 0; i < nb; ++i) {
2820
+ for (int i = 0; i < nb; i++) {
2850
2821
  /* Compute combined scale for the block */
2851
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2822
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2852
2825
 
2853
- __m256i bx = bytesFromNibbles(x[i].qs);
2826
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
2854
2829
 
2855
2830
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
- const __m256i off = _mm256_set1_epi8( 8 );
2857
- bx = _mm256_sub_epi8( bx, off );
2831
+ const __m256i off = _mm256_set1_epi8(8);
2832
+ bx = _mm256_sub_epi8(bx, off);
2858
2833
 
2859
2834
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
2835
 
2861
2836
  // Get absolute values of x vectors
2862
2837
  const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
-
2864
2838
  // Sign the values of the y vectors
2865
2839
  const __m256i sy = _mm256_sign_epi8(by, bx);
2866
-
2867
2840
  // Perform multiplication and create 16-bit values
2868
2841
  const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
2842
 
@@ -2871,92 +2844,208 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2871
2844
  __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
2845
 
2873
2846
  /* Convert to vectore of 8 int32_t to 8 floats */
2874
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2847
+ __m256 q = _mm256_cvtepi32_ps(xy_q);
2875
2848
 
2876
2849
  /* Multiply q with scale and accumulate */
2877
- acc = _mm256_fmadd_ps( d, q, acc );
2850
+ acc = _mm256_fmadd_ps(d, q, acc);
2878
2851
  }
2879
2852
 
2880
2853
  // Return horizontal sum of the acc vector
2881
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2854
+ __m128 res = _mm256_extractf128_ps(acc, 1);
2855
+ res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
2885
2858
 
2886
- sumf = _mm_cvtss_f32( res );
2887
- #elif defined(__AVX__)
2888
- // Initialize accumulator with zeros
2889
- __m256 acc = _mm256_setzero_ps();
2859
+ sumf = _mm_cvtss_f32(res);
2860
+ #else
2861
+ // scalar
2862
+ for (int i = 0; i < nb; i++) {
2863
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
2865
+ const int8_t * restrict y0 = y[i].qs;
2890
2866
 
2891
- // Main loop
2892
- for (int i = 0; i < nb; ++i) {
2893
- // Compute combined scale for the block
2894
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2867
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
2895
2869
 
2896
- __m128i i32[2];
2897
- for (int j = 0; j < 2; ++j) {
2898
- // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
- __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2870
+ int sumi_0 = 0;
2871
+ int sumi_1 = 0;
2872
+
2873
+ for (int j = 0; j < QK8_0/4; j++) {
2874
+ const uint8_t v0 = x0[j];
2875
+ const uint8_t v1 = x1[j];
2876
+
2877
+ const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
2879
+
2880
+ const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
2882
+
2883
+ const int i2_0 = y0[2*j + 0];
2884
+ const int i3_0 = y0[2*j + 1];
2885
+
2886
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
2888
+
2889
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
2891
+ }
2892
+
2893
+ sumf += (d0 * y[i].d) * sumi_0;
2894
+ sumf += (d1 * y[i].d) * sumi_1;
2895
+ }
2896
+ #endif
2897
+
2898
+ *s = sumf;
2899
+ }
2900
+
2901
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
+ const int nb = n / QK8_0;
2903
+
2904
+ assert(n % QK8_0 == 0);
2905
+ assert(nb % 2 == 0);
2906
+ assert(QK8_0 == 2*QK4_2);
2907
+
2908
+ const block_q4_3 * restrict x = vx;
2909
+ const block_q8_0 * restrict y = vy;
2910
+
2911
+ float sumf = 0.0;
2912
+
2913
+ #if defined(__ARM_NEON)
2914
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2916
+
2917
+ for (int i = 0; i < nb; i += 2) {
2918
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
+ const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
+ const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2922
+
2923
+ const block_q8_0 * restrict y0 = &y[i + 0];
2924
+ const block_q8_0 * restrict y1 = &y[i + 1];
2925
+
2926
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
+
2928
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
+ const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
+ const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
+
2933
+ const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
+ const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
+ const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
+ const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
+
2938
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
+
2941
+ // 4-bit -> 8-bit
2942
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2901
2946
 
2902
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
- const __m128i off = _mm_set1_epi8( 8 );
2904
- bx = _mm_sub_epi8( bx, off );
2947
+ // interleave
2948
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2905
2952
 
2906
- // Get absolute values of x vectors
2907
- const __m128i ax = _mm_sign_epi8(bx, bx);
2953
+ // load y
2954
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2908
2958
 
2909
- // Sign the values of the y vectors
2910
- const __m128i sy = _mm_sign_epi8(by, bx);
2959
+ const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
+ const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2911
2961
 
2912
- // Perform multiplication and create 16-bit values
2913
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
2962
+ const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
+ const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2914
2964
 
2915
- const __m128i ones = _mm_set1_epi16(1);
2916
- i32[j] = _mm_madd_epi16(ones, dot);
2917
- }
2965
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
2918
2969
 
2919
- // Convert int32_t to float
2920
- __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
- // Apply the scale, and accumulate
2922
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2970
+ #if defined(__ARM_FEATURE_DOTPROD)
2971
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
+ #else
2976
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
+
2981
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2985
+
2986
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2990
+
2991
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
2995
+ #endif
2923
2996
  }
2924
2997
 
2925
- // Return horizontal sum of the acc vector
2926
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
-
2931
- sumf = _mm_cvtss_f32( res );
2998
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2932
2999
  #else
2933
3000
  // scalar
2934
3001
  for (int i = 0; i < nb; i++) {
2935
- const float d0 = x[i].d;
2936
- const float d1 = y[i].d;
3002
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
+ const int8_t * restrict y0 = y[i].qs;
2937
3005
 
2938
- const uint8_t * restrict p0 = x[i].qs;
2939
- const int8_t * restrict p1 = y[i].qs;
3006
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
2940
3010
 
2941
- int sumi = 0;
2942
- for (int j = 0; j < QK8_0/2; j++) {
2943
- const uint8_t v0 = p0[j];
3011
+ int sy_0 = 0;
3012
+ int sy_1 = 0;
2944
3013
 
2945
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
- const int i1 = (int8_t) (v0 >> 4) - 8;
3014
+ int sxy_0 = 0;
3015
+ int sxy_1 = 0;
2947
3016
 
2948
- const int i2 = p1[2*j + 0];
2949
- const int i3 = p1[2*j + 1];
3017
+ for (int j = 0; j < QK8_0/4; j++) {
3018
+ const uint8_t v0 = x0[j];
3019
+ const uint8_t v1 = x1[j];
2950
3020
 
2951
- sumi += i0*i2 + i1*i3;
3021
+ const int x0_0 = v0 & 0xf;
3022
+ const int x1_0 = v0 >> 4;
3023
+
3024
+ const int x0_1 = v1 & 0xf;
3025
+ const int x1_1 = v1 >> 4;
3026
+
3027
+ const int y0_0 = y0[2*j + 0];
3028
+ const int y1_0 = y0[2*j + 1];
3029
+
3030
+ const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
+ const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3032
+
3033
+ sy_0 += y0_0 + y1_0;
3034
+ sy_1 += y0_1 + y1_1;
3035
+
3036
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
2952
3038
  }
2953
- sumf += d0*d1*sumi;
3039
+
3040
+ sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
+ sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
2954
3042
  }
2955
3043
  #endif
2956
3044
 
2957
3045
  *s = sumf;
2958
3046
  }
2959
3047
 
3048
+
2960
3049
  // compute GGML_VEC_DOT_UNROLL dot products at once
2961
3050
  // xs - x row stride in bytes
2962
3051
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3203,24 +3292,28 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3203
3292
  [GGML_TYPE_F16] = 1,
3204
3293
  [GGML_TYPE_Q4_0] = QK4_0,
3205
3294
  [GGML_TYPE_Q4_1] = QK4_1,
3295
+ [GGML_TYPE_Q4_2] = QK4_2,
3296
+ [GGML_TYPE_Q4_3] = QK4_3,
3206
3297
  [GGML_TYPE_Q8_0] = QK8_0,
3207
3298
  [GGML_TYPE_I8] = 1,
3208
3299
  [GGML_TYPE_I16] = 1,
3209
3300
  [GGML_TYPE_I32] = 1,
3210
3301
  };
3211
- static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
3302
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
3212
3303
 
3213
3304
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3214
3305
  [GGML_TYPE_F32] = sizeof(float),
3215
3306
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
3216
3307
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3217
3308
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
+ [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
+ [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3218
3311
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3219
3312
  [GGML_TYPE_I8] = sizeof(int8_t),
3220
3313
  [GGML_TYPE_I16] = sizeof(int16_t),
3221
3314
  [GGML_TYPE_I32] = sizeof(int32_t),
3222
3315
  };
3223
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
3316
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
3224
3317
 
3225
3318
 
3226
3319
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3228,12 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3228
3321
  [GGML_TYPE_F16] = "f16",
3229
3322
  [GGML_TYPE_Q4_0] = "q4_0",
3230
3323
  [GGML_TYPE_Q4_1] = "q4_1",
3324
+ [GGML_TYPE_Q4_2] = "q4_2",
3325
+ [GGML_TYPE_Q4_3] = "q4_3",
3231
3326
  [GGML_TYPE_Q8_0] = "q8_0",
3232
3327
  [GGML_TYPE_I8] = "i8",
3233
3328
  [GGML_TYPE_I16] = "i16",
3234
3329
  [GGML_TYPE_I32] = "i32",
3235
3330
  };
3236
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
3331
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
3332
+
3333
+ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
+ [GGML_TYPE_F32] = false,
3335
+ [GGML_TYPE_F16] = false,
3336
+ [GGML_TYPE_Q4_0] = true,
3337
+ [GGML_TYPE_Q4_1] = true,
3338
+ [GGML_TYPE_Q4_2] = true,
3339
+ [GGML_TYPE_Q4_3] = true,
3340
+ [GGML_TYPE_Q8_0] = true,
3341
+ [GGML_TYPE_I8] = false,
3342
+ [GGML_TYPE_I16] = false,
3343
+ [GGML_TYPE_I32] = false,
3344
+ };
3345
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
3237
3346
 
3238
3347
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3239
3348
  "NONE",
@@ -3495,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3495
3604
  (t0->ne[3] == t1->ne[3]);
3496
3605
  }
3497
3606
 
3607
+ bool ggml_is_quantized(enum ggml_type type) {
3608
+ return GGML_IS_QUANTIZED[type];
3609
+ }
3610
+
3498
3611
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3499
3612
  return tensor->nb[0] > tensor->nb[1];
3500
3613
  }
@@ -3605,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3605
3718
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3606
3719
  }
3607
3720
 
3721
+ // initialize cuBLAS
3722
+ #if defined(GGML_USE_CUBLAS)
3723
+ init_cublas();
3724
+ #endif
3725
+
3608
3726
  is_first_call = false;
3609
3727
  }
3610
3728
 
@@ -5535,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
5535
5653
  const struct ggml_compute_params * params,
5536
5654
  const struct ggml_tensor * src0,
5537
5655
  struct ggml_tensor * dst) {
5538
- GGML_ASSERT(params->ith == 0);
5539
5656
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5540
5657
 
5541
5658
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5547,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
5547
5664
  const int64_t ne02 = src0->ne[2];
5548
5665
  const int64_t ne03 = src0->ne[3];
5549
5666
 
5667
+ const int64_t ne0 = dst->ne[0];
5668
+ const int64_t ne1 = dst->ne[1];
5669
+ const int64_t ne2 = dst->ne[2];
5670
+ const int64_t ne3 = dst->ne[3];
5671
+
5550
5672
  const size_t nb00 = src0->nb[0];
5551
5673
  const size_t nb01 = src0->nb[1];
5552
5674
  const size_t nb02 = src0->nb[2];
@@ -5557,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
5557
5679
  const size_t nb2 = dst->nb[2];
5558
5680
  const size_t nb3 = dst->nb[3];
5559
5681
 
5682
+ const int ith = params->ith; // thread index
5683
+ const int nth = params->nth; // number of threads
5684
+
5560
5685
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5561
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5686
+ // parallelize by elements
5687
+ const int ne = ggml_nelements(dst);
5688
+ const int dr = (ne + nth - 1) / nth;
5689
+ const int ie0 = dr * ith;
5690
+ const int ie1 = MIN(ie0 + dr, ne);
5691
+
5692
+ memcpy(
5693
+ ((char *) dst->data + ie0*nb0),
5694
+ ((char *) src0->data + ie0*nb00),
5695
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5696
+
5562
5697
  return;
5563
5698
  }
5564
5699
 
5700
+ // parallelize by rows
5701
+ const int nr = ne01;
5702
+ // number of rows per thread
5703
+ const int dr = (nr + nth - 1) / nth;
5704
+ // row range for this thread
5705
+ const int ir0 = dr * ith;
5706
+ const int ir1 = MIN(ir0 + dr, nr);
5707
+
5565
5708
  if (src0->type == dst->type &&
5566
- src0->ne[0] == dst->ne[0] &&
5567
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5709
+ ne00 == ne0 &&
5710
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5568
5711
  // copy by rows
5569
5712
  const size_t rs = ne00*nb00;
5570
5713
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5571
5714
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5572
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5715
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5573
5716
  memcpy(
5574
5717
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5575
5718
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5583,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
5583
5726
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
5584
5727
 
5585
5728
  if (ggml_is_contiguous(dst)) {
5586
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5729
+ if (nb00 == sizeof(ggml_fp16_t)) {
5587
5730
  if (dst->type == GGML_TYPE_F16) {
5588
5731
  size_t id = 0;
5589
- const size_t rs = ne00*nb00;
5732
+ const size_t rs = ne00 * nb00;
5733
+ char * dst_ptr = (char *) dst->data;
5590
5734
 
5591
5735
  for (int i03 = 0; i03 < ne03; i03++) {
5592
5736
  for (int i02 = 0; i02 < ne02; i02++) {
5593
- for (int i01 = 0; i01 < ne01; i01++) {
5737
+ id += rs * ir0;
5738
+ for (int i01 = ir0; i01 < ir1; i01++) {
5594
5739
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5595
- char * dst_ptr = (char *) dst->data + id*rs;
5596
-
5597
- memcpy(dst_ptr, src0_ptr, rs);
5598
-
5599
- id++;
5740
+ memcpy(dst_ptr + id, src0_ptr, rs);
5741
+ id += rs;
5600
5742
  }
5743
+ id += rs * (ne01 - ir1);
5601
5744
  }
5602
5745
  }
5603
5746
  } else if (dst->type == GGML_TYPE_F32) {
@@ -5606,34 +5749,39 @@ static void ggml_compute_forward_dup_f16(
5606
5749
 
5607
5750
  for (int i03 = 0; i03 < ne03; i03++) {
5608
5751
  for (int i02 = 0; i02 < ne02; i02++) {
5609
- for (int i01 = 0; i01 < ne01; i01++) {
5752
+ id += ne00 * ir0;
5753
+ for (int i01 = ir0; i01 < ir1; i01++) {
5754
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5610
5755
  for (int i00 = 0; i00 < ne00; i00++) {
5611
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5612
-
5613
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5756
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5614
5757
  id++;
5615
5758
  }
5616
5759
  }
5760
+ id += ne00 * (ne01 - ir1);
5617
5761
  }
5618
5762
  }
5619
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5763
+ } else if (ggml_is_quantized(dst->type)) {
5620
5764
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5765
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5766
+
5621
5767
  size_t id = 0;
5622
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
- float * src0_f32 = (float *) params->wdata;
5768
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5769
+ char * dst_ptr = (char *) dst->data;
5625
5770
 
5626
5771
  for (int i03 = 0; i03 < ne03; i03++) {
5627
5772
  for (int i02 = 0; i02 < ne02; i02++) {
5628
- for (int i01 = 0; i01 < ne01; i01++) {
5773
+ id += rs * ir0;
5774
+ for (int i01 = ir0; i01 < ir1; i01++) {
5629
5775
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
- // convert to f32 and quantize
5776
+
5631
5777
  for (int i00 = 0; i00 < ne00; i00++) {
5632
5778
  src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
5779
  }
5780
+
5634
5781
  quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
- id += dst_row_size;
5782
+ id += rs;
5636
5783
  }
5784
+ id += rs * (ne01 - ir1);
5637
5785
  }
5638
5786
  }
5639
5787
  } else {
@@ -5648,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
5648
5796
 
5649
5797
  for (int i03 = 0; i03 < ne03; i03++) {
5650
5798
  for (int i02 = 0; i02 < ne02; i02++) {
5651
- for (int i01 = 0; i01 < ne01; i01++) {
5799
+ id += ne00 * ir0;
5800
+ for (int i01 = ir0; i01 < ir1; i01++) {
5652
5801
  for (int i00 = 0; i00 < ne00; i00++) {
5653
5802
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5654
5803
 
@@ -5656,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
5656
5805
  id++;
5657
5806
  }
5658
5807
  }
5808
+ id += ne00 * (ne01 - ir1);
5659
5809
  }
5660
5810
  }
5661
5811
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5664,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
5664
5814
 
5665
5815
  for (int i03 = 0; i03 < ne03; i03++) {
5666
5816
  for (int i02 = 0; i02 < ne02; i02++) {
5667
- for (int i01 = 0; i01 < ne01; i01++) {
5817
+ id += ne00 * ir0;
5818
+ for (int i01 = ir0; i01 < ir1; i01++) {
5668
5819
  for (int i00 = 0; i00 < ne00; i00++) {
5669
5820
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5670
5821
 
@@ -5672,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
5672
5823
  id++;
5673
5824
  }
5674
5825
  }
5826
+ id += ne00 * (ne01 - ir1);
5675
5827
  }
5676
5828
  }
5677
5829
  } else {
@@ -5690,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
5690
5842
  if (dst->type == GGML_TYPE_F16) {
5691
5843
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5692
5844
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5693
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5845
+ i10 += ne00 * ir0;
5846
+ while (i10 >= ne0) {
5847
+ i10 -= ne0;
5848
+ if (++i11 == ne1) {
5849
+ i11 = 0;
5850
+ if (++i12 == ne2) {
5851
+ i12 = 0;
5852
+ if (++i13 == ne3) {
5853
+ i13 = 0;
5854
+ }
5855
+ }
5856
+ }
5857
+ }
5858
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5694
5859
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5695
5860
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5696
5861
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@@ -5711,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
5711
5876
  }
5712
5877
  }
5713
5878
  }
5879
+ i10 += ne00 * (ne01 - ir1);
5880
+ while (i10 >= ne0) {
5881
+ i10 -= ne0;
5882
+ if (++i11 == ne1) {
5883
+ i11 = 0;
5884
+ if (++i12 == ne2) {
5885
+ i12 = 0;
5886
+ if (++i13 == ne3) {
5887
+ i13 = 0;
5888
+ }
5889
+ }
5890
+ }
5891
+ }
5714
5892
  }
5715
5893
  }
5716
5894
  } else if (dst->type == GGML_TYPE_F32) {
5717
5895
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5718
5896
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5719
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5897
+ i10 += ne00 * ir0;
5898
+ while (i10 >= ne0) {
5899
+ i10 -= ne0;
5900
+ if (++i11 == ne1) {
5901
+ i11 = 0;
5902
+ if (++i12 == ne2) {
5903
+ i12 = 0;
5904
+ if (++i13 == ne3) {
5905
+ i13 = 0;
5906
+ }
5907
+ }
5908
+ }
5909
+ }
5910
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5720
5911
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5721
5912
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5722
5913
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5723
5914
 
5724
5915
  *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5725
5916
 
5726
- if (++i10 == ne00) {
5917
+ if (++i10 == ne0) {
5727
5918
  i10 = 0;
5728
- if (++i11 == ne01) {
5919
+ if (++i11 == ne1) {
5729
5920
  i11 = 0;
5730
- if (++i12 == ne02) {
5921
+ if (++i12 == ne2) {
5731
5922
  i12 = 0;
5732
- if (++i13 == ne03) {
5923
+ if (++i13 == ne3) {
5733
5924
  i13 = 0;
5734
5925
  }
5735
5926
  }
@@ -5737,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
5737
5928
  }
5738
5929
  }
5739
5930
  }
5931
+ i10 += ne00 * (ne01 - ir1);
5932
+ while (i10 >= ne0) {
5933
+ i10 -= ne0;
5934
+ if (++i11 == ne1) {
5935
+ i11 = 0;
5936
+ if (++i12 == ne2) {
5937
+ i12 = 0;
5938
+ if (++i13 == ne3) {
5939
+ i13 = 0;
5940
+ }
5941
+ }
5942
+ }
5943
+ }
5740
5944
  }
5741
5945
  }
5742
5946
  } else {
@@ -5748,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
5748
5952
  const struct ggml_compute_params * params,
5749
5953
  const struct ggml_tensor * src0,
5750
5954
  struct ggml_tensor * dst) {
5751
- GGML_ASSERT(params->ith == 0);
5752
5955
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5753
5956
 
5754
5957
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5760,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
5760
5963
  const int64_t ne02 = src0->ne[2];
5761
5964
  const int64_t ne03 = src0->ne[3];
5762
5965
 
5966
+ const int64_t ne0 = dst->ne[0];
5967
+ const int64_t ne1 = dst->ne[1];
5968
+ const int64_t ne2 = dst->ne[2];
5969
+ const int64_t ne3 = dst->ne[3];
5970
+
5763
5971
  const size_t nb00 = src0->nb[0];
5764
5972
  const size_t nb01 = src0->nb[1];
5765
5973
  const size_t nb02 = src0->nb[2];
@@ -5770,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
5770
5978
  const size_t nb2 = dst->nb[2];
5771
5979
  const size_t nb3 = dst->nb[3];
5772
5980
 
5981
+ const int ith = params->ith; // thread index
5982
+ const int nth = params->nth; // number of threads
5983
+
5773
5984
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5774
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5985
+ // parallelize by elements
5986
+ const int ne = ggml_nelements(dst);
5987
+ const int dr = (ne + nth - 1) / nth;
5988
+ const int ie0 = dr * ith;
5989
+ const int ie1 = MIN(ie0 + dr, ne);
5990
+
5991
+ memcpy(
5992
+ ((char *) dst->data + ie0*nb0),
5993
+ ((char *) src0->data + ie0*nb00),
5994
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5995
+
5775
5996
  return;
5776
5997
  }
5777
5998
 
5999
+ // parallelize by rows
6000
+ const int nr = ne01;
6001
+ // number of rows per thread
6002
+ const int dr = (nr + nth - 1) / nth;
6003
+ // row range for this thread
6004
+ const int ir0 = dr * ith;
6005
+ const int ir1 = MIN(ir0 + dr, nr);
6006
+
5778
6007
  if (src0->type == dst->type &&
5779
- src0->ne[0] == dst->ne[0] &&
5780
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6008
+ ne00 == ne0 &&
6009
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5781
6010
  // copy by rows
5782
6011
  const size_t rs = ne00*nb00;
5783
6012
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5784
6013
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5785
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6014
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5786
6015
  memcpy(
5787
6016
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5788
6017
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5795,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
5795
6024
 
5796
6025
  if (ggml_is_contiguous(dst)) {
5797
6026
  // TODO: simplify
5798
- if (src0->nb[0] == sizeof(float)) {
6027
+ if (nb00 == sizeof(float)) {
5799
6028
  if (dst->type == GGML_TYPE_F32) {
5800
6029
  size_t id = 0;
5801
- const size_t rs = ne00*nb00;
6030
+ const size_t rs = ne00 * nb00;
6031
+ char * dst_ptr = (char *) dst->data;
5802
6032
 
5803
6033
  for (int i03 = 0; i03 < ne03; i03++) {
5804
6034
  for (int i02 = 0; i02 < ne02; i02++) {
5805
- for (int i01 = 0; i01 < ne01; i01++) {
6035
+ id += rs * ir0;
6036
+ for (int i01 = ir0; i01 < ir1; i01++) {
5806
6037
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5807
- char * dst_ptr = (char *) dst->data + id*rs;
5808
-
5809
- memcpy(dst_ptr, src0_ptr, rs);
5810
-
5811
- id++;
6038
+ memcpy(dst_ptr + id, src0_ptr, rs);
6039
+ id += rs;
5812
6040
  }
6041
+ id += rs * (ne01 - ir1);
5813
6042
  }
5814
6043
  }
5815
6044
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5818,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
5818
6047
 
5819
6048
  for (int i03 = 0; i03 < ne03; i03++) {
5820
6049
  for (int i02 = 0; i02 < ne02; i02++) {
5821
- for (int i01 = 0; i01 < ne01; i01++) {
6050
+ id += ne00 * ir0;
6051
+ for (int i01 = ir0; i01 < ir1; i01++) {
5822
6052
  for (int i00 = 0; i00 < ne00; i00++) {
5823
6053
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5824
6054
 
@@ -5826,21 +6056,25 @@ static void ggml_compute_forward_dup_f32(
5826
6056
  id++;
5827
6057
  }
5828
6058
  }
6059
+ id += ne00 * (ne01 - ir1);
5829
6060
  }
5830
6061
  }
5831
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
6062
+ } else if (ggml_is_quantized(dst->type)) {
5832
6063
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6064
+
5833
6065
  size_t id = 0;
5834
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6066
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6067
+ char * dst_ptr = (char *) dst->data;
5836
6068
 
5837
6069
  for (int i03 = 0; i03 < ne03; i03++) {
5838
6070
  for (int i02 = 0; i02 < ne02; i02++) {
5839
- for (int i01 = 0; i01 < ne01; i01++) {
6071
+ id += rs * ir0;
6072
+ for (int i01 = ir0; i01 < ir1; i01++) {
5840
6073
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
6074
  quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
- id += dst_row_size;
6075
+ id += rs;
5843
6076
  }
6077
+ id += rs * (ne01 - ir1);
5844
6078
  }
5845
6079
  }
5846
6080
  } else {
@@ -5855,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
5855
6089
 
5856
6090
  for (int i03 = 0; i03 < ne03; i03++) {
5857
6091
  for (int i02 = 0; i02 < ne02; i02++) {
5858
- for (int i01 = 0; i01 < ne01; i01++) {
6092
+ id += ne00 * ir0;
6093
+ for (int i01 = ir0; i01 < ir1; i01++) {
5859
6094
  for (int i00 = 0; i00 < ne00; i00++) {
5860
6095
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5861
6096
 
@@ -5863,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
5863
6098
  id++;
5864
6099
  }
5865
6100
  }
6101
+ id += ne00 * (ne01 - ir1);
5866
6102
  }
5867
6103
  }
5868
6104
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5871,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
5871
6107
 
5872
6108
  for (int i03 = 0; i03 < ne03; i03++) {
5873
6109
  for (int i02 = 0; i02 < ne02; i02++) {
5874
- for (int i01 = 0; i01 < ne01; i01++) {
6110
+ id += ne00 * ir0;
6111
+ for (int i01 = ir0; i01 < ir1; i01++) {
5875
6112
  for (int i00 = 0; i00 < ne00; i00++) {
5876
6113
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5877
6114
 
@@ -5879,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
5879
6116
  id++;
5880
6117
  }
5881
6118
  }
6119
+ id += ne00 * (ne01 - ir1);
5882
6120
  }
5883
6121
  }
5884
6122
  } else {
@@ -5890,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
5890
6128
  }
5891
6129
 
5892
6130
  // dst counters
6131
+
5893
6132
  int64_t i10 = 0;
5894
6133
  int64_t i11 = 0;
5895
6134
  int64_t i12 = 0;
@@ -5898,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
5898
6137
  if (dst->type == GGML_TYPE_F32) {
5899
6138
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5900
6139
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5901
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6140
+ i10 += ne00 * ir0;
6141
+ while (i10 >= ne0) {
6142
+ i10 -= ne0;
6143
+ if (++i11 == ne1) {
6144
+ i11 = 0;
6145
+ if (++i12 == ne2) {
6146
+ i12 = 0;
6147
+ if (++i13 == ne3) {
6148
+ i13 = 0;
6149
+ }
6150
+ }
6151
+ }
6152
+ }
6153
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5902
6154
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5903
6155
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5904
6156
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5905
6157
 
5906
6158
  memcpy(dst_ptr, src0_ptr, sizeof(float));
5907
6159
 
5908
- if (++i10 == dst->ne[0]) {
6160
+ if (++i10 == ne0) {
5909
6161
  i10 = 0;
5910
- if (++i11 == dst->ne[1]) {
6162
+ if (++i11 == ne1) {
5911
6163
  i11 = 0;
5912
- if (++i12 == dst->ne[2]) {
6164
+ if (++i12 == ne2) {
5913
6165
  i12 = 0;
5914
- if (++i13 == dst->ne[3]) {
6166
+ if (++i13 == ne3) {
5915
6167
  i13 = 0;
5916
6168
  }
5917
6169
  }
@@ -5919,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
5919
6171
  }
5920
6172
  }
5921
6173
  }
6174
+ i10 += ne00 * (ne01 - ir1);
6175
+ while (i10 >= ne0) {
6176
+ i10 -= ne0;
6177
+ if (++i11 == ne1) {
6178
+ i11 = 0;
6179
+ if (++i12 == ne2) {
6180
+ i12 = 0;
6181
+ if (++i13 == ne3) {
6182
+ i13 = 0;
6183
+ }
6184
+ }
6185
+ }
6186
+ }
5922
6187
  }
5923
6188
  }
5924
6189
  } else if (dst->type == GGML_TYPE_F16) {
5925
6190
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5926
6191
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5927
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6192
+ i10 += ne00 * ir0;
6193
+ while (i10 >= ne0) {
6194
+ i10 -= ne0;
6195
+ if (++i11 == ne1) {
6196
+ i11 = 0;
6197
+ if (++i12 == ne2) {
6198
+ i12 = 0;
6199
+ if (++i13 == ne3) {
6200
+ i13 = 0;
6201
+ }
6202
+ }
6203
+ }
6204
+ }
6205
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5928
6206
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5929
6207
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5930
6208
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5931
6209
 
5932
6210
  *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5933
6211
 
5934
- if (++i10 == dst->ne[0]) {
6212
+ if (++i10 == ne0) {
5935
6213
  i10 = 0;
5936
- if (++i11 == dst->ne[1]) {
6214
+ if (++i11 == ne1) {
5937
6215
  i11 = 0;
5938
- if (++i12 == dst->ne[2]) {
6216
+ if (++i12 == ne2) {
5939
6217
  i12 = 0;
5940
- if (++i13 == dst->ne[3]) {
6218
+ if (++i13 == ne3) {
5941
6219
  i13 = 0;
5942
6220
  }
5943
6221
  }
@@ -5945,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
5945
6223
  }
5946
6224
  }
5947
6225
  }
6226
+ i10 += ne00 * (ne01 - ir1);
6227
+ while (i10 >= ne0) {
6228
+ i10 -= ne0;
6229
+ if (++i11 == ne1) {
6230
+ i11 = 0;
6231
+ if (++i12 == ne2) {
6232
+ i12 = 0;
6233
+ if (++i13 == ne3) {
6234
+ i13 = 0;
6235
+ }
6236
+ }
6237
+ }
6238
+ }
5948
6239
  }
5949
6240
  }
5950
6241
  } else {
@@ -6191,7 +6482,7 @@ static void ggml_compute_forward_add_q_f32(
6191
6482
  GGML_ASSERT(nb1 <= nb2);
6192
6483
  GGML_ASSERT(nb2 <= nb3);
6193
6484
 
6194
- GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
6485
+ GGML_ASSERT(ggml_is_quantized(src0->type));
6195
6486
  GGML_ASSERT(dst->type == src0->type);
6196
6487
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
6488
 
@@ -6205,7 +6496,7 @@ static void ggml_compute_forward_add_q_f32(
6205
6496
  const int ir0 = dr*ith;
6206
6497
  const int ir1 = MIN(ir0 + dr, nr);
6207
6498
 
6208
- float * wdata = (float*) params->wdata + ne00 * ith;
6499
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6209
6500
 
6210
6501
  for (int ir = ir0; ir < ir1; ++ir) {
6211
6502
  // src0 indices
@@ -6261,6 +6552,8 @@ static void ggml_compute_forward_add(
6261
6552
  } break;
6262
6553
  case GGML_TYPE_Q4_0:
6263
6554
  case GGML_TYPE_Q4_1:
6555
+ case GGML_TYPE_Q4_2:
6556
+ case GGML_TYPE_Q4_3:
6264
6557
  {
6265
6558
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
6559
  } break;
@@ -7161,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
7161
7454
 
7162
7455
  // ggml_compute_forward_mul_mat
7163
7456
 
7164
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7457
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7165
7458
  // helper function to determine if it is better to use BLAS or not
7166
7459
  // for large matrices, BLAS is faster
7167
7460
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7201,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
7201
7494
  const int64_t ne02 = src0->ne[2];
7202
7495
  const int64_t ne03 = src0->ne[3];
7203
7496
 
7204
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7497
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7205
7498
  const int64_t ne10 = src1->ne[0];
7206
7499
  #endif
7207
7500
  const int64_t ne11 = src1->ne[1];
@@ -7258,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
7258
7551
  // nb01 >= nb00 - src0 is not transposed
7259
7552
  // compute by src0 rows
7260
7553
 
7261
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7554
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7262
7555
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7263
7556
  if (params->ith != 0) {
7264
7557
  return;
@@ -7272,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
7272
7565
  return;
7273
7566
  }
7274
7567
 
7568
+ #if defined(GGML_USE_CUBLAS)
7569
+ float *d_X = NULL;
7570
+ float *d_Y = NULL;
7571
+ float *d_D = NULL;
7572
+ const float alpha = 1.0f;
7573
+ const float beta = 0.0f;
7574
+ const int x_ne = ne01 * ne10;
7575
+ const int y_ne = ne11 * ne10;
7576
+ const int d_ne = ne11 * ne01;
7577
+
7578
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7581
+ #endif
7582
+
7275
7583
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7276
7584
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7277
7585
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -7279,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
7279
7587
 
7280
7588
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7281
7589
 
7590
+ #if defined(GGML_USE_CUBLAS)
7591
+ // copy data to device
7592
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7594
+
7595
+ // compute
7596
+ CUBLAS_CHECK(
7597
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
+ ne01, ne11, ne10,
7599
+ &alpha, d_X, ne00,
7600
+ d_Y, ne10,
7601
+ &beta, d_D, ne01));
7602
+
7603
+ // copy data to host
7604
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
+ #else
7282
7606
  // zT = y * xT
7283
7607
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7284
7608
  ne11, ne01, ne10,
7285
7609
  1.0f, y, ne10,
7286
7610
  x, ne00,
7287
7611
  0.0f, d, ne01);
7612
+ #endif
7288
7613
  }
7289
7614
  }
7290
-
7615
+ #if defined(GGML_USE_CUBLAS)
7616
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
+ CUDA_CHECK(cudaFree(d_X));
7618
+ CUDA_CHECK(cudaFree(d_Y));
7619
+ CUDA_CHECK(cudaFree(d_D));
7620
+ #endif
7291
7621
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7292
7622
 
7293
7623
  return;
@@ -7417,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7417
7747
  // nb01 >= nb00 - src0 is not transposed
7418
7748
  // compute by src0 rows
7419
7749
 
7420
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7750
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7421
7751
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7422
7752
  GGML_ASSERT(nb10 == sizeof(float));
7423
7753
 
@@ -7433,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7433
7763
  return;
7434
7764
  }
7435
7765
 
7436
- float * const wdata = params->wdata;
7766
+ #if defined(GGML_USE_CUBLAS)
7767
+ ggml_fp16_t * const wdata = params->wdata;
7437
7768
 
7769
+ float *d_X = NULL;
7770
+ float *d_Y = NULL;
7771
+ float *d_D = NULL;
7772
+ const float alpha = 1.0f;
7773
+ const float beta = 0.0f;
7774
+ const int x_ne = ne01 * ne10;
7775
+ const int y_ne = ne11 * ne10;
7776
+ const int d_ne = ne11 * ne01;
7777
+
7778
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7781
+ #else
7782
+ float * const wdata = params->wdata;
7783
+ #endif
7438
7784
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7439
7785
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7786
+ #if defined(GGML_USE_CUBLAS)
7787
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7788
+ {
7789
+ size_t id = 0;
7790
+ for (int64_t i01 = 0; i01 < ne11; ++i01) {
7791
+ for (int64_t i00 = 0; i00 < ne10; ++i00) {
7792
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
7793
+ }
7794
+ }
7795
+ }
7796
+ #else
7440
7797
  {
7441
7798
  size_t id = 0;
7442
7799
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7445,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7445
7802
  }
7446
7803
  }
7447
7804
  }
7805
+ #endif
7806
+
7807
+ #if defined(GGML_USE_CUBLAS)
7808
+ const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
7809
+ const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
7810
+
7811
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7448
7812
 
7813
+ // copy data to device
7814
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7816
+
7817
+ // compute
7818
+ CUBLAS_CHECK(
7819
+ cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
+ ne01, ne11, ne10,
7821
+ &alpha, d_X, CUDA_R_16F, ne00,
7822
+ d_Y, CUDA_R_16F, ne10,
7823
+ &beta, d_D, CUDA_R_32F, ne01,
7824
+ CUBLAS_COMPUTE_32F,
7825
+ CUBLAS_GEMM_DEFAULT));
7826
+
7827
+ // copy data to host
7828
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7829
+ #else
7449
7830
  const float * x = wdata;
7450
7831
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7451
7832
 
@@ -7457,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7457
7838
  1.0f, y, ne10,
7458
7839
  x, ne00,
7459
7840
  0.0f, d, ne01);
7841
+ #endif
7460
7842
  }
7461
7843
  }
7462
7844
 
7845
+ #if defined(GGML_USE_CUBLAS)
7846
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
+ CUDA_CHECK(cudaFree(d_X));
7848
+ CUDA_CHECK(cudaFree(d_Y));
7849
+ CUDA_CHECK(cudaFree(d_D));
7850
+ #endif
7463
7851
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7464
7852
 
7465
7853
  return;
@@ -7611,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7611
7999
  // nb01 >= nb00 - src0 is not transposed
7612
8000
  // compute by src0 rows
7613
8001
 
7614
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8002
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7615
8003
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7616
8004
  if (params->ith != 0) {
7617
8005
  return;
@@ -7625,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
7625
8013
  return;
7626
8014
  }
7627
8015
 
8016
+ #if defined(GGML_USE_CUBLAS)
8017
+ float *d_X = NULL;
8018
+ float *d_Y = NULL;
8019
+ float *d_D = NULL;
8020
+ float *d_Q = NULL;
8021
+ const float alpha = 1.0f;
8022
+ const float beta = 0.0f;
8023
+ const int x_ne = ne01 * ne10;
8024
+ const int y_ne = ne11 * ne10;
8025
+ const int d_ne = ne11 * ne01;
8026
+
8027
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
+ CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8031
+
8032
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
+ if (type == GGML_TYPE_Q4_0) {
8034
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8035
+ }
8036
+ else if (type == GGML_TYPE_Q4_1) {
8037
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8038
+ }
8039
+ else if (type == GGML_TYPE_Q4_2) {
8040
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
+ }
8042
+ else {
8043
+ GGML_ASSERT(false);
8044
+ }
8045
+ #else
7628
8046
  float * const wdata = params->wdata;
7629
8047
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
+ #endif
7630
8049
 
7631
8050
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7632
8051
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8052
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8053
+
8054
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8055
+
8056
+ #if defined(GGML_USE_CUBLAS)
8057
+ // copy and dequantize on device
8058
+ CUDA_CHECK(
8059
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8061
+
8062
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8063
+ CUDA_CHECK(cudaGetLastError());
8064
+ #else
7633
8065
  {
7634
8066
  size_t id = 0;
7635
8067
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7637,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
7637
8069
  id += ne00;
7638
8070
  }
7639
8071
  }
7640
-
7641
8072
  const float * x = wdata;
7642
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8073
+ #endif
7643
8074
 
7644
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7645
8075
 
8076
+ #if defined(GGML_USE_CUBLAS)
8077
+ // copy data to device
8078
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8079
+
8080
+ // compute
8081
+ CUBLAS_CHECK(
8082
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
+ ne01, ne11, ne10,
8084
+ &alpha, d_X, ne00,
8085
+ d_Y, ne10,
8086
+ &beta, d_D, ne01));
8087
+
8088
+ // copy data to host
8089
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
+ #else
7646
8091
  // zT = y * xT
7647
8092
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7648
8093
  ne11, ne01, ne10,
7649
8094
  1.0f, y, ne10,
7650
8095
  x, ne00,
7651
8096
  0.0f, d, ne01);
8097
+ #endif
7652
8098
  }
7653
8099
  }
7654
8100
 
8101
+ #if defined(GGML_USE_CUBLAS)
8102
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
+ CUDA_CHECK(cudaFree(d_X));
8104
+ CUDA_CHECK(cudaFree(d_Y));
8105
+ CUDA_CHECK(cudaFree(d_D));
8106
+ CUDA_CHECK(cudaFree(d_Q));
8107
+ #endif
7655
8108
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7656
8109
 
7657
8110
  return;
@@ -7739,6 +8192,8 @@ static void ggml_compute_forward_mul_mat(
7739
8192
  switch (src0->type) {
7740
8193
  case GGML_TYPE_Q4_0:
7741
8194
  case GGML_TYPE_Q4_1:
8195
+ case GGML_TYPE_Q4_2:
8196
+ case GGML_TYPE_Q4_3:
7742
8197
  case GGML_TYPE_Q8_0:
7743
8198
  {
7744
8199
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
@@ -7756,34 +8211,6 @@ static void ggml_compute_forward_mul_mat(
7756
8211
  GGML_ASSERT(false);
7757
8212
  } break;
7758
8213
  }
7759
-
7760
- #if 0
7761
- if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
7762
- static int first = 8;
7763
- printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7764
- printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7765
- printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7766
- if (first) {
7767
- --first;
7768
- } else {
7769
- for (int k = 0; k < dst->ne[1]; ++k) {
7770
- for (int j = 0; j < dst->ne[0]/16; ++j) {
7771
- for (int i = 0; i < 16; ++i) {
7772
- printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
7773
- }
7774
- printf("\n");
7775
- }
7776
- printf("\n");
7777
- }
7778
- printf("\n");
7779
- exit(0);
7780
- }
7781
- } else {
7782
- printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7783
- printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7784
- printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7785
- }
7786
- #endif
7787
8214
  }
7788
8215
 
7789
8216
  // ggml_compute_forward_scale
@@ -7994,6 +8421,8 @@ static void ggml_compute_forward_get_rows(
7994
8421
  switch (src0->type) {
7995
8422
  case GGML_TYPE_Q4_0:
7996
8423
  case GGML_TYPE_Q4_1:
8424
+ case GGML_TYPE_Q4_2:
8425
+ case GGML_TYPE_Q4_3:
7997
8426
  case GGML_TYPE_Q8_0:
7998
8427
  {
7999
8428
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
@@ -8224,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
8224
8653
 
8225
8654
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8226
8655
 
8656
+ const bool is_neox = mode & 2;
8657
+
8227
8658
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8228
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8229
- const int p = (mode == 0 ? n_past + i2 : i2);
8659
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8660
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8230
8661
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8231
8662
  if (ir++ < ir0) continue;
8232
8663
  if (ir > ir1) break;
@@ -8239,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
8239
8670
 
8240
8671
  theta *= theta_scale;
8241
8672
 
8242
- const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8243
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8673
+ if (!is_neox) {
8674
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8675
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8244
8676
 
8245
- const float x0 = src[0];
8246
- const float x1 = src[1];
8677
+ const float x0 = src[0];
8678
+ const float x1 = src[1];
8247
8679
 
8248
- dst_data[0] = x0*cos_theta - x1*sin_theta;
8249
- dst_data[1] = x0*sin_theta + x1*cos_theta;
8680
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8681
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
8682
+ } else {
8683
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8684
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8685
+
8686
+ const float x0 = src[0];
8687
+ const float x1 = src[n_dims/2];
8688
+
8689
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8690
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
8691
+ }
8250
8692
  }
8251
8693
  }
8252
8694
  }
@@ -8301,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
8301
8743
 
8302
8744
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8303
8745
 
8746
+ const bool is_neox = mode & 2;
8747
+
8304
8748
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8305
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8306
- const int p = (mode == 0 ? n_past + i2 : i2);
8749
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8750
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8307
8751
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8308
8752
  if (ir++ < ir0) continue;
8309
8753
  if (ir > ir1) break;
@@ -8316,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
8316
8760
 
8317
8761
  theta *= theta_scale;
8318
8762
 
8319
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8320
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8763
+ if (!is_neox) {
8764
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8765
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8321
8766
 
8322
- const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
- const float x1 = GGML_FP16_TO_FP32(src[1]);
8767
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8768
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
8324
8769
 
8325
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
- dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8770
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8771
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8772
+ } else {
8773
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8774
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8775
+
8776
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8777
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8778
+
8779
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8780
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8781
+ }
8327
8782
  }
8328
8783
  }
8329
8784
  }
@@ -10402,11 +10857,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10402
10857
  case GGML_OP_CPY:
10403
10858
  case GGML_OP_DUP:
10404
10859
  {
10405
- node->n_tasks = 1;
10860
+ node->n_tasks = n_threads;
10406
10861
 
10407
10862
  size_t cur = 0;
10408
- if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
10863
+ if (ggml_is_quantized(node->type)) {
10864
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
10410
10865
  }
10411
10866
 
10412
10867
  work_size = MAX(work_size, cur);
@@ -10417,7 +10872,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10417
10872
 
10418
10873
  size_t cur = 0;
10419
10874
 
10420
- if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
10875
+ if (ggml_is_quantized(node->src0->type)) {
10421
10876
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
10877
  }
10423
10878
 
@@ -10466,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10466
10921
  size_t cur = 0;
10467
10922
 
10468
10923
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10469
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10924
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
10470
10925
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10471
10926
  node->n_tasks = 1; // TODO: this actually is doing nothing
10472
10927
  // the threads are still spinning
@@ -10482,8 +10937,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10482
10937
  #endif
10483
10938
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10484
10939
  cur = 0;
10485
- } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
10486
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10940
+ } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
10487
10942
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10488
10943
  node->n_tasks = 1;
10489
10944
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -11709,6 +12164,86 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11709
12164
  return (n/QK4_1*sizeof(block_q4_1));
11710
12165
  }
11711
12166
 
12167
+ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
12168
+ assert(k % QK4_2 == 0);
12169
+ const int nb = k / QK4_2;
12170
+
12171
+ for (int j = 0; j < n; j += k) {
12172
+ block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
+
12174
+ //quantize_row_q4_2_reference(src + j, y, k);
12175
+ quantize_row_q4_2_rmse(src + j, y, k);
12176
+
12177
+ for (int i = 0; i < nb; i++) {
12178
+ for (int l = 0; l < QK4_2; l += 2) {
12179
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12180
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
+
12182
+ hist[vi0]++;
12183
+ hist[vi1]++;
12184
+ }
12185
+ }
12186
+ }
12187
+
12188
+ return (n/QK4_2*sizeof(block_q4_2));
12189
+ }
12190
+
12191
+ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12192
+ assert(k % QK4_3 == 0);
12193
+ const int nb = k / QK4_3;
12194
+
12195
+ for (int j = 0; j < n; j += k) {
12196
+ block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
12197
+
12198
+ quantize_row_q4_3_reference(src + j, y, k);
12199
+
12200
+ for (int i = 0; i < nb; i++) {
12201
+ for (int l = 0; l < QK4_3; l += 2) {
12202
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12203
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12204
+
12205
+ hist[vi0]++;
12206
+ hist[vi1]++;
12207
+ }
12208
+ }
12209
+ }
12210
+
12211
+ return (n/QK4_3*sizeof(block_q4_3));
12212
+ }
12213
+
12214
+ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12215
+ size_t result = 0;
12216
+ switch (type) {
12217
+ case GGML_TYPE_Q4_0:
12218
+ {
12219
+ GGML_ASSERT(start % QK4_0 == 0);
12220
+ block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
12221
+ result = ggml_quantize_q4_0(src + start, block, n, n, hist);
12222
+ } break;
12223
+ case GGML_TYPE_Q4_1:
12224
+ {
12225
+ GGML_ASSERT(start % QK4_1 == 0);
12226
+ block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
12227
+ result = ggml_quantize_q4_1(src + start, block, n, n, hist);
12228
+ } break;
12229
+ case GGML_TYPE_Q4_2:
12230
+ {
12231
+ GGML_ASSERT(start % QK4_2 == 0);
12232
+ block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
12233
+ result = ggml_quantize_q4_2(src + start, block, n, n, hist);
12234
+ } break;
12235
+ case GGML_TYPE_Q4_3:
12236
+ {
12237
+ GGML_ASSERT(start % QK4_3 == 0);
12238
+ block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
+ result = ggml_quantize_q4_3(src + start, block, n, n, hist);
12240
+ } break;
12241
+ default:
12242
+ assert(false);
12243
+ }
12244
+ return result;
12245
+ }
12246
+
11712
12247
  ////////////////////////////////////////////////////////////////////////////////
11713
12248
 
11714
12249
  int ggml_cpu_has_avx(void) {
@@ -11800,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
11800
12335
  }
11801
12336
 
11802
12337
  int ggml_cpu_has_blas(void) {
11803
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12338
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
12339
+ return 1;
12340
+ #else
12341
+ return 0;
12342
+ #endif
12343
+ }
12344
+
12345
+ int ggml_cpu_has_cublas(void) {
12346
+ #if defined(GGML_USE_CUBLAS)
11804
12347
  return 1;
11805
12348
  #else
11806
12349
  return 0;