llama_cpp 0.0.4 → 0.0.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -19,6 +19,7 @@
19
19
  #include <inttypes.h>
20
20
  #include <stdio.h>
21
21
  #include <float.h>
22
+ #include <limits.h>
22
23
 
23
24
  // if C99 - static_assert is noop
24
25
  // ref: https://stackoverflow.com/a/53923785/4039976
@@ -118,7 +119,16 @@ typedef void* thread_ret_t;
118
119
  #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
120
  #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
121
  #else
121
- #define GGML_ALIGNED_MALLOC(size) aligned_alloc(GGML_MEM_ALIGN, size)
122
+ inline static void* ggml_aligned_malloc(size_t size) {
123
+ void* aligned_memory = NULL;
124
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
125
+ if (result != 0) {
126
+ // Handle allocation failure
127
+ return NULL;
128
+ }
129
+ return aligned_memory;
130
+ }
131
+ #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
122
132
  #define GGML_ALIGNED_FREE(ptr) free(ptr)
123
133
  #endif
124
134
 
@@ -133,10 +143,49 @@ typedef void* thread_ret_t;
133
143
  } \
134
144
  } while (0)
135
145
 
136
- #ifdef GGML_USE_ACCELERATE
146
+ #if defined(GGML_USE_ACCELERATE)
137
147
  #include <Accelerate/Accelerate.h>
138
- #elif GGML_USE_OPENBLAS
148
+ #elif defined(GGML_USE_OPENBLAS)
139
149
  #include <cblas.h>
150
+ #elif defined(GGML_USE_CUBLAS)
151
+ #include <cublas_v2.h>
152
+ #include <cuda_runtime.h>
153
+ #include "ggml-cuda.h"
154
+
155
+ #define CUDA_CHECK(err) \
156
+ do { \
157
+ cudaError_t err_ = (err); \
158
+ if (err_ != cudaSuccess) { \
159
+ printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
160
+ cudaGetErrorString(err_)); \
161
+ exit(1); \
162
+ } \
163
+ } while (0)
164
+
165
+ #define CUBLAS_CHECK(err) \
166
+ do { \
167
+ cublasStatus_t err_ = (err); \
168
+ if (err_ != CUBLAS_STATUS_SUCCESS) { \
169
+ printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
170
+ exit(1); \
171
+ } \
172
+ } while (0)
173
+
174
+ static cublasHandle_t cublasH = NULL;
175
+ static cudaStream_t cudaStream = NULL;
176
+ static void init_cublas(void) {
177
+ if (cublasH == NULL) {
178
+ // create cublas handle, bind a stream
179
+ CUBLAS_CHECK(cublasCreate(&cublasH));
180
+
181
+ CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
182
+
183
+ CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
184
+
185
+ // configure logging to stdout
186
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
187
+ }
188
+ }
140
189
  #endif
141
190
 
142
191
  #undef MIN
@@ -418,14 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
418
467
  // quantization
419
468
  //
420
469
 
421
- #define QK 32
470
+ #if __AVX__ || __AVX2__ || __AVX512F__
471
+ // Unpack 16 4-bit fields into 16 bytes
472
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
473
+ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474
+ {
475
+ // Load 8 bytes from memory
476
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477
+
478
+ // Expand bytes into uint16_t values
479
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
480
+
481
+ // Unpack values into individual bytes
482
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
483
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
484
+ __m128i low = _mm_and_si128( lowMask, bytes );
485
+ high = _mm_slli_epi16( high, 4 );
486
+ bytes = _mm_or_si128( low, high );
487
+ return bytes;
488
+ }
422
489
 
423
- // AVX routines provided by GH user Const-me
424
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
425
490
  #if __AVX2__ || __AVX512F__
426
491
  // Unpack 32 4-bit fields into 32 bytes
427
492
  // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
428
- static inline __m256i bytesFromNibbles( const uint8_t* rsi )
493
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
429
494
  {
430
495
  // Load 16 bytes from memory
431
496
  __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -456,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
456
521
  __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
457
522
  return _mm_packus_epi16( r0, r1 );
458
523
  }
459
- #elif __AVX__
460
- static inline __m128i bytesFromNibbles( const uint8_t* rsi )
461
- {
462
- // Load 8 bytes from memory
463
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
464
-
465
- // Expand bytes into uint16_t values
466
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
467
-
468
- // Unpack values into individual bytes
469
- const __m128i lowMask = _mm_set1_epi8( 0xF );
470
- __m128i high = _mm_andnot_si128( lowMask, bytes );
471
- __m128i low = _mm_and_si128( lowMask, bytes );
472
- high = _mm_slli_epi16( high, 4 );
473
- bytes = _mm_or_si128( low, high );
474
- return bytes;
475
- }
476
-
524
+ #else
477
525
  static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
478
526
  {
479
527
  // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -490,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
490
538
  return _mm_packus_epi16( bytes1, bytes2);
491
539
  }
492
540
  #endif
541
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
493
542
 
494
543
  #if __ARM_NEON
495
544
 
@@ -507,6 +556,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
507
556
  (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
508
557
  }
509
558
 
559
+ inline static int16_t vaddvq_s8(int8x16_t v) {
560
+ return
561
+ (int16_t)vgetq_lane_s8(v, 0) + (int16_t)vgetq_lane_s8(v, 1) +
562
+ (int16_t)vgetq_lane_s8(v, 2) + (int16_t)vgetq_lane_s8(v, 3) +
563
+ (int16_t)vgetq_lane_s8(v, 4) + (int16_t)vgetq_lane_s8(v, 5) +
564
+ (int16_t)vgetq_lane_s8(v, 6) + (int16_t)vgetq_lane_s8(v, 7) +
565
+ (int16_t)vgetq_lane_s8(v, 8) + (int16_t)vgetq_lane_s8(v, 9) +
566
+ (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
567
+ (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
568
+ (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
569
+ }
570
+
510
571
  inline static int32_t vaddvq_s16(int16x8_t v) {
511
572
  return
512
573
  (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -531,68 +592,88 @@ inline static float vaddvq_f32(float32x4_t v) {
531
592
  return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
532
593
  }
533
594
 
534
- inline float vminvq_f32(float32x4_t v) {
595
+ float vminvq_f32(float32x4_t v) {
535
596
  return
536
597
  MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
537
598
  MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
538
599
  }
539
600
 
540
- inline float vmaxvq_f32(float32x4_t v) {
601
+ float vmaxvq_f32(float32x4_t v) {
541
602
  return
542
603
  MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
543
604
  MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
544
605
  }
545
606
 
546
- inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
607
+ int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
547
608
  return vget_low_s8(vcombine_s8(a, b));
548
609
  }
549
610
 
550
- inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
611
+ int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
551
612
  return vget_high_s8(vcombine_s8(a, b));
552
613
  }
553
614
 
554
- inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
615
+ uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
555
616
  return vget_low_u8(vcombine_u8(a, b));
556
617
  }
557
618
 
558
- inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
619
+ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
559
620
  return vget_high_u8(vcombine_u8(a, b));
560
621
  }
561
622
 
562
623
  #endif
563
624
  #endif
564
625
 
565
- // method 5
566
- // blocks of QK elements
567
- // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
626
+
627
+ #define QK4_0 32
568
628
  typedef struct {
569
- float d; // delta
570
- uint8_t qs[QK / 2]; // nibbles / quants
629
+ float d; // delta
630
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
571
631
  } block_q4_0;
572
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
632
+ static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
573
633
 
574
- // method 4
575
- // blocks of QK elements
576
- // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
634
+ #define QK4_1 32
577
635
  typedef struct {
578
- float d;
579
- float m;
580
- uint8_t qs[QK / 2]; // nibbles / quants
636
+ float d; // delta
637
+ float m; // min
638
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
581
639
  } block_q4_1;
582
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
640
+ static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
641
+
642
+ #define QK4_2 16
643
+ typedef struct {
644
+ ggml_fp16_t d; // delta
645
+ uint8_t qs[QK4_2 / 2]; // nibbles / quants
646
+ } block_q4_2;
647
+ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
648
+
649
+ #define QK4_3 16
650
+ typedef struct {
651
+ ggml_fp16_t d; // delta
652
+ ggml_fp16_t m; // min
653
+ uint8_t qs[QK4_3 / 2]; // nibbles / quants
654
+ } block_q4_3;
655
+ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
656
+
657
+ #define QK8_0 32
658
+ typedef struct {
659
+ float d; // delta
660
+ int8_t qs[QK8_0]; // quants
661
+ } block_q8_0;
662
+ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663
+
583
664
 
584
665
  // reference implementation for deterministic creation of model files
585
666
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
586
- assert(k % QK == 0);
587
- const int nb = k / QK;
667
+ assert(k % QK4_0 == 0);
668
+ const int nb = k / QK4_0;
588
669
 
589
- uint8_t pp[QK/2];
670
+ uint8_t pp[QK4_0/2];
590
671
 
591
672
  for (int i = 0; i < nb; i++) {
592
673
  float amax = 0.0f; // absolute max
593
674
 
594
- for (int l = 0; l < QK; l++) {
595
- const float v = x[i*QK + l];
675
+ for (int l = 0; l < QK4_0; l++) {
676
+ const float v = x[i*QK4_0 + l];
596
677
  amax = MAX(amax, fabsf(v));
597
678
  }
598
679
 
@@ -601,9 +682,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
601
682
 
602
683
  y[i].d = d;
603
684
 
604
- for (int l = 0; l < QK; l += 2) {
605
- const float v0 = x[i*QK + l + 0]*id;
606
- const float v1 = x[i*QK + l + 1]*id;
685
+ for (int l = 0; l < QK4_0; l += 2) {
686
+ const float v0 = x[i*QK4_0 + l + 0]*id;
687
+ const float v1 = x[i*QK4_0 + l + 1]*id;
607
688
 
608
689
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
609
690
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -619,8 +700,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
619
700
  }
620
701
 
621
702
  static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
622
- assert(k % QK == 0);
623
- const int nb = k / QK;
703
+ assert(k % QK4_0 == 0);
704
+ const int nb = k / QK4_0;
624
705
 
625
706
  block_q4_0 * restrict y = vy;
626
707
 
@@ -870,19 +951,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
870
951
  }
871
952
 
872
953
  static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
873
- assert(k % QK == 0);
874
- const int nb = k / QK;
954
+ assert(k % QK4_1 == 0);
955
+ const int nb = k / QK4_1;
875
956
 
876
957
  block_q4_1 * restrict y = vy;
877
958
 
878
- uint8_t pp[QK/2];
959
+ uint8_t pp[QK4_1/2];
879
960
 
880
961
  for (int i = 0; i < nb; i++) {
881
962
  float min = FLT_MAX;
882
963
  float max = -FLT_MAX;
883
964
 
884
- for (int l = 0; l < QK; l++) {
885
- const float v = x[i*QK + l];
965
+ for (int l = 0; l < QK4_1; l++) {
966
+ const float v = x[i*QK4_1 + l];
886
967
  if (v < min) min = v;
887
968
  if (v > max) max = v;
888
969
  }
@@ -893,9 +974,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
893
974
  y[i].d = d;
894
975
  y[i].m = min;
895
976
 
896
- for (int l = 0; l < QK; l += 2) {
897
- const float v0 = (x[i*QK + l + 0] - min)*id;
898
- const float v1 = (x[i*QK + l + 1] - min)*id;
977
+ for (int l = 0; l < QK4_1; l += 2) {
978
+ const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
979
+ const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
899
980
 
900
981
  const uint8_t vi0 = roundf(v0);
901
982
  const uint8_t vi1 = roundf(v1);
@@ -911,9 +992,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
911
992
  }
912
993
 
913
994
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
914
- assert(k % QK == 0);
995
+ assert(k % QK4_1 == 0);
915
996
 
916
- const int nb = k / QK;
997
+ const int nb = k / QK4_1;
917
998
 
918
999
  block_q4_1 * restrict y = vy;
919
1000
 
@@ -997,7 +1078,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
997
1078
  float32x4_t minv[8];
998
1079
  float32x4_t maxv[8];
999
1080
 
1000
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
1081
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
1001
1082
 
1002
1083
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
1003
1084
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -1033,9 +1114,327 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1033
1114
  #endif
1034
1115
  }
1035
1116
 
1117
+ // reference implementation for deterministic creation of model files
1118
+ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
1119
+ assert(k % QK4_2 == 0);
1120
+
1121
+ const int nb = k / QK4_2;
1122
+
1123
+ for (int i = 0; i < nb; i++) {
1124
+ float amax = 0.0f; // absolute max
1125
+
1126
+ for (int l = 0; l < QK4_2; l++) {
1127
+ const float v = x[i*QK4_2 + l];
1128
+ amax = MAX(amax, fabsf(v));
1129
+ }
1130
+
1131
+ const float d = amax / ((1 << 3) - 1);
1132
+
1133
+ const float id = d ? 1.0f/d : 0.0f;
1134
+
1135
+ y[i].d = GGML_FP32_TO_FP16(d);
1136
+
1137
+ for (int l = 0; l < QK4_2; l += 2) {
1138
+ const float v0 = x[i*QK4_2 + l + 0]*id;
1139
+ const float v1 = x[i*QK4_2 + l + 1]*id;
1140
+
1141
+ const uint8_t vi0 = (uint8_t)(v0 + 8.5f);
1142
+ const uint8_t vi1 = (uint8_t)(v1 + 8.5f);
1143
+
1144
+ assert(vi0 < 16);
1145
+ assert(vi1 < 16);
1146
+
1147
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1148
+ }
1149
+ }
1150
+ }
1151
+
1152
+ static inline int nearest_int(float fval) {
1153
+ assert(fval <= 4194303.f);
1154
+ float val = fval + 12582912.f;
1155
+ int i; memcpy(&i, &val, sizeof(int));
1156
+ return (i & 0x007fffff) - 0x00400000;
1157
+ }
1158
+
1159
+ static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
1160
+ const float * restrict candidates, int8_t * restrict L) {
1161
+ assert (nmin >= INT8_MIN);
1162
+ assert (nmax <= INT8_MAX);
1163
+ float amax = 0;
1164
+ for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
1165
+ if (!amax) { // all zero
1166
+ for (int i=0; i<n; ++i) L[i] = 0;
1167
+ return 1.f;
1168
+ }
1169
+ float best = 0, bestScale = 0;
1170
+ for (int si=0; si<nCandidates; ++si) {
1171
+ float iscale = candidates[si]/amax;
1172
+ float sumlxP = 0; int suml2P = 0;
1173
+ float sumlxM = 0; int suml2M = 0;
1174
+ for (int i=0; i<n; ++i) {
1175
+ int l = nearest_int(iscale*X[i]);
1176
+ int lp = MAX(nmin, MIN(nmax, +l));
1177
+ int lm = MAX(nmin, MIN(nmax, -l));
1178
+ sumlxP += X[i]*lp; suml2P += lp*lp;
1179
+ sumlxM += X[i]*lm; suml2M += lm*lm;
1180
+ }
1181
+ float sumlxP2 = sumlxP*sumlxP;
1182
+ float sumlxM2 = sumlxM*sumlxM;
1183
+ if (sumlxP2*suml2M > sumlxM2*suml2P) {
1184
+ if (sumlxP2 > best*suml2P) {
1185
+ best = sumlxP2/suml2P; bestScale = iscale;
1186
+ }
1187
+ } else {
1188
+ if (sumlxM2 > best*suml2M) {
1189
+ best = sumlxM2/suml2M; bestScale = -iscale;
1190
+ }
1191
+ }
1192
+ }
1193
+ float sumlx = 0; int suml2 = 0;
1194
+ for (int i=0; i<n; ++i) {
1195
+ int l = nearest_int(bestScale*X[i]);
1196
+ l = MAX(nmin, MIN(nmax, l));
1197
+ sumlx += X[i]*l; suml2 += l*l;
1198
+ L[i] = l;
1199
+ }
1200
+ float scale = sumlx/suml2;
1201
+ return scale;
1202
+ }
1203
+
1204
+ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
1205
+ #define CANDIDATE_COUNT 8
1206
+ static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
1207
+ assert(k % QK4_2 == 0);
1208
+
1209
+ int8_t L[QK4_2];
1210
+
1211
+ const int nb = k / QK4_2;
1212
+
1213
+ for (int i = 0; i < nb; i++) {
1214
+ float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
1215
+ y[i].d = GGML_FP32_TO_FP16(scale);
1216
+
1217
+ for (int l = 0; l < QK4_2; l += 2) {
1218
+ const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
1219
+ const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
1220
+
1221
+ assert(vi0 < 16);
1222
+ assert(vi1 < 16);
1223
+
1224
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1225
+ }
1226
+
1227
+ x += QK4_2;
1228
+ }
1229
+ }
1230
+
1231
+ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
1232
+ assert(k % QK4_2 == 0);
1233
+
1234
+ block_q4_2 * restrict y = vy;
1235
+
1236
+ //quantize_row_q4_2_reference(x, y, k);
1237
+ // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1238
+ quantize_row_q4_2_rmse(x, y, k);
1239
+ }
1240
+
1241
+ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
1242
+ assert(k % QK4_3 == 0);
1243
+ const int nb = k / QK4_3;
1244
+
1245
+ for (int i = 0; i < nb; i++) {
1246
+ float min = FLT_MAX;
1247
+ float max = -FLT_MAX;
1248
+
1249
+ for (int l = 0; l < QK4_3; l++) {
1250
+ const float v = x[i*QK4_3 + l];
1251
+ if (v < min) min = v;
1252
+ if (v > max) max = v;
1253
+ }
1254
+
1255
+ const float d = (max - min) / ((1 << 4) - 1);
1256
+ const float id = d ? 1.0f/d : 0.0f;
1257
+
1258
+ y[i].d = GGML_FP32_TO_FP16(d);
1259
+ y[i].m = GGML_FP32_TO_FP16(min);
1260
+
1261
+ for (int l = 0; l < QK4_3; l += 2) {
1262
+ const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
1263
+ const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
1264
+
1265
+ const uint8_t vi0 = (int) (v0 + 0.5f);
1266
+ const uint8_t vi1 = (int) (v1 + 0.5f);
1267
+
1268
+ assert(vi0 < 16);
1269
+ assert(vi1 < 16);
1270
+
1271
+ y[i].qs[l/2] = vi0 | (vi1 << 4);
1272
+ }
1273
+ }
1274
+ }
1275
+
1276
+ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
1277
+ assert(k % QK4_3 == 0);
1278
+
1279
+ block_q4_3 * restrict y = vy;
1280
+
1281
+ quantize_row_q4_3_reference(x, y, k);
1282
+ }
1283
+
1284
+ // reference implementation for deterministic creation of model files
1285
+ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1286
+ assert(k % QK8_0 == 0);
1287
+ const int nb = k / QK8_0;
1288
+
1289
+ for (int i = 0; i < nb; i++) {
1290
+ float amax = 0.0f; // absolute max
1291
+
1292
+ for (int l = 0; l < QK8_0; l++) {
1293
+ const float v = x[i*QK8_0 + l];
1294
+ amax = MAX(amax, fabsf(v));
1295
+ }
1296
+
1297
+ const float d = amax / ((1 << 7) - 1);
1298
+ const float id = d ? 1.0f/d : 0.0f;
1299
+
1300
+ y[i].d = d;
1301
+
1302
+ for (int l = 0; l < QK8_0; ++l) {
1303
+ const float v = x[i*QK8_0 + l]*id;
1304
+ y[i].qs[l] = roundf(v);
1305
+ }
1306
+ }
1307
+ }
1308
+
1309
+ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1310
+ assert(k % QK8_0 == 0);
1311
+ const int nb = k / QK8_0;
1312
+
1313
+ block_q8_0 * restrict y = vy;
1314
+
1315
+ #if defined(__ARM_NEON)
1316
+ for (int i = 0; i < nb; i++) {
1317
+ float32x4_t srcv [8];
1318
+ float32x4_t asrcv[8];
1319
+ float32x4_t amaxv[8];
1320
+
1321
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1322
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1323
+
1324
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1325
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1326
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1327
+
1328
+ const float amax = vmaxvq_f32(amaxv[0]);
1329
+
1330
+ const float d = amax / ((1 << 7) - 1);
1331
+ const float id = d ? 1.0f/d : 0.0f;
1332
+
1333
+ y[i].d = d;
1334
+
1335
+ for (int l = 0; l < 8; l++) {
1336
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1337
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1338
+
1339
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1340
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1341
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1342
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1343
+ }
1344
+ }
1345
+ #elif defined(__AVX2__) || defined(__AVX__)
1346
+ for (int i = 0; i < nb; i++) {
1347
+ // Load elements into 4 AVX vectors
1348
+ __m256 v0 = _mm256_loadu_ps( x );
1349
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1350
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1351
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1352
+ x += 32;
1353
+
1354
+ // Compute max(abs(e)) for the block
1355
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1356
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1357
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1358
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1359
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1360
+
1361
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1362
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1363
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1364
+ const float maxScalar = _mm_cvtss_f32( max4 );
1365
+
1366
+ // Quantize these floats
1367
+ const float d = maxScalar / 127.f;
1368
+ y[i].d = d;
1369
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1370
+ const __m256 mul = _mm256_set1_ps( id );
1371
+
1372
+ // Apply the multiplier
1373
+ v0 = _mm256_mul_ps( v0, mul );
1374
+ v1 = _mm256_mul_ps( v1, mul );
1375
+ v2 = _mm256_mul_ps( v2, mul );
1376
+ v3 = _mm256_mul_ps( v3, mul );
1377
+
1378
+ // Round to nearest integer
1379
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1380
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1381
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1382
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1383
+
1384
+ // Convert floats to integers
1385
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1386
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1387
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1388
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1389
+
1390
+ #if defined(__AVX2__)
1391
+ // Convert int32 to int16
1392
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1393
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1394
+ // Convert int16 to int8
1395
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1396
+
1397
+ // We got our precious signed bytes, but the order is now wrong
1398
+ // These AVX2 pack instructions process 16-byte pieces independently
1399
+ // The following instruction is fixing the order
1400
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1401
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1402
+
1403
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1404
+ #else
1405
+ // Since we don't have in AVX some necessary functions,
1406
+ // we split the registers in half and call AVX2 analogs from SSE
1407
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1408
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1409
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1410
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1411
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1412
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1413
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1414
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1415
+
1416
+ // Convert int32 to int16
1417
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1418
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1419
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1420
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1421
+ // Convert int16 to int8
1422
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1423
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1424
+
1425
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1426
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1427
+ #endif
1428
+ }
1429
+ #else
1430
+ // scalar
1431
+ quantize_row_q8_0_reference(x, y, k);
1432
+ #endif
1433
+ }
1434
+
1036
1435
  static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
