llama_cpp 0.0.5 → 0.0.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -19,6 +19,7 @@
19
19
  #include <inttypes.h>
20
20
  #include <stdio.h>
21
21
  #include <float.h>
22
+ #include <limits.h>
22
23
 
23
24
  // if C99 - static_assert is noop
24
25
  // ref: https://stackoverflow.com/a/53923785/4039976
@@ -142,10 +143,49 @@ inline static void* ggml_aligned_malloc(size_t size) {
142
143
  } \
143
144
  } while (0)
144
145
 
145
- #ifdef GGML_USE_ACCELERATE
146
+ #if defined(GGML_USE_ACCELERATE)
146
147
  #include <Accelerate/Accelerate.h>
147
- #elif GGML_USE_OPENBLAS
148
+ #elif defined(GGML_USE_OPENBLAS)
148
149
  #include <cblas.h>
150
+ #elif defined(GGML_USE_CUBLAS)
151
+ #include <cublas_v2.h>
152
+ #include <cuda_runtime.h>
153
+ #include "ggml-cuda.h"
154
+
155
+ #define CUDA_CHECK(err) \
156
+ do { \
157
+ cudaError_t err_ = (err); \
158
+ if (err_ != cudaSuccess) { \
159
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
+ cudaGetErrorString(err_)); \
161
+ exit(1); \
162
+ } \
163
+ } while (0)
164
+
165
+ #define CUBLAS_CHECK(err) \
166
+ do { \
167
+ cublasStatus_t err_ = (err); \
168
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
+ exit(1); \
171
+ } \
172
+ } while (0)
173
+
174
+ static cublasHandle_t cublasH = NULL;
175
+ static cudaStream_t cudaStream = NULL;
176
+ static void init_cublas(void) {
177
+ if (cublasH == NULL) {
178
+ // create cublas handle, bind a stream
179
+ CUBLAS_CHECK(cublasCreate(&cublasH));
180
+
181
+ CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
+
183
+ CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
+
185
+ // configure logging to stdout
186
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
+ }
188
+ }
149
189
  #endif
150
190
 
151
191
  #undef MIN
@@ -427,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427
467
  // quantization
428
468
  //
429
469
 
430
- // AVX routines provided by GH user Const-me
431
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
470
+ #if __AVX__ || __AVX2__ || __AVX512F__
471
+ // Unpack 16 4-bit fields into 16 bytes
472
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
473
+ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
+ {
475
+ // Load 8 bytes from memory
476
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477
+
478
+ // Expand bytes into uint16_t values
479
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
480
+
481
+ // Unpack values into individual bytes
482
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
483
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
484
+ __m128i low = _mm_and_si128( lowMask, bytes );
485
+ high = _mm_slli_epi16( high, 4 );
486
+ bytes = _mm_or_si128( low, high );
487
+ return bytes;
488
+ }
489
+
432
490
  #if __AVX2__ || __AVX512F__
433
491
  // Unpack 32 4-bit fields into 32 bytes
434
492
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435
- static inline __m256i bytesFromNibbles( const uint8_t* rsi )
493
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
436
494
  {
437
495
  // Load 16 bytes from memory
438
496
  __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -463,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
463
521
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
464
522
  return _mm_packus_epi16( r0, r1 );
465
523
  }
466
- #elif __AVX__
467
- static inline __m128i bytesFromNibbles( const uint8_t* rsi )
468
- {
469
- // Load 8 bytes from memory
470
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
471
-
472
- // Expand bytes into uint16_t values
473
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
474
-
475
- // Unpack values into individual bytes
476
- const __m128i lowMask = _mm_set1_epi8( 0xF );
477
- __m128i high = _mm_andnot_si128( lowMask, bytes );
478
- __m128i low = _mm_and_si128( lowMask, bytes );
479
- high = _mm_slli_epi16( high, 4 );
480
- bytes = _mm_or_si128( low, high );
481
- return bytes;
482
- }
483
-
524
+ #else
484
525
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
485
526
  {
486
527
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -497,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
538
  return _mm_packus_epi16( bytes1, bytes2);
498
539
  }
499
540
  #endif
541
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
500
542
 
501
543
  #if __ARM_NEON
502
544
 
@@ -514,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
514
556
  (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
515
557
  }
516
558
 
559
+ inline static int16_t vaddvq_s8(int8x16_t v) {
560
+ return
561
+ (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
562
+ (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
563
+ (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
564
+ (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
565
+ (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
566
+ (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
567
+ (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
568
+ (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
569
+ }
570
+
517
571
  inline static int32_t vaddvq_s16(int16x8_t v) {
518
572
  return
519
573
  (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -583,7 +637,22 @@ typedef struct {
583
637
  float m; // min
584
638
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
585
639
  } block_q4_1;
586
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
640
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
641
+
642
+ #define QK4_2 16
643
+ typedef struct {
644
+ ggml_fp16_t d; // delta
645
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
646
+ } block_q4_2;
647
+ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
648
+
649
+ #define QK4_3 16
650
+ typedef struct {
651
+ ggml_fp16_t d; // delta
652
+ ggml_fp16_t m; // min
653
+ uint8_t qs[QK4_3 / 2]; // nibbles / quants
654
+ } block_q4_3;
655
+ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
587
656
 
588
657
  #define QK8_0 32
589
658
  typedef struct {
@@ -1045,6 +1114,173 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1045
1114
  #endif
1046
1115
  }
1047
1116
 
1117
+ // reference implementation for deterministic creation of model files
1118
+ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
1119
+ assert(k % QK4_2 == 0);
1120
+
1121
+ const int nb = k / QK4_2;
1122
+
1123
+ for (int i = 0; i < nb; i++) {
1124
+ float amax = 0.0f; // absolute max
1125
+
1126
+ for (int l = 0; l < QK4_2; l++) {
1127
+ const float v = x[i*QK4_2 + l];
1128
+ amax = MAX(amax, fabsf(v));
1129
+ }
1130
+
1131
+ const float d = amax / ((1 << 3) - 1);
1132
+
1133
+ const float id = d ? 1.0f/d : 0.0f;
1134
+
1135
+ y[i].d = GGML_FP32_TO_FP16(d);
1136
+
1137
+ for (int l = 0; l < QK4_2; l += 2) {
1138
+ const float v0 = x[i*QK4_2 + l + 0]*id;
1139
+ const float v1 = x[i*QK4_2 + l + 1]*id;
1140
+
1141
+ const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
+ const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1143
+
1144
+ assert(vi0 < 16);
1145
+ assert(vi1 < 16);
1146
+
1147
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1148
+ }
1149
+ }
1150
+ }
1151
+
1152
+ static inline int nearest_int(float fval) {
1153
+ assert(fval <= 4194303.f);
1154
+ float val = fval + 12582912.f;
1155
+ int i; memcpy(&i, &val, sizeof(int));
1156
+ return (i & 0x007fffff) - 0x00400000;
1157
+ }
1158
+
1159
+ static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
+ const float * restrict candidates, int8_t * restrict L) {
1161
+ assert (nmin >= INT8_MIN);
1162
+ assert (nmax <= INT8_MAX);
1163
+ float amax = 0;
1164
+ for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
+ if (!amax) { // all zero
1166
+ for (int i=0; i<n; ++i) L[i] = 0;
1167
+ return 1.f;
1168
+ }
1169
+ float best = 0, bestScale = 0;
1170
+ for (int si=0; si<nCandidates; ++si) {
1171
+ float iscale = candidates[si]/amax;
1172
+ float sumlxP = 0; int suml2P = 0;
1173
+ float sumlxM = 0; int suml2M = 0;
1174
+ for (int i=0; i<n; ++i) {
1175
+ int l = nearest_int(iscale*X[i]);
1176
+ int lp = MAX(nmin, MIN(nmax, +l));
1177
+ int lm = MAX(nmin, MIN(nmax, -l));
1178
+ sumlxP += X[i]*lp; suml2P += lp*lp;
1179
+ sumlxM += X[i]*lm; suml2M += lm*lm;
1180
+ }
1181
+ float sumlxP2 = sumlxP*sumlxP;
1182
+ float sumlxM2 = sumlxM*sumlxM;
1183
+ if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
+ if (sumlxP2 > best*suml2P) {
1185
+ best = sumlxP2/suml2P; bestScale = iscale;
1186
+ }
1187
+ } else {
1188
+ if (sumlxM2 > best*suml2M) {
1189
+ best = sumlxM2/suml2M; bestScale = -iscale;
1190
+ }
1191
+ }
1192
+ }
1193
+ float sumlx = 0; int suml2 = 0;
1194
+ for (int i=0; i<n; ++i) {
1195
+ int l = nearest_int(bestScale*X[i]);
1196
+ l = MAX(nmin, MIN(nmax, l));
1197
+ sumlx += X[i]*l; suml2 += l*l;
1198
+ L[i] = l;
1199
+ }
1200
+ float scale = sumlx/suml2;
1201
+ return scale;
1202
+ }
1203
+
1204
+ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
+ #define CANDIDATE_COUNT 8
1206
+ static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
+ assert(k % QK4_2 == 0);
1208
+
1209
+ int8_t L[QK4_2];
1210
+
1211
+ const int nb = k / QK4_2;
1212
+
1213
+ for (int i = 0; i < nb; i++) {
1214
+ float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
+ y[i].d = GGML_FP32_TO_FP16(scale);
1216
+
1217
+ for (int l = 0; l < QK4_2; l += 2) {
1218
+ const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
+ const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1220
+
1221
+ assert(vi0 < 16);
1222
+ assert(vi1 < 16);
1223
+
1224
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1225
+ }
1226
+
1227
+ x += QK4_2;
1228
+ }
1229
+ }
1230
+
1231
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1232
+ assert(k % QK4_2 == 0);
1233
+
1234
+ block_q4_2 * restrict y = vy;
1235
+
1236
+ //quantize_row_q4_2_reference(x, y, k);
1237
+ // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
+ quantize_row_q4_2_rmse(x, y, k);
1239
+ }
1240
+
1241
+ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1242
+ assert(k % QK4_3 == 0);
1243
+ const int nb = k / QK4_3;
1244
+
1245
+ for (int i = 0; i < nb; i++) {
1246
+ float min = FLT_MAX;
1247
+ float max = -FLT_MAX;
1248
+
1249
+ for (int l = 0; l < QK4_3; l++) {
1250
+ const float v = x[i*QK4_3 + l];
1251
+ if (v < min) min = v;
1252
+ if (v > max) max = v;
1253
+ }
1254
+
1255
+ const float d = (max - min) / ((1 << 4) - 1);
1256
+ const float id = d ? 1.0f/d : 0.0f;
1257
+
1258
+ y[i].d = GGML_FP32_TO_FP16(d);
1259
+ y[i].m = GGML_FP32_TO_FP16(min);
1260
+
1261
+ for (int l = 0; l < QK4_3; l += 2) {
1262
+ const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1263
+ const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1264
+
1265
+ const uint8_t vi0 = (int) (v0 + 0.5f);
1266
+ const uint8_t vi1 = (int) (v1 + 0.5f);
1267
+
1268
+ assert(vi0 < 16);
1269
+ assert(vi1 < 16);
1270
+
1271
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1272
+ }
1273
+ }
1274
+ }
1275
+
1276
+ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1277
+ assert(k % QK4_3 == 0);
1278
+
1279
+ block_q4_3 * restrict y = vy;
1280
+
1281
+ quantize_row_q4_3_reference(x, y, k);
1282
+ }
1283
+
1048
1284
  // reference implementation for deterministic creation of model files
1049
1285
  static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
1286
  assert(k % QK8_0 == 0);
@@ -1064,7 +1300,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
1064
1300
  y[i].d = d;
1065
1301
 
1066
1302
  for (int l = 0; l < QK8_0; ++l) {
1067
- const float v = x[i*QK8_0 + l]*id;
1303
+ const float v = x[i*QK8_0 + l]*id;
1068
1304
  y[i].qs[l] = roundf(v);
1069
1305
  }
1070
1306
  }
@@ -1211,7 +1447,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1211
1447
 
1212
1448
  for (int l = 0; l < QK4_0; l += 32) {
1213
1449
  // Load 32x4-bit integers into 32x8-bit integers
1214
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1450
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1215
1451
 
1216
1452
  // Subtract 8 from the integers
1217
1453
  vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1329,7 +1565,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1329
1565
 
1330
1566
  for (int l = 0; l < QK4_1; l += 32) {
1331
1567
  // Load 32x4-bit integers into 32x8-bit integers
1332
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1568
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1333
1569
 
1334
1570
  // Convert to 16-bit int
1335
1571
  const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1420,8 +1656,69 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1420
1656
  #endif
1421
1657
  }
1422
1658
 
1423
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1659
+ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1660
+ assert(k % QK4_2 == 0);
1661
+ const int nb = k / QK4_2;
1662
+
1663
+ const block_q4_2 * restrict x = vx;
1664
+
1665
+ for (int i = 0; i < nb; i++) {
1666
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1667
+
1668
+ const uint8_t * restrict pp = x[i].qs;
1669
+
1670
+ for (int l = 0; l < QK4_2; l += 2) {
1671
+ const uint8_t vi = pp[l/2];
1672
+
1673
+ const int8_t vi0 = vi & 0xf;
1674
+ const int8_t vi1 = vi >> 4;
1675
+
1676
+ const float v0 = (vi0 - 8)*d;
1677
+ const float v1 = (vi1 - 8)*d;
1678
+
1679
+ y[i*QK4_2 + l + 0] = v0;
1680
+ y[i*QK4_2 + l + 1] = v1;
1681
+
1682
+ assert(!isnan(y[i*QK4_2 + l + 0]));
1683
+ assert(!isnan(y[i*QK4_2 + l + 1]));
1684
+ }
1685
+ }
1686
+ }
1687
+
1688
+ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1689
+ assert(k % QK4_3 == 0);
1690
+ const int nb = k / QK4_3;
1691
+
1692
+ const block_q4_3 * restrict x = vx;
1693
+
1694
+ for (int i = 0; i < nb; i++) {
1695
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1696
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1697
+
1698
+ const uint8_t * restrict pp = x[i].qs;
1699
+
1700
+ for (int l = 0; l < QK4_3; l += 2) {
1701
+ const uint8_t vi = pp[l/2];
1702
+
1703
+ const int8_t vi0 = vi & 0xf;
1704
+ const int8_t vi1 = vi >> 4;
1705
+
1706
+ const float v0 = vi0*d + m;
1707
+ const float v1 = vi1*d + m;
1708
+
1709
+ y[i*QK4_3 + l + 0] = v0;
1710
+ y[i*QK4_3 + l + 1] = v1;
1711
+
1712
+ assert(!isnan(y[i*QK4_3 + l + 0]));
1713
+ assert(!isnan(y[i*QK4_3 + l + 1]));
1714
+ }
1715
+ }
1716
+ }
1717
+
1424
1718
  static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
1722
 
1426
1723
  static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
1724
  [GGML_TYPE_Q4_0] = {
@@ -1435,10 +1732,30 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1435
1732
  .dequantize_row_q = dequantize_row_q4_1,
1436
1733
  .quantize_row_q = quantize_row_q4_1,
1437
1734
  .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
- .quantize_row_q_dot = quantize_row_q4_1,
1439
- .vec_dot_q = ggml_vec_dot_q4_1,
1735
+ .quantize_row_q_dot = quantize_row_q8_0,
1736
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
1737
+ },
1738
+ [GGML_TYPE_Q4_2] = {
1739
+ .dequantize_row_q = dequantize_row_q4_2,
1740
+ .quantize_row_q = quantize_row_q4_2,
1741
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
1742
+ .quantize_row_q_dot = quantize_row_q8_0,
1743
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
1744
+ },
1745
+ [GGML_TYPE_Q4_3] = {
1746
+ .dequantize_row_q = dequantize_row_q4_3,
1747
+ .quantize_row_q = quantize_row_q4_3,
1748
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
1749
+ .quantize_row_q_dot = quantize_row_q8_0,
1750
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
1751
+ },
1752
+ [GGML_TYPE_Q8_0] = {
1753
+ .dequantize_row_q = NULL, // TODO
1754
+ .quantize_row_q = quantize_row_q8_0,
1755
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
+ .quantize_row_q_dot = quantize_row_q8_0,
1757
+ .vec_dot_q = NULL, // TODO
1440
1758
  },
1441
- // TODO: GGML_TYPE_Q8_0
1442
1759
  };