1037
- assert(k % QK == 0);
1038
- const int nb = k / QK;
1436
+ assert(k % QK4_0 == 0);
1437
+ const int nb = k / QK4_0;
1039
1438
 
1040
1439
  const block_q4_0 * restrict x = vx;
1041
1440
 
@@ -1046,9 +1445,9 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1046
1445
 
1047
1446
  const uint8_t * restrict pp = x[i].qs;
1048
1447
 
1049
- for (int l = 0; l < QK; l += 32) {
1448
+ for (int l = 0; l < QK4_0; l += 32) {
1050
1449
  // Load 32x4-bit integers into 32x8-bit integers
1051
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1450
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1052
1451
 
1053
1452
  // Subtract 8 from the integers
1054
1453
  vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1068,7 +1467,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1068
1467
  // Scale and store
1069
1468
  for (int j = 0; j < 4; j++) {
1070
1469
  const __m256 result = _mm256_mul_ps(vf[j], d_v);
1071
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1470
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1072
1471
  }
1073
1472
  }
1074
1473
  }
@@ -1078,7 +1477,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1078
1477
 
1079
1478
  const uint8_t * restrict pp = x[i].qs;
1080
1479
 
1081
- for (int l = 0; l < QK; l += 16) {
1480
+ for (int l = 0; l < QK4_0; l += 16) {
1082
1481
  // Load 16x4-bit integers into 8x8-bit integers
1083
1482
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1084
1483
 
@@ -1117,10 +1516,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1117
1516
  const float32x4_t r3 = vmulq_f32(vf_3, vd);
1118
1517
 
1119
1518
  // Store
1120
- vst1q_f32(y + i*QK + l + 0, r0);
1121
- vst1q_f32(y + i*QK + l + 4, r1);
1122
- vst1q_f32(y + i*QK + l + 8, r2);
1123
- vst1q_f32(y + i*QK + l + 12, r3);
1519
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1520
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1521
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1522
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1124
1523
  }
1125
1524
  }
1126
1525
  #else
@@ -1130,7 +1529,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1130
1529
 
1131
1530
  const uint8_t * restrict pp = x[i].qs;
1132
1531
 
1133
- for (int l = 0; l < QK; l += 2) {
1532
+ for (int l = 0; l < QK4_0; l += 2) {
1134
1533
  const uint8_t vi = pp[l/2];
1135
1534
 
1136
1535
  const int8_t vi0 = vi & 0xf;
@@ -1141,19 +1540,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1141
1540
 
1142
1541
  //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
1143
1542
 
1144
- y[i*QK + l + 0] = v0;
1145
- y[i*QK + l + 1] = v1;
1543
+ y[i*QK4_0 + l + 0] = v0;
1544
+ y[i*QK4_0 + l + 1] = v1;
1146
1545
 
1147
- assert(!isnan(y[i*QK + l + 0]));
1148
- assert(!isnan(y[i*QK + l + 1]));
1546
+ assert(!isnan(y[i*QK4_0 + l + 0]));
1547
+ assert(!isnan(y[i*QK4_0 + l + 1]));
1149
1548
  }
1150
1549
  }
1151
1550
  #endif
1152
1551
  }
1153
1552
 
1154
1553
  static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
1155
- assert(k % QK == 0);
1156
- const int nb = k / QK;
1554
+ assert(k % QK4_1 == 0);
1555
+ const int nb = k / QK4_1;
1157
1556
 
1158
1557
  const block_q4_1 * restrict x = vx;
1159
1558
 
@@ -1164,9 +1563,9 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1164
1563
 
1165
1564
  const uint8_t * restrict pp = x[i].qs;
1166
1565
 
1167
- for (int l = 0; l < QK; l += 32) {
1566
+ for (int l = 0; l < QK4_1; l += 32) {
1168
1567
  // Load 32x4-bit integers into 32x8-bit integers
1169
- __m256i vx8 = bytesFromNibbles(pp+l/2);
1568
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
1170
1569
 
1171
1570
  // Convert to 16-bit int
1172
1571
  const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1183,7 +1582,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1183
1582
  // Scale, add m and store
1184
1583
  for (int j = 0; j < 4; j++) {
1185
1584
  const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
1186
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1585
+ _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
1187
1586
  }
1188
1587
  }
1189
1588
  }
@@ -1194,7 +1593,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1194
1593
 
1195
1594
  const uint8_t * restrict pp = x[i].qs;
1196
1595
 
1197
- for (int l = 0; l < QK; l += 16) {
1596
+ for (int l = 0; l < QK4_1; l += 16) {
1198
1597
  // Load 16x4-bit integers into 8x8-bit integers
1199
1598
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1200
1599
 
@@ -1225,10 +1624,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1225
1624
  const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1226
1625
 
1227
1626
  // Store
1228
- vst1q_f32(y + i*QK + l + 0, r0);
1229
- vst1q_f32(y + i*QK + l + 4, r1);
1230
- vst1q_f32(y + i*QK + l + 8, r2);
1231
- vst1q_f32(y + i*QK + l + 12, r3);
1627
+ vst1q_f32(y + i*QK4_1 + l + 0, r0);
1628
+ vst1q_f32(y + i*QK4_1 + l + 4, r1);
1629
+ vst1q_f32(y + i*QK4_1 + l + 8, r2);
1630
+ vst1q_f32(y + i*QK4_1 + l + 12, r3);
1232
1631
  }
1233
1632
  }
1234
1633
  #else
@@ -1238,7 +1637,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1238
1637
 
1239
1638
  const uint8_t * restrict pp = x[i].qs;
1240
1639
 
1241
- for (int l = 0; l < QK; l += 2) {
1640
+ for (int l = 0; l < QK4_1; l += 2) {
1242
1641
  const uint8_t vi = pp[l/2];
1243
1642
 
1244
1643
  const int8_t vi0 = vi & 0xf;
@@ -1247,21 +1646,130 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1247
1646
  const float v0 = vi0*d + m;
1248
1647
  const float v1 = vi1*d + m;
1249
1648
 
1250
- y[i*QK + l + 0] = v0;
1251
- y[i*QK + l + 1] = v1;
1649
+ y[i*QK4_1 + l + 0] = v0;
1650
+ y[i*QK4_1 + l + 1] = v1;
1252
1651
 
1253
- assert(!isnan(y[i*QK + l + 0]));
1254
- assert(!isnan(y[i*QK + l + 1]));
1652
+ assert(!isnan(y[i*QK4_1 + l + 0]));
1653
+ assert(!isnan(y[i*QK4_1 + l + 1]));
1255
1654
  }
1256
1655
  }
1257
1656
  #endif
1258
1657
  }
1259
1658
 
1260
- //
1261
- // simd mappings
1262
- //
1659
+ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1660
+ assert(k % QK4_2 == 0);
1661
+ const int nb = k / QK4_2;
1263
1662
 
1264
- // we define a common set of C macros which map to specific intrinsics based on the current architecture
1663
+ const block_q4_2 * restrict x = vx;
1664
+
1665
+ for (int i = 0; i < nb; i++) {
1666
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1667
+
1668
+ const uint8_t * restrict pp = x[i].qs;
1669
+
1670
+ for (int l = 0; l < QK4_2; l += 2) {
1671
+ const uint8_t vi = pp[l/2];
1672
+
1673
+ const int8_t vi0 = vi & 0xf;
1674
+ const int8_t vi1 = vi >> 4;
1675
+
1676
+ const float v0 = (vi0 - 8)*d;
1677
+ const float v1 = (vi1 - 8)*d;
1678
+
1679
+ y[i*QK4_2 + l + 0] = v0;
1680
+ y[i*QK4_2 + l + 1] = v1;
1681
+
1682
+ assert(!isnan(y[i*QK4_2 + l + 0]));
1683
+ assert(!isnan(y[i*QK4_2 + l + 1]));
1684
+ }
1685
+ }
1686
+ }
1687
+
1688
+ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
1689
+ assert(k % QK4_3 == 0);
1690
+ const int nb = k / QK4_3;
1691
+
1692
+ const block_q4_3 * restrict x = vx;
1693
+
1694
+ for (int i = 0; i < nb; i++) {
1695
+ const float d = GGML_FP16_TO_FP32(x[i].d);
1696
+ const float m = GGML_FP16_TO_FP32(x[i].m);
1697
+
1698
+ const uint8_t * restrict pp = x[i].qs;
1699
+
1700
+ for (int l = 0; l < QK4_3; l += 2) {
1701
+ const uint8_t vi = pp[l/2];
1702
+
1703
+ const int8_t vi0 = vi & 0xf;
1704
+ const int8_t vi1 = vi >> 4;
1705
+
1706
+ const float v0 = vi0*d + m;
1707
+ const float v1 = vi1*d + m;
1708
+
1709
+ y[i*QK4_3 + l + 0] = v0;
1710
+ y[i*QK4_3 + l + 1] = v1;
1711
+
1712
+ assert(!isnan(y[i*QK4_3 + l + 0]));
1713
+ assert(!isnan(y[i*QK4_3 + l + 1]));
1714
+ }
1715
+ }
1716
+ }
1717
+
1718
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1719
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1720
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1721
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1722
+
1723
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1724
+ [GGML_TYPE_Q4_0] = {
1725
+ .dequantize_row_q = dequantize_row_q4_0,
1726
+ .quantize_row_q = quantize_row_q4_0,
1727
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1728
+ .quantize_row_q_dot = quantize_row_q8_0,
1729
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1730
+ },
1731
+ [GGML_TYPE_Q4_1] = {
1732
+ .dequantize_row_q = dequantize_row_q4_1,
1733
+ .quantize_row_q = quantize_row_q4_1,
1734
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1735
+ .quantize_row_q_dot = quantize_row_q8_0,
1736
+ .vec_dot_q = ggml_vec_dot_q4_1_q8_0,
1737
+ },
1738
+ [GGML_TYPE_Q4_2] = {
1739
+ .dequantize_row_q = dequantize_row_q4_2,
1740
+ .quantize_row_q = quantize_row_q4_2,
1741
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
1742
+ .quantize_row_q_dot = quantize_row_q8_0,
1743
+ .vec_dot_q = ggml_vec_dot_q4_2_q8_0,
1744
+ },
1745
+ [GGML_TYPE_Q4_3] = {
1746
+ .dequantize_row_q = dequantize_row_q4_3,
1747
+ .quantize_row_q = quantize_row_q4_3,
1748
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
1749
+ .quantize_row_q_dot = quantize_row_q8_0,
1750
+ .vec_dot_q = ggml_vec_dot_q4_3_q8_0,
1751
+ },
1752
+ [GGML_TYPE_Q8_0] = {
1753
+ .dequantize_row_q = NULL, // TODO
1754
+ .quantize_row_q = quantize_row_q8_0,
1755
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1756
+ .quantize_row_q_dot = quantize_row_q8_0,
1757
+ .vec_dot_q = NULL, // TODO
1758
+ },
1759
+ };
1760
+
1761
+ // For internal test use
1762
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1763
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
1764
+ return quantize_fns[i];
1765
+ }
1766
+
1767
+
1768
+ //
1769
+ // simd mappings
1770
+ //
1771
+
1772
+ // we define a common set of C macros which map to specific intrinsics based on the current architecture
1265
1773
  // we then implement the fundamental computation operations below using only these macros
1266
1774
  // adding support for new architectures requires to define the corresponding SIMD macros
1267
1775
  //
@@ -1813,37 +2321,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1813
2321
  *s = sumf;
1814
2322
  }
1815
2323
 
1816
- #if __AVX512F__ && QK == 32
1817
- static inline __m512 dot_q4_0_oneblock_avx512(
1818
- __m512 acc,
1819
- const block_q4_0 * restrict x,
1820
- const block_q4_0 * restrict y,
1821
- int i
1822
- ) {
1823
- // Compute combined scale for the block
1824
- __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
1825
-
1826
- __m256i bx = bytesFromNibbles( x[i].qs );
1827
- __m256i by = bytesFromNibbles( y[i].qs );
1828
-
1829
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1830
- const __m256i off = _mm256_set1_epi8( 8 );
1831
- bx = _mm256_sub_epi8( bx, off );
1832
- by = _mm256_sub_epi8( by, off );
1833
-
1834
- // Sign-extend 16 signed bytes into int16_t
1835
- __m512i x32 = _mm512_cvtepi8_epi16( bx );
1836
- __m512i y32 = _mm512_cvtepi8_epi16( by );
1837
- // Compute products of int16_t integers, add pairwise
1838
- __m512i i64 = _mm512_madd_epi16( x32, y32 );
1839
-
1840
- // Convert int32_t to float
1841
- __m512 p = _mm512_cvtepi32_ps( i64 );
1842
- // Apply the scale, and accumulate
1843
- return _mm512_fmadd_ps( d, p, acc );
1844
- }
1845
- #endif
1846
-
1847
2324
  inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
1848
2325
  ggml_float sumf = 0.0;
1849
2326
 
@@ -1880,67 +2357,64 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
1880
2357
  *s = sumf;
1881
2358
  }
1882
2359
 
1883
- static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1884
- const int nb = n / QK;
2360
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2361
+ const int nb = n / QK8_0;
1885
2362
 
1886
- assert(n % QK == 0);
2363
+ assert(n % QK8_0 == 0);
1887
2364
  assert(nb % 2 == 0);
1888
2365
 
1889
2366
  const block_q4_0 * restrict x = vx;
1890
- const block_q4_0 * restrict y = vy;
2367
+ const block_q8_0 * restrict y = vy;
1891
2368
 
1892
2369
  float sumf = 0.0;
1893
2370
 
1894
2371
  #if defined(__ARM_NEON)
1895
- float sum0 = 0.0f;
1896
- float sum1 = 0.0f;
2372
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2373
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
1897
2374
 