1443
1760
 
1444
1761
  // For internal test use
@@ -2004,191 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
2004
2321
  *s = sumf;
2005
2322
  }
2006
2323
 
2007
- #if __AVX512F__ && QK4_0 == 32
2008
- static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
- // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
- // | :. =_ () [] <> () Zz Yy|
2014
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
- // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
- // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
- //
2020
- // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
- // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
- // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
- // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
- #ifdef __AVX512VBMI__
2025
- const __m512i byte_perm = _mm512_set_epi8(
2026
- 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
- 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
- 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
- 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
- );
2031
- const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
- // After applying VPERMB, `permuted` looks like this:
2033
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
- // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
- // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
- #else
2043
- const __m512i word_perm = _mm512_set_epi16(
2044
- 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
- 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
- );
2047
- const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
- // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
- // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
- // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
- #endif
2052
-
2053
- // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
- const __mmask32 shift_mask = 0xaaaaaaaa;
2055
- const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
- // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
- // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
- // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
-
2067
- // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
- const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
- return _mm512_and_si512( low_nibble_mask, shifted );
2070
- // The final result looks like this:
2071
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
- // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
- // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
- // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
- // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
- // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
- }
2081
-
2082
- static inline __m512 dot_q4_0_twoblocks_avx512(
2083
- __m512 acc,
2084
- const block_q4_0 * restrict x,
2085
- const block_q4_0 * restrict y,
2086
- int i
2087
- ) {
2088
- // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
- // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
- // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
- // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
- const __mmask8 load_mask = 0x1f;
2093
- const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
- const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
-
2096
- // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
- // blocks_0_float
2101
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
- // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
- // blocks_1_float
2105
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
- // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
- const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
- const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
- // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
- // random data, which might very well underflow. At least on Intel, this leads
2112
- // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
- // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
- // (and ggml can't assume that you do)...
2115
- const __mmask16 scale_mul_mask = 0x21;
2116
- #ifdef __clang__
2117
- // ...however, clang decides to optimize the multiplication mask away:
2118
- // https://godbolt.org/z/P8PqdsfvW
2119
- // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
- __m512i scales;
2121
- __asm__(
2122
- "vmulps %1, %2, %0%{%3%}"
2123
- : "=v" ( scales )
2124
- : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
- );
2126
- #else
2127
- const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
- #endif
2129
- const __m512i scale_perm = _mm512_set_epi32(
2130
- 5, 5, 5, 5, 5, 5, 5, 5,
2131
- 0, 0, 0, 0, 0, 0, 0, 0
2132
- );
2133
- const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
- // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
- // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
- // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
- // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
-
2141
- const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
- const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
-
2144
- // Now we want to compute dot products of 4-element byte vectors and store them in
2145
- // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
- // +----+----+----+----+
2147
- // ... | 03 | 02 | 01 | 00 |
2148
- // +----+----+----+----+
2149
- // bytes_0
2150
- // +----+----+----+----+
2151
- // ... | D | C | B | A |
2152
- // +----+----+----+----+
2153
- // bytes_1
2154
- // +----+----+----+----+
2155
- // ... | H | G | F | E |
2156
- // +----+----+----+----+
2157
- // final_res_int
2158
- // +----+----+----+----+
2159
- // ... | A*E+B*F+C*G+D*H |
2160
- // +----+----+----+----+
2161
- const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
- const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
-
2164
- #ifdef __AVX512VNNI__
2165
- // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
- // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
- // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
- // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
- // which means we only need 2 instructions.
2170
- const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
- const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
- const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
- const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
- #else
2175
- // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
- // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
- // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
- // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
- const __m512i one = _mm512_set1_epi16( 1 );
2180
- const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
- const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
- const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
- const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
- #endif
2185
-
2186
- // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
- const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
- return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
2189
- }
2190
- #endif
2191
-
2192
2324
  inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
2193
2325
  ggml_float sumf = 0.0;
2194
2326
 
@@ -2225,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2225
2357
  *s = sumf;
2226
2358
  }
2227
2359
 
2228
- static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK4_0;
2360
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2361
+ const int nb = n / QK8_0;
2230
2362
 
2231
- assert(n % QK4_0 == 0);
2363
+ assert(n % QK8_0 == 0);
2232
2364
  assert(nb % 2 == 0);
2233
2365
 
2234
2366
  const block_q4_0 * restrict x = vx;
2235
- const block_q4_0 * restrict y = vy;
2367
+ const block_q8_0 * restrict y = vy;
2236
2368
 
2237
2369
  float sumf = 0.0;
2238
2370
 
2239
2371
  #if defined(__ARM_NEON)
2240
- float sum0 = 0.0f;
2241
- float sum1 = 0.0f;
2372
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2242
2374
 
2243
2375
  for (int i = 0; i < nb; i += 2) {
2244
2376
  const block_q4_0 * restrict x0 = &x[i + 0];
2245
- const block_q4_0 * restrict y0 = &y[i + 0];
2246
2377
  const block_q4_0 * restrict x1 = &x[i + 1];
2247
- const block_q4_0 * restrict y1 = &y[i + 1];
2378
+ const block_q8_0 * restrict y0 = &y[i + 0];
2379
+ const block_q8_0 * restrict y1 = &y[i + 1];
2248
2380
 
2249
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2250
- const int8x16_t s8b = vdupq_n_s8(0x8);
2381
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2382
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2251
2383
 
2252
2384
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2253
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2254
2385
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2255
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2256
2386
 
2257
2387
  // 4-bit -> 8-bit
2258
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
2259
- const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
2388
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2260
2389
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2261
- const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
2262
-
2263
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
2264
- const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
2390
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2265
2391
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2266
- const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
2267
2392
 
2268
2393
  // sub 8
2269
2394
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2270
- const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
2271
2395
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2272
- const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
2273
-
2274
2396
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2275
- const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
2276
2397
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2277
- const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
2398
+
2399
+ // load y
2400
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2401
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2402
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2403
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2404
+
2405
+ // interleave
2406
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2407
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2408
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2409
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2278
2410
 
2279
2411
  #if defined(__ARM_FEATURE_DOTPROD)
2280
2412
  // dot product into int32x4_t
2281
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2282
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2413
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2283
2415
 
2284
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2285
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2286
-
2287
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2288
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2416
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2417
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2289
2418
  #else
2290
2419
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2291
2420
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -2297,116 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2297
2426
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2298
2427
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2299
2428
 
2300
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2301
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2302
-
2303
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2304
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2429
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2430
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2431
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2432
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2305
2433
 
2306
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2307
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2308
-
2309
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2310
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2434
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2435
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2311
2436
  #endif
2312
2437
  }
2313
2438
 
2314
- sumf = sum0 + sum1;
2315
- #elif defined(__AVX512F__)
2439
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
+ #elif defined(__AVX2__)
2316
2441
  // Initialize accumulator with zeros