1898
2375
  for (int i = 0; i < nb; i += 2) {
1899
2376
  const block_q4_0 * restrict x0 = &x[i + 0];
1900
- const block_q4_0 * restrict y0 = &y[i + 0];
1901
2377
  const block_q4_0 * restrict x1 = &x[i + 1];
1902
- const block_q4_0 * restrict y1 = &y[i + 1];
2378
+ const block_q8_0 * restrict y0 = &y[i + 0];
2379
+ const block_q8_0 * restrict y1 = &y[i + 1];
1903
2380
 
1904
- const uint8x16_t m4b = vdupq_n_u8(0xf);
1905
- const int8x16_t s8b = vdupq_n_s8(0x8);
2381
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2382
+ const int8x16_t s8b = vdupq_n_s8(0x8);
1906
2383
 
1907
2384
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
1908
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
1909
2385
  const uint8x16_t v0_1 = vld1q_u8(x1->qs);
1910
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
1911
2386
 
1912
2387
  // 4-bit -> 8-bit
1913
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1914
- const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
2388
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
1915
2389
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1916
- const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1917
-
1918
- const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1919
- const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
2390
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
1920
2391
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1921
- const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1922
2392
 
1923
2393
  // sub 8
1924
2394
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1925
- const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1926
2395
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1927
- const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1928
-
1929
2396
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1930
- const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1931
2397
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1932
- const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
2398
+
2399
+ // load y
2400
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2401
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2402
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2403
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2404
+
2405
+ // interleave
2406
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2407
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2408
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2409
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
1933
2410
 
1934
2411
  #if defined(__ARM_FEATURE_DOTPROD)
1935
2412
  // dot product into int32x4_t
1936
- int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1937
- int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2413
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
1938
2415
 
1939
- p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1940
- p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1941
-
1942
- sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1943
- sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2416
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2417
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
1944
2418
  #else
1945
2419
  const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1946
2420
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
@@ -1952,115 +2426,51 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1952
2426
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1953
2427
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1954
2428
 
1955
- const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
1956
- const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
1957
-
1958
- const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
1959
- const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2429
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2430
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2431
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2432
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
1960
2433
 
1961
- const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1962
- const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1963
-
1964
- sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1965
- sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2434
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2435
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
1966
2436
  #endif
1967
2437
  }
1968
2438
 
1969
- sumf = sum0 + sum1;
1970
- #elif defined(__AVX512F__)
2439
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2440
+ #elif defined(__AVX2__)
1971
2441
  // Initialize accumulator with zeros
1972
- __m512 acc0 = _mm512_setzero_ps();
1973
- __m512 acc1 = _mm512_setzero_ps();
2442
+ __m256 acc = _mm256_setzero_ps();
1974
2443
 
1975
- const int superblock_size = 8;
1976
- const int superblock_count = nb / superblock_size;
2444
+ // Main loop
2445
+ for (int i = 0; i < nb; ++i) {
2446
+ /* Compute combined scale for the block */
2447
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
1977
2448
 
1978
- for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1979
- int i = superblock_ix * superblock_size;
2449
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
1980
2450
 
1981
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
1982
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
1983
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
1984
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
1985
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
1986
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
1987
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
1988
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
1989
- }
2451
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2452
+ const __m256i off = _mm256_set1_epi8( 8 );
2453
+ bx = _mm256_sub_epi8( bx, off );
1990
2454
 
1991
- // Remainders
1992
- for (int i = superblock_count * superblock_size; i < nb; ++i) {
1993
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
1994
- }
2455
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
1995
2456
 
1996
- // Horizontal sum of all lanes of the accumulator
1997
- sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
1998
- #elif defined(__AVX2__)
1999
- // Initialize accumulator with zeros
2000
- __m256 acc = _mm256_setzero_ps();
2457
+ // Get absolute values of x vectors
2458
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2001
2459
 
2002
- /* Prepare the constants we will need during execution */
2003
- const __m256i lowMask = _mm256_set1_epi8( 0xF );
2004
- const __m256i offset_8 = _mm256_set1_epi16( 8 );
2460
+ // Sign the values of the y vectors
2461
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2005
2462
 
2006
- #define UNROLL_COUNT 8
2007
- // make sure we only unroll multiples of the block count
2008
- assert(nb % UNROLL_COUNT == 0);
2463
+ // Perform multiplication and create 16-bit values
2464
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2009
2465
 
2010
- // Main loop
2011
- for (int i = 0; i < nb; i+=UNROLL_COUNT) {
2012
- // This loop will be unrolled by the compiler
2013
- for (int u=0;u<UNROLL_COUNT;u++) {
2014
- /* Compute combined scale for the block */
2015
- const __m256 scale = _mm256_mul_ps(
2016
- _mm256_broadcast_ss( &x[i+u].d ),
2017
- _mm256_broadcast_ss( &y[i+u].d ) );
2018
-
2019
- /* get input from x
2020
- Input: 32 Nibbles (16 bytes) at *x[i+u]
2021
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2022
-
2023
- /* Load 16 bytes from memory */
2024
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2025
- /* Expand bytes into uint16_t values */
2026
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2027
- /* Unpack values into individual bytes */
2028
- __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
2029
- const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
2030
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2031
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2032
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2033
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2034
-
2035
- /* get input from y
2036
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2037
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2038
-
2039
- /* Load 16 bytes from memory */
2040
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2041
- /* Expand bytes into uint16_t values */
2042
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2043
- /* Unpack values into individual bytes */
2044
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2045
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2046
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2047
- /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2048
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2049
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2050
-
2051
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2052
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2053
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2054
-
2055
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2056
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2057
-
2058
- /* Convert to vectore of 8 int32_t to 8 floats */
2059
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2060
-
2061
- /* Multiply q with scale and accumulate */
2062
- acc = _mm256_fmadd_ps( scale, q, acc );
2063
- }
2466
+ const __m256i ones = _mm256_set1_epi16(1);
2467
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2468
+
2469
+ /* Convert to vectore of 8 int32_t to 8 floats */
2470
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2471
+
2472
+ /* Multiply q with scale and accumulate */
2473
+ acc = _mm256_fmadd_ps( d, q, acc );
2064
2474
  }
2065
2475
 
2066
2476
  // Return horizontal sum of the acc vector
@@ -2082,13 +2492,12 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2082
2492
  __m128i i32[2];