2317
- __m512 acc0 = _mm512_setzero_ps();
2318
- __m512 acc1 = _mm512_setzero_ps();
2442
+ __m256 acc = _mm256_setzero_ps();
2319
2443
 
2320
- const int superblock_size = 16;
2444
+ // Main loop
2445
+ for (int i = 0; i < nb; ++i) {
2446
+ /* Compute combined scale for the block */
2447
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2321
2448
 
2322
- const int superblock_count = nb / superblock_size;
2449
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
2323
2450
 
2324
- for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
2325
- int i = superblock_ix * superblock_size;
2451
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2452
+ const __m256i off = _mm256_set1_epi8( 8 );
2453
+ bx = _mm256_sub_epi8( bx, off );
2326
2454
 
2327
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
- acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
2335
- }
2455
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2336
2456
 
2337
- // Remainders
2338
- for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
- acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
2340
- }
2457
+ // Get absolute values of x vectors
2458
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2341
2459
 
2342
- // Horizontal sum of all lanes of the accumulator
2343
- sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
2344
- #elif defined(__AVX2__)
2345
- // Initialize accumulator with zeros
2346
- __m256 acc = _mm256_setzero_ps();
2460
+ // Sign the values of the y vectors
2461
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2347
2462
 
2348
- /* Prepare the constants we will need during execution */
2349
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
2350
- const __m256i offset_8 = _mm256_set1_epi16( 8 );
2463
+ // Perform multiplication and create 16-bit values
2464
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2351
2465
 
2352
- #define UNROLL_COUNT 8
2353
- // make sure we only unroll multiples of the block count
2354
- assert(nb % UNROLL_COUNT == 0);
2466
+ const __m256i ones = _mm256_set1_epi16(1);
2467
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2355
2468
 
2356
- // Main loop
2357
- for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2358
- // This loop will be unrolled by the compiler
2359
- for (int u=0;u<UNROLL_COUNT;u++) {
2360
- /* Compute combined scale for the block */
2361
- const __m256 scale = _mm256_mul_ps(
2362
- _mm256_broadcast_ss( &x[i+u].d ),
2363
- _mm256_broadcast_ss( &y[i+u].d ) );
2364
-
2365
- /* get input from x
2366
- Input: 32 Nibbles (16 bytes) at *x[i+u]
2367
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2368
-
2369
- /* Load 16 bytes from memory */
2370
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2371
- /* Expand bytes into uint16_t values */
2372
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2373
- /* Unpack values into individual bytes */
2374
- __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2375
- const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2376
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2377
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2378
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2379
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2380
-
2381
- /* get input from y
2382
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2383
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2384
-
2385
- /* Load 16 bytes from memory */
2386
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2387
- /* Expand bytes into uint16_t values */
2388
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2389
- /* Unpack values into individual bytes */
2390
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2391
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2392
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2393
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2394
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2395
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2396
-
2397
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2398
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2399
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2400
-
2401
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2402
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2403
-
2404
- /* Convert to vectore of 8 int32_t to 8 floats */
2405
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2406
-
2407
- /* Multiply q with scale and accumulate */
2408
- acc = _mm256_fmadd_ps( scale, q, acc );
2409
- }
2469
+ /* Convert to vectore of 8 int32_t to 8 floats */
2470
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2471
+
2472
+ /* Multiply q with scale and accumulate */
2473
+ acc = _mm256_fmadd_ps( d, q, acc );
2410
2474
  }
2411
2475
 
2412
2476
  // Return horizontal sum of the acc vector
@@ -2428,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2428
2492
  __m128i i32[2];
2429
2493
  for (int j = 0; j < 2; ++j) {
2430
2494
  // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2431
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2432
- __m128i by = bytesFromNibbles( y[i].qs + 8*j );
2495
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2496
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2433
2497
 
2434
2498
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2435
2499
  const __m128i off = _mm_set1_epi8( 8 );
2436
2500
  bx = _mm_sub_epi8( bx, off );
2437
- by = _mm_sub_epi8( by, off );
2438
2501
 
2439
2502
  // Get absolute values of x vectors
2440
2503
  const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2462,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2462
2525
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2463
2526
 
2464
2527
  sumf = _mm_cvtss_f32( res );
2465
- #elif defined(__wasm_simd128__)
2466
- // wasm simd
2467
- float sum0 = 0.0f;
2468
- float sum1 = 0.0f;
2469
-
2470
- for (int i = 0; i < nb; i += 2) {
2471
- const block_q4_0 * restrict x0 = &x[i + 0];
2472
- const block_q4_0 * restrict y0 = &y[i + 0];
2473
- const block_q4_0 * restrict x1 = &x[i + 1];
2474
- const block_q4_0 * restrict y1 = &y[i + 1];
2475
-
2476
- const v128_t m4b = wasm_u8x16_splat(0xf);
2477
- const v128_t s8b = wasm_i8x16_splat(0x8);
2478
-
2479
- const v128_t v0_0 = wasm_v128_load(x0->qs);
2480
- const v128_t v0_1 = wasm_v128_load(y0->qs);
2481
- const v128_t v1_0 = wasm_v128_load(x1->qs);
2482
- const v128_t v1_1 = wasm_v128_load(y1->qs);
2483
-
2484
- // 4-bit -> 8-bit
2485
- const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2486
- const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
2487
-
2488
- const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2489
- const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
2490
-
2491
- const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2492
- const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
2493
-
2494
- const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2495
- const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
2496
-
2497
- // sub 8
2498
- const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2499
- const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
2500
-
2501
- const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2502
- const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
2503
-
2504
- const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2505
- const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
2506
-
2507
- const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2508
- const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
2509
-
2510
- // dot product into int16x8_t
2511
- const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
2512
- const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
2513
-
2514
- const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
2515
- const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
2516
-
2517
- const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
2518
- const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
2519
-
2520
- const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
2521
- const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
2522
-
2523
- const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
2524
- const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
2525
-
2526
- const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
2527
- const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
2528
-
2529
- const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
2530
- const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
2531
-
2532
- sum0 += x0->d * y0->d * (
2533
- wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
2534
- wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
2535
- wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
2536
- wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
2537
- sum1 += x1->d * y1->d * (
2538
- wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
2539
- wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
2540
- wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
2541
- wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
2542
- }
2543
-
2544
- sumf = sum0 + sum1;
2545
2528
  #else
2546
2529
  // scalar
2547
2530
  for (int i = 0; i < nb; i++) {
@@ -2549,202 +2532,187 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2549
2532
  const float d1 = y[i].d;
2550
2533
 
2551
2534
  const uint8_t * restrict p0 = x[i].qs;
2552
- const uint8_t * restrict p1 = y[i].qs;
2535
+ const int8_t * restrict p1 = y[i].qs;
2553
2536
 
2554
2537
  int sumi = 0;
2555
- for (int j = 0; j < QK4_0/2; j++) {
2538
+ for (int j = 0; j < QK8_0/2; j++) {
2556
2539
  const uint8_t v0 = p0[j];
2557
- const uint8_t v1 = p1[j];
2558
2540
 
2559
- const int i0 = (v0 & 0xf) - 8;
2560
- const int i1 = (v0 >> 4) - 8;
2541
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2561
2543
 
2562
- const int i2 = (v1 & 0xf) - 8;
2563
- const int i3 = (v1 >> 4) - 8;
2544
+ const int i2 = p1[2*j + 0];
2545
+ const int i3 = p1[2*j + 1];
2564
2546
 
2565
2547
  sumi += i0*i2 + i1*i3;
2566
2548
  }
2567
- sumf += d0 * d1 * sumi;
2549
+ sumf += d0*d1*sumi;
2568
2550
  }
2569
2551
  #endif
2570
2552
 
2571
2553
  *s = sumf;
2572
2554
  }
2573
2555
 
2574
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2575
- const int nb = n / QK4_1;
2556
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
+ const int nb = n / QK8_0;
2558
+
2559
+ assert(n % QK8_0 == 0);
2560
+ assert(nb % 2 == 0);
2576
2561
 
2577
2562
  const block_q4_1 * restrict x = vx;
2578
- const block_q4_1 * restrict y = vy;
2563
+ const block_q8_0 * restrict y = vy;
2579
2564
 
2580
2565
  float sumf = 0.0;
2581
2566
 
2582
- #if defined(__AVX2__)
2583
- // Initialize accumulator with zeros
2584
- __m256 acc = _mm256_setzero_ps();
2585
- // Accumulator for constant offsets
2586
- float acc_offset = 0.0f;
2587
-
2588
- // Main loop
2589
- for (int i = 0; i < nb; ++i) {
2590
- const float * d0 = &x[i].d;
2591
- const float * d1 = &y[i].d;
2592
-
2593
- const float * m0 = &x[i].m;
2594
- const float * m1 = &y[i].m;
2595
-
2596
- const __m256 d0v = _mm256_broadcast_ss( d0 );
2597
- const __m256 d1v = _mm256_broadcast_ss( d1 );
2598
- const __m256 m0v = _mm256_broadcast_ss( m0 );
2599
- const __m256 m1v = _mm256_broadcast_ss( m1 );
2600
-
2601
- // Compute combined scale for the block
2602
- const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
2603
-
2604
- // Compute cross scales for the block
2605
- const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2606
- const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2607
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
2608
-
2609
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2610
- __m256i bx = bytesFromNibbles( x[i].qs );
2611
- __m256i by = bytesFromNibbles( y[i].qs );
2612
-
2613
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2614
-
2615
- // Sign-extend first 16 signed bytes into int16_t
2616
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2617
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2618
- // Compute products of int16_t integers, add pairwise
2619
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2620
-
2621
- // Sign-extend last 16 signed bytes into int16_t vectors
2622
- __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2623
- __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2624
- // Accumulate products of int16_t integers
2625
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2626
-
2627
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2628
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2629
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2630
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2631
- __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2632
- __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2633
- __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2634
- __m256 sums = _mm256_cvtepi32_ps( sumsi );
2635
-
2636
- // Convert int32_t to float
2637
- __m256 p = _mm256_cvtepi32_ps( i32 );
2638
- // Apply the scale, and accumulate
2639
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2640
- acc = _mm256_fmadd_ps( scale_01, p, acc );
2641
- acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2642
- // acc_offset += m0*m1 (for each entry in the block)
2643
- acc_offset += (*m0)*(*m1);
2644
- }
2645
-
2646
- // Return horizontal sum of the acc vector
2647
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2648
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2649
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2650
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2651
-
2652
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2653
- #elif defined(__ARM_NEON)
2654
- float sum00 = 0.0f;
2655
- float sum01 = 0.0f;
2656
- float sum10 = 0.0f;
2657
- float sum11 = 0.0f;
2567
+ // TODO: add AVX / WASM SIMD / etc
2568
+ #if defined(__ARM_NEON)
2569
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2658
2571
 
2659
2572
  for (int i = 0; i < nb; i += 2) {
2660
2573
  const block_q4_1 * restrict x0 = &x[i + 0];
2661
- const block_q4_1 * restrict y0 = &y[i + 0];
2662
2574
  const block_q4_1 * restrict x1 = &x[i + 1];
2663
- const block_q4_1 * restrict y1 = &y[i + 1];
2575
+ const block_q8_0 * restrict y0 = &y[i + 0];
2576
+ const block_q8_0 * restrict y1 = &y[i + 1];
2664
2577
 
2665
2578
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2666
2579
 
2667
2580
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2668
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2669
2581
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2670
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2671
2582
 
2672
2583
  // 4-bit -> 8-bit
2673
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2674
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2675
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2676
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2584
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2585
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2586
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
+
2589
+ // load y
2590
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2677
2594
 
2678
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2679
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2680
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2681
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2595
+ // interleave
2596
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2682
2600
 
2683
- sum00 += x0->m*y0->m;
2684
- sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
- sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2601
+ const int16x8_t s0i = vaddq_s16(
2602
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2686
2604
 
2687
- sum00 += x1->m*y1->m;
2688
- sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
- sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2605
+ const int16x8_t s1i = vaddq_s16(
2606
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
+
2609
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2690
2611
 
2691
2612
  #if defined(__ARM_FEATURE_DOTPROD)
2692
2613
  // dot product into int32x4_t
2693
- uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2694
- uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2695
-
2696
- p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2697
- p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2614
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2698
2616
 
2699
- sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2700
- sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2617
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2701
2619
  #else
2702
- const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2703
- const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2704
- const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2705
- const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2620
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2624
+
2625
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2629
+
2630
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2632
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2633
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2634
+
2635
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2636
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2637
+ #endif
2638
+ }
2706
2639
 
2707
- const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2708
- const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2709
- const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2710
- const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2640
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2641
+ #elif defined(__AVX2__)
2642
+ // Initialize accumulator with zeros
2643
+ __m256 acc = _mm256_setzero_ps();
2711
2644
 
2712
- const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2713
- const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2645
+ // Main loop
2646
+ for (int i = 0; i < nb; ++i) {
2647
+ const float * d0 = &x[i].d;
2648
+ const float * d1 = &y[i].d;
2649
+ const float * m0 = &x[i].m;
2650
+
2651
+ const __m256 d0v = _mm256_broadcast_ss( d0 );
2652
+ const __m256 d1v = _mm256_broadcast_ss( d1 );
2653
+ const __m256 m0v = _mm256_broadcast_ss( m0 );
2714
2654
 
2715
- const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2716
- const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2655
+ // Compute combined scales
2656
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
+ const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2717
2658
 
2718
- const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2719
- const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2659
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2660
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2720
2662
 
2721
- sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2722
- sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2723
- #endif
2663
+ // Get absolute values of x vectors
2664
+ const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
+
2666
+ // Sign the values of the y vectors
2667
+ const __m256i sy = _mm256_sign_epi8( by, bx );
2668
+
2669
+ // Perform multiplication and create 16-bit values
2670
+ const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
+ const __m256i ones = _mm256_set1_epi16( 1 );
2672
+ const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
+
2674
+ // Convert to vector of 8 int32_t to 8 floats
2675
+ const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2676
+
2677
+ // Accumulate d0*d1*x*y
2678
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
+
2680
+ // Compute sum of y values
2681
+ const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
+ const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
+ const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
+ const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
+
2686
+ // Accumulate d1*m0*y
2687
+ acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2724
2688
  }
2725
2689
 
2726
- sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2690
+ // Return horizontal sum of the acc vector
2691
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2692
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2693
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2694
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2695
+
2696
+ sumf = _mm_cvtss_f32( res );
2727
2697
  #else
2728
2698
  // scalar
2729
2699
  for (int i = 0; i < nb; i++) {
2730
2700
  const float d0 = x[i].d;
2731
- const float d1 = y[i].d;
2732
-
2733
2701
  const float m0 = x[i].m;
2734
- const float m1 = y[i].m;
2702
+ const float d1 = y[i].d;
2735
2703
 
2736
2704
  const uint8_t * restrict p0 = x[i].qs;
2737
- const uint8_t * restrict p1 = y[i].qs;
2705
+ const int8_t * restrict p1 = y[i].qs;
2738
2706
 
2739
- for (int j = 0; j < QK4_1/2; j++) {
2707
+ // TODO: this is very slow ..
2708
+ for (int j = 0; j < QK8_0/2; j++) {
2740
2709
  const uint8_t v0 = p0[j];
2741
- const uint8_t v1 = p1[j];
2742
2710
 
2743
2711
  const float f0 = d0*(v0 & 0xf) + m0;
2744
2712
  const float f1 = d0*(v0 >> 4) + m0;
2745
2713
 
2746
- const float f2 = d1*(v1 & 0xf) + m1;
2747
- const float f3 = d1*(v1 >> 4) + m1;
2714
+ const float f2 = d1*p1[2*j + 0];
2715
+ const float f3 = d1*p1[2*j + 1];
2748
2716
 
2749
2717
  sumf += f0*f2 + f1*f3;
2750
2718
  }
@@ -2754,32 +2722,36 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2754
2722
  *s = sumf;
2755
2723
  }
2756
2724
 
2757
- static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2725
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
2726
  const int nb = n / QK8_0;
2759
2727
 
2760
2728
  assert(n % QK8_0 == 0);
2761
2729
  assert(nb % 2 == 0);
2730
+ assert(QK8_0 == 2*QK4_2);
2762
2731
 
2763
- const block_q4_0 * restrict x = vx;
2732
+ const block_q4_2 * restrict x = vx;
2764
2733
  const block_q8_0 * restrict y = vy;
2765
2734
 
2766
2735
  float sumf = 0.0;
2767
2736
 
2768
2737
  #if defined(__ARM_NEON)
2769
- float sum0 = 0.0f;
2770
- float sum1 = 0.0f;
2738
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2771
2740
 
2772
2741
  for (int i = 0; i < nb; i += 2) {
2773
- const block_q4_0 * restrict x0 = &x[i + 0];
2774
- const block_q4_0 * restrict x1 = &x[i + 1];
2742
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
2743
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
2744
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
2745
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
2746
+
2775
2747
  const block_q8_0 * restrict y0 = &y[i + 0];
2776
2748
  const block_q8_0 * restrict y1 = &y[i + 1];
2777
2749
 
2778
2750
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
2751
  const int8x16_t s8b = vdupq_n_s8(0x8);
2780
2752
 
2781
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2753
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2754
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2783
2755
 
2784
2756
  // 4-bit -> 8-bit
2785
2757
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
@@ -2793,77 +2765,78 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2793
2765
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
2766
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2795
2767
 
2768
+ // interleave
2769
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2770
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2771
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2772
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2773
+
2796
2774
  // load y
2797
2775
  const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
2776
  const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
2777
  const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
2778
  const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2801
2779
 
2802
- // interleave
2803
- const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
- const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
- const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
- const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2807
-
2808
2780
  #if defined(__ARM_FEATURE_DOTPROD)
2809
- // dot product into int32x4_t
2810
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2812
-
2813
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2781
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2782
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2815
2784
 
2816
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2785
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2818
2788
  #else
2819
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2823
-
2824
- const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
- const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
- const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
- const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
-
2829
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
-
2832
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
-
2835
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
-
2838
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2789
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2793
+
2794
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2798
+
2799
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2803
+
2804
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2807
+
2808
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2840
2811
  #endif
2841
2812
  }
2842
2813
 
2843
- sumf = sum0 + sum1;
2814
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2844
2815
  #elif defined(__AVX2__)
2845
2816
  // Initialize accumulator with zeros
2846
2817
  __m256 acc = _mm256_setzero_ps();
2847
2818
 
2848
2819
  // Main loop
2849
- for (int i = 0; i < nb; ++i) {
2820
+ for (int i = 0; i < nb; i++) {
2850
2821
  /* Compute combined scale for the block */
2851
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2822
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2852
2825
 
2853
- __m256i bx = bytesFromNibbles(x[i].qs);
2826
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
2854
2829
 
2855
2830
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
- const __m256i off = _mm256_set1_epi8( 8 );
2857
- bx = _mm256_sub_epi8( bx, off );
2831
+ const __m256i off = _mm256_set1_epi8(8);
2832
+ bx = _mm256_sub_epi8(bx, off);
2858
2833
 
2859
2834
  __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
2835
 
2861
2836
  // Get absolute values of x vectors
2862
2837
  const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
-
2864
2838
  // Sign the values of the y vectors
2865
2839
  const __m256i sy = _mm256_sign_epi8(by, bx);
2866
-
2867
2840
  // Perform multiplication and create 16-bit values
2868
2841
  const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
2842
 
@@ -2871,92 +2844,208 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2871
2844
  __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
2845
 
2873
2846
  /* Convert to vectore of 8 int32_t to 8 floats */
2874
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2847
+ __m256 q = _mm256_cvtepi32_ps(xy_q);
2875
2848
 
2876
2849
  /* Multiply q with scale and accumulate */
2877
- acc = _mm256_fmadd_ps( d, q, acc );
2850
+ acc = _mm256_fmadd_ps(d, q, acc);
2878
2851
  }
2879
2852
 
2880
2853
  // Return horizontal sum of the acc vector
2881
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2854
+ __m128 res = _mm256_extractf128_ps(acc, 1);
2855
+ res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
2885
2858
 
2886
- sumf = _mm_cvtss_f32( res );
2887
- #elif defined(__AVX__)
2888
- // Initialize accumulator with zeros
2889
- __m256 acc = _mm256_setzero_ps();
2859
+ sumf = _mm_cvtss_f32(res);
2860
+ #else
2861
+ // scalar
2862
+ for (int i = 0; i < nb; i++) {
2863
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
2865
+ const int8_t * restrict y0 = y[i].qs;
2890
2866
 
2891
- // Main loop
2892
- for (int i = 0; i < nb; ++i) {
2893
- // Compute combined scale for the block
2894
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2867
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
2895
2869
 
2896
- __m128i i32[2];
2897
- for (int j = 0; j < 2; ++j) {
2898
- // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
- __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2870
+ int sumi_0 = 0;
2871
+ int sumi_1 = 0;
2872
+
2873
+ for (int j = 0; j < QK8_0/4; j++) {
2874
+ const uint8_t v0 = x0[j];
2875
+ const uint8_t v1 = x1[j];
2876
+
2877
+ const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
2879
+
2880
+ const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
2882
+
2883
+ const int i2_0 = y0[2*j + 0];
2884
+ const int i3_0 = y0[2*j + 1];
2885
+
2886
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
2888
+
2889
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
2891
+ }
2892
+
2893
+ sumf += (d0 * y[i].d) * sumi_0;
2894
+ sumf += (d1 * y[i].d) * sumi_1;
2895
+ }
2896
+ #endif
2897
+
2898
+ *s = sumf;
2899
+ }
2900
+
2901
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
+ const int nb = n / QK8_0;
2903
+
2904
+ assert(n % QK8_0 == 0);
2905
+ assert(nb % 2 == 0);
2906
+ assert(QK8_0 == 2*QK4_2);
2907
+
2908
+ const block_q4_3 * restrict x = vx;
2909
+ const block_q8_0 * restrict y = vy;
2910
+
2911
+ float sumf = 0.0;
2912
+
2913
+ #if defined(__ARM_NEON)
2914
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2916
+
2917
+ for (int i = 0; i < nb; i += 2) {
2918
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
+ const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
+ const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2922
+
2923
+ const block_q8_0 * restrict y0 = &y[i + 0];
2924
+ const block_q8_0 * restrict y1 = &y[i + 1];
2925
+
2926
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
+
2928
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
+ const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
+ const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
+
2933
+ const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
+ const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
+ const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
+ const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
+
2938
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
+
2941
+ // 4-bit -> 8-bit
2942
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2901
2946
 
2902
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
- const __m128i off = _mm_set1_epi8( 8 );
2904
- bx = _mm_sub_epi8( bx, off );
2947
+ // interleave
2948
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2905
2952
 
2906
- // Get absolute values of x vectors
2907
- const __m128i ax = _mm_sign_epi8(bx, bx);
2953
+ // load y
2954
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2908
2958
 
2909
- // Sign the values of the y vectors
2910
- const __m128i sy = _mm_sign_epi8(by, bx);
2959
+ const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
+ const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2911
2961
 
2912
- // Perform multiplication and create 16-bit values
2913
- const __m128i dot = _mm_maddubs_epi16(ax, sy);
2962
+ const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
+ const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2914
2964
 
2915
- const __m128i ones = _mm_set1_epi16(1);
2916
- i32[j] = _mm_madd_epi16(ones, dot);
2917
- }
2965
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
2918
2969
 
2919
- // Convert int32_t to float
2920
- __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
- // Apply the scale, and accumulate
2922
- acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2970
+ #if defined(__ARM_FEATURE_DOTPROD)
2971
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
+ #else
2976
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
+
2981
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2985
+
2986
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2990
+
2991
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
2995
+ #endif
2923
2996
  }
2924
2997
 
2925
- // Return horizontal sum of the acc vector
2926
- __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
- res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
- res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
- res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
-
2931
- sumf = _mm_cvtss_f32( res );
2998
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2932
2999
  #else
2933
3000
  // scalar
2934
3001
  for (int i = 0; i < nb; i++) {
2935
- const float d0 = x[i].d;
2936
- const float d1 = y[i].d;
3002
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
+ const int8_t * restrict y0 = y[i].qs;
2937
3005
 
2938
- const uint8_t * restrict p0 = x[i].qs;
2939
- const int8_t * restrict p1 = y[i].qs;
3006
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
2940
3010
 
2941
- int sumi = 0;
2942
- for (int j = 0; j < QK8_0/2; j++) {
2943
- const uint8_t v0 = p0[j];
3011
+ int sy_0 = 0;
3012
+ int sy_1 = 0;
2944
3013
 
2945
- const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
- const int i1 = (int8_t) (v0 >> 4) - 8;
3014
+ int sxy_0 = 0;
3015
+ int sxy_1 = 0;
2947
3016
 
2948
- const int i2 = p1[2*j + 0];
2949
- const int i3 = p1[2*j + 1];
3017
+ for (int j = 0; j < QK8_0/4; j++) {
3018
+ const uint8_t v0 = x0[j];
3019
+ const uint8_t v1 = x1[j];
2950
3020
 
2951
- sumi += i0*i2 + i1*i3;
3021
+ const int x0_0 = v0 & 0xf;
3022
+ const int x1_0 = v0 >> 4;
3023
+
3024
+ const int x0_1 = v1 & 0xf;
3025
+ const int x1_1 = v1 >> 4;
3026
+
3027
+ const int y0_0 = y0[2*j + 0];
3028
+ const int y1_0 = y0[2*j + 1];
3029
+
3030
+ const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
+ const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3032
+
3033
+ sy_0 += y0_0 + y1_0;
3034
+ sy_1 += y0_1 + y1_1;
3035
+
3036
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
2952
3038
  }
2953
- sumf += d0*d1*sumi;
3039
+
3040
+ sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
+ sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
2954
3042
  }
2955
3043
  #endif
2956
3044
 
2957
3045
  *s = sumf;
2958
3046
  }
2959
3047
 
3048
+
2960
3049
  // compute GGML_VEC_DOT_UNROLL dot products at once
2961
3050
  // xs - x row stride in bytes
2962
3051
  inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3203,24 +3292,28 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
3203
3292
  [GGML_TYPE_F16] = 1,
3204
3293
  [GGML_TYPE_Q4_0] = QK4_0,
3205
3294
  [GGML_TYPE_Q4_1] = QK4_1,
3295
+ [GGML_TYPE_Q4_2] = QK4_2,
3296
+ [GGML_TYPE_Q4_3] = QK4_3,
3206
3297
  [GGML_TYPE_Q8_0] = QK8_0,
3207
3298
  [GGML_TYPE_I8] = 1,
3208
3299
  [GGML_TYPE_I16] = 1,
3209
3300
  [GGML_TYPE_I32] = 1,
3210
3301
  };
3211
- static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
3302
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
3212
3303
 
3213
3304
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
3214
3305
  [GGML_TYPE_F32] = sizeof(float),
3215
3306
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
3216
3307
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3217
3308
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
+ [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
+ [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3218
3311
  [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3219
3312
  [GGML_TYPE_I8] = sizeof(int8_t),
3220
3313
  [GGML_TYPE_I16] = sizeof(int16_t),
3221
3314
  [GGML_TYPE_I32] = sizeof(int32_t),
3222
3315
  };
3223
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
3316
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
3224
3317
 
3225
3318
 
3226
3319
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3228,12 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3228
3321
  [GGML_TYPE_F16] = "f16",
3229
3322
  [GGML_TYPE_Q4_0] = "q4_0",
3230
3323
  [GGML_TYPE_Q4_1] = "q4_1",
3324
+ [GGML_TYPE_Q4_2] = "q4_2",
3325
+ [GGML_TYPE_Q4_3] = "q4_3",
3231
3326
  [GGML_TYPE_Q8_0] = "q8_0",
3232
3327
  [GGML_TYPE_I8] = "i8",
3233
3328
  [GGML_TYPE_I16] = "i16",
3234
3329
  [GGML_TYPE_I32] = "i32",
3235
3330
  };
3236
- static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
3331
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
3332
+
3333
+ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
+ [GGML_TYPE_F32] = false,
3335
+ [GGML_TYPE_F16] = false,
3336
+ [GGML_TYPE_Q4_0] = true,
3337
+ [GGML_TYPE_Q4_1] = true,
3338
+ [GGML_TYPE_Q4_2] = true,
3339
+ [GGML_TYPE_Q4_3] = true,
3340
+ [GGML_TYPE_Q8_0] = true,
3341
+ [GGML_TYPE_I8] = false,
3342
+ [GGML_TYPE_I16] = false,
3343
+ [GGML_TYPE_I32] = false,
3344
+ };
3345
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
3237
3346
 
3238
3347
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
3239
3348
  "NONE",
@@ -3495,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
3495
3604
  (t0->ne[3] == t1->ne[3]);
3496
3605
  }
3497
3606
 
3607
+ bool ggml_is_quantized(enum ggml_type type) {
3608
+ return GGML_IS_QUANTIZED[type];
3609
+ }
3610
+
3498
3611
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
3499
3612
  return tensor->nb[0] > tensor->nb[1];
3500
3613
  }
@@ -3605,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3605
3718
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3606
3719
  }
3607
3720
 
3721
+ // initialize cuBLAS
3722
+ #if defined(GGML_USE_CUBLAS)
3723
+ init_cublas();
3724
+ #endif
3725
+
3608
3726
  is_first_call = false;
3609
3727
  }
3610
3728
 
@@ -5535,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
5535
5653
  const struct ggml_compute_params * params,
5536
5654
  const struct ggml_tensor * src0,
5537
5655
  struct ggml_tensor * dst) {
5538
- GGML_ASSERT(params->ith == 0);
5539
5656
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5540
5657
 
5541
5658
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5547,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
5547
5664
  const int64_t ne02 = src0->ne[2];
5548
5665
  const int64_t ne03 = src0->ne[3];
5549
5666
 
5667
+ const int64_t ne0 = dst->ne[0];
5668
+ const int64_t ne1 = dst->ne[1];
5669
+ const int64_t ne2 = dst->ne[2];
5670
+ const int64_t ne3 = dst->ne[3];
5671
+
5550
5672
  const size_t nb00 = src0->nb[0];
5551
5673
  const size_t nb01 = src0->nb[1];
5552
5674
  const size_t nb02 = src0->nb[2];
@@ -5557,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
5557
5679
  const size_t nb2 = dst->nb[2];
5558
5680
  const size_t nb3 = dst->nb[3];
5559
5681
 
5682
+ const int ith = params->ith; // thread index
5683
+ const int nth = params->nth; // number of threads
5684
+
5560
5685
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5561
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5686
+ // parallelize by elements
5687
+ const int ne = ggml_nelements(dst);
5688
+ const int dr = (ne + nth - 1) / nth;
5689
+ const int ie0 = dr * ith;
5690
+ const int ie1 = MIN(ie0 + dr, ne);
5691
+
5692
+ memcpy(
5693
+ ((char *) dst->data + ie0*nb0),
5694
+ ((char *) src0->data + ie0*nb00),
5695
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5696
+
5562
5697
  return;
5563
5698
  }
5564
5699
 
5700
+ // parallelize by rows
5701
+ const int nr = ne01;
5702
+ // number of rows per thread
5703
+ const int dr = (nr + nth - 1) / nth;
5704
+ // row range for this thread
5705
+ const int ir0 = dr * ith;
5706
+ const int ir1 = MIN(ir0 + dr, nr);
5707
+
5565
5708
  if (src0->type == dst->type &&
5566
- src0->ne[0] == dst->ne[0] &&
5567
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5709
+ ne00 == ne0 &&
5710
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5568
5711
  // copy by rows
5569
5712
  const size_t rs = ne00*nb00;
5570
5713
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5571
5714
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5572
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5715
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5573
5716
  memcpy(
5574
5717
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5575
5718
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5583,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
5583
5726
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
5584
5727
 
5585
5728
  if (ggml_is_contiguous(dst)) {
5586
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5729
+ if (nb00 == sizeof(ggml_fp16_t)) {
5587
5730
  if (dst->type == GGML_TYPE_F16) {
5588
5731
  size_t id = 0;
5589
- const size_t rs = ne00*nb00;
5732
+ const size_t rs = ne00 * nb00;
5733
+ char * dst_ptr = (char *) dst->data;
5590
5734
 
5591
5735
  for (int i03 = 0; i03 < ne03; i03++) {
5592
5736
  for (int i02 = 0; i02 < ne02; i02++) {
5593
- for (int i01 = 0; i01 < ne01; i01++) {
5737
+ id += rs * ir0;
5738
+ for (int i01 = ir0; i01 < ir1; i01++) {
5594
5739
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5595
- char * dst_ptr = (char *) dst->data + id*rs;
5596
-
5597
- memcpy(dst_ptr, src0_ptr, rs);
5598
-
5599
- id++;
5740
+ memcpy(dst_ptr + id, src0_ptr, rs);
5741
+ id += rs;
5600
5742
  }
5743
+ id += rs * (ne01 - ir1);
5601
5744
  }
5602
5745
  }
5603
5746
  } else if (dst->type == GGML_TYPE_F32) {
@@ -5606,34 +5749,39 @@ static void ggml_compute_forward_dup_f16(
5606
5749
 
5607
5750
  for (int i03 = 0; i03 < ne03; i03++) {
5608
5751
  for (int i02 = 0; i02 < ne02; i02++) {
5609
- for (int i01 = 0; i01 < ne01; i01++) {
5752
+ id += ne00 * ir0;
5753
+ for (int i01 = ir0; i01 < ir1; i01++) {
5754
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5610
5755
  for (int i00 = 0; i00 < ne00; i00++) {
5611
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5612
-
5613
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5756
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5614
5757
  id++;
5615
5758
  }
5616
5759
  }
5760
+ id += ne00 * (ne01 - ir1);
5617
5761
  }
5618
5762
  }
5619
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5763
+ } else if (ggml_is_quantized(dst->type)) {
5620
5764
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5765
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5766
+
5621
5767
  size_t id = 0;
5622
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
- float * src0_f32 = (float *) params->wdata;
5768
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5769
+ char * dst_ptr = (char *) dst->data;
5625
5770
 
5626
5771
  for (int i03 = 0; i03 < ne03; i03++) {
5627
5772
  for (int i02 = 0; i02 < ne02; i02++) {
5628
- for (int i01 = 0; i01 < ne01; i01++) {
5773
+ id += rs * ir0;
5774
+ for (int i01 = ir0; i01 < ir1; i01++) {
5629
5775
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
- // convert to f32 and quantize
5776
+
5631
5777
  for (int i00 = 0; i00 < ne00; i00++) {
5632
5778
  src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
5779
  }
5780
+
5634
5781
  quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
- id += dst_row_size;
5782
+ id += rs;
5636
5783
  }
5784
+ id += rs * (ne01 - ir1);
5637
5785
  }
5638
5786
  }
5639
5787
  } else {
@@ -5648,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
5648
5796
 
5649
5797
  for (int i03 = 0; i03 < ne03; i03++) {
5650
5798
  for (int i02 = 0; i02 < ne02; i02++) {
5651
- for (int i01 = 0; i01 < ne01; i01++) {
5799
+ id += ne00 * ir0;
5800
+ for (int i01 = ir0; i01 < ir1; i01++) {
5652
5801
  for (int i00 = 0; i00 < ne00; i00++) {
5653
5802
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5654
5803
 
@@ -5656,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
5656
5805
  id++;
5657
5806
  }
5658
5807
  }
5808
+ id += ne00 * (ne01 - ir1);
5659
5809
  }
5660
5810
  }
5661
5811
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5664,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
5664
5814
 
5665
5815
  for (int i03 = 0; i03 < ne03; i03++) {
5666
5816
  for (int i02 = 0; i02 < ne02; i02++) {
5667
- for (int i01 = 0; i01 < ne01; i01++) {
5817
+ id += ne00 * ir0;
5818
+ for (int i01 = ir0; i01 < ir1; i01++) {
5668
5819
  for (int i00 = 0; i00 < ne00; i00++) {
5669
5820
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5670
5821
 
@@ -5672,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
5672
5823
  id++;
5673
5824
  }
5674
5825
  }
5826
+ id += ne00 * (ne01 - ir1);
5675
5827
  }
5676
5828
  }
5677
5829
  } else {
@@ -5690,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
5690
5842
  if (dst->type == GGML_TYPE_F16) {
5691
5843
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5692
5844
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5693
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5845
+ i10 += ne00 * ir0;
5846
+ while (i10 >= ne0) {
5847
+ i10 -= ne0;
5848
+ if (++i11 == ne1) {
5849
+ i11 = 0;
5850
+ if (++i12 == ne2) {
5851
+ i12 = 0;
5852
+ if (++i13 == ne3) {
5853
+ i13 = 0;
5854
+ }
5855
+ }
5856
+ }
5857
+ }
5858
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5694
5859
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5695
5860
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5696
5861
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@@ -5711,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
5711
5876
  }
5712
5877
  }
5713
5878
  }
5879
+ i10 += ne00 * (ne01 - ir1);
5880
+ while (i10 >= ne0) {
5881
+ i10 -= ne0;
5882
+ if (++i11 == ne1) {
5883
+ i11 = 0;
5884
+ if (++i12 == ne2) {
5885
+ i12 = 0;
5886
+ if (++i13 == ne3) {
5887
+ i13 = 0;
5888
+ }
5889
+ }
5890
+ }
5891
+ }
5714
5892
  }
5715
5893
  }
5716
5894
  } else if (dst->type == GGML_TYPE_F32) {
5717
5895
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5718
5896
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5719
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5897
+ i10 += ne00 * ir0;
5898
+ while (i10 >= ne0) {
5899
+ i10 -= ne0;
5900
+ if (++i11 == ne1) {
5901
+ i11 = 0;
5902
+ if (++i12 == ne2) {
5903
+ i12 = 0;
5904
+ if (++i13 == ne3) {
5905
+ i13 = 0;
5906
+ }
5907
+ }
5908
+ }
5909
+ }
5910
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5720
5911
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5721
5912
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5722
5913
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5723
5914
 
5724
5915
  *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5725
5916
 
5726
- if (++i10 == ne00) {
5917
+ if (++i10 == ne0) {
5727
5918
  i10 = 0;
5728
- if (++i11 == ne01) {
5919
+ if (++i11 == ne1) {
5729
5920
  i11 = 0;
5730
- if (++i12 == ne02) {
5921
+ if (++i12 == ne2) {
5731
5922
  i12 = 0;
5732
- if (++i13 == ne03) {
5923
+ if (++i13 == ne3) {
5733
5924
  i13 = 0;
5734
5925
  }
5735
5926
  }
@@ -5737,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
5737
5928
  }
5738
5929
  }
5739
5930
  }
5931
+ i10 += ne00 * (ne01 - ir1);
5932
+ while (i10 >= ne0) {
5933
+ i10 -= ne0;
5934
+ if (++i11 == ne1) {
5935
+ i11 = 0;
5936
+ if (++i12 == ne2) {
5937
+ i12 = 0;
5938
+ if (++i13 == ne3) {
5939
+ i13 = 0;
5940
+ }
5941
+ }
5942
+ }
5943
+ }
5740
5944
  }
5741
5945
  }
5742
5946
  } else {
@@ -5748,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
5748
5952
  const struct ggml_compute_params * params,
5749
5953
  const struct ggml_tensor * src0,
5750
5954
  struct ggml_tensor * dst) {
5751
- GGML_ASSERT(params->ith == 0);
5752
5955
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5753
5956
 
5754
5957
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5760,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
5760
5963
  const int64_t ne02 = src0->ne[2];
5761
5964
  const int64_t ne03 = src0->ne[3];
5762
5965
 
5966
+ const int64_t ne0 = dst->ne[0];
5967
+ const int64_t ne1 = dst->ne[1];
5968
+ const int64_t ne2 = dst->ne[2];
5969
+ const int64_t ne3 = dst->ne[3];
5970
+
5763
5971
  const size_t nb00 = src0->nb[0];
5764
5972
  const size_t nb01 = src0->nb[1];
5765
5973
  const size_t nb02 = src0->nb[2];
@@ -5770,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
5770
5978
  const size_t nb2 = dst->nb[2];
5771
5979
  const size_t nb3 = dst->nb[3];
5772
5980
 
5981
+ const int ith = params->ith; // thread index
5982
+ const int nth = params->nth; // number of threads
5983
+
5773
5984
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5774
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5985
+ // parallelize by elements
5986
+ const int ne = ggml_nelements(dst);
5987
+ const int dr = (ne + nth - 1) / nth;
5988
+ const int ie0 = dr * ith;
5989
+ const int ie1 = MIN(ie0 + dr, ne);
5990
+
5991
+ memcpy(
5992
+ ((char *) dst->data + ie0*nb0),
5993
+ ((char *) src0->data + ie0*nb00),
5994
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5995
+
5775
5996
  return;
5776
5997
  }
5777
5998
 
5999
+ // parallelize by rows
6000
+ const int nr = ne01;
6001
+ // number of rows per thread
6002
+ const int dr = (nr + nth - 1) / nth;
6003
+ // row range for this thread
6004
+ const int ir0 = dr * ith;
6005
+ const int ir1 = MIN(ir0 + dr, nr);
6006
+
5778
6007
  if (src0->type == dst->type &&
5779
- src0->ne[0] == dst->ne[0] &&
5780
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6008
+ ne00 == ne0 &&
6009
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5781
6010
  // copy by rows
5782
6011
  const size_t rs = ne00*nb00;
5783
6012
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5784
6013
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5785
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6014
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5786
6015
  memcpy(
5787
6016
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5788
6017
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5795,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
5795
6024
 
5796
6025
  if (ggml_is_contiguous(dst)) {
5797
6026
  // TODO: simplify
5798
- if (src0->nb[0] == sizeof(float)) {
6027
+ if (nb00 == sizeof(float)) {
5799
6028
  if (dst->type == GGML_TYPE_F32) {
5800
6029
  size_t id = 0;
5801
- const size_t rs = ne00*nb00;
6030
+ const size_t rs = ne00 * nb00;
6031
+ char * dst_ptr = (char *) dst->data;
5802
6032
 
5803
6033
  for (int i03 = 0; i03 < ne03; i03++) {
5804
6034
  for (int i02 = 0; i02 < ne02; i02++) {
5805
- for (int i01 = 0; i01 < ne01; i01++) {
6035
+ id += rs * ir0;
6036
+ for (int i01 = ir0; i01 < ir1; i01++) {
5806
6037
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5807
- char * dst_ptr = (char *) dst->data + id*rs;
5808
-
5809
- memcpy(dst_ptr, src0_ptr, rs);
5810
-
5811
- id++;
6038
+ memcpy(dst_ptr + id, src0_ptr, rs);
6039
+ id += rs;
5812
6040
  }
6041
+ id += rs * (ne01 - ir1);
5813
6042
  }
5814
6043
  }
5815
6044
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5818,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
5818
6047
 
5819
6048
  for (int i03 = 0; i03 < ne03; i03++) {
5820
6049
  for (int i02 = 0; i02 < ne02; i02++) {
5821
- for (int i01 = 0; i01 < ne01; i01++) {
6050
+ id += ne00 * ir0;
6051
+ for (int i01 = ir0; i01 < ir1; i01++) {
5822
6052
  for (int i00 = 0; i00 < ne00; i00++) {
5823
6053
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5824
6054
 
@@ -5826,21 +6056,25 @@ static void ggml_compute_forward_dup_f32(
5826
6056
  id++;
5827
6057
  }
5828
6058
  }
6059
+ id += ne00 * (ne01 - ir1);
5829
6060
  }
5830
6061
  }
5831
- } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
6062
+ } else if (ggml_is_quantized(dst->type)) {
5832
6063
  quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6064
+
5833
6065
  size_t id = 0;
5834
- uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
- size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6066
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6067
+ char * dst_ptr = (char *) dst->data;
5836
6068
 
5837
6069
  for (int i03 = 0; i03 < ne03; i03++) {
5838
6070
  for (int i02 = 0; i02 < ne02; i02++) {
5839
- for (int i01 = 0; i01 < ne01; i01++) {
6071
+ id += rs * ir0;
6072
+ for (int i01 = ir0; i01 < ir1; i01++) {
5840
6073
  const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
6074
  quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
- id += dst_row_size;
6075
+ id += rs;
5843
6076
  }
6077
+ id += rs * (ne01 - ir1);
5844
6078
  }
5845
6079
  }
5846
6080
  } else {
@@ -5855,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
5855
6089
 
5856
6090
  for (int i03 = 0; i03 < ne03; i03++) {
5857
6091
  for (int i02 = 0; i02 < ne02; i02++) {
5858
- for (int i01 = 0; i01 < ne01; i01++) {
6092
+ id += ne00 * ir0;
6093
+ for (int i01 = ir0; i01 < ir1; i01++) {
5859
6094
  for (int i00 = 0; i00 < ne00; i00++) {
5860
6095
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5861
6096
 
@@ -5863,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
5863
6098
  id++;
5864
6099
  }
5865
6100
  }
6101
+ id += ne00 * (ne01 - ir1);
5866
6102
  }
5867
6103
  }
5868
6104
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5871,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
5871
6107
 
5872
6108
  for (int i03 = 0; i03 < ne03; i03++) {
5873
6109
  for (int i02 = 0; i02 < ne02; i02++) {
5874
- for (int i01 = 0; i01 < ne01; i01++) {
6110
+ id += ne00 * ir0;
6111
+ for (int i01 = ir0; i01 < ir1; i01++) {
5875
6112
  for (int i00 = 0; i00 < ne00; i00++) {
5876
6113
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5877
6114
 
@@ -5879,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
5879
6116
  id++;
5880
6117
  }
5881
6118
  }
6119
+ id += ne00 * (ne01 - ir1);
5882
6120
  }
5883
6121
  }
5884
6122
  } else {
@@ -5890,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
5890
6128
  }
5891
6129
 
5892
6130
  // dst counters
6131
+
5893
6132
  int64_t i10 = 0;
5894
6133
  int64_t i11 = 0;
5895
6134
  int64_t i12 = 0;
@@ -5898,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
5898
6137
  if (dst->type == GGML_TYPE_F32) {
5899
6138
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5900
6139
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5901
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6140
+ i10 += ne00 * ir0;
6141
+ while (i10 >= ne0) {
6142
+ i10 -= ne0;
6143
+ if (++i11 == ne1) {
6144
+ i11 = 0;
6145
+ if (++i12 == ne2) {
6146
+ i12 = 0;
6147
+ if (++i13 == ne3) {
6148
+ i13 = 0;
6149
+ }
6150
+ }
6151
+ }
6152
+ }
6153
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5902
6154
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5903
6155
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5904
6156
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5905
6157
 
5906
6158
  memcpy(dst_ptr, src0_ptr, sizeof(float));
5907
6159
 
5908
- if (++i10 == dst->ne[0]) {
6160
+ if (++i10 == ne0) {
5909
6161
  i10 = 0;
5910
- if (++i11 == dst->ne[1]) {
6162
+ if (++i11 == ne1) {
5911
6163
  i11 = 0;
5912
- if (++i12 == dst->ne[2]) {
6164
+ if (++i12 == ne2) {
5913
6165
  i12 = 0;
5914
- if (++i13 == dst->ne[3]) {
6166
+ if (++i13 == ne3) {
5915
6167
  i13 = 0;
5916
6168
  }
5917
6169
  }
@@ -5919,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
5919
6171
  }
5920
6172
  }
5921
6173
  }
6174
+ i10 += ne00 * (ne01 - ir1);
6175
+ while (i10 >= ne0) {
6176
+ i10 -= ne0;
6177
+ if (++i11 == ne1) {
6178
+ i11 = 0;
6179
+ if (++i12 == ne2) {
6180
+ i12 = 0;
6181
+ if (++i13 == ne3) {
6182
+ i13 = 0;
6183
+ }
6184
+ }
6185
+ }
6186
+ }
5922
6187
  }
5923
6188
  }
5924
6189
  } else if (dst->type == GGML_TYPE_F16) {
5925
6190
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5926
6191
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5927
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6192
+ i10 += ne00 * ir0;
6193
+ while (i10 >= ne0) {
6194
+ i10 -= ne0;
6195
+ if (++i11 == ne1) {
6196
+ i11 = 0;
6197
+ if (++i12 == ne2) {
6198
+ i12 = 0;
6199
+ if (++i13 == ne3) {
6200
+ i13 = 0;
6201
+ }
6202
+ }
6203
+ }
6204
+ }
6205
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5928
6206
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5929
6207
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5930
6208
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5931
6209
 
5932
6210
  *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5933
6211
 
5934
- if (++i10 == dst->ne[0]) {
6212
+ if (++i10 == ne0) {
5935
6213
  i10 = 0;
5936
- if (++i11 == dst->ne[1]) {
6214
+ if (++i11 == ne1) {
5937
6215
  i11 = 0;
5938
- if (++i12 == dst->ne[2]) {
6216
+ if (++i12 == ne2) {
5939
6217
  i12 = 0;
5940
- if (++i13 == dst->ne[3]) {
6218
+ if (++i13 == ne3) {
5941
6219
  i13 = 0;
5942
6220
  }
5943
6221
  }
@@ -5945,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
5945
6223
  }
5946
6224
  }
5947
6225
  }
6226
+ i10 += ne00 * (ne01 - ir1);
6227
+ while (i10 >= ne0) {
6228
+ i10 -= ne0;
6229
+ if (++i11 == ne1) {
6230
+ i11 = 0;
6231
+ if (++i12 == ne2) {
6232
+ i12 = 0;
6233
+ if (++i13 == ne3) {
6234
+ i13 = 0;
6235
+ }
6236
+ }
6237
+ }
6238
+ }
5948
6239
  }
5949
6240
  }
5950
6241
  } else {
@@ -6191,7 +6482,7 @@ static void ggml_compute_forward_add_q_f32(
6191
6482
  GGML_ASSERT(nb1 <= nb2);
6192
6483
  GGML_ASSERT(nb2 <= nb3);
6193
6484
 
6194
- GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
6485
+ GGML_ASSERT(ggml_is_quantized(src0->type));
6195
6486
  GGML_ASSERT(dst->type == src0->type);
6196
6487
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
6488
 
@@ -6205,7 +6496,7 @@ static void ggml_compute_forward_add_q_f32(
6205
6496
  const int ir0 = dr*ith;
6206
6497
  const int ir1 = MIN(ir0 + dr, nr);
6207
6498
 
6208
- float * wdata = (float*) params->wdata + ne00 * ith;
6499
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6209
6500
 
6210
6501
  for (int ir = ir0; ir < ir1; ++ir) {
6211
6502
  // src0 indices
@@ -6261,6 +6552,8 @@ static void ggml_compute_forward_add(
6261
6552
  } break;
6262
6553
  case GGML_TYPE_Q4_0:
6263
6554
  case GGML_TYPE_Q4_1:
6555
+ case GGML_TYPE_Q4_2:
6556
+ case GGML_TYPE_Q4_3:
6264
6557
  {
6265
6558
  ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
6559
  } break;
@@ -7161,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
7161
7454
 
7162
7455
  // ggml_compute_forward_mul_mat
7163
7456
 
7164
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7457
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7165
7458
  // helper function to determine if it is better to use BLAS or not
7166
7459
  // for large matrices, BLAS is faster
7167
7460
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -7201,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
7201
7494
  const int64_t ne02 = src0->ne[2];
7202
7495
  const int64_t ne03 = src0->ne[3];
7203
7496
 
7204
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7497
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7205
7498
  const int64_t ne10 = src1->ne[0];
7206
7499
  #endif
7207
7500
  const int64_t ne11 = src1->ne[1];
@@ -7258,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
7258
7551
  // nb01 >= nb00 - src0 is not transposed
7259
7552
  // compute by src0 rows
7260
7553
 
7261
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7554
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7262
7555
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7263
7556
  if (params->ith != 0) {
7264
7557
  return;
@@ -7272,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
7272
7565
  return;
7273
7566
  }
7274
7567
 
7568
+ #if defined(GGML_USE_CUBLAS)
7569
+ float *d_X = NULL;
7570
+ float *d_Y = NULL;
7571
+ float *d_D = NULL;
7572
+ const float alpha = 1.0f;
7573
+ const float beta = 0.0f;
7574
+ const int x_ne = ne01 * ne10;
7575
+ const int y_ne = ne11 * ne10;
7576
+ const int d_ne = ne11 * ne01;
7577
+
7578
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7581
+ #endif
7582
+
7275
7583
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7276
7584
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7277
7585
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -7279,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
7279
7587
 
7280
7588
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7281
7589
 
7590
+ #if defined(GGML_USE_CUBLAS)
7591
+ // copy data to device
7592
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7594
+
7595
+ // compute
7596
+ CUBLAS_CHECK(
7597
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
+ ne01, ne11, ne10,
7599
+ &alpha, d_X, ne00,
7600
+ d_Y, ne10,
7601
+ &beta, d_D, ne01));
7602
+
7603
+ // copy data to host
7604
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
+ #else
7282
7606
  // zT = y * xT
7283
7607
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7284
7608
  ne11, ne01, ne10,
7285
7609
  1.0f, y, ne10,
7286
7610
  x, ne00,
7287
7611
  0.0f, d, ne01);
7612
+ #endif
7288
7613
  }
7289
7614
  }
7290
-
7615
+ #if defined(GGML_USE_CUBLAS)
7616
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
+ CUDA_CHECK(cudaFree(d_X));
7618
+ CUDA_CHECK(cudaFree(d_Y));
7619
+ CUDA_CHECK(cudaFree(d_D));
7620
+ #endif
7291
7621
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7292
7622
 
7293
7623
  return;
@@ -7417,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7417
7747
  // nb01 >= nb00 - src0 is not transposed
7418
7748
  // compute by src0 rows
7419
7749
 
7420
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7750
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7421
7751
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7422
7752
  GGML_ASSERT(nb10 == sizeof(float));
7423
7753
 
@@ -7433,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7433
7763
  return;
7434
7764
  }
7435
7765
 
7436
- float * const wdata = params->wdata;
7766
+ #if defined(GGML_USE_CUBLAS)
7767
+ ggml_fp16_t * const wdata = params->wdata;
7437
7768
 
7769
+ float *d_X = NULL;
7770
+ float *d_Y = NULL;
7771
+ float *d_D = NULL;
7772
+ const float alpha = 1.0f;
7773
+ const float beta = 0.0f;
7774
+ const int x_ne = ne01 * ne10;
7775
+ const int y_ne = ne11 * ne10;
7776
+ const int d_ne = ne11 * ne01;
7777
+
7778
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7781
+ #else
7782
+ float * const wdata = params->wdata;
7783
+ #endif
7438
7784
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7439
7785
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7786
+ #if defined(GGML_USE_CUBLAS)
7787
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7788
+ {
7789
+ size_t id = 0;
7790
+ for (int64_t i01 = 0; i01 < ne11; ++i01) {
7791
+ for (int64_t i00 = 0; i00 < ne10; ++i00) {
7792
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
7793
+ }
7794
+ }
7795
+ }
7796
+ #else
7440
7797
  {
7441
7798
  size_t id = 0;
7442
7799
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7445,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7445
7802
  }
7446
7803
  }
7447
7804
  }
7805
+ #endif
7806
+
7807
+ #if defined(GGML_USE_CUBLAS)
7808
+ const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
7809
+ const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
7810
+
7811
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7448
7812
 
7813
+ // copy data to device
7814
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7816
+
7817
+ // compute
7818
+ CUBLAS_CHECK(
7819
+ cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
+ ne01, ne11, ne10,
7821
+ &alpha, d_X, CUDA_R_16F, ne00,
7822
+ d_Y, CUDA_R_16F, ne10,
7823
+ &beta, d_D, CUDA_R_32F, ne01,
7824
+ CUBLAS_COMPUTE_32F,
7825
+ CUBLAS_GEMM_DEFAULT));
7826
+
7827
+ // copy data to host
7828
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7829
+ #else
7449
7830
  const float * x = wdata;
7450
7831
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
7451
7832
 
@@ -7457,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
7457
7838
  1.0f, y, ne10,
7458
7839
  x, ne00,
7459
7840
  0.0f, d, ne01);
7841
+ #endif
7460
7842
  }
7461
7843
  }
7462
7844
 
7845
+ #if defined(GGML_USE_CUBLAS)
7846
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
+ CUDA_CHECK(cudaFree(d_X));
7848
+ CUDA_CHECK(cudaFree(d_Y));
7849
+ CUDA_CHECK(cudaFree(d_D));
7850
+ #endif
7463
7851
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
7464
7852
 
7465
7853
  return;
@@ -7611,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7611
7999
  // nb01 >= nb00 - src0 is not transposed
7612
8000
  // compute by src0 rows
7613
8001
 
7614
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8002
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
7615
8003
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
7616
8004
  if (params->ith != 0) {
7617
8005
  return;
@@ -7625,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
7625
8013
  return;
7626
8014
  }
7627
8015
 
8016
+ #if defined(GGML_USE_CUBLAS)
8017
+ float *d_X = NULL;
8018
+ float *d_Y = NULL;
8019
+ float *d_D = NULL;
8020
+ float *d_Q = NULL;
8021
+ const float alpha = 1.0f;
8022
+ const float beta = 0.0f;
8023
+ const int x_ne = ne01 * ne10;
8024
+ const int y_ne = ne11 * ne10;
8025
+ const int d_ne = ne11 * ne01;
8026
+
8027
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
+ CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8031
+
8032
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
+ if (type == GGML_TYPE_Q4_0) {
8034
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8035
+ }
8036
+ else if (type == GGML_TYPE_Q4_1) {
8037
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8038
+ }
8039
+ else if (type == GGML_TYPE_Q4_2) {
8040
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
+ }
8042
+ else {
8043
+ GGML_ASSERT(false);
8044
+ }
8045
+ #else
7628
8046
  float * const wdata = params->wdata;
7629
8047
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
+ #endif
7630
8049
 
7631
8050
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7632
8051
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8052
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8053
+
8054
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8055
+
8056
+ #if defined(GGML_USE_CUBLAS)
8057
+ // copy and dequantize on device
8058
+ CUDA_CHECK(
8059
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8061
+
8062
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8063
+ CUDA_CHECK(cudaGetLastError());
8064
+ #else
7633
8065
  {
7634
8066
  size_t id = 0;
7635
8067
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7637,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
7637
8069
  id += ne00;
7638
8070
  }
7639
8071
  }
7640
-
7641
8072
  const float * x = wdata;
7642
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8073
+ #endif
7643
8074
 
7644
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7645
8075
 
8076
+ #if defined(GGML_USE_CUBLAS)
8077
+ // copy data to device
8078
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8079
+
8080
+ // compute
8081
+ CUBLAS_CHECK(
8082
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
+ ne01, ne11, ne10,
8084
+ &alpha, d_X, ne00,
8085
+ d_Y, ne10,
8086
+ &beta, d_D, ne01));
8087
+
8088
+ // copy data to host
8089
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
+ #else
7646
8091
  // zT = y * xT
7647
8092
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7648
8093
  ne11, ne01, ne10,
7649
8094
  1.0f, y, ne10,
7650
8095
  x, ne00,
7651
8096
  0.0f, d, ne01);
8097
+ #endif
7652
8098
  }
7653
8099
  }
7654
8100
 
8101
+ #if defined(GGML_USE_CUBLAS)
8102
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
+ CUDA_CHECK(cudaFree(d_X));
8104
+ CUDA_CHECK(cudaFree(d_Y));
8105
+ CUDA_CHECK(cudaFree(d_D));
8106
+ CUDA_CHECK(cudaFree(d_Q));
8107
+ #endif
7655
8108
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7656
8109
 
7657
8110
  return;
@@ -7739,6 +8192,8 @@ static void ggml_compute_forward_mul_mat(
7739
8192
  switch (src0->type) {
7740
8193
  case GGML_TYPE_Q4_0:
7741
8194
  case GGML_TYPE_Q4_1:
8195
+ case GGML_TYPE_Q4_2:
8196
+ case GGML_TYPE_Q4_3:
7742
8197
  case GGML_TYPE_Q8_0:
7743
8198
  {
7744
8199
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
@@ -7756,34 +8211,6 @@ static void ggml_compute_forward_mul_mat(
7756
8211
  GGML_ASSERT(false);
7757
8212
  } break;
7758
8213
  }
7759
-
7760
- #if 0
7761
- if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
7762
- static int first = 8;
7763
- printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7764
- printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7765
- printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7766
- if (first) {
7767
- --first;
7768
- } else {
7769
- for (int k = 0; k < dst->ne[1]; ++k) {
7770
- for (int j = 0; j < dst->ne[0]/16; ++j) {
7771
- for (int i = 0; i < 16; ++i) {
7772
- printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
7773
- }
7774
- printf("\n");
7775
- }
7776
- printf("\n");
7777
- }
7778
- printf("\n");
7779
- exit(0);
7780
- }
7781
- } else {
7782
- printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7783
- printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7784
- printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7785
- }
7786
- #endif
7787
8214
  }
7788
8215
 
7789
8216
  // ggml_compute_forward_scale
@@ -7994,6 +8421,8 @@ static void ggml_compute_forward_get_rows(
7994
8421
  switch (src0->type) {
7995
8422
  case GGML_TYPE_Q4_0:
7996
8423
  case GGML_TYPE_Q4_1:
8424
+ case GGML_TYPE_Q4_2:
8425
+ case GGML_TYPE_Q4_3:
7997
8426
  case GGML_TYPE_Q8_0:
7998
8427
  {
7999
8428
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
@@ -8224,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
8224
8653
 
8225
8654
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8226
8655
 
8656
+ const bool is_neox = mode & 2;
8657
+
8227
8658
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8228
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8229
- const int p = (mode == 0 ? n_past + i2 : i2);
8659
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8660
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8230
8661
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8231
8662
  if (ir++ < ir0) continue;
8232
8663
  if (ir > ir1) break;
@@ -8239,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
8239
8670
 
8240
8671
  theta *= theta_scale;
8241
8672
 
8242
- const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8243
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8673
+ if (!is_neox) {
8674
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8675
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8244
8676
 
8245
- const float x0 = src[0];
8246
- const float x1 = src[1];
8677
+ const float x0 = src[0];
8678
+ const float x1 = src[1];
8247
8679
 
8248
- dst_data[0] = x0*cos_theta - x1*sin_theta;
8249
- dst_data[1] = x0*sin_theta + x1*cos_theta;
8680
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8681
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
8682
+ } else {
8683
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8684
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8685
+
8686
+ const float x0 = src[0];
8687
+ const float x1 = src[n_dims/2];
8688
+
8689
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8690
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
8691
+ }
8250
8692
  }
8251
8693
  }
8252
8694
  }
@@ -8301,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
8301
8743
 
8302
8744
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
8303
8745
 
8746
+ const bool is_neox = mode & 2;
8747
+
8304
8748
  for (int64_t i3 = 0; i3 < ne3; i3++) {
8305
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8306
- const int p = (mode == 0 ? n_past + i2 : i2);
8749
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8750
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
8307
8751
  for (int64_t i1 = 0; i1 < ne1; i1++) {
8308
8752
  if (ir++ < ir0) continue;
8309
8753
  if (ir > ir1) break;
@@ -8316,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
8316
8760
 
8317
8761
  theta *= theta_scale;
8318
8762
 
8319
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8320
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8763
+ if (!is_neox) {
8764
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8765
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8321
8766
 
8322
- const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
- const float x1 = GGML_FP16_TO_FP32(src[1]);
8767
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8768
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
8324
8769
 
8325
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
- dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8770
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8771
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8772
+ } else {
8773
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8774
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8775
+
8776
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8777
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8778
+
8779
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8780
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8781
+ }
8327
8782
  }
8328
8783
  }
8329
8784
  }
@@ -10402,11 +10857,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10402
10857
  case GGML_OP_CPY:
10403
10858
  case GGML_OP_DUP:
10404
10859
  {
10405
- node->n_tasks = 1;
10860
+ node->n_tasks = n_threads;
10406
10861
 
10407
10862
  size_t cur = 0;
10408
- if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
10863
+ if (ggml_is_quantized(node->type)) {
10864
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
10410
10865
  }
10411
10866
 
10412
10867
  work_size = MAX(work_size, cur);
@@ -10417,7 +10872,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10417
10872
 
10418
10873
  size_t cur = 0;
10419
10874
 
10420
- if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
10875
+ if (ggml_is_quantized(node->src0->type)) {
10421
10876
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
10877
  }
10423
10878
 
@@ -10466,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10466
10921
  size_t cur = 0;
10467
10922
 
10468
10923
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
10469
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10924
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
10470
10925
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10471
10926
  node->n_tasks = 1; // TODO: this actually is doing nothing
10472
10927
  // the threads are still spinning
@@ -10482,8 +10937,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
10482
10937
  #endif
10483
10938
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
10484
10939
  cur = 0;
10485
- } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
10486
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10940
+ } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
10487
10942
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
10488
10943
  node->n_tasks = 1;
10489
10944
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -11709,6 +12164,86 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11709
12164
  return (n/QK4_1*sizeof(block_q4_1));
11710
12165
  }
11711
12166
 
12167
+ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
12168
+ assert(k % QK4_2 == 0);
12169
+ const int nb = k / QK4_2;
12170
+
12171
+ for (int j = 0; j < n; j += k) {
12172
+ block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
+
12174
+ //quantize_row_q4_2_reference(src + j, y, k);
12175
+ quantize_row_q4_2_rmse(src + j, y, k);
12176
+
12177
+ for (int i = 0; i < nb; i++) {
12178
+ for (int l = 0; l < QK4_2; l += 2) {
12179
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12180
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
+
12182
+ hist[vi0]++;
12183
+ hist[vi1]++;
12184
+ }
12185
+ }
12186
+ }
12187
+
12188
+ return (n/QK4_2*sizeof(block_q4_2));
12189
+ }
12190
+
12191
+ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12192
+ assert(k % QK4_3 == 0);
12193
+ const int nb = k / QK4_3;
12194
+
12195
+ for (int j = 0; j < n; j += k) {
12196
+ block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
12197
+
12198
+ quantize_row_q4_3_reference(src + j, y, k);
12199
+
12200
+ for (int i = 0; i < nb; i++) {
12201
+ for (int l = 0; l < QK4_3; l += 2) {
12202
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12203
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12204
+
12205
+ hist[vi0]++;
12206
+ hist[vi1]++;
12207
+ }
12208
+ }
12209
+ }
12210
+
12211
+ return (n/QK4_3*sizeof(block_q4_3));
12212
+ }
12213
+
12214
+ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12215
+ size_t result = 0;
12216
+ switch (type) {
12217
+ case GGML_TYPE_Q4_0:
12218
+ {
12219
+ GGML_ASSERT(start % QK4_0 == 0);
12220
+ block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
12221
+ result = ggml_quantize_q4_0(src + start, block, n, n, hist);
12222
+ } break;
12223
+ case GGML_TYPE_Q4_1:
12224
+ {
12225
+ GGML_ASSERT(start % QK4_1 == 0);
12226
+ block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
12227
+ result = ggml_quantize_q4_1(src + start, block, n, n, hist);
12228
+ } break;
12229
+ case GGML_TYPE_Q4_2:
12230
+ {
12231
+ GGML_ASSERT(start % QK4_2 == 0);
12232
+ block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
12233
+ result = ggml_quantize_q4_2(src + start, block, n, n, hist);
12234
+ } break;
12235
+ case GGML_TYPE_Q4_3:
12236
+ {
12237
+ GGML_ASSERT(start % QK4_3 == 0);
12238
+ block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
+ result = ggml_quantize_q4_3(src + start, block, n, n, hist);
12240
+ } break;
12241
+ default:
12242
+ assert(false);
12243
+ }
12244
+ return result;
12245
+ }
12246
+
11712
12247
  ////////////////////////////////////////////////////////////////////////////////
11713
12248
 
11714
12249
  int ggml_cpu_has_avx(void) {
@@ -11800,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
11800
12335
  }
11801
12336
 
11802
12337
  int ggml_cpu_has_blas(void) {
11803
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12338
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
12339
+ return 1;
12340
+ #else
12341
+ return 0;
12342
+ #endif
12343
+ }
12344
+
12345
+ int ggml_cpu_has_cublas(void) {
12346
+ #if defined(GGML_USE_CUBLAS)
11804
12347
  return 1;
11805
12348
  #else
11806
12349
  return 0;