2083
2493
  for (int j = 0; j < 2; ++j) {
2084
2494
  // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2085
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2086
- __m128i by = bytesFromNibbles( y[i].qs + 8*j );
2495
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2496
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2087
2497
 
2088
2498
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2089
2499
  const __m128i off = _mm_set1_epi8( 8 );
2090
2500
  bx = _mm_sub_epi8( bx, off );
2091
- by = _mm_sub_epi8( by, off );
2092
2501
 
2093
2502
  // Get absolute values of x vectors
2094
2503
  const __m128i ax = _mm_sign_epi8(bx, bx);
@@ -2116,86 +2525,6 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2116
2525
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2117
2526
 
2118
2527
  sumf = _mm_cvtss_f32( res );
2119
- #elif defined(__wasm_simd128__)
2120
- // wasm simd
2121
- float sum0 = 0.0f;
2122
- float sum1 = 0.0f;
2123
-
2124
- for (int i = 0; i < nb; i += 2) {
2125
- const block_q4_0 * restrict x0 = &x[i + 0];
2126
- const block_q4_0 * restrict y0 = &y[i + 0];
2127
- const block_q4_0 * restrict x1 = &x[i + 1];
2128
- const block_q4_0 * restrict y1 = &y[i + 1];
2129
-
2130
- const v128_t m4b = wasm_u8x16_splat(0xf);
2131
- const v128_t s8b = wasm_i8x16_splat(0x8);
2132
-
2133
- const v128_t v0_0 = wasm_v128_load(x0->qs);
2134
- const v128_t v0_1 = wasm_v128_load(y0->qs);
2135
- const v128_t v1_0 = wasm_v128_load(x1->qs);
2136
- const v128_t v1_1 = wasm_v128_load(y1->qs);
2137
-
2138
- // 4-bit -> 8-bit
2139
- const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
2140
- const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
2141
-
2142
- const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
2143
- const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
2144
-
2145
- const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
2146
- const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
2147
-
2148
- const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
2149
- const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
2150
-
2151
- // sub 8
2152
- const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
2153
- const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
2154
-
2155
- const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
2156
- const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
2157
-
2158
- const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
2159
- const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
2160
-
2161
- const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
2162
- const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
2163
-
2164
- // dot product into int16x8_t
2165
- const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
2166
- const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
2167
-
2168
- const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
2169
- const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
2170
-
2171
- const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
2172
- const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
2173
-
2174
- const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
2175
- const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
2176
-
2177
- const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
2178
- const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
2179
-
2180
- const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
2181
- const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
2182
-
2183
- const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
2184
- const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
2185
-
2186
- sum0 += x0->d * y0->d * (
2187
- wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
2188
- wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
2189
- wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
2190
- wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
2191
- sum1 += x1->d * y1->d * (
2192
- wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
2193
- wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
2194
- wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
2195
- wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
2196
- }
2197
-
2198
- sumf = sum0 + sum1;
2199
2528
  #else
2200
2529
  // scalar
2201
2530
  for (int i = 0; i < nb; i++) {
@@ -2203,98 +2532,159 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2203
2532
  const float d1 = y[i].d;
2204
2533
 
2205
2534
  const uint8_t * restrict p0 = x[i].qs;
2206
- const uint8_t * restrict p1 = y[i].qs;
2535
+ const int8_t * restrict p1 = y[i].qs;
2207
2536
 
2208
2537
  int sumi = 0;
2209
- for (int j = 0; j < QK/2; j++) {
2538
+ for (int j = 0; j < QK8_0/2; j++) {
2210
2539
  const uint8_t v0 = p0[j];
2211
- const uint8_t v1 = p1[j];
2212
2540
 
2213
- const int8_t i0 = (int8_t) (v0 & 0xf) - 8;
2214
- const int8_t i1 = (int8_t) (v0 >> 4) - 8;
2541
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2542
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2215
2543
 
2216
- const int8_t i2 = (int8_t) (v1 & 0xf) - 8;
2217
- const int8_t i3 = (int8_t) (v1 >> 4) - 8;
2544
+ const int i2 = p1[2*j + 0];
2545
+ const int i3 = p1[2*j + 1];
2218
2546
 
2219
2547
  sumi += i0*i2 + i1*i3;
2220
2548
  }
2221
- sumf += d0 * d1 * sumi;
2549
+ sumf += d0*d1*sumi;
2222
2550
  }
2223
2551
  #endif
2224
2552
 
2225
2553
  *s = sumf;
2226
2554
  }
2227
2555
 
2228
- static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2229
- const int nb = n / QK;
2556
+ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2557
+ const int nb = n / QK8_0;
2558
+
2559
+ assert(n % QK8_0 == 0);
2560
+ assert(nb % 2 == 0);
2230
2561
 
2231
2562
  const block_q4_1 * restrict x = vx;
2232
- const block_q4_1 * restrict y = vy;
2563
+ const block_q8_0 * restrict y = vy;
2233
2564
 
2234
2565
  float sumf = 0.0;
2235
2566
 
2236
- #if defined(__AVX2__)
2567
+ // TODO: add AVX / WASM SIMD / etc
2568
+ #if defined(__ARM_NEON)
2569
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2570
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2571
+
2572
+ for (int i = 0; i < nb; i += 2) {
2573
+ const block_q4_1 * restrict x0 = &x[i + 0];
2574
+ const block_q4_1 * restrict x1 = &x[i + 1];
2575
+ const block_q8_0 * restrict y0 = &y[i + 0];
2576
+ const block_q8_0 * restrict y1 = &y[i + 1];
2577
+
2578
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2579
+
2580
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2581
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2582
+
2583
+ // 4-bit -> 8-bit
2584
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2585
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2586
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2587
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2588
+
2589
+ // load y
2590
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2591
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2592
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2593
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2594
+
2595
+ // interleave
2596
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2597
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2598
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2599
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2600
+
2601
+ const int16x8_t s0i = vaddq_s16(
2602
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2604
+
2605
+ const int16x8_t s1i = vaddq_s16(
2606
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607
+ vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2608
+
2609
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2611
+
2612
+ #if defined(__ARM_FEATURE_DOTPROD)
2613
+ // dot product into int32x4_t
2614
+ const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2615
+ const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2616
+
2617
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
2618
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
2619
+ #else
2620
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2621
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2622
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2623
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2624
+
2625
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2626
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2627
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2628
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2629
+
2630
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2631
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2632
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2633
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2634
+
2635
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
2636
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
2637
+ #endif
2638
+ }
2639
+
2640
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2641
+ #elif defined(__AVX2__)
2237
2642
  // Initialize accumulator with zeros
2238
2643
  __m256 acc = _mm256_setzero_ps();
2239
- // Accumulator for constant offsets
2240
- float acc_offset = 0.0f;
2241
2644
 
2242
2645
  // Main loop
2243
2646
  for (int i = 0; i < nb; ++i) {
2244
2647
  const float * d0 = &x[i].d;
2245
2648
  const float * d1 = &y[i].d;
2246
-
2247
2649
  const float * m0 = &x[i].m;
2248
- const float * m1 = &y[i].m;
2249
2650
 
2250
2651
  const __m256 d0v = _mm256_broadcast_ss( d0 );
2251
2652
  const __m256 d1v = _mm256_broadcast_ss( d1 );
2252
2653
  const __m256 m0v = _mm256_broadcast_ss( m0 );
2253
- const __m256 m1v = _mm256_broadcast_ss( m1 );
2254
2654
 
2255
- // Compute combined scale for the block
2256
- const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
2257
-
2258
- // Compute cross scales for the block
2259
- const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
2260
- const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
2261
- const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
2655
+ // Compute combined scales
2656
+ const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657
+ const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2262
2658
 
2263
2659
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2264
- __m256i bx = bytesFromNibbles( x[i].qs );
2265
- __m256i by = bytesFromNibbles( y[i].qs );
2266
-
2267
- // Now we have a vector with bytes in [ 0 .. 15 ] interval.
2268
-
2269
- // Sign-extend first 16 signed bytes into int16_t
2270
- __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
2271
- __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2272
- // Compute products of int16_t integers, add pairwise
2273
- __m256i i32 = _mm256_madd_epi16( x16, y16 );
2274
-
2275
- // Sign-extend last 16 signed bytes into int16_t vectors
2276
- __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
2277
- __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2278
- // Accumulate products of int16_t integers
2279
- i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
2280
-
2281
- // compute sums of unsigned bytes in bx, by in blocks of 8.
2282
- // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
2283
- // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
2284
- // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
2285
- __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
2286
- __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
2287
- __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
2288
- __m256 sums = _mm256_cvtepi32_ps( sumsi );
2660
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
2661
+ const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2289
2662
 
2290
- // Convert int32_t to float
2291
- __m256 p = _mm256_cvtepi32_ps( i32 );
2292
- // Apply the scale, and accumulate
2293
- // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
2294
- acc = _mm256_fmadd_ps( scale_01, p, acc );
2295
- acc = _mm256_fmadd_ps( cross_scales, sums, acc );
2296
- // acc_offset += m0*m1 (for each entry in the block)
2297
- acc_offset += (*m0)*(*m1);
2663
+ // Get absolute values of x vectors
2664
+ const __m256i ax = _mm256_sign_epi8( bx, bx );
2665
+
2666
+ // Sign the values of the y vectors
2667
+ const __m256i sy = _mm256_sign_epi8( by, bx );
2668
+
2669
+ // Perform multiplication and create 16-bit values
2670
+ const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2671
+ const __m256i ones = _mm256_set1_epi16( 1 );
2672
+ const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2673
+
2674
+ // Convert to vector of 8 int32_t to 8 floats
2675
+ const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2676
+
2677
+ // Accumulate d0*d1*x*y
2678
+ acc = _mm256_fmadd_ps( d0d1, xy, acc );
2679
+
2680
+ // Compute sum of y values
2681
+ const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682
+ const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683
+ const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684
+ const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2685
+
2686
+ // Accumulate d1*m0*y
2687
+ acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2298
2688
  }
2299
2689
 
2300
2690
  // Return horizontal sum of the acc vector
@@ -2303,131 +2693,379 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2303
2693
  res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2304
2694
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2305
2695
 
2306
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2307
- #elif defined(__ARM_NEON)
2308
- float sum00 = 0.0f;
2309
- float sum01 = 0.0f;
2310
- float sum10 = 0.0f;
2311
- float sum11 = 0.0f;
2696
+ sumf = _mm_cvtss_f32( res );
2697
+ #else
2698
+ // scalar
2699
+ for (int i = 0; i < nb; i++) {
2700
+ const float d0 = x[i].d;
2701
+ const float m0 = x[i].m;
2702
+ const float d1 = y[i].d;
2703
+
2704
+ const uint8_t * restrict p0 = x[i].qs;
2705
+ const int8_t * restrict p1 = y[i].qs;
2706
+
2707
+ // TODO: this is very slow ..
2708
+ for (int j = 0; j < QK8_0/2; j++) {
2709
+ const uint8_t v0 = p0[j];
2710
+
2711
+ const float f0 = d0*(v0 & 0xf) + m0;
2712
+ const float f1 = d0*(v0 >> 4) + m0;
2713
+
2714
+ const float f2 = d1*p1[2*j + 0];
2715
+ const float f3 = d1*p1[2*j + 1];
2716
+
2717
+ sumf += f0*f2 + f1*f3;
2718
+ }
2719
+ }
2720
+ #endif
2721
+
2722
+ *s = sumf;
2723
+ }
2724
+
2725
+ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2726
+ const int nb = n / QK8_0;
2727
+
2728
+ assert(n % QK8_0 == 0);
2729
+ assert(nb % 2 == 0);
2730
+ assert(QK8_0 == 2*QK4_2);
2731
+
2732
+ const block_q4_2 * restrict x = vx;
2733
+ const block_q8_0 * restrict y = vy;
2734
+
2735
+ float sumf = 0.0;
2736
+
2737
+ #if defined(__ARM_NEON)
2738
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2739
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2312
2740
 
2313
2741
  for (int i = 0; i < nb; i += 2) {
2314
- const block_q4_1 * restrict x0 = &x[i + 0];
2315
- const block_q4_1 * restrict y0 = &y[i + 0];
2316
- const block_q4_1 * restrict x1 = &x[i + 1];
2317
- const block_q4_1 * restrict y1 = &y[i + 1];
2742
+ const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
2743
+ const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
2744
+ const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
2745
+ const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
2318
2746
 
2319
- const uint8x16_t m4b = vdupq_n_u8(0xf);
2747
+ const block_q8_0 * restrict y0 = &y[i + 0];
2748
+ const block_q8_0 * restrict y1 = &y[i + 1];
2320
2749
 
2321
- const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2322
- const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2323
- const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2324
- const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2750
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2751
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2752
+
2753
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2754
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2325
2755
 
2326
2756
  // 4-bit -> 8-bit
2327
- const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2328
- const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2329
- const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2330
- const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2757
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2758
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2759
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2760
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2331
2761
 
2332
- const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2333
- const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2334
- const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2335
- const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2762
+ // sub 8
2763
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2764
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2765
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2766
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2336
2767
 
2337
- sum00 += x0->m*y0->m;
2338
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2339
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2768
+ // interleave
2769
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2770
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2771
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2772
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2340
2773
 
2341
- sum00 += x1->m*y1->m;
2342
- sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2343
- sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2774
+ // load y
2775
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2776
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2777
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2778
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2344
2779
 
2345
2780
  #if defined(__ARM_FEATURE_DOTPROD)
2346
- // dot product into int32x4_t
2347
- uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2348
- uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2781
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2782
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2783
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2349
2784
 
2350
- p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2351
- p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2352
-
2353
- sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2354
- sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2785
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2786
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2787
+ vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2355
2788
  #else
2356
- const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2357
- const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2358
- const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2359
- const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2789
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2790
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2791
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2792
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2793
+
2794
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2795
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2796
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2797
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2798
+
2799
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2800
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2801
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2802
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2803
+
2804
+ sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2805
+ vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
2806
+ vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2807
+
2808
+ sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2809
+ vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
2810
+ vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2811
+ #endif
2812
+ }
2813
+
2814
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2815
+ #elif defined(__AVX2__)
2816
+ // Initialize accumulator with zeros
2817
+ __m256 acc = _mm256_setzero_ps();
2818
+
2819
+ // Main loop
2820
+ for (int i = 0; i < nb; i++) {
2821
+ /* Compute combined scale for the block */
2822
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2823
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2824
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2360
2825
 
2361
- const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2362
- const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2363
- const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2364
- const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2826
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2827
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2828
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
2365
2829
 
2366
- const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2367
- const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2830
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2831
+ const __m256i off = _mm256_set1_epi8(8);
2832
+ bx = _mm256_sub_epi8(bx, off);
2368
2833
 
2369
- const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2370
- const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2834
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2371
2835
 
2372
- const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2373
- const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2836
+ // Get absolute values of x vectors
2837
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2838
+ // Sign the values of the y vectors
2839
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2840
+ // Perform multiplication and create 16-bit values
2841
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2374
2842
 
2375
- sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2376
- sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2377
- #endif
2843
+ const __m256i ones = _mm256_set1_epi16(1);
2844
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2845
+
2846
+ /* Convert to vectore of 8 int32_t to 8 floats */
2847
+ __m256 q = _mm256_cvtepi32_ps(xy_q);
2848
+
2849
+ /* Multiply q with scale and accumulate */
2850
+ acc = _mm256_fmadd_ps(d, q, acc);
2378
2851
  }
2379
2852
 
2380
- sumf = QK*sum00 + sum01 + sum10 + sum11;
2853
+ // Return horizontal sum of the acc vector
2854
+ __m128 res = _mm256_extractf128_ps(acc, 1);
2855
+ res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2856
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2857
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
2858
+
2859
+ sumf = _mm_cvtss_f32(res);
2381
2860
  #else
2382
2861
  // scalar
2383
2862
  for (int i = 0; i < nb; i++) {
2384
- const float d0 = x[i].d;
2385
- const float d1 = y[i].d;
2863
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
2864
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
2865
+ const int8_t * restrict y0 = y[i].qs;
2386
2866
 
2387
- const float m0 = x[i].m;
2388
- const float m1 = y[i].m;
2867
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
2868
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
2389
2869
 
2390
- const uint8_t * restrict p0 = x[i].qs;
2391
- const uint8_t * restrict p1 = y[i].qs;
2870
+ int sumi_0 = 0;
2871
+ int sumi_1 = 0;
2392
2872
 
2393
- for (int j = 0; j < QK/2; j++) {
2394
- const uint8_t v0 = p0[j];
2395
- const uint8_t v1 = p1[j];
2873
+ for (int j = 0; j < QK8_0/4; j++) {
2874
+ const uint8_t v0 = x0[j];
2875
+ const uint8_t v1 = x1[j];
2396
2876
 
2397
- const float f0 = d0*(v0 & 0xf) + m0;
2398
- const float f1 = d0*(v0 >> 4) + m0;
2877
+ const int i0_0 = (int8_t) (v0 & 0xf) - 8;
2878
+ const int i1_0 = (int8_t) (v0 >> 4) - 8;
2399
2879
 
2400
- const float f2 = d1*(v1 & 0xf) + m1;
2401
- const float f3 = d1*(v1 >> 4) + m1;
2880
+ const int i0_1 = (int8_t) (v1 & 0xf) - 8;
2881
+ const int i1_1 = (int8_t) (v1 >> 4) - 8;
2402
2882
 
2403
- sumf += f0*f2 + f1*f3;
2883
+ const int i2_0 = y0[2*j + 0];
2884
+ const int i3_0 = y0[2*j + 1];
2885
+
2886
+ const int i2_1 = y0[2*(j + QK8_0/4) + 0];
2887
+ const int i3_1 = y0[2*(j + QK8_0/4) + 1];
2888
+
2889
+ sumi_0 += i0_0*i2_0 + i1_0*i3_0;
2890
+ sumi_1 += i0_1*i2_1 + i1_1*i3_1;
2404
2891
  }
2892
+
2893
+ sumf += (d0 * y[i].d) * sumi_0;
2894
+ sumf += (d1 * y[i].d) * sumi_1;
2405
2895
  }
2406
2896
  #endif
2407
2897
 
2408
2898
  *s = sumf;
2409
2899
  }
2410
2900
 
2411
- // compute GGML_VEC_DOT_UNROLL dot products at once
2412
- // xs - x row stride in bytes
2413
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2414
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2901
+ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2902
+ const int nb = n / QK8_0;
2415
2903
 
2416
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2904
+ assert(n % QK8_0 == 0);
2905
+ assert(nb % 2 == 0);
2906
+ assert(QK8_0 == 2*QK4_2);
2417
2907
 
2418
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2419
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2420
- }
2908
+ const block_q4_3 * restrict x = vx;
2909
+ const block_q8_0 * restrict y = vy;
2421
2910
 
2422
- #if defined(GGML_SIMD)
2423
- const int np = (n & ~(GGML_F16_STEP - 1));
2911
+ float sumf = 0.0;
2424
2912
 
2425
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2913
+ #if defined(__ARM_NEON)
2914
+ float32x4_t sumv0 = vdupq_n_f32(0.0f);
2915
+ float32x4_t sumv1 = vdupq_n_f32(0.0f);
2426
2916
 
2427
- GGML_F16_VEC ax[GGML_F16_ARR];
2428
- GGML_F16_VEC ay[GGML_F16_ARR];
2917
+ for (int i = 0; i < nb; i += 2) {
2918
+ const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
2919
+ const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2920
+ const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2921
+ const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
2429
2922
 
2430
- for (int i = 0; i < np; i += GGML_F16_STEP) {
2923
+ const block_q8_0 * restrict y0 = &y[i + 0];
2924
+ const block_q8_0 * restrict y1 = &y[i + 1];
2925
+
2926
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2927
+
2928
+ const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2929
+ const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2930
+ const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2931
+ const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2932
+
2933
+ const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2934
+ const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2935
+ const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2936
+ const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2937
+
2938
+ const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2939
+ const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
2940
+
2941
+ // 4-bit -> 8-bit
2942
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2943
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2944
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2945
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2946
+
2947
+ // interleave
2948
+ const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
2949
+ const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2950
+ const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2951
+ const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
2952
+
2953
+ // load y
2954
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2955
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2956
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2957
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2958
+
2959
+ const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2960
+ const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2961
+
2962
+ const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2963
+ const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2964
+
2965
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2966
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2967
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2968
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
2969
+
2970
+ #if defined(__ARM_FEATURE_DOTPROD)
2971
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2972
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2973
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2974
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2975
+ #else
2976
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2977
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2978
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2979
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2980
+
2981
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2982
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2983
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2984
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2985
+
2986
+ const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
2987
+ const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2988
+ const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2989
+ const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
2990
+
2991
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2992
+ sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2993
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2994
+ sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
2995
+ #endif
2996
+ }
2997
+
2998
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2999
+ #else
3000
+ // scalar
3001
+ for (int i = 0; i < nb; i++) {
3002
+ const uint8_t * restrict x0 = x[2*i + 0].qs;
3003
+ const uint8_t * restrict x1 = x[2*i + 1].qs;
3004
+ const int8_t * restrict y0 = y[i].qs;
3005
+
3006
+ const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
3007
+ const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
3008
+ const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
3009
+ const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
3010
+
3011
+ int sy_0 = 0;
3012
+ int sy_1 = 0;
3013
+
3014
+ int sxy_0 = 0;
3015
+ int sxy_1 = 0;
3016
+
3017
+ for (int j = 0; j < QK8_0/4; j++) {
3018
+ const uint8_t v0 = x0[j];
3019
+ const uint8_t v1 = x1[j];
3020
+
3021
+ const int x0_0 = v0 & 0xf;
3022
+ const int x1_0 = v0 >> 4;
3023
+
3024
+ const int x0_1 = v1 & 0xf;
3025
+ const int x1_1 = v1 >> 4;
3026
+
3027
+ const int y0_0 = y0[2*j + 0];
3028
+ const int y1_0 = y0[2*j + 1];
3029
+
3030
+ const int y0_1 = y0[2*(j + QK8_0/4) + 0];
3031
+ const int y1_1 = y0[2*(j + QK8_0/4) + 1];
3032
+
3033
+ sy_0 += y0_0 + y1_0;
3034
+ sy_1 += y0_1 + y1_1;
3035
+
3036
+ sxy_0 += x0_0*y0_0 + x1_0*y1_0;
3037
+ sxy_1 += x0_1*y0_1 + x1_1*y1_1;
3038
+ }
3039
+
3040
+ sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
3041
+ sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
3042
+ }
3043
+ #endif
3044
+
3045
+ *s = sumf;
3046
+ }
3047
+
3048
+
3049
+ // compute GGML_VEC_DOT_UNROLL dot products at once
3050
+ // xs - x row stride in bytes
3051
+ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
3052
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
3053
+
3054
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
3055
+
3056
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
3057
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
3058
+ }
3059
+
3060
+ #if defined(GGML_SIMD)
3061
+ const int np = (n & ~(GGML_F16_STEP - 1));
3062
+
3063
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
3064
+
3065
+ GGML_F16_VEC ax[GGML_F16_ARR];
3066
+ GGML_F16_VEC ay[GGML_F16_ARR];
3067
+
3068
+ for (int i = 0; i < np; i += GGML_F16_STEP) {
2431
3069
  for (int j = 0; j < GGML_F16_ARR; j++) {
2432
3070
  ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
2433
3071
 
@@ -2652,24 +3290,30 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2652
3290
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2653
3291
  [GGML_TYPE_F32] = 1,
2654
3292
  [GGML_TYPE_F16] = 1,
2655
- [GGML_TYPE_Q4_0] = QK,
2656
- [GGML_TYPE_Q4_1] = QK,
3293
+ [GGML_TYPE_Q4_0] = QK4_0,
3294
+ [GGML_TYPE_Q4_1] = QK4_1,
3295
+ [GGML_TYPE_Q4_2] = QK4_2,
3296
+ [GGML_TYPE_Q4_3] = QK4_3,
3297
+ [GGML_TYPE_Q8_0] = QK8_0,
2657
3298
  [GGML_TYPE_I8] = 1,
2658
3299
  [GGML_TYPE_I16] = 1,
2659
3300
  [GGML_TYPE_I32] = 1,
2660
3301
  };
2661
- static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
3302
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
2662
3303
 
2663
3304
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2664
3305
  [GGML_TYPE_F32] = sizeof(float),
2665
3306
  [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
2666
3307
  [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
2667
3308
  [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3309
+ [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
3310
+ [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
3311
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
2668
3312
  [GGML_TYPE_I8] = sizeof(int8_t),
2669
3313
  [GGML_TYPE_I16] = sizeof(int16_t),
2670
3314
  [GGML_TYPE_I32] = sizeof(int32_t),
2671
3315
  };
2672
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
3316
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
2673
3317
 
2674
3318
 
2675
3319
  static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2677,11 +3321,28 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
2677
3321
  [GGML_TYPE_F16] = "f16",
2678
3322
  [GGML_TYPE_Q4_0] = "q4_0",
2679
3323
  [GGML_TYPE_Q4_1] = "q4_1",
3324
+ [GGML_TYPE_Q4_2] = "q4_2",
3325
+ [GGML_TYPE_Q4_3] = "q4_3",
3326
+ [GGML_TYPE_Q8_0] = "q8_0",
2680
3327
  [GGML_TYPE_I8] = "i8",
2681
3328
  [GGML_TYPE_I16] = "i16",
2682
3329
  [GGML_TYPE_I32] = "i32",
2683
3330
  };
2684
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
3331
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
3332
+
3333
+ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
3334
+ [GGML_TYPE_F32] = false,
3335
+ [GGML_TYPE_F16] = false,
3336
+ [GGML_TYPE_Q4_0] = true,
3337
+ [GGML_TYPE_Q4_1] = true,
3338
+ [GGML_TYPE_Q4_2] = true,
3339
+ [GGML_TYPE_Q4_3] = true,
3340
+ [GGML_TYPE_Q8_0] = true,
3341
+ [GGML_TYPE_I8] = false,
3342
+ [GGML_TYPE_I16] = false,
3343
+ [GGML_TYPE_I32] = false,
3344
+ };
3345
+ static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
2685
3346
 
2686
3347
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2687
3348
  "NONE",
@@ -2943,6 +3604,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
2943
3604
  (t0->ne[3] == t1->ne[3]);
2944
3605
  }
2945
3606
 
3607
+ bool ggml_is_quantized(enum ggml_type type) {
3608
+ return GGML_IS_QUANTIZED[type];
3609
+ }
3610
+
2946
3611
  static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
2947
3612
  return tensor->nb[0] > tensor->nb[1];
2948
3613
  }
@@ -3053,6 +3718,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3053
3718
  GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3054
3719
  }
3055
3720
 
3721
+ // initialize cuBLAS
3722
+ #if defined(GGML_USE_CUBLAS)
3723
+ init_cublas();
3724
+ #endif
3725
+
3056
3726
  is_first_call = false;
3057
3727
  }
3058
3728
 
@@ -3354,14 +4024,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3354
4024
  char * const data = tensor->data;
3355
4025
 
3356
4026
  switch (tensor->type) {
3357
- case GGML_TYPE_Q4_0:
3358
- {
3359
- GGML_ASSERT(false);
3360
- } break;
3361
- case GGML_TYPE_Q4_1:
3362
- {
3363
- GGML_ASSERT(false);
3364
- } break;
3365
4027
  case GGML_TYPE_I8:
3366
4028
  {
3367
4029
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3397,7 +4059,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3397
4059
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3398
4060
  }
3399
4061
  } break;
3400
- case GGML_TYPE_COUNT:
4062
+ default:
3401
4063
  {
3402
4064
  GGML_ASSERT(false);
3403
4065
  } break;
@@ -3414,14 +4076,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3414
4076
  char * const data = tensor->data;
3415
4077
 
3416
4078
  switch (tensor->type) {
3417
- case GGML_TYPE_Q4_0:
3418
- {
3419
- GGML_ASSERT(false);
3420
- } break;
3421
- case GGML_TYPE_Q4_1:
3422
- {
3423
- GGML_ASSERT(false);
3424
- } break;
3425
4079
  case GGML_TYPE_I8:
3426
4080
  {
3427
4081
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3457,7 +4111,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3457
4111
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3458
4112
  }
3459
4113
  } break;
3460
- case GGML_TYPE_COUNT:
4114
+ default:
3461
4115
  {
3462
4116
  GGML_ASSERT(false);
3463
4117
  } break;
@@ -3468,14 +4122,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3468
4122
 
3469
4123
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3470
4124
  switch (tensor->type) {
3471
- case GGML_TYPE_Q4_0:
3472
- {
3473
- GGML_ASSERT(false);
3474
- } break;
3475
- case GGML_TYPE_Q4_1:
3476
- {
3477
- GGML_ASSERT(false);
3478
- } break;
3479
4125
  case GGML_TYPE_I8:
3480
4126
  {
3481
4127
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3501,7 +4147,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3501
4147
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3502
4148
  return ((float *)(tensor->data))[i];
3503
4149
  } break;
3504
- case GGML_TYPE_COUNT:
4150
+ default:
3505
4151
  {
3506
4152
  GGML_ASSERT(false);
3507
4153
  } break;
@@ -3512,14 +4158,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3512
4158
 
3513
4159
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3514
4160
  switch (tensor->type) {
3515
- case GGML_TYPE_Q4_0:
3516
- {
3517
- GGML_ASSERT(false);
3518
- } break;
3519
- case GGML_TYPE_Q4_1:
3520
- {
3521
- GGML_ASSERT(false);
3522
- } break;
3523
4161
  case GGML_TYPE_I8:
3524
4162
  {
3525
4163
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3545,7 +4183,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3545
4183
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3546
4184
  ((float *)(tensor->data))[i] = value;
3547
4185
  } break;
3548
- case GGML_TYPE_COUNT:
4186
+ default:
3549
4187
  {
3550
4188
  GGML_ASSERT(false);
3551
4189
  } break;
@@ -3554,14 +4192,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3554
4192
 
3555
4193
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3556
4194
  switch (tensor->type) {
3557
- case GGML_TYPE_Q4_0:
3558
- {
3559
- GGML_ASSERT(false);
3560
- } break;
3561
- case GGML_TYPE_Q4_1:
3562
- {
3563
- GGML_ASSERT(false);
3564
- } break;
3565
4195
  case GGML_TYPE_I8:
3566
4196
  {
3567
4197
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3587,7 +4217,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3587
4217
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3588
4218
  return ((float *)(tensor->data))[i];
3589
4219
  } break;
3590
- case GGML_TYPE_COUNT:
4220
+ default:
3591
4221
  {
3592
4222
  GGML_ASSERT(false);
3593
4223
  } break;
@@ -3598,14 +4228,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3598
4228
 
3599
4229
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3600
4230
  switch (tensor->type) {
3601
- case GGML_TYPE_Q4_0:
3602
- {
3603
- GGML_ASSERT(false);
3604
- } break;
3605
- case GGML_TYPE_Q4_1:
3606
- {
3607
- GGML_ASSERT(false);
3608
- } break;
3609
4231
  case GGML_TYPE_I8:
3610
4232
  {
3611
4233
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3631,7 +4253,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3631
4253
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3632
4254
  ((float *)(tensor->data))[i] = value;
3633
4255
  } break;
3634
- case GGML_TYPE_COUNT:
4256
+ default:
3635
4257
  {
3636
4258
  GGML_ASSERT(false);
3637
4259
  } break;
@@ -5031,7 +5653,6 @@ static void ggml_compute_forward_dup_f16(
5031
5653
  const struct ggml_compute_params * params,
5032
5654
  const struct ggml_tensor * src0,
5033
5655
  struct ggml_tensor * dst) {
5034
- GGML_ASSERT(params->ith == 0);
5035
5656
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5036
5657
 
5037
5658
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5043,6 +5664,11 @@ static void ggml_compute_forward_dup_f16(
5043
5664
  const int64_t ne02 = src0->ne[2];
5044
5665
  const int64_t ne03 = src0->ne[3];
5045
5666
 
5667
+ const int64_t ne0 = dst->ne[0];
5668
+ const int64_t ne1 = dst->ne[1];
5669
+ const int64_t ne2 = dst->ne[2];
5670
+ const int64_t ne3 = dst->ne[3];
5671
+
5046
5672
  const size_t nb00 = src0->nb[0];
5047
5673
  const size_t nb01 = src0->nb[1];
5048
5674
  const size_t nb02 = src0->nb[2];
@@ -5053,19 +5679,40 @@ static void ggml_compute_forward_dup_f16(
5053
5679
  const size_t nb2 = dst->nb[2];
5054
5680
  const size_t nb3 = dst->nb[3];
5055
5681
 
5682
+ const int ith = params->ith; // thread index
5683
+ const int nth = params->nth; // number of threads
5684
+
5056
5685
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5057
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5686
+ // parallelize by elements
5687
+ const int ne = ggml_nelements(dst);
5688
+ const int dr = (ne + nth - 1) / nth;
5689
+ const int ie0 = dr * ith;
5690
+ const int ie1 = MIN(ie0 + dr, ne);
5691
+
5692
+ memcpy(
5693
+ ((char *) dst->data + ie0*nb0),
5694
+ ((char *) src0->data + ie0*nb00),
5695
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5696
+
5058
5697
  return;
5059
5698
  }
5060
5699
 
5700
+ // parallelize by rows
5701
+ const int nr = ne01;
5702
+ // number of rows per thread
5703
+ const int dr = (nr + nth - 1) / nth;
5704
+ // row range for this thread
5705
+ const int ir0 = dr * ith;
5706
+ const int ir1 = MIN(ir0 + dr, nr);
5707
+
5061
5708
  if (src0->type == dst->type &&
5062
- src0->ne[0] == dst->ne[0] &&
5063
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5709
+ ne00 == ne0 &&
5710
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5064
5711
  // copy by rows
5065
5712
  const size_t rs = ne00*nb00;
5066
5713
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5067
5714
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5068
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5715
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5069
5716
  memcpy(
5070
5717
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5071
5718
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5079,21 +5726,21 @@ static void ggml_compute_forward_dup_f16(
5079
5726
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
5080
5727
 
5081
5728
  if (ggml_is_contiguous(dst)) {
5082
- if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5729
+ if (nb00 == sizeof(ggml_fp16_t)) {
5083
5730
  if (dst->type == GGML_TYPE_F16) {
5084
5731
  size_t id = 0;
5085
- const size_t rs = ne00*nb00;
5732
+ const size_t rs = ne00 * nb00;
5733
+ char * dst_ptr = (char *) dst->data;
5086
5734
 
5087
5735
  for (int i03 = 0; i03 < ne03; i03++) {
5088
5736
  for (int i02 = 0; i02 < ne02; i02++) {
5089
- for (int i01 = 0; i01 < ne01; i01++) {
5737
+ id += rs * ir0;
5738
+ for (int i01 = ir0; i01 < ir1; i01++) {
5090
5739
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5091
- char * dst_ptr = (char *) dst->data + id*rs;
5092
-
5093
- memcpy(dst_ptr, src0_ptr, rs);
5094
-
5095
- id++;
5740
+ memcpy(dst_ptr + id, src0_ptr, rs);
5741
+ id += rs;
5096
5742
  }
5743
+ id += rs * (ne01 - ir1);
5097
5744
  }
5098
5745
  }
5099
5746
  } else if (dst->type == GGML_TYPE_F32) {
@@ -5102,14 +5749,39 @@ static void ggml_compute_forward_dup_f16(
5102
5749
 
5103
5750
  for (int i03 = 0; i03 < ne03; i03++) {
5104
5751
  for (int i02 = 0; i02 < ne02; i02++) {
5105
- for (int i01 = 0; i01 < ne01; i01++) {
5752
+ id += ne00 * ir0;
5753
+ for (int i01 = ir0; i01 < ir1; i01++) {
5754
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5106
5755
  for (int i00 = 0; i00 < ne00; i00++) {
5107
- const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5108
-
5109
- dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5756
+ dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5110
5757
  id++;
5111
5758
  }
5112
5759
  }
5760
+ id += ne00 * (ne01 - ir1);
5761
+ }
5762
+ }
5763
+ } else if (ggml_is_quantized(dst->type)) {
5764
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5765
+ float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5766
+
5767
+ size_t id = 0;
5768
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5769
+ char * dst_ptr = (char *) dst->data;
5770
+
5771
+ for (int i03 = 0; i03 < ne03; i03++) {
5772
+ for (int i02 = 0; i02 < ne02; i02++) {
5773
+ id += rs * ir0;
5774
+ for (int i01 = ir0; i01 < ir1; i01++) {
5775
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5776
+
5777
+ for (int i00 = 0; i00 < ne00; i00++) {
5778
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5779
+ }
5780
+
5781
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
5782
+ id += rs;
5783
+ }
5784
+ id += rs * (ne01 - ir1);
5113
5785
  }
5114
5786
  }
5115
5787
  } else {
@@ -5124,7 +5796,8 @@ static void ggml_compute_forward_dup_f16(
5124
5796
 
5125
5797
  for (int i03 = 0; i03 < ne03; i03++) {
5126
5798
  for (int i02 = 0; i02 < ne02; i02++) {
5127
- for (int i01 = 0; i01 < ne01; i01++) {
5799
+ id += ne00 * ir0;
5800
+ for (int i01 = ir0; i01 < ir1; i01++) {
5128
5801
  for (int i00 = 0; i00 < ne00; i00++) {
5129
5802
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5130
5803
 
@@ -5132,6 +5805,7 @@ static void ggml_compute_forward_dup_f16(
5132
5805
  id++;
5133
5806
  }
5134
5807
  }
5808
+ id += ne00 * (ne01 - ir1);
5135
5809
  }
5136
5810
  }
5137
5811
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5140,7 +5814,8 @@ static void ggml_compute_forward_dup_f16(
5140
5814
 
5141
5815
  for (int i03 = 0; i03 < ne03; i03++) {
5142
5816
  for (int i02 = 0; i02 < ne02; i02++) {
5143
- for (int i01 = 0; i01 < ne01; i01++) {
5817
+ id += ne00 * ir0;
5818
+ for (int i01 = ir0; i01 < ir1; i01++) {
5144
5819
  for (int i00 = 0; i00 < ne00; i00++) {
5145
5820
  const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5146
5821
 
@@ -5148,6 +5823,7 @@ static void ggml_compute_forward_dup_f16(
5148
5823
  id++;
5149
5824
  }
5150
5825
  }
5826
+ id += ne00 * (ne01 - ir1);
5151
5827
  }
5152
5828
  }
5153
5829
  } else {
@@ -5166,7 +5842,20 @@ static void ggml_compute_forward_dup_f16(
5166
5842
  if (dst->type == GGML_TYPE_F16) {
5167
5843
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5168
5844
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5169
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5845
+ i10 += ne00 * ir0;
5846
+ while (i10 >= ne0) {
5847
+ i10 -= ne0;
5848
+ if (++i11 == ne1) {
5849
+ i11 = 0;
5850
+ if (++i12 == ne2) {
5851
+ i12 = 0;
5852
+ if (++i13 == ne3) {
5853
+ i13 = 0;
5854
+ }
5855
+ }
5856
+ }
5857
+ }
5858
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5170
5859
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5171
5860
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5172
5861
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
@@ -5187,25 +5876,51 @@ static void ggml_compute_forward_dup_f16(
5187
5876
  }
5188
5877
  }
5189
5878
  }
5879
+ i10 += ne00 * (ne01 - ir1);
5880
+ while (i10 >= ne0) {
5881
+ i10 -= ne0;
5882
+ if (++i11 == ne1) {
5883
+ i11 = 0;
5884
+ if (++i12 == ne2) {
5885
+ i12 = 0;
5886
+ if (++i13 == ne3) {
5887
+ i13 = 0;
5888
+ }
5889
+ }
5890
+ }
5891
+ }
5190
5892
  }
5191
5893
  }
5192
5894
  } else if (dst->type == GGML_TYPE_F32) {
5193
5895
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5194
5896
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5195
- for (int64_t i01 = 0; i01 < ne01; i01++) {
5897
+ i10 += ne00 * ir0;
5898
+ while (i10 >= ne0) {
5899
+ i10 -= ne0;
5900
+ if (++i11 == ne1) {
5901
+ i11 = 0;
5902
+ if (++i12 == ne2) {
5903
+ i12 = 0;
5904
+ if (++i13 == ne3) {
5905
+ i13 = 0;
5906
+ }
5907
+ }
5908
+ }
5909
+ }
5910
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5196
5911
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5197
5912
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5198
5913
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5199
5914
 
5200
5915
  *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
5201
5916
 
5202
- if (++i10 == ne00) {
5917
+ if (++i10 == ne0) {
5203
5918
  i10 = 0;
5204
- if (++i11 == ne01) {
5919
+ if (++i11 == ne1) {
5205
5920
  i11 = 0;
5206
- if (++i12 == ne02) {
5921
+ if (++i12 == ne2) {
5207
5922
  i12 = 0;
5208
- if (++i13 == ne03) {
5923
+ if (++i13 == ne3) {
5209
5924
  i13 = 0;
5210
5925
  }
5211
5926
  }
@@ -5213,6 +5928,19 @@ static void ggml_compute_forward_dup_f16(
5213
5928
  }
5214
5929
  }
5215
5930
  }
5931
+ i10 += ne00 * (ne01 - ir1);
5932
+ while (i10 >= ne0) {
5933
+ i10 -= ne0;
5934
+ if (++i11 == ne1) {
5935
+ i11 = 0;
5936
+ if (++i12 == ne2) {
5937
+ i12 = 0;
5938
+ if (++i13 == ne3) {
5939
+ i13 = 0;
5940
+ }
5941
+ }
5942
+ }
5943
+ }
5216
5944
  }
5217
5945
  }
5218
5946
  } else {
@@ -5224,7 +5952,6 @@ static void ggml_compute_forward_dup_f32(
5224
5952
  const struct ggml_compute_params * params,
5225
5953
  const struct ggml_tensor * src0,
5226
5954
  struct ggml_tensor * dst) {
5227
- GGML_ASSERT(params->ith == 0);
5228
5955
  GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5229
5956
 
5230
5957
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@@ -5236,6 +5963,11 @@ static void ggml_compute_forward_dup_f32(
5236
5963
  const int64_t ne02 = src0->ne[2];
5237
5964
  const int64_t ne03 = src0->ne[3];
5238
5965
 
5966
+ const int64_t ne0 = dst->ne[0];
5967
+ const int64_t ne1 = dst->ne[1];
5968
+ const int64_t ne2 = dst->ne[2];
5969
+ const int64_t ne3 = dst->ne[3];
5970
+
5239
5971
  const size_t nb00 = src0->nb[0];
5240
5972
  const size_t nb01 = src0->nb[1];
5241
5973
  const size_t nb02 = src0->nb[2];
@@ -5246,19 +5978,40 @@ static void ggml_compute_forward_dup_f32(
5246
5978
  const size_t nb2 = dst->nb[2];
5247
5979
  const size_t nb3 = dst->nb[3];
5248
5980
 
5981
+ const int ith = params->ith; // thread index
5982
+ const int nth = params->nth; // number of threads
5983
+
5249
5984
  if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
5250
- memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
5985
+ // parallelize by elements
5986
+ const int ne = ggml_nelements(dst);
5987
+ const int dr = (ne + nth - 1) / nth;
5988
+ const int ie0 = dr * ith;
5989
+ const int ie1 = MIN(ie0 + dr, ne);
5990
+
5991
+ memcpy(
5992
+ ((char *) dst->data + ie0*nb0),
5993
+ ((char *) src0->data + ie0*nb00),
5994
+ (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
5995
+
5251
5996
  return;
5252
5997
  }
5253
5998
 
5999
+ // parallelize by rows
6000
+ const int nr = ne01;
6001
+ // number of rows per thread
6002
+ const int dr = (nr + nth - 1) / nth;
6003
+ // row range for this thread
6004
+ const int ir0 = dr * ith;
6005
+ const int ir1 = MIN(ir0 + dr, nr);
6006
+
5254
6007
  if (src0->type == dst->type &&
5255
- src0->ne[0] == dst->ne[0] &&
5256
- src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
6008
+ ne00 == ne0 &&
6009
+ nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
5257
6010
  // copy by rows
5258
6011
  const size_t rs = ne00*nb00;
5259
6012
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5260
6013
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5261
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6014
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5262
6015
  memcpy(
5263
6016
  ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5264
6017
  ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
@@ -5271,21 +6024,21 @@ static void ggml_compute_forward_dup_f32(
5271
6024
 
5272
6025
  if (ggml_is_contiguous(dst)) {
5273
6026
  // TODO: simplify
5274
- if (src0->nb[0] == sizeof(float)) {
6027
+ if (nb00 == sizeof(float)) {
5275
6028
  if (dst->type == GGML_TYPE_F32) {
5276
6029
  size_t id = 0;
5277
- const size_t rs = ne00*nb00;
6030
+ const size_t rs = ne00 * nb00;
6031
+ char * dst_ptr = (char *) dst->data;
5278
6032
 
5279
6033
  for (int i03 = 0; i03 < ne03; i03++) {
5280
6034
  for (int i02 = 0; i02 < ne02; i02++) {
5281
- for (int i01 = 0; i01 < ne01; i01++) {
6035
+ id += rs * ir0;
6036
+ for (int i01 = ir0; i01 < ir1; i01++) {
5282
6037
  const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5283
- char * dst_ptr = (char *) dst->data + id*rs;
5284
-
5285
- memcpy(dst_ptr, src0_ptr, rs);
5286
-
5287
- id++;
6038
+ memcpy(dst_ptr + id, src0_ptr, rs);
6039
+ id += rs;
5288
6040
  }
6041
+ id += rs * (ne01 - ir1);
5289
6042
  }
5290
6043
  }
5291
6044
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5294,7 +6047,8 @@ static void ggml_compute_forward_dup_f32(
5294
6047
 
5295
6048
  for (int i03 = 0; i03 < ne03; i03++) {
5296
6049
  for (int i02 = 0; i02 < ne02; i02++) {
5297
- for (int i01 = 0; i01 < ne01; i01++) {
6050
+ id += ne00 * ir0;
6051
+ for (int i01 = ir0; i01 < ir1; i01++) {
5298
6052
  for (int i00 = 0; i00 < ne00; i00++) {
5299
6053
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5300
6054
 
@@ -5302,6 +6056,25 @@ static void ggml_compute_forward_dup_f32(
5302
6056
  id++;
5303
6057
  }
5304
6058
  }
6059
+ id += ne00 * (ne01 - ir1);
6060
+ }
6061
+ }
6062
+ } else if (ggml_is_quantized(dst->type)) {
6063
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
6064
+
6065
+ size_t id = 0;
6066
+ size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
6067
+ char * dst_ptr = (char *) dst->data;
6068
+
6069
+ for (int i03 = 0; i03 < ne03; i03++) {
6070
+ for (int i02 = 0; i02 < ne02; i02++) {
6071
+ id += rs * ir0;
6072
+ for (int i01 = ir0; i01 < ir1; i01++) {
6073
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
6074
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
6075
+ id += rs;
6076
+ }
6077
+ id += rs * (ne01 - ir1);
5305
6078
  }
5306
6079
  }
5307
6080
  } else {
@@ -5316,7 +6089,8 @@ static void ggml_compute_forward_dup_f32(
5316
6089
 
5317
6090
  for (int i03 = 0; i03 < ne03; i03++) {
5318
6091
  for (int i02 = 0; i02 < ne02; i02++) {
5319
- for (int i01 = 0; i01 < ne01; i01++) {
6092
+ id += ne00 * ir0;
6093
+ for (int i01 = ir0; i01 < ir1; i01++) {
5320
6094
  for (int i00 = 0; i00 < ne00; i00++) {
5321
6095
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5322
6096
 
@@ -5324,6 +6098,7 @@ static void ggml_compute_forward_dup_f32(
5324
6098
  id++;
5325
6099
  }
5326
6100
  }
6101
+ id += ne00 * (ne01 - ir1);
5327
6102
  }
5328
6103
  }
5329
6104
  } else if (dst->type == GGML_TYPE_F16) {
@@ -5332,7 +6107,8 @@ static void ggml_compute_forward_dup_f32(
5332
6107
 
5333
6108
  for (int i03 = 0; i03 < ne03; i03++) {
5334
6109
  for (int i02 = 0; i02 < ne02; i02++) {
5335
- for (int i01 = 0; i01 < ne01; i01++) {
6110
+ id += ne00 * ir0;
6111
+ for (int i01 = ir0; i01 < ir1; i01++) {
5336
6112
  for (int i00 = 0; i00 < ne00; i00++) {
5337
6113
  const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5338
6114
 
@@ -5340,6 +6116,7 @@ static void ggml_compute_forward_dup_f32(
5340
6116
  id++;
5341
6117
  }
5342
6118
  }
6119
+ id += ne00 * (ne01 - ir1);
5343
6120
  }
5344
6121
  }
5345
6122
  } else {
@@ -5351,6 +6128,7 @@ static void ggml_compute_forward_dup_f32(
5351
6128
  }
5352
6129
 
5353
6130
  // dst counters
6131
+
5354
6132
  int64_t i10 = 0;
5355
6133
  int64_t i11 = 0;
5356
6134
  int64_t i12 = 0;
@@ -5359,20 +6137,33 @@ static void ggml_compute_forward_dup_f32(
5359
6137
  if (dst->type == GGML_TYPE_F32) {
5360
6138
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5361
6139
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5362
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6140
+ i10 += ne00 * ir0;
6141
+ while (i10 >= ne0) {
6142
+ i10 -= ne0;
6143
+ if (++i11 == ne1) {
6144
+ i11 = 0;
6145
+ if (++i12 == ne2) {
6146
+ i12 = 0;
6147
+ if (++i13 == ne3) {
6148
+ i13 = 0;
6149
+ }
6150
+ }
6151
+ }
6152
+ }
6153
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5363
6154
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5364
6155
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5365
6156
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5366
6157
 
5367
6158
  memcpy(dst_ptr, src0_ptr, sizeof(float));
5368
6159
 
5369
- if (++i10 == dst->ne[0]) {
6160
+ if (++i10 == ne0) {
5370
6161
  i10 = 0;
5371
- if (++i11 == dst->ne[1]) {
6162
+ if (++i11 == ne1) {
5372
6163
  i11 = 0;
5373
- if (++i12 == dst->ne[2]) {
6164
+ if (++i12 == ne2) {
5374
6165
  i12 = 0;
5375
- if (++i13 == dst->ne[3]) {
6166
+ if (++i13 == ne3) {
5376
6167
  i13 = 0;
5377
6168
  }
5378
6169
  }
@@ -5380,25 +6171,51 @@ static void ggml_compute_forward_dup_f32(
5380
6171
  }
5381
6172
  }
5382
6173
  }
6174
+ i10 += ne00 * (ne01 - ir1);
6175
+ while (i10 >= ne0) {
6176
+ i10 -= ne0;
6177
+ if (++i11 == ne1) {
6178
+ i11 = 0;
6179
+ if (++i12 == ne2) {
6180
+ i12 = 0;
6181
+ if (++i13 == ne3) {
6182
+ i13 = 0;
6183
+ }
6184
+ }
6185
+ }
6186
+ }
5383
6187
  }
5384
6188
  }
5385
6189
  } else if (dst->type == GGML_TYPE_F16) {
5386
6190
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5387
6191
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5388
- for (int64_t i01 = 0; i01 < ne01; i01++) {
6192
+ i10 += ne00 * ir0;
6193
+ while (i10 >= ne0) {
6194
+ i10 -= ne0;
6195
+ if (++i11 == ne1) {
6196
+ i11 = 0;
6197
+ if (++i12 == ne2) {
6198
+ i12 = 0;
6199
+ if (++i13 == ne3) {
6200
+ i13 = 0;
6201
+ }
6202
+ }
6203
+ }
6204
+ }
6205
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
5389
6206
  for (int64_t i00 = 0; i00 < ne00; i00++) {
5390
6207
  const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5391
6208
  char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
5392
6209
 
5393
6210
  *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
5394
6211
 
5395
- if (++i10 == dst->ne[0]) {
6212
+ if (++i10 == ne0) {
5396
6213
  i10 = 0;
5397
- if (++i11 == dst->ne[1]) {
6214
+ if (++i11 == ne1) {
5398
6215
  i11 = 0;
5399
- if (++i12 == dst->ne[2]) {
6216
+ if (++i12 == ne2) {
5400
6217
  i12 = 0;
5401
- if (++i13 == dst->ne[3]) {
6218
+ if (++i13 == ne3) {
5402
6219
  i13 = 0;
5403
6220
  }
5404
6221
  }
@@ -5406,6 +6223,19 @@ static void ggml_compute_forward_dup_f32(
5406
6223
  }
5407
6224
  }
5408
6225
  }
6226
+ i10 += ne00 * (ne01 - ir1);
6227
+ while (i10 >= ne0) {
6228
+ i10 -= ne0;
6229
+ if (++i11 == ne1) {
6230
+ i11 = 0;
6231
+ if (++i12 == ne2) {
6232
+ i12 = 0;
6233
+ if (++i13 == ne3) {
6234
+ i13 = 0;
6235
+ }
6236
+ }
6237
+ }
6238
+ }
5409
6239
  }
5410
6240
  }
5411
6241
  } else {
@@ -5426,12 +6256,7 @@ static void ggml_compute_forward_dup(
5426
6256
  {
5427
6257
  ggml_compute_forward_dup_f32(params, src0, dst);
5428
6258
  } break;
5429
- case GGML_TYPE_Q4_0:
5430
- case GGML_TYPE_Q4_1:
5431
- case GGML_TYPE_I8:
5432
- case GGML_TYPE_I16:
5433
- case GGML_TYPE_I32:
5434
- case GGML_TYPE_COUNT:
6259
+ default:
5435
6260
  {
5436
6261
  GGML_ASSERT(false);
5437
6262
  } break;
@@ -5497,6 +6322,212 @@ static void ggml_compute_forward_add_f32(
5497
6322
  }
5498
6323
  }
5499
6324
 
6325
+ static void ggml_compute_forward_add_f16_f32(
6326
+ const struct ggml_compute_params * params,
6327
+ const struct ggml_tensor * src0,
6328
+ const struct ggml_tensor * src1,
6329
+ struct ggml_tensor * dst) {
6330
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6331
+
6332
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6333
+ return;
6334
+ }
6335
+
6336
+ const int ith = params->ith;
6337
+ const int nth = params->nth;
6338
+
6339
+ const int n = ggml_nrows(src0);
6340
+ const int nc = src0->ne[0];
6341
+
6342
+ const size_t nb00 = src0->nb[0];
6343
+ const size_t nb01 = src0->nb[1];
6344
+
6345
+ const size_t nb10 = src1->nb[0];
6346
+ const size_t nb11 = src1->nb[1];
6347
+
6348
+ const size_t nb0 = dst->nb[0];
6349
+ const size_t nb1 = dst->nb[1];
6350
+
6351
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6352
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6353
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6354
+
6355
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6356
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6357
+
6358
+ if (nb10 == sizeof(float)) {
6359
+ for (int j = ith; j < n; j += nth) {
6360
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6361
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6362
+ for (int i = 0; i < nc; i++) {
6363
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6364
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
6365
+ }
6366
+ }
6367
+ }
6368
+ else {
6369
+ // src1 is not contiguous
6370
+ GGML_ASSERT(false);
6371
+ }
6372
+ }
6373
+
6374
+ static void ggml_compute_forward_add_f16_f16(
6375
+ const struct ggml_compute_params * params,
6376
+ const struct ggml_tensor * src0,
6377
+ const struct ggml_tensor * src1,
6378
+ struct ggml_tensor * dst) {
6379
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6380
+
6381
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6382
+ return;
6383
+ }
6384
+
6385
+ const int ith = params->ith;
6386
+ const int nth = params->nth;
6387
+
6388
+ const int n = ggml_nrows(src0);
6389
+ const int nc = src0->ne[0];
6390
+
6391
+ const size_t nb00 = src0->nb[0];
6392
+ const size_t nb01 = src0->nb[1];
6393
+
6394
+ const size_t nb10 = src1->nb[0];
6395
+ const size_t nb11 = src1->nb[1];
6396
+
6397
+ const size_t nb0 = dst->nb[0];
6398
+ const size_t nb1 = dst->nb[1];
6399
+
6400
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6401
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
6402
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6403
+
6404
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6405
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6406
+
6407
+ if (nb10 == sizeof(ggml_fp16_t)) {
6408
+ for (int j = ith; j < n; j += nth) {
6409
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6410
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6411
+ for (int i = 0; i < nc; i++) {
6412
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
6413
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
6414
+ }
6415
+ }
6416
+ }
6417
+ else {
6418
+ // src1 is not contiguous
6419
+ GGML_ASSERT(false);
6420
+ }
6421
+ }
6422
+
6423
+ static void ggml_compute_forward_add_q_f32(
6424
+ const struct ggml_compute_params * params,
6425
+ const struct ggml_tensor * src0,
6426
+ const struct ggml_tensor * src1,
6427
+ struct ggml_tensor * dst) {
6428
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6429
+
6430
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6431
+ return;
6432
+ }
6433
+
6434
+ const int64_t ne00 = src0->ne[0];
6435
+ const int64_t ne01 = src0->ne[1];
6436
+ const int64_t ne02 = src0->ne[2];
6437
+ const int64_t ne03 = src0->ne[3];
6438
+
6439
+ //const int64_t ne10 = src1->ne[0];
6440
+ //const int64_t ne11 = src1->ne[1];
6441
+ const int64_t ne12 = src1->ne[2];
6442
+ const int64_t ne13 = src1->ne[3];
6443
+
6444
+ //const int64_t ne0 = dst->ne[0];
6445
+ //const int64_t ne1 = dst->ne[1];
6446
+ const int64_t ne2 = dst->ne[2];
6447
+ const int64_t ne3 = dst->ne[3];
6448
+
6449
+ const int nb00 = src0->nb[0];
6450
+ const int nb01 = src0->nb[1];
6451
+ const int nb02 = src0->nb[2];
6452
+ const int nb03 = src0->nb[3];
6453
+
6454
+ const int nb10 = src1->nb[0];
6455
+ const int nb11 = src1->nb[1];
6456
+ const int nb12 = src1->nb[2];
6457
+ const int nb13 = src1->nb[3];
6458
+
6459
+ const int nb0 = dst->nb[0];
6460
+ const int nb1 = dst->nb[1];
6461
+ const int nb2 = dst->nb[2];
6462
+ const int nb3 = dst->nb[3];
6463
+
6464
+ const int ith = params->ith;
6465
+ const int nth = params->nth;
6466
+
6467
+ GGML_ASSERT(ne02 == ne12);
6468
+ GGML_ASSERT(ne03 == ne13);
6469
+ GGML_ASSERT(ne2 == ne12);
6470
+ GGML_ASSERT(ne3 == ne13);
6471
+
6472
+ const enum ggml_type type = src0->type;
6473
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6474
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6475
+
6476
+ // we don't support permuted src0 or src1
6477
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
6478
+ GGML_ASSERT(nb10 == sizeof(float));
6479
+
6480
+ // dst cannot be transposed or permuted
6481
+ GGML_ASSERT(nb0 <= nb1);
6482
+ GGML_ASSERT(nb1 <= nb2);
6483
+ GGML_ASSERT(nb2 <= nb3);
6484
+
6485
+ GGML_ASSERT(ggml_is_quantized(src0->type));
6486
+ GGML_ASSERT(dst->type == src0->type);
6487
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6488
+
6489
+ // total rows in src0
6490
+ const int nr = ne01*ne02*ne03;
6491
+
6492
+ // rows per thread
6493
+ const int dr = (nr + nth - 1)/nth;
6494
+
6495
+ // row range for this thread
6496
+ const int ir0 = dr*ith;
6497
+ const int ir1 = MIN(ir0 + dr, nr);
6498
+
6499
+ float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
6500
+
6501
+ for (int ir = ir0; ir < ir1; ++ir) {
6502
+ // src0 indices
6503
+ const int i03 = ir/(ne02*ne01);
6504
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
6505
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6506
+
6507
+ // src1 and dst are same shape as src0 => same indices
6508
+ const int i13 = i03;
6509
+ const int i12 = i02;
6510
+ const int i11 = i01;
6511
+
6512
+ const int i3 = i03;
6513
+ const int i2 = i02;
6514
+ const int i1 = i01;
6515
+
6516
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
6517
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
6518
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
6519
+
6520
+ assert(ne00 % 32 == 0);
6521
+
6522
+ // unquantize row from src0 to temp buffer
6523
+ dequantize_row_q(src0_row, wdata, ne00);
6524
+ // add src1
6525
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
6526
+ // quantize row to dst
6527
+ quantize_row_q(wdata, dst_row, ne00);
6528
+ }
6529
+ }
6530
+
5500
6531
  static void ggml_compute_forward_add(
5501
6532
  const struct ggml_compute_params * params,
5502
6533
  const struct ggml_tensor * src0,
@@ -5507,13 +6538,26 @@ static void ggml_compute_forward_add(
5507
6538
  {
5508
6539
  ggml_compute_forward_add_f32(params, src0, src1, dst);
5509
6540
  } break;
6541
+ case GGML_TYPE_F16:
6542
+ {
6543
+ if (src1->type == GGML_TYPE_F16) {
6544
+ ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6545
+ }
6546
+ else if (src1->type == GGML_TYPE_F32) {
6547
+ ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6548
+ }
6549
+ else {
6550
+ GGML_ASSERT(false);
6551
+ }
6552
+ } break;
5510
6553
  case GGML_TYPE_Q4_0:
5511
6554
  case GGML_TYPE_Q4_1:
5512
- case GGML_TYPE_I8:
5513
- case GGML_TYPE_I16:
5514
- case GGML_TYPE_I32:
5515
- case GGML_TYPE_F16:
5516
- case GGML_TYPE_COUNT:
6555
+ case GGML_TYPE_Q4_2:
6556
+ case GGML_TYPE_Q4_3:
6557
+ {
6558
+ ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6559
+ } break;
6560
+ default:
5517
6561
  {
5518
6562
  GGML_ASSERT(false);
5519
6563
  } break;
@@ -5559,13 +6603,7 @@ static void ggml_compute_forward_sub(
5559
6603
  {
5560
6604
  ggml_compute_forward_sub_f32(params, src0, src1, dst);
5561
6605
  } break;
5562
- case GGML_TYPE_Q4_0:
5563
- case GGML_TYPE_Q4_1:
5564
- case GGML_TYPE_I8:
5565
- case GGML_TYPE_I16:
5566
- case GGML_TYPE_I32:
5567
- case GGML_TYPE_F16:
5568
- case GGML_TYPE_COUNT:
6606
+ default:
5569
6607
  {
5570
6608
  GGML_ASSERT(false);
5571
6609
  } break;
@@ -5611,13 +6649,7 @@ static void ggml_compute_forward_mul(
5611
6649
  {
5612
6650
  ggml_compute_forward_mul_f32(params, src0, src1, dst);
5613
6651
  } break;
5614
- case GGML_TYPE_Q4_0:
5615
- case GGML_TYPE_Q4_1:
5616
- case GGML_TYPE_I8:
5617
- case GGML_TYPE_I16:
5618
- case GGML_TYPE_I32:
5619
- case GGML_TYPE_F16:
5620
- case GGML_TYPE_COUNT:
6652
+ default:
5621
6653
  {
5622
6654
  GGML_ASSERT(false);
5623
6655
  } break;
@@ -5663,13 +6695,7 @@ static void ggml_compute_forward_div(
5663
6695
  {
5664
6696
  ggml_compute_forward_div_f32(params, src0, src1, dst);
5665
6697
  } break;
5666
- case GGML_TYPE_Q4_0:
5667
- case GGML_TYPE_Q4_1:
5668
- case GGML_TYPE_I8:
5669
- case GGML_TYPE_I16:
5670
- case GGML_TYPE_I32:
5671
- case GGML_TYPE_F16:
5672
- case GGML_TYPE_COUNT:
6698
+ default:
5673
6699
  {
5674
6700
  GGML_ASSERT(false);
5675
6701
  } break;
@@ -5711,13 +6737,7 @@ static void ggml_compute_forward_sqr(
5711
6737
  {
5712
6738
  ggml_compute_forward_sqr_f32(params, src0, dst);
5713
6739
  } break;
5714
- case GGML_TYPE_Q4_0:
5715
- case GGML_TYPE_Q4_1:
5716
- case GGML_TYPE_I8:
5717
- case GGML_TYPE_I16:
5718
- case GGML_TYPE_I32:
5719
- case GGML_TYPE_F16:
5720
- case GGML_TYPE_COUNT:
6740
+ default:
5721
6741
  {
5722
6742
  GGML_ASSERT(false);
5723
6743
  } break;
@@ -5759,13 +6779,7 @@ static void ggml_compute_forward_sqrt(
5759
6779
  {
5760
6780
  ggml_compute_forward_sqrt_f32(params, src0, dst);
5761
6781
  } break;
5762
- case GGML_TYPE_Q4_0:
5763
- case GGML_TYPE_Q4_1:
5764
- case GGML_TYPE_I8:
5765
- case GGML_TYPE_I16:
5766
- case GGML_TYPE_I32:
5767
- case GGML_TYPE_F16:
5768
- case GGML_TYPE_COUNT:
6782
+ default:
5769
6783
  {
5770
6784
  GGML_ASSERT(false);
5771
6785
  } break;
@@ -5817,13 +6831,7 @@ static void ggml_compute_forward_sum(
5817
6831
  {
5818
6832
  ggml_compute_forward_sum_f32(params, src0, dst);
5819
6833
  } break;
5820
- case GGML_TYPE_Q4_0:
5821
- case GGML_TYPE_Q4_1:
5822
- case GGML_TYPE_I8:
5823
- case GGML_TYPE_I16:
5824
- case GGML_TYPE_I32:
5825
- case GGML_TYPE_F16:
5826
- case GGML_TYPE_COUNT:
6834
+ default:
5827
6835
  {
5828
6836
  GGML_ASSERT(false);
5829
6837
  } break;
@@ -5894,13 +6902,7 @@ static void ggml_compute_forward_mean(
5894
6902
  {
5895
6903
  ggml_compute_forward_mean_f32(params, src0, dst);
5896
6904
  } break;
5897
- case GGML_TYPE_Q4_0:
5898
- case GGML_TYPE_Q4_1:
5899
- case GGML_TYPE_I8:
5900
- case GGML_TYPE_I16:
5901
- case GGML_TYPE_I32:
5902
- case GGML_TYPE_F16:
5903
- case GGML_TYPE_COUNT:
6905
+ default:
5904
6906
  {
5905
6907
  GGML_ASSERT(false);
5906
6908
  } break;
@@ -5958,13 +6960,7 @@ static void ggml_compute_forward_repeat(
5958
6960
  {
5959
6961
  ggml_compute_forward_repeat_f32(params, src0, dst);
5960
6962
  } break;
5961
- case GGML_TYPE_Q4_0:
5962
- case GGML_TYPE_Q4_1:
5963
- case GGML_TYPE_I8:
5964
- case GGML_TYPE_I16:
5965
- case GGML_TYPE_I32:
5966
- case GGML_TYPE_F16:
5967
- case GGML_TYPE_COUNT:
6963
+ default:
5968
6964
  {
5969
6965
  GGML_ASSERT(false);
5970
6966
  } break;
@@ -6006,13 +7002,7 @@ static void ggml_compute_forward_abs(
6006
7002
  {
6007
7003
  ggml_compute_forward_abs_f32(params, src0, dst);
6008
7004
  } break;
6009
- case GGML_TYPE_Q4_0:
6010
- case GGML_TYPE_Q4_1:
6011
- case GGML_TYPE_I8:
6012
- case GGML_TYPE_I16:
6013
- case GGML_TYPE_I32:
6014
- case GGML_TYPE_F16:
6015
- case GGML_TYPE_COUNT:
7005
+ default:
6016
7006
  {
6017
7007
  GGML_ASSERT(false);
6018
7008
  } break;
@@ -6054,13 +7044,7 @@ static void ggml_compute_forward_sgn(
6054
7044
  {
6055
7045
  ggml_compute_forward_sgn_f32(params, src0, dst);
6056
7046
  } break;
6057
- case GGML_TYPE_Q4_0:
6058
- case GGML_TYPE_Q4_1:
6059
- case GGML_TYPE_I8:
6060
- case GGML_TYPE_I16:
6061
- case GGML_TYPE_I32:
6062
- case GGML_TYPE_F16:
6063
- case GGML_TYPE_COUNT:
7047
+ default:
6064
7048
  {
6065
7049
  GGML_ASSERT(false);
6066
7050
  } break;
@@ -6102,13 +7086,7 @@ static void ggml_compute_forward_neg(
6102
7086
  {
6103
7087
  ggml_compute_forward_neg_f32(params, src0, dst);
6104
7088
  } break;
6105
- case GGML_TYPE_Q4_0:
6106
- case GGML_TYPE_Q4_1:
6107
- case GGML_TYPE_I8:
6108
- case GGML_TYPE_I16:
6109
- case GGML_TYPE_I32:
6110
- case GGML_TYPE_F16:
6111
- case GGML_TYPE_COUNT:
7089
+ default:
6112
7090
  {
6113
7091
  GGML_ASSERT(false);
6114
7092
  } break;
@@ -6150,13 +7128,7 @@ static void ggml_compute_forward_step(
6150
7128
  {
6151
7129
  ggml_compute_forward_step_f32(params, src0, dst);
6152
7130
  } break;
6153
- case GGML_TYPE_Q4_0:
6154
- case GGML_TYPE_Q4_1:
6155
- case GGML_TYPE_I8:
6156
- case GGML_TYPE_I16:
6157
- case GGML_TYPE_I32:
6158
- case GGML_TYPE_F16:
6159
- case GGML_TYPE_COUNT:
7131
+ default:
6160
7132
  {
6161
7133
  GGML_ASSERT(false);
6162
7134
  } break;
@@ -6193,18 +7165,12 @@ static void ggml_compute_forward_relu(
6193
7165
  const struct ggml_compute_params * params,
6194
7166
  const struct ggml_tensor * src0,
6195
7167
  struct ggml_tensor * dst) {
6196
- switch (src0->type) {
6197
- case GGML_TYPE_F32:
6198
- {
6199
- ggml_compute_forward_relu_f32(params, src0, dst);
6200
- } break;
6201
- case GGML_TYPE_Q4_0:
6202
- case GGML_TYPE_Q4_1:
6203
- case GGML_TYPE_I8:
6204
- case GGML_TYPE_I16:
6205
- case GGML_TYPE_I32:
6206
- case GGML_TYPE_F16:
6207
- case GGML_TYPE_COUNT:
7168
+ switch (src0->type) {
7169
+ case GGML_TYPE_F32:
7170
+ {
7171
+ ggml_compute_forward_relu_f32(params, src0, dst);
7172
+ } break;
7173
+ default:
6208
7174
  {
6209
7175
  GGML_ASSERT(false);
6210
7176
  } break;
@@ -6263,13 +7229,7 @@ static void ggml_compute_forward_gelu(
6263
7229
  {
6264
7230
  ggml_compute_forward_gelu_f32(params, src0, dst);
6265
7231
  } break;
6266
- case GGML_TYPE_Q4_0:
6267
- case GGML_TYPE_Q4_1:
6268
- case GGML_TYPE_I8:
6269
- case GGML_TYPE_I16:
6270
- case GGML_TYPE_I32:
6271
- case GGML_TYPE_F16:
6272
- case GGML_TYPE_COUNT:
7232
+ default:
6273
7233
  {
6274
7234
  GGML_ASSERT(false);
6275
7235
  } break;
@@ -6330,13 +7290,7 @@ static void ggml_compute_forward_silu(
6330
7290
  {
6331
7291
  ggml_compute_forward_silu_f32(params, src0, dst);
6332
7292
  } break;
6333
- case GGML_TYPE_Q4_0:
6334
- case GGML_TYPE_Q4_1:
6335
- case GGML_TYPE_I8:
6336
- case GGML_TYPE_I16:
6337
- case GGML_TYPE_I32:
6338
- case GGML_TYPE_F16:
6339
- case GGML_TYPE_COUNT:
7293
+ default:
6340
7294
  {
6341
7295
  GGML_ASSERT(false);
6342
7296
  } break;
@@ -6416,13 +7370,7 @@ static void ggml_compute_forward_norm(
6416
7370
  {
6417
7371
  ggml_compute_forward_norm_f32(params, src0, dst);
6418
7372
  } break;
6419
- case GGML_TYPE_Q4_0:
6420
- case GGML_TYPE_Q4_1:
6421
- case GGML_TYPE_I8:
6422
- case GGML_TYPE_I16:
6423
- case GGML_TYPE_I32:
6424
- case GGML_TYPE_F16:
6425
- case GGML_TYPE_COUNT:
7373
+ default:
6426
7374
  {
6427
7375
  GGML_ASSERT(false);
6428
7376
  } break;
@@ -6496,13 +7444,7 @@ static void ggml_compute_forward_rms_norm(
6496
7444
  {
6497
7445
  ggml_compute_forward_rms_norm_f32(params, src0, dst);
6498
7446
  } break;
6499
- case GGML_TYPE_Q4_0:
6500
- case GGML_TYPE_Q4_1:
6501
- case GGML_TYPE_I8:
6502
- case GGML_TYPE_I16:
6503
- case GGML_TYPE_I32:
6504
- case GGML_TYPE_F16:
6505
- case GGML_TYPE_COUNT:
7447
+ default:
6506
7448
  {
6507
7449
  GGML_ASSERT(false);
6508
7450
  } break;
@@ -6512,7 +7454,7 @@ static void ggml_compute_forward_rms_norm(
6512
7454
 
6513
7455
  // ggml_compute_forward_mul_mat
6514
7456
 
6515
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7457
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6516
7458
  // helper function to determine if it is better to use BLAS or not
6517
7459
  // for large matrices, BLAS is faster
6518
7460
  static bool ggml_compute_forward_mul_mat_use_blas(
@@ -6552,7 +7494,7 @@ static void ggml_compute_forward_mul_mat_f32(
6552
7494
  const int64_t ne02 = src0->ne[2];
6553
7495
  const int64_t ne03 = src0->ne[3];
6554
7496
 
6555
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7497
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6556
7498
  const int64_t ne10 = src1->ne[0];
6557
7499
  #endif
6558
7500
  const int64_t ne11 = src1->ne[1];
@@ -6609,7 +7551,7 @@ static void ggml_compute_forward_mul_mat_f32(
6609
7551
  // nb01 >= nb00 - src0 is not transposed
6610
7552
  // compute by src0 rows
6611
7553
 
6612
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7554
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6613
7555
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
6614
7556
  if (params->ith != 0) {
6615
7557
  return;
@@ -6623,6 +7565,21 @@ static void ggml_compute_forward_mul_mat_f32(
6623
7565
  return;
6624
7566
  }
6625
7567
 
7568
+ #if defined(GGML_USE_CUBLAS)
7569
+ float *d_X = NULL;
7570
+ float *d_Y = NULL;
7571
+ float *d_D = NULL;
7572
+ const float alpha = 1.0f;
7573
+ const float beta = 0.0f;
7574
+ const int x_ne = ne01 * ne10;
7575
+ const int y_ne = ne11 * ne10;
7576
+ const int d_ne = ne11 * ne01;
7577
+
7578
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
7579
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7580
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7581
+ #endif
7582
+
6626
7583
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6627
7584
  for (int64_t i02 = 0; i02 < ne02; i02++) {
6628
7585
  const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
@@ -6630,15 +7587,37 @@ static void ggml_compute_forward_mul_mat_f32(
6630
7587
 
6631
7588
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
6632
7589
 
7590
+ #if defined(GGML_USE_CUBLAS)
7591
+ // copy data to device
7592
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7593
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7594
+
7595
+ // compute
7596
+ CUBLAS_CHECK(
7597
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7598
+ ne01, ne11, ne10,
7599
+ &alpha, d_X, ne00,
7600
+ d_Y, ne10,
7601
+ &beta, d_D, ne01));
7602
+
7603
+ // copy data to host
7604
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7605
+ #else
6633
7606
  // zT = y * xT
6634
7607
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6635
7608
  ne11, ne01, ne10,
6636
7609
  1.0f, y, ne10,
6637
7610
  x, ne00,
6638
7611
  0.0f, d, ne01);
7612
+ #endif
6639
7613
  }
6640
7614
  }
6641
-
7615
+ #if defined(GGML_USE_CUBLAS)
7616
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7617
+ CUDA_CHECK(cudaFree(d_X));
7618
+ CUDA_CHECK(cudaFree(d_Y));
7619
+ CUDA_CHECK(cudaFree(d_D));
7620
+ #endif
6642
7621
  //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
6643
7622
 
6644
7623
  return;
@@ -6768,7 +7747,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6768
7747
  // nb01 >= nb00 - src0 is not transposed
6769
7748
  // compute by src0 rows
6770
7749
 
6771
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
7750
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6772
7751
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
6773
7752
  GGML_ASSERT(nb10 == sizeof(float));
6774
7753
 
@@ -6784,10 +7763,37 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6784
7763
  return;
6785
7764
  }
6786
7765
 
6787
- float * const wdata = params->wdata;
7766
+ #if defined(GGML_USE_CUBLAS)
7767
+ ggml_fp16_t * const wdata = params->wdata;
6788
7768
 
7769
+ float *d_X = NULL;
7770
+ float *d_Y = NULL;
7771
+ float *d_D = NULL;
7772
+ const float alpha = 1.0f;
7773
+ const float beta = 0.0f;
7774
+ const int x_ne = ne01 * ne10;
7775
+ const int y_ne = ne11 * ne10;
7776
+ const int d_ne = ne11 * ne01;
7777
+
7778
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
7779
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
7780
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
7781
+ #else
7782
+ float * const wdata = params->wdata;
7783
+ #endif
6789
7784
  for (int64_t i03 = 0; i03 < ne03; i03++) {
6790
7785
  for (int64_t i02 = 0; i02 < ne02; i02++) {
7786
+ #if defined(GGML_USE_CUBLAS)
7787
+ // with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
7788
+ {
7789
+ size_t id = 0;
7790
+ for (int64_t i01 = 0; i01 < ne11; ++i01) {
7791
+ for (int64_t i00 = 0; i00 < ne10; ++i00) {
7792
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
7793
+ }
7794
+ }
7795
+ }
7796
+ #else
6791
7797
  {
6792
7798
  size_t id = 0;
6793
7799
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -6796,7 +7802,31 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6796
7802
  }
6797
7803
  }
6798
7804
  }
7805
+ #endif
6799
7806
 
7807
+ #if defined(GGML_USE_CUBLAS)
7808
+ const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
7809
+ const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
7810
+
7811
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7812
+
7813
+ // copy data to device
7814
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
7815
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
7816
+
7817
+ // compute
7818
+ CUBLAS_CHECK(
7819
+ cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
7820
+ ne01, ne11, ne10,
7821
+ &alpha, d_X, CUDA_R_16F, ne00,
7822
+ d_Y, CUDA_R_16F, ne10,
7823
+ &beta, d_D, CUDA_R_32F, ne01,
7824
+ CUBLAS_COMPUTE_32F,
7825
+ CUBLAS_GEMM_DEFAULT));
7826
+
7827
+ // copy data to host
7828
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
7829
+ #else
6800
7830
  const float * x = wdata;
6801
7831
  const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
6802
7832
 
@@ -6808,9 +7838,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6808
7838
  1.0f, y, ne10,
6809
7839
  x, ne00,
6810
7840
  0.0f, d, ne01);
7841
+ #endif
6811
7842
  }
6812
7843
  }
6813
7844
 
7845
+ #if defined(GGML_USE_CUBLAS)
7846
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
7847
+ CUDA_CHECK(cudaFree(d_X));
7848
+ CUDA_CHECK(cudaFree(d_Y));
7849
+ CUDA_CHECK(cudaFree(d_D));
7850
+ #endif
6814
7851
  /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
6815
7852
 
6816
7853
  return;
@@ -6894,27 +7931,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6894
7931
  //}
6895
7932
  }
6896
7933
 
6897
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6898
- [GGML_TYPE_Q4_0] = {
6899
- .dequantize_row_q = dequantize_row_q4_0,
6900
- .quantize_row_q = quantize_row_q4_0,
6901
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6902
- .vec_dot_q = ggml_vec_dot_q4_0,
6903
- },
6904
- [GGML_TYPE_Q4_1] = {
6905
- .dequantize_row_q = dequantize_row_q4_1,
6906
- .quantize_row_q = quantize_row_q4_1,
6907
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6908
- .vec_dot_q = ggml_vec_dot_q4_1,
6909
- },
6910
- };
6911
-
6912
- // For internal test use
6913
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6914
- GGML_ASSERT(i < GGML_TYPE_COUNT);
6915
- return quantize_fns[i];
6916
- }
6917
-
6918
7934
  static void ggml_compute_forward_mul_mat_q_f32(
6919
7935
  const struct ggml_compute_params * params,
6920
7936
  const struct ggml_tensor * src0,
@@ -6962,8 +7978,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
6962
7978
  GGML_ASSERT(ne3 == ne13);
6963
7979
 
6964
7980
  const enum ggml_type type = src0->type;
6965
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6966
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
7981
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7982
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
6967
7983
 
6968
7984
  // we don't support permuted src0 or src1
6969
7985
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -6983,7 +7999,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6983
7999
  // nb01 >= nb00 - src0 is not transposed
6984
8000
  // compute by src0 rows
6985
8001
 
6986
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
8002
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
6987
8003
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
6988
8004
  if (params->ith != 0) {
6989
8005
  return;
@@ -6997,11 +8013,55 @@ static void ggml_compute_forward_mul_mat_q_f32(
6997
8013
  return;
6998
8014
  }
6999
8015
 
8016
+ #if defined(GGML_USE_CUBLAS)
8017
+ float *d_X = NULL;
8018
+ float *d_Y = NULL;
8019
+ float *d_D = NULL;
8020
+ float *d_Q = NULL;
8021
+ const float alpha = 1.0f;
8022
+ const float beta = 0.0f;
8023
+ const int x_ne = ne01 * ne10;
8024
+ const int y_ne = ne11 * ne10;
8025
+ const int d_ne = ne11 * ne01;
8026
+
8027
+ CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
8028
+ CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
8029
+ CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
8030
+ CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
8031
+
8032
+ void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8033
+ if (type == GGML_TYPE_Q4_0) {
8034
+ dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8035
+ }
8036
+ else if (type == GGML_TYPE_Q4_1) {
8037
+ dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8038
+ }
8039
+ else if (type == GGML_TYPE_Q4_2) {
8040
+ dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8041
+ }
8042
+ else {
8043
+ GGML_ASSERT(false);
8044
+ }
8045
+ #else
7000
8046
  float * const wdata = params->wdata;
7001
8047
  dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
8048
+ #endif
7002
8049
 
7003
8050
  for (int64_t i03 = 0; i03 < ne03; i03++) {
7004
8051
  for (int64_t i02 = 0; i02 < ne02; i02++) {
8052
+ const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8053
+
8054
+ float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
8055
+
8056
+ #if defined(GGML_USE_CUBLAS)
8057
+ // copy and dequantize on device
8058
+ CUDA_CHECK(
8059
+ cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8060
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
8061
+
8062
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
8063
+ CUDA_CHECK(cudaGetLastError());
8064
+ #else
7005
8065
  {
7006
8066
  size_t id = 0;
7007
8067
  for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7009,21 +8069,42 @@ static void ggml_compute_forward_mul_mat_q_f32(
7009
8069
  id += ne00;
7010
8070
  }
7011
8071
  }
7012
-
7013
8072
  const float * x = wdata;
7014
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
8073
+ #endif
7015
8074
 
7016
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
7017
8075
 
8076
+ #if defined(GGML_USE_CUBLAS)
8077
+ // copy data to device
8078
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
8079
+
8080
+ // compute
8081
+ CUBLAS_CHECK(
8082
+ cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
8083
+ ne01, ne11, ne10,
8084
+ &alpha, d_X, ne00,
8085
+ d_Y, ne10,
8086
+ &beta, d_D, ne01));
8087
+
8088
+ // copy data to host
8089
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
8090
+ #else
7018
8091
  // zT = y * xT
7019
8092
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
7020
8093
  ne11, ne01, ne10,
7021
8094
  1.0f, y, ne10,
7022
8095
  x, ne00,
7023
8096
  0.0f, d, ne01);
8097
+ #endif
7024
8098
  }
7025
8099
  }
7026
8100
 
8101
+ #if defined(GGML_USE_CUBLAS)
8102
+ CUDA_CHECK(cudaStreamSynchronize(cudaStream));
8103
+ CUDA_CHECK(cudaFree(d_X));
8104
+ CUDA_CHECK(cudaFree(d_Y));
8105
+ CUDA_CHECK(cudaFree(d_D));
8106
+ CUDA_CHECK(cudaFree(d_Q));
8107
+ #endif
7027
8108
  //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
7028
8109
 
7029
8110
  return;
@@ -7032,12 +8113,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
7032
8113
 
7033
8114
  if (params->type == GGML_TASK_INIT) {
7034
8115
  char * wdata = params->wdata;
7035
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
8116
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7036
8117
 
7037
8118
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7038
8119
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7039
8120
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
7040
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
8121
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7041
8122
  wdata += row_size;
7042
8123
  }
7043
8124
  }
@@ -7063,7 +8144,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
7063
8144
  const int ir1 = MIN(ir0 + dr, nr);
7064
8145
 
7065
8146
  void * wdata = params->wdata;
7066
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
8147
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
7067
8148
 
7068
8149
  for (int ir = ir0; ir < ir1; ++ir) {
7069
8150
  // src0 indices
@@ -7111,6 +8192,9 @@ static void ggml_compute_forward_mul_mat(
7111
8192
  switch (src0->type) {
7112
8193
  case GGML_TYPE_Q4_0:
7113
8194
  case GGML_TYPE_Q4_1:
8195
+ case GGML_TYPE_Q4_2:
8196
+ case GGML_TYPE_Q4_3:
8197
+ case GGML_TYPE_Q8_0:
7114
8198
  {
7115
8199
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
7116
8200
  } break;
@@ -7122,42 +8206,11 @@ static void ggml_compute_forward_mul_mat(
7122
8206
  {
7123
8207
  ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
7124
8208
  } break;
7125
- case GGML_TYPE_I8:
7126
- case GGML_TYPE_I16:
7127
- case GGML_TYPE_I32:
7128
- case GGML_TYPE_COUNT:
8209
+ default:
7129
8210
  {
7130
8211
  GGML_ASSERT(false);
7131
8212
  } break;
7132
8213
  }
7133
-
7134
- #if 0
7135
- if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
7136
- static int first = 8;
7137
- printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7138
- printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7139
- printf("dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7140
- if (first) {
7141
- --first;
7142
- } else {
7143
- for (int k = 0; k < dst->ne[1]; ++k) {
7144
- for (int j = 0; j < dst->ne[0]/16; ++j) {
7145
- for (int i = 0; i < 16; ++i) {
7146
- printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
7147
- }
7148
- printf("\n");
7149
- }
7150
- printf("\n");
7151
- }
7152
- printf("\n");
7153
- exit(0);
7154
- }
7155
- } else {
7156
- printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
7157
- printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
7158
- printf("aaaa dst: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
7159
- }
7160
- #endif
7161
8214
  }
7162
8215
 
7163
8216
  // ggml_compute_forward_scale
@@ -7207,13 +8260,7 @@ static void ggml_compute_forward_scale(
7207
8260
  {
7208
8261
  ggml_compute_forward_scale_f32(params, src0, src1, dst);
7209
8262
  } break;
7210
- case GGML_TYPE_Q4_0:
7211
- case GGML_TYPE_Q4_1:
7212
- case GGML_TYPE_I8:
7213
- case GGML_TYPE_I16:
7214
- case GGML_TYPE_I32:
7215
- case GGML_TYPE_F16:
7216
- case GGML_TYPE_COUNT:
8263
+ default:
7217
8264
  {
7218
8265
  GGML_ASSERT(false);
7219
8266
  } break;
@@ -7374,6 +8421,9 @@ static void ggml_compute_forward_get_rows(
7374
8421
  switch (src0->type) {
7375
8422
  case GGML_TYPE_Q4_0:
7376
8423
  case GGML_TYPE_Q4_1:
8424
+ case GGML_TYPE_Q4_2:
8425
+ case GGML_TYPE_Q4_3:
8426
+ case GGML_TYPE_Q8_0:
7377
8427
  {
7378
8428
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
7379
8429
  } break;
@@ -7385,10 +8435,7 @@ static void ggml_compute_forward_get_rows(
7385
8435
  {
7386
8436
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
7387
8437
  } break;
7388
- case GGML_TYPE_I8:
7389
- case GGML_TYPE_I16:
7390
- case GGML_TYPE_I32:
7391
- case GGML_TYPE_COUNT:
8438
+ default:
7392
8439
  {
7393
8440
  GGML_ASSERT(false);
7394
8441
  } break;
@@ -7461,13 +8508,7 @@ static void ggml_compute_forward_diag_mask_inf(
7461
8508
  {
7462
8509
  ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
7463
8510
  } break;
7464
- case GGML_TYPE_Q4_0:
7465
- case GGML_TYPE_Q4_1:
7466
- case GGML_TYPE_I8:
7467
- case GGML_TYPE_I16:
7468
- case GGML_TYPE_I32:
7469
- case GGML_TYPE_F16:
7470
- case GGML_TYPE_COUNT:
8511
+ default:
7471
8512
  {
7472
8513
  GGML_ASSERT(false);
7473
8514
  } break;
@@ -7555,13 +8596,7 @@ static void ggml_compute_forward_soft_max(
7555
8596
  {
7556
8597
  ggml_compute_forward_soft_max_f32(params, src0, dst);
7557
8598
  } break;
7558
- case GGML_TYPE_Q4_0:
7559
- case GGML_TYPE_Q4_1:
7560
- case GGML_TYPE_I8:
7561
- case GGML_TYPE_I16:
7562
- case GGML_TYPE_I32:
7563
- case GGML_TYPE_F16:
7564
- case GGML_TYPE_COUNT:
8599
+ default:
7565
8600
  {
7566
8601
  GGML_ASSERT(false);
7567
8602
  } break;
@@ -7618,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
7618
8653
 
7619
8654
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
7620
8655
 
8656
+ const bool is_neox = mode & 2;
8657
+
7621
8658
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7622
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7623
- const int p = (mode == 0 ? n_past + i2 : i2);
8659
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8660
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
7624
8661
  for (int64_t i1 = 0; i1 < ne1; i1++) {
7625
8662
  if (ir++ < ir0) continue;
7626
8663
  if (ir > ir1) break;
@@ -7633,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
7633
8670
 
7634
8671
  theta *= theta_scale;
7635
8672
 
7636
- const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7637
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8673
+ if (!is_neox) {
8674
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8675
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8676
+
8677
+ const float x0 = src[0];
8678
+ const float x1 = src[1];
7638
8679
 
7639
- const float x0 = src[0];
7640
- const float x1 = src[1];
8680
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8681
+ dst_data[1] = x0*sin_theta + x1*cos_theta;
8682
+ } else {
8683
+ const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8684
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
7641
8685
 
7642
- dst_data[0] = x0*cos_theta - x1*sin_theta;
7643
- dst_data[1] = x0*sin_theta + x1*cos_theta;
8686
+ const float x0 = src[0];
8687
+ const float x1 = src[n_dims/2];
8688
+
8689
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
8690
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
8691
+ }
7644
8692
  }
7645
8693
  }
7646
8694
  }
@@ -7695,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
7695
8743
 
7696
8744
  const float theta_scale = powf(10000.0, -2.0f/n_dims);
7697
8745
 
8746
+ const bool is_neox = mode & 2;
8747
+
7698
8748
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7699
- for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7700
- const int p = (mode == 0 ? n_past + i2 : i2);
8749
+ for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8750
+ const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
7701
8751
  for (int64_t i1 = 0; i1 < ne1; i1++) {
7702
8752
  if (ir++ < ir0) continue;
7703
8753
  if (ir > ir1) break;
@@ -7710,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
7710
8760
 
7711
8761
  theta *= theta_scale;
7712
8762
 
7713
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7714
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8763
+ if (!is_neox) {
8764
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8765
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8766
+
8767
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8768
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
7715
8769
 
7716
- const float x0 = ggml_fp16_to_fp32(src[0]);
7717
- const float x1 = ggml_fp16_to_fp32(src[1]);
8770
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8771
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8772
+ } else {
8773
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8774
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
7718
8775
 
7719
- dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
7720
- dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
8776
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8777
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8778
+
8779
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8780
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8781
+ }
7721
8782
  }
7722
8783
  }
7723
8784
  }
@@ -7738,12 +8799,7 @@ static void ggml_compute_forward_rope(
7738
8799
  {
7739
8800
  ggml_compute_forward_rope_f32(params, src0, src1, dst);
7740
8801
  } break;
7741
- case GGML_TYPE_Q4_0:
7742
- case GGML_TYPE_Q4_1:
7743
- case GGML_TYPE_I8:
7744
- case GGML_TYPE_I16:
7745
- case GGML_TYPE_I32:
7746
- case GGML_TYPE_COUNT:
8802
+ default:
7747
8803
  {
7748
8804
  GGML_ASSERT(false);
7749
8805
  } break;
@@ -8006,12 +9062,7 @@ static void ggml_compute_forward_conv_1d_1s(
8006
9062
  {
8007
9063
  ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
8008
9064
  } break;
8009
- case GGML_TYPE_Q4_0:
8010
- case GGML_TYPE_Q4_1:
8011
- case GGML_TYPE_I8:
8012
- case GGML_TYPE_I16:
8013
- case GGML_TYPE_I32:
8014
- case GGML_TYPE_COUNT:
9065
+ default:
8015
9066
  {
8016
9067
  GGML_ASSERT(false);
8017
9068
  } break;
@@ -8274,12 +9325,7 @@ static void ggml_compute_forward_conv_1d_2s(
8274
9325
  {
8275
9326
  ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
8276
9327
  } break;
8277
- case GGML_TYPE_Q4_0:
8278
- case GGML_TYPE_Q4_1:
8279
- case GGML_TYPE_I8:
8280
- case GGML_TYPE_I16:
8281
- case GGML_TYPE_I32:
8282
- case GGML_TYPE_COUNT:
9328
+ default:
8283
9329
  {
8284
9330
  GGML_ASSERT(false);
8285
9331
  } break;
@@ -8759,12 +9805,7 @@ static void ggml_compute_forward_flash_attn(
8759
9805
  {
8760
9806
  ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
8761
9807
  } break;
8762
- case GGML_TYPE_Q4_0:
8763
- case GGML_TYPE_Q4_1:
8764
- case GGML_TYPE_I8:
8765
- case GGML_TYPE_I16:
8766
- case GGML_TYPE_I32:
8767
- case GGML_TYPE_COUNT:
9808
+ default:
8768
9809
  {
8769
9810
  GGML_ASSERT(false);
8770
9811
  } break;
@@ -8970,12 +10011,7 @@ static void ggml_compute_forward_flash_ff(
8970
10011
  {
8971
10012
  GGML_ASSERT(false); // TODO
8972
10013
  } break;
8973
- case GGML_TYPE_Q4_0:
8974
- case GGML_TYPE_Q4_1:
8975
- case GGML_TYPE_I8:
8976
- case GGML_TYPE_I16:
8977
- case GGML_TYPE_I32:
8978
- case GGML_TYPE_COUNT:
10014
+ default:
8979
10015
  {
8980
10016
  GGML_ASSERT(false);
8981
10017
  } break;
@@ -9019,13 +10055,7 @@ static void ggml_compute_forward_map_unary(
9019
10055
  {
9020
10056
  ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9021
10057
  } break;
9022
- case GGML_TYPE_Q4_0:
9023
- case GGML_TYPE_Q4_1:
9024
- case GGML_TYPE_I8:
9025
- case GGML_TYPE_I16:
9026
- case GGML_TYPE_I32:
9027
- case GGML_TYPE_F16:
9028
- case GGML_TYPE_COUNT:
10058
+ default:
9029
10059
  {
9030
10060
  GGML_ASSERT(false);
9031
10061
  } break;
@@ -9074,13 +10104,7 @@ static void ggml_compute_forward_map_binary(
9074
10104
  {
9075
10105
  ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9076
10106
  } break;
9077
- case GGML_TYPE_Q4_0:
9078
- case GGML_TYPE_Q4_1:
9079
- case GGML_TYPE_I8:
9080
- case GGML_TYPE_I16:
9081
- case GGML_TYPE_I32:
9082
- case GGML_TYPE_F16:
9083
- case GGML_TYPE_COUNT:
10107
+ default:
9084
10108
  {
9085
10109
  GGML_ASSERT(false);
9086
10110
  } break;
@@ -9830,13 +10854,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9830
10854
  struct ggml_tensor * node = cgraph->nodes[i];
9831
10855
 
9832
10856
  switch (node->op) {
10857
+ case GGML_OP_CPY:
9833
10858
  case GGML_OP_DUP:
9834
10859
  {
9835
- node->n_tasks = 1;
10860
+ node->n_tasks = n_threads;
10861
+
10862
+ size_t cur = 0;
10863
+ if (ggml_is_quantized(node->type)) {
10864
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
10865
+ }
10866
+
10867
+ work_size = MAX(work_size, cur);
9836
10868
  } break;
9837
10869
  case GGML_OP_ADD:
9838
10870
  {
9839
10871
  node->n_tasks = n_threads;
10872
+
10873
+ size_t cur = 0;
10874
+
10875
+ if (ggml_is_quantized(node->src0->type)) {
10876
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10877
+ }
10878
+
10879
+ work_size = MAX(work_size, cur);
9840
10880
  } break;
9841
10881
  case GGML_OP_SUB:
9842
10882
  case GGML_OP_MUL:
@@ -9881,7 +10921,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9881
10921
  size_t cur = 0;
9882
10922
 
9883
10923
  if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
9884
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10924
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
9885
10925
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
9886
10926
  node->n_tasks = 1; // TODO: this actually is doing nothing
9887
10927
  // the threads are still spinning
@@ -9897,15 +10937,15 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9897
10937
  #endif
9898
10938
  } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
9899
10939
  cur = 0;
9900
- } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
9901
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10940
+ } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
10941
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
9902
10942
  if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
9903
10943
  node->n_tasks = 1;
9904
10944
  cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
9905
10945
  } else
9906
10946
  #endif
9907
10947
  {
9908
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
10948
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
9909
10949
  }
9910
10950
  } else {
9911
10951
  GGML_ASSERT(false);
@@ -9917,7 +10957,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9917
10957
  {
9918
10958
  node->n_tasks = n_threads;
9919
10959
  } break;
9920
- case GGML_OP_CPY:
9921
10960
  case GGML_OP_CONT:
9922
10961
  case GGML_OP_RESHAPE:
9923
10962
  case GGML_OP_VIEW:
@@ -11080,16 +12119,16 @@ enum ggml_opt_result ggml_opt(
11080
12119
  ////////////////////////////////////////////////////////////////////////////////
11081
12120
 
11082
12121
  size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
11083
- assert(k % QK == 0);
11084
- const int nb = k / QK;
12122
+ assert(k % QK4_0 == 0);
12123
+ const int nb = k / QK4_0;
11085
12124
 
11086
12125
  for (int j = 0; j < n; j += k) {
11087
- block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
12126
+ block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
11088
12127
 
11089
12128
  quantize_row_q4_0_reference(src + j, y, k);
11090
12129
 
11091
12130
  for (int i = 0; i < nb; i++) {
11092
- for (int l = 0; l < QK; l += 2) {
12131
+ for (int l = 0; l < QK4_0; l += 2) {
11093
12132
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11094
12133
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11095
12134
 
@@ -11099,20 +12138,67 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
11099
12138
  }
11100
12139
  }
11101
12140
 
11102
- return (n/QK*sizeof(block_q4_0));
12141
+ return (n/QK4_0*sizeof(block_q4_0));
11103
12142
  }
11104
12143
 
11105
12144
  size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
11106
- assert(k % QK == 0);
11107
- const int nb = k / QK;
12145
+ assert(k % QK4_1 == 0);
12146
+ const int nb = k / QK4_1;
11108
12147
 
11109
12148
  for (int j = 0; j < n; j += k) {
11110
- block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
12149
+ block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
11111
12150
 
11112
12151
  quantize_row_q4_1_reference(src + j, y, k);
11113
12152
 
11114
12153
  for (int i = 0; i < nb; i++) {
11115
- for (int l = 0; l < QK; l += 2) {
12154
+ for (int l = 0; l < QK4_1; l += 2) {
12155
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12156
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12157
+
12158
+ hist[vi0]++;
12159
+ hist[vi1]++;
12160
+ }
12161
+ }
12162
+ }
12163
+
12164
+ return (n/QK4_1*sizeof(block_q4_1));
12165
+ }
12166
+
12167
+ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist) {
12168
+ assert(k % QK4_2 == 0);
12169
+ const int nb = k / QK4_2;
12170
+
12171
+ for (int j = 0; j < n; j += k) {
12172
+ block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
12173
+
12174
+ //quantize_row_q4_2_reference(src + j, y, k);
12175
+ quantize_row_q4_2_rmse(src + j, y, k);
12176
+
12177
+ for (int i = 0; i < nb; i++) {
12178
+ for (int l = 0; l < QK4_2; l += 2) {
12179
+ const uint8_t vi0 = y[i].qs[l/2] & 0xF;
12180
+ const uint8_t vi1 = y[i].qs[l/2] >> 4;
12181
+
12182
+ hist[vi0]++;
12183
+ hist[vi1]++;
12184
+ }
12185
+ }
12186
+ }
12187
+
12188
+ return (n/QK4_2*sizeof(block_q4_2));
12189
+ }
12190
+
12191
+ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
12192
+ assert(k % QK4_3 == 0);
12193
+ const int nb = k / QK4_3;
12194
+
12195
+ for (int j = 0; j < n; j += k) {
12196
+ block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
12197
+
12198
+ quantize_row_q4_3_reference(src + j, y, k);
12199
+
12200
+ for (int i = 0; i < nb; i++) {
12201
+ for (int l = 0; l < QK4_3; l += 2) {
11116
12202
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
11117
12203
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
11118
12204
 
@@ -11122,7 +12208,40 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
11122
12208
  }
11123
12209
  }
11124
12210
 
11125
- return (n/QK*sizeof(block_q4_1));
12211
+ return (n/QK4_3*sizeof(block_q4_3));
12212
+ }
12213
+
12214
+ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
12215
+ size_t result = 0;
12216
+ switch (type) {
12217
+ case GGML_TYPE_Q4_0:
12218
+ {
12219
+ GGML_ASSERT(start % QK4_0 == 0);
12220
+ block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
12221
+ result = ggml_quantize_q4_0(src + start, block, n, n, hist);
12222
+ } break;
12223
+ case GGML_TYPE_Q4_1:
12224
+ {
12225
+ GGML_ASSERT(start % QK4_1 == 0);
12226
+ block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
12227
+ result = ggml_quantize_q4_1(src + start, block, n, n, hist);
12228
+ } break;
12229
+ case GGML_TYPE_Q4_2:
12230
+ {
12231
+ GGML_ASSERT(start % QK4_2 == 0);
12232
+ block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
12233
+ result = ggml_quantize_q4_2(src + start, block, n, n, hist);
12234
+ } break;
12235
+ case GGML_TYPE_Q4_3:
12236
+ {
12237
+ GGML_ASSERT(start % QK4_3 == 0);
12238
+ block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
12239
+ result = ggml_quantize_q4_3(src + start, block, n, n, hist);
12240
+ } break;
12241
+ default:
12242
+ assert(false);
12243
+ }
12244
+ return result;
11126
12245
  }
11127
12246
 
11128
12247
  ////////////////////////////////////////////////////////////////////////////////
@@ -11151,6 +12270,22 @@ int ggml_cpu_has_avx512(void) {
11151
12270
  #endif
11152
12271
  }
11153
12272
 
12273
+ int ggml_cpu_has_avx512_vbmi(void) {
12274
+ #if defined(__AVX512VBMI__)
12275
+ return 1;
12276
+ #else
12277
+ return 0;
12278
+ #endif
12279
+ }
12280
+
12281
+ int ggml_cpu_has_avx512_vnni(void) {
12282
+ #if defined(__AVX512VNNI__)
12283
+ return 1;
12284
+ #else
12285
+ return 0;
12286
+ #endif
12287
+ }
12288
+
11154
12289
  int ggml_cpu_has_fma(void) {
11155
12290
  #if defined(__FMA__)
11156
12291
  return 1;
@@ -11200,7 +12335,15 @@ int ggml_cpu_has_wasm_simd(void) {
11200
12335
  }
11201
12336
 
11202
12337
  int ggml_cpu_has_blas(void) {
11203
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12338
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
12339
+ return 1;
12340
+ #else
12341
+ return 0;
12342
+ #endif
12343
+ }
12344
+
12345
+ int ggml_cpu_has_cublas(void) {
12346
+ #if defined(GGML_USE_CUBLAS)
11204
12347
  return 1;
11205
12348
  #else
11206
12349
  return 0;