llama_cpp 0.0.3 → 0.0.5

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,4 +1,4 @@
1
- // Defines CLOCK_MONOTONIC and asprintf on Linux
1
+ // Defines CLOCK_MONOTONIC on Linux
2
2
  #define _GNU_SOURCE
3
3
 
4
4
  #include "ggml.h"
@@ -26,14 +26,9 @@
26
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
27
27
  #endif
28
28
 
29
- #if defined _MSC_VER || defined(__MINGW32__)
29
+ #if defined(_WIN32)
30
30
 
31
- #if !defined(__MINGW32__)
32
- #include <Windows.h>
33
- #else
34
- // ref: https://github.com/ggerganov/whisper.cpp/issues/168
35
31
  #include <windows.h>
36
- #endif
37
32
 
38
33
  typedef volatile LONG atomic_int;
39
34
  typedef atomic_int atomic_bool;
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
55
50
 
56
51
  typedef DWORD thread_ret_t;
57
52
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
53
+ (void) unused;
58
54
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
59
55
  if (handle == NULL)
60
56
  {
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
66
62
  }
67
63
 
68
64
  static int pthread_join(pthread_t thread, void* unused) {
65
+ (void) unused;
69
66
  return (int) WaitForSingleObject(thread, INFINITE);
70
67
  }
71
68
 
@@ -97,17 +94,6 @@ typedef void* thread_ret_t;
97
94
  #define static_assert(cond, msg) _Static_assert(cond, msg)
98
95
  #endif
99
96
 
100
- #define GGML_MLOCK_SUPPORT 0
101
-
102
- #ifdef __has_include
103
- #if __has_include(<sys/mman.h>)
104
- #undef GGML_MLOCK_SUPPORT
105
- #define GGML_MLOCK_SUPPORT 1
106
- #include <sys/mman.h>
107
- #endif
108
- #endif
109
-
110
-
111
97
  /*#define GGML_PERF*/
112
98
  #define GGML_DEBUG 0
113
99
  #define GGML_GELU_FP16
@@ -128,6 +114,23 @@ typedef void* thread_ret_t;
128
114
  #define GGML_MEM_ALIGN 16
129
115
  #endif
130
116
 
117
+ #if defined(_MSC_VER) || defined(__MINGW32__)
118
+ #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
+ #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
+ #else
121
+ inline static void* ggml_aligned_malloc(size_t size) {
122
+ void* aligned_memory = NULL;
123
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
124
+ if (result != 0) {
125
+ // Handle allocation failure
126
+ return NULL;
127
+ }
128
+ return aligned_memory;
129
+ }
130
+ #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
131
+ #define GGML_ALIGNED_FREE(ptr) free(ptr)
132
+ #endif
133
+
131
134
  #define UNUSED(x) (void)(x)
132
135
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
133
136
 
@@ -242,12 +245,12 @@ static inline float fp32_from_bits(uint32_t w) {
242
245
  }
243
246
 
244
247
  static inline uint32_t fp32_to_bits(float f) {
245
- union {
246
- float as_value;
247
- uint32_t as_bits;
248
- } fp32;
249
- fp32.as_value = f;
250
- return fp32.as_bits;
248
+ union {
249
+ float as_value;
250
+ uint32_t as_bits;
251
+ } fp32;
252
+ fp32.as_value = f;
253
+ return fp32.as_bits;
251
254
  }
252
255
 
253
256
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -424,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
424
427
  // quantization
425
428
  //
426
429
 
427
- #define QK 32
428
-
429
430
  // AVX routines provided by GH user Const-me
430
431
  // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
431
432
  #if __AVX2__ || __AVX512F__
@@ -497,37 +498,113 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
498
  }
498
499
  #endif
499
500
 
500
- // method 5
501
- // blocks of QK elements
502
- // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
501
+ #if __ARM_NEON
502
+
503
+ #if !defined(__aarch64__)
504
+
505
+ inline static uint16_t vaddvq_u8(uint8x16_t v) {
506
+ return
507
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
508
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
509
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
510
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
511
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
512
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
513
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
514
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
515
+ }
516
+
517
+ inline static int32_t vaddvq_s16(int16x8_t v) {
518
+ return
519
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
520
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
521
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
522
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
523
+ }
524
+
525
+ inline static uint32_t vaddvq_u16(uint16x8_t v) {
526
+ return
527
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
528
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
529
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
530
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
531
+ }
532
+
533
+ inline static int32_t vaddvq_s32(int32x4_t v) {
534
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
535
+ }
536
+
537
+ inline static float vaddvq_f32(float32x4_t v) {
538
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
539
+ }
540
+
541
+ float vminvq_f32(float32x4_t v) {
542
+ return
543
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
544
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
545
+ }
546
+
547
+ float vmaxvq_f32(float32x4_t v) {
548
+ return
549
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
550
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
551
+ }
552
+
553
+ int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
554
+ return vget_low_s8(vcombine_s8(a, b));
555
+ }
556
+
557
+ int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
558
+ return vget_high_s8(vcombine_s8(a, b));
559
+ }
560
+
561
+ uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
562
+ return vget_low_u8(vcombine_u8(a, b));
563
+ }
564
+
565
+ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
566
+ return vget_high_u8(vcombine_u8(a, b));
567
+ }
568
+
569
+ #endif
570
+ #endif
571
+
572
+
573
+ #define QK4_0 32
503
574
  typedef struct {
504
- float d; // delta
505
- uint8_t qs[QK / 2]; // nibbles / quants
575
+ float d; // delta
576
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
506
577
  } block_q4_0;
507
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
578
+ static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
508
579
 
509
- // method 4
510
- // blocks of QK elements
511
- // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
580
+ #define QK4_1 32
512
581
  typedef struct {
513
- float d;
514
- float m;
515
- uint8_t qs[QK / 2]; // nibbles / quants
582
+ float d; // delta
583
+ float m; // min
584
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
516
585
  } block_q4_1;
517
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
586
+ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
587
+
588
+ #define QK8_0 32
589
+ typedef struct {
590
+ float d; // delta
591
+ int8_t qs[QK8_0]; // quants
592
+ } block_q8_0;
593
+ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
594
+
518
595
 
519
596
  // reference implementation for deterministic creation of model files
520
597
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
521
- assert(k % QK == 0);
522
- const int nb = k / QK;
598
+ assert(k % QK4_0 == 0);
599
+ const int nb = k / QK4_0;
523
600
 
524
- uint8_t pp[QK/2];
601
+ uint8_t pp[QK4_0/2];
525
602
 
526
603
  for (int i = 0; i < nb; i++) {
527
604
  float amax = 0.0f; // absolute max
528
605
 
529
- for (int l = 0; l < QK; l++) {
530
- const float v = x[i*QK + l];
606
+ for (int l = 0; l < QK4_0; l++) {
607
+ const float v = x[i*QK4_0 + l];
531
608
  amax = MAX(amax, fabsf(v));
532
609
  }
533
610
 
@@ -536,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
536
613
 
537
614
  y[i].d = d;
538
615
 
539
- for (int l = 0; l < QK; l += 2) {
540
- const float v0 = x[i*QK + l + 0]*id;
541
- const float v1 = x[i*QK + l + 1]*id;
616
+ for (int l = 0; l < QK4_0; l += 2) {
617
+ const float v0 = x[i*QK4_0 + l + 0]*id;
618
+ const float v1 = x[i*QK4_0 + l + 1]*id;
542
619
 
543
620
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
544
621
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -554,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
554
631
  }
555
632
 
556
633
  static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
557
- assert(k % QK == 0);
558
- const int nb = k / QK;
634
+ assert(k % QK4_0 == 0);
635
+ const int nb = k / QK4_0;
559
636
 
560
637
  block_q4_0 * restrict y = vy;
561
638
 
@@ -610,10 +687,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
610
687
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
611
688
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
612
689
 
613
- // absolute max
614
- const float amax = MAX(
615
- MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
616
- MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
690
+ const float amax = vmaxvq_f32(amaxv[0]);
617
691
 
618
692
  const float d = amax / ((1 << 3) - 1);
619
693
  const float id = d ? 1.0f/d : 0.0f;
@@ -808,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
808
882
  }
809
883
 
810
884
  static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
811
- assert(k % QK == 0);
812
- const int nb = k / QK;
885
+ assert(k % QK4_1 == 0);
886
+ const int nb = k / QK4_1;
813
887
 
814
888
  block_q4_1 * restrict y = vy;
815
889
 
816
- uint8_t pp[QK/2];
890
+ uint8_t pp[QK4_1/2];
817
891
 
818
892
  for (int i = 0; i < nb; i++) {
819
893
  float min = FLT_MAX;
820
894
  float max = -FLT_MAX;
821
895
 
822
- for (int l = 0; l < QK; l++) {
823
- const float v = x[i*QK + l];
896
+ for (int l = 0; l < QK4_1; l++) {
897
+ const float v = x[i*QK4_1 + l];
824
898
  if (v < min) min = v;
825
899
  if (v > max) max = v;
826
900
  }
@@ -831,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
831
905
  y[i].d = d;
832
906
  y[i].m = min;
833
907
 
834
- for (int l = 0; l < QK; l += 2) {
835
- const float v0 = (x[i*QK + l + 0] - min)*id;
836
- const float v1 = (x[i*QK + l + 1] - min)*id;
908
+ for (int l = 0; l < QK4_1; l += 2) {
909
+ const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
910
+ const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
837
911
 
838
912
  const uint8_t vi0 = roundf(v0);
839
913
  const uint8_t vi1 = roundf(v1);
@@ -849,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
849
923
  }
850
924
 
851
925
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
852
- assert(k % QK == 0);
926
+ assert(k % QK4_1 == 0);
853
927
 
854
- const int nb = k / QK;
928
+ const int nb = k / QK4_1;
855
929
 
856
930
  block_q4_1 * restrict y = vy;
857
931
 
@@ -935,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
935
1009
  float32x4_t minv[8];
936
1010
  float32x4_t maxv[8];
937
1011
 
938
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1012
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
939
1013
 
940
1014
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
941
1015
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -958,7 +1032,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
958
1032
 
959
1033
  for (int l = 0; l < 8; l++) {
960
1034
  const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
961
- const int32x4_t vi = vcvtq_s32_f32(v);
1035
+ const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
1036
+ const int32x4_t vi = vcvtq_s32_f32(vf);
962
1037
 
963
1038
  y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
964
1039
  y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
@@ -970,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
970
1045
  #endif
971
1046
  }
972
1047
 
1048
+ // reference implementation for deterministic creation of model files
1049
+ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
+ assert(k % QK8_0 == 0);
1051
+ const int nb = k / QK8_0;
1052
+
1053
+ for (int i = 0; i < nb; i++) {
1054
+ float amax = 0.0f; // absolute max
1055
+
1056
+ for (int l = 0; l < QK8_0; l++) {
1057
+ const float v = x[i*QK8_0 + l];
1058
+ amax = MAX(amax, fabsf(v));
1059
+ }
1060
+
1061
+ const float d = amax / ((1 << 7) - 1);
1062
+ const float id = d ? 1.0f/d : 0.0f;
1063
+
1064
+ y[i].d = d;
1065
+
1066
+ for (int l = 0; l < QK8_0; ++l) {
1067
+ const float v = x[i*QK8_0 + l]*id;
1068
+ y[i].qs[l] = roundf(v);
1069
+ }
1070
+ }
1071
+ }
1072
+
1073
+ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1074
+ assert(k % QK8_0 == 0);
1075
+ const int nb = k / QK8_0;
1076
+
1077
+ block_q8_0 * restrict y = vy;
1078
+
1079
+ #if defined(__ARM_NEON)
1080
+ for (int i = 0; i < nb; i++) {
1081
+ float32x4_t srcv [8];
1082
+ float32x4_t asrcv[8];
1083
+ float32x4_t amaxv[8];
1084
+
1085
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1086
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1087
+
1088
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1089
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1090
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1091
+
1092
+ const float amax = vmaxvq_f32(amaxv[0]);
1093
+
1094
+ const float d = amax / ((1 << 7) - 1);
1095
+ const float id = d ? 1.0f/d : 0.0f;
1096
+
1097
+ y[i].d = d;
1098
+
1099
+ for (int l = 0; l < 8; l++) {
1100
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1101
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1102
+
1103
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1104
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1105
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1106
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1107
+ }
1108
+ }
1109
+ #elif defined(__AVX2__) || defined(__AVX__)
1110
+ for (int i = 0; i < nb; i++) {
1111
+ // Load elements into 4 AVX vectors
1112
+ __m256 v0 = _mm256_loadu_ps( x );
1113
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1114
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1115
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1116
+ x += 32;
1117
+
1118
+ // Compute max(abs(e)) for the block
1119
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1120
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1121
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1122
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1123
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1124
+
1125
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1126
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1127
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1128
+ const float maxScalar = _mm_cvtss_f32( max4 );
1129
+
1130
+ // Quantize these floats
1131
+ const float d = maxScalar / 127.f;
1132
+ y[i].d = d;
1133
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1134
+ const __m256 mul = _mm256_set1_ps( id );
1135
+
1136
+ // Apply the multiplier
1137
+ v0 = _mm256_mul_ps( v0, mul );
1138
+ v1 = _mm256_mul_ps( v1, mul );
1139
+ v2 = _mm256_mul_ps( v2, mul );
1140
+ v3 = _mm256_mul_ps( v3, mul );
1141
+
1142
+ // Round to nearest integer
1143
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1144
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1145
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1146
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1147
+
1148
+ // Convert floats to integers
1149
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1150
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1151
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1152
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1153
+
1154
+ #if defined(__AVX2__)
1155
+ // Convert int32 to int16
1156
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1157
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1158
+ // Convert int16 to int8
1159
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1160
+
1161
+ // We got our precious signed bytes, but the order is now wrong
1162
+ // These AVX2 pack instructions process 16-byte pieces independently
1163
+ // The following instruction is fixing the order
1164
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1165
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1166
+
1167
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1168
+ #else
1169
+ // Since we don't have in AVX some necessary functions,
1170
+ // we split the registers in half and call AVX2 analogs from SSE
1171
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1172
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1173
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1174
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1175
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1176
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1177
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1178
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1179
+
1180
+ // Convert int32 to int16
1181
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1182
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1183
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1184
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1185
+ // Convert int16 to int8
1186
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1187
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1188
+
1189
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1190
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1191
+ #endif
1192
+ }
1193
+ #else
1194
+ // scalar
1195
+ quantize_row_q8_0_reference(x, y, k);
1196
+ #endif
1197
+ }
1198
+
973
1199
  static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
974
- assert(k % QK == 0);
975
- const int nb = k / QK;
1200
+ assert(k % QK4_0 == 0);
1201
+ const int nb = k / QK4_0;
976
1202
 
977
1203
  const block_q4_0 * restrict x = vx;
978
1204
 
@@ -983,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
983
1209
 
984
1210
  const uint8_t * restrict pp = x[i].qs;
985
1211
 
986
- for (int l = 0; l < QK; l += 32) {
1212
+ for (int l = 0; l < QK4_0; l += 32) {
987
1213
  // Load 32x4-bit integers into 32x8-bit integers
988
1214
  __m256i vx8 = bytesFromNibbles(pp+l/2);
989
1215
 
@@ -1005,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1005
1231
  // Scale and store
1006
1232
  for (int j = 0; j < 4; j++) {
1007
1233
  const __m256 result = _mm256_mul_ps(vf[j], d_v);
1008
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1234
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1009
1235
  }
1010
1236
  }
1011
1237
  }
@@ -1015,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1015
1241
 
1016
1242
  const uint8_t * restrict pp = x[i].qs;
1017
1243
 
1018
- for (int l = 0; l < QK; l += 16) {
1244
+ for (int l = 0; l < QK4_0; l += 16) {
1019
1245
  // Load 16x4-bit integers into 8x8-bit integers
1020
1246
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1021
1247
 
@@ -1054,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1054
1280
  const float32x4_t r3 = vmulq_f32(vf_3, vd);
1055
1281
 
1056
1282
  // Store
1057
- vst1q_f32(y + i*QK + l + 0, r0);
1058
- vst1q_f32(y + i*QK + l + 4, r1);
1059
- vst1q_f32(y + i*QK + l + 8, r2);
1060
- vst1q_f32(y + i*QK + l + 12, r3);
1283
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1284
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1285
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1286
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1061
1287
  }
1062
1288
  }
1063
1289
  #else
@@ -1067,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1067
1293
 
1068
1294
  const uint8_t * restrict pp = x[i].qs;
1069
1295
 
1070
- for (int l = 0; l < QK; l += 2) {
1296
+ for (int l = 0; l < QK4_0; l += 2) {
1071
1297
  const uint8_t vi = pp[l/2];
1072
1298
 
1073
1299
  const int8_t vi0 = vi & 0xf;
@@ -1078,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1078
1304
 
1079
1305
  //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
1080
1306
 
1081
- y[i*QK + l + 0] = v0;
1082
- y[i*QK + l + 1] = v1;
1307
+ y[i*QK4_0 + l + 0] = v0;
1308
+ y[i*QK4_0 + l + 1] = v1;
1083
1309
 
1084
- assert(!isnan(y[i*QK + l + 0]));
1085
- assert(!isnan(y[i*QK + l + 1]));
1310
+ assert(!isnan(y[i*QK4_0 + l + 0]));
1311
+ assert(!isnan(y[i*QK4_0 + l + 1]));
1086
1312
  }
1087
1313
  }
1088
1314
  #endif
1089
1315
  }
1090
1316
 
1091
1317
  static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
1092
- assert(k % QK == 0);
1093
- const int nb = k / QK;
1318
+ assert(k % QK4_1 == 0);
1319
+ const int nb = k / QK4_1;
1094
1320
 
1095
1321
  const block_q4_1 * restrict x = vx;
1096
1322
 
@@ -1101,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1101
1327
 
1102
1328
  const uint8_t * restrict pp = x[i].qs;
1103
1329
 
1104
- for (int l = 0; l < QK; l += 32) {
1330
+ for (int l = 0; l < QK4_1; l += 32) {
1105
1331
  // Load 32x4-bit integers into 32x8-bit integers
1106
1332
  __m256i vx8 = bytesFromNibbles(pp+l/2);
1107
1333
 
@@ -1120,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1120
1346
  // Scale, add m and store
1121
1347
  for (int j = 0; j < 4; j++) {
1122
1348
  const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
1123
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1349
+ _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
1124
1350
  }
1125
1351
  }
1126
1352
  }
@@ -1131,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1131
1357
 
1132
1358
  const uint8_t * restrict pp = x[i].qs;
1133
1359
 
1134
- for (int l = 0; l < QK; l += 16) {
1360
+ for (int l = 0; l < QK4_1; l += 16) {
1135
1361
  // Load 16x4-bit integers into 8x8-bit integers
1136
1362
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1137
1363
 
@@ -1162,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1162
1388
  const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1163
1389
 
1164
1390
  // Store
1165
- vst1q_f32(y + i*QK + l + 0, r0);
1166
- vst1q_f32(y + i*QK + l + 4, r1);
1167
- vst1q_f32(y + i*QK + l + 8, r2);
1168
- vst1q_f32(y + i*QK + l + 12, r3);
1391
+ vst1q_f32(y + i*QK4_1 + l + 0, r0);
1392
+ vst1q_f32(y + i*QK4_1 + l + 4, r1);
1393
+ vst1q_f32(y + i*QK4_1 + l + 8, r2);
1394
+ vst1q_f32(y + i*QK4_1 + l + 12, r3);
1169
1395
  }
1170
1396
  }
1171
1397
  #else
@@ -1175,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1175
1401
 
1176
1402
  const uint8_t * restrict pp = x[i].qs;
1177
1403
 
1178
- for (int l = 0; l < QK; l += 2) {
1404
+ for (int l = 0; l < QK4_1; l += 2) {
1179
1405
  const uint8_t vi = pp[l/2];
1180
1406
 
1181
1407
  const int8_t vi0 = vi & 0xf;
@@ -1184,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1184
1410
  const float v0 = vi0*d + m;
1185
1411
  const float v1 = vi1*d + m;
1186
1412
 
1187
- y[i*QK + l + 0] = v0;
1188
- y[i*QK + l + 1] = v1;
1413
+ y[i*QK4_1 + l + 0] = v0;
1414
+ y[i*QK4_1 + l + 1] = v1;
1189
1415
 
1190
- assert(!isnan(y[i*QK + l + 0]));
1191
- assert(!isnan(y[i*QK + l + 1]));
1416
+ assert(!isnan(y[i*QK4_1 + l + 0]));
1417
+ assert(!isnan(y[i*QK4_1 + l + 1]));
1192
1418
  }
1193
1419
  }
1194
1420
  #endif
1195
1421
  }
1196
1422
 
1423
+ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1424
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
+
1426
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
+ [GGML_TYPE_Q4_0] = {
1428
+ .dequantize_row_q = dequantize_row_q4_0,
1429
+ .quantize_row_q = quantize_row_q4_0,
1430
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1431
+ .quantize_row_q_dot = quantize_row_q8_0,
1432
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1433
+ },
1434
+ [GGML_TYPE_Q4_1] = {
1435
+ .dequantize_row_q = dequantize_row_q4_1,
1436
+ .quantize_row_q = quantize_row_q4_1,
1437
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
+ .quantize_row_q_dot = quantize_row_q4_1,
1439
+ .vec_dot_q = ggml_vec_dot_q4_1,
1440
+ },
1441
+ // TODO: GGML_TYPE_Q8_0
1442
+ };
1443
+
1444
+ // For internal test use
1445
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1446
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
1447
+ return quantize_fns[i];
1448
+ }
1449
+
1450
+
1197
1451
  //
1198
1452
  // simd mappings
1199
1453
  //
@@ -1226,15 +1480,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1226
1480
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1227
1481
  #define GGML_F32x4_ADD vaddq_f32
1228
1482
  #define GGML_F32x4_MUL vmulq_f32
1229
- #if defined(__ARM_FEATURE_QRDMX)
1230
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1231
- #else
1232
- #define GGML_F32x4_REDUCE_ONE(x) \
1233
- (vgetq_lane_f32(x, 0) + \
1234
- vgetq_lane_f32(x, 1) + \
1235
- vgetq_lane_f32(x, 2) + \
1236
- vgetq_lane_f32(x, 3))
1237
- #endif
1483
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1238
1484
  #define GGML_F32x4_REDUCE(res, x) \
1239
1485
  { \
1240
1486
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1758,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1758
2004
  *s = sumf;
1759
2005
  }
1760
2006
 
1761
- #if __AVX512F__ && QK == 32
1762
- static inline __m512 dot_q4_0_oneblock_avx512(
2007
+ #if __AVX512F__ && QK4_0 == 32
2008
+ static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
+ // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
+ // | :. =_ () [] <> () Zz Yy|
2014
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
+ // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
+ //
2020
+ // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
+ // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
+ // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
+ // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
+ #ifdef __AVX512VBMI__
2025
+ const __m512i byte_perm = _mm512_set_epi8(
2026
+ 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
+ 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
+ 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
+ 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
+ );
2031
+ const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
+ // After applying VPERMB, `permuted` looks like this:
2033
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
+ // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
+ // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
+ #else
2043
+ const __m512i word_perm = _mm512_set_epi16(
2044
+ 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
+ 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
+ );
2047
+ const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
+ // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
+ // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
+ // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
+ #endif
2052
+
2053
+ // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
+ const __mmask32 shift_mask = 0xaaaaaaaa;
2055
+ const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
+ // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
+ // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
+ // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
+
2067
+ // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
+ const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
+ return _mm512_and_si512( low_nibble_mask, shifted );
2070
+ // The final result looks like this:
2071
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
+ // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
+ // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
+ }
2081
+
2082
+ static inline __m512 dot_q4_0_twoblocks_avx512(
1763
2083
  __m512 acc,
1764
2084
  const block_q4_0 * restrict x,
1765
2085
  const block_q4_0 * restrict y,
1766
2086
  int i
1767
2087
  ) {
1768
- // Compute combined scale for the block
1769
- __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
1770
-
1771
- __m256i bx = bytesFromNibbles( x[i].qs );
1772
- __m256i by = bytesFromNibbles( y[i].qs );
1773
-
1774
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1775
- const __m256i off = _mm256_set1_epi8( 8 );
1776
- bx = _mm256_sub_epi8( bx, off );
1777
- by = _mm256_sub_epi8( by, off );
1778
-
1779
- // Sign-extend 16 signed bytes into int16_t
1780
- __m512i x32 = _mm512_cvtepi8_epi16( bx );
1781
- __m512i y32 = _mm512_cvtepi8_epi16( by );
1782
- // Compute products of int16_t integers, add pairwise
1783
- __m512i i64 = _mm512_madd_epi16( x32, y32 );
2088
+ // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
+ // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
+ // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
+ // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
+ const __mmask8 load_mask = 0x1f;
2093
+ const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
+ const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
+
2096
+ // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
+ // blocks_0_float
2101
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
+ // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
+ // blocks_1_float
2105
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
+ // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
+ const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
+ const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
+ // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
+ // random data, which might very well underflow. At least on Intel, this leads
2112
+ // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
+ // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
+ // (and ggml can't assume that you do)...
2115
+ const __mmask16 scale_mul_mask = 0x21;
2116
+ #ifdef __clang__
2117
+ // ...however, clang decides to optimize the multiplication mask away:
2118
+ // https://godbolt.org/z/P8PqdsfvW
2119
+ // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
+ __m512i scales;
2121
+ __asm__(
2122
+ "vmulps %1, %2, %0%{%3%}"
2123
+ : "=v" ( scales )
2124
+ : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
+ );
2126
+ #else
2127
+ const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
+ #endif
2129
+ const __m512i scale_perm = _mm512_set_epi32(
2130
+ 5, 5, 5, 5, 5, 5, 5, 5,
2131
+ 0, 0, 0, 0, 0, 0, 0, 0
2132
+ );
2133
+ const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
+ // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
+ // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
+
2141
+ const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
+ const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
+
2144
+ // Now we want to compute dot products of 4-element byte vectors and store them in
2145
+ // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
+ // +----+----+----+----+
2147
+ // ... | 03 | 02 | 01 | 00 |
2148
+ // +----+----+----+----+
2149
+ // bytes_0
2150
+ // +----+----+----+----+
2151
+ // ... | D | C | B | A |
2152
+ // +----+----+----+----+
2153
+ // bytes_1
2154
+ // +----+----+----+----+
2155
+ // ... | H | G | F | E |
2156
+ // +----+----+----+----+
2157
+ // final_res_int
2158
+ // +----+----+----+----+
2159
+ // ... | A*E+B*F+C*G+D*H |
2160
+ // +----+----+----+----+
2161
+ const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
+ const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
+
2164
+ #ifdef __AVX512VNNI__
2165
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
+ // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
+ // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
+ // which means we only need 2 instructions.
2170
+ const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
+ const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
+ const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
+ const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
+ #else
2175
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
+ const __m512i one = _mm512_set1_epi16( 1 );
2180
+ const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
+ const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
+ const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
+ const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
+ #endif
1784
2185
 
1785
- // Convert int32_t to float
1786
- __m512 p = _mm512_cvtepi32_ps( i64 );
1787
- // Apply the scale, and accumulate
1788
- return _mm512_fmadd_ps( d, p, acc );
2186
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
+ const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
+ return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
1789
2189
  }
1790
2190
  #endif
1791
2191
 
@@ -1826,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
1826
2226
  }
1827
2227
 
1828
2228
  static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1829
- const int nb = n / QK;
2229
+ const int nb = n / QK4_0;
1830
2230
 
1831
- assert(n % QK == 0);
2231
+ assert(n % QK4_0 == 0);
1832
2232
  assert(nb % 2 == 0);
1833
2233
 
1834
2234
  const block_q4_0 * restrict x = vx;
@@ -1857,55 +2257,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1857
2257
  // 4-bit -> 8-bit
1858
2258
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1859
2259
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1860
-
1861
2260
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1862
2261
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1863
2262
 
1864
2263
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1865
2264
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1866
-
1867
2265
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1868
2266
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1869
2267
 
1870
2268
  // sub 8
1871
2269
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1872
2270
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1873
-
1874
2271
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1875
2272
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1876
2273
 
1877
2274
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1878
2275
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1879
-
1880
2276
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1881
2277
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1882
2278
 
1883
2279
  #if defined(__ARM_FEATURE_DOTPROD)
1884
- // dot product into int16x8_t
2280
+ // dot product into int32x4_t
1885
2281
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1886
2282
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1887
2283
 
1888
2284
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1889
2285
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1890
2286
 
1891
- // scalar
1892
- #if defined(__ARM_FEATURE_QRDMX)
1893
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1894
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
2287
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2288
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
1895
2289
  #else
1896
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1897
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1898
- #endif
1899
- #else
1900
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2290
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1901
2291
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1902
-
1903
2292
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1904
2293
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1905
2294
 
1906
2295
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1907
2296
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1908
-
1909
2297
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1910
2298
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1911
2299
 
@@ -1918,14 +2306,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1918
2306
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1919
2307
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1920
2308
 
1921
- // scalar
1922
- #if defined(__ARM_FEATURE_QRDMX)
1923
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1924
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1925
- #else
1926
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1927
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1928
- #endif
2309
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2310
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
1929
2311
  #endif
1930
2312
  }
1931
2313
 
@@ -1935,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1935
2317
  __m512 acc0 = _mm512_setzero_ps();
1936
2318
  __m512 acc1 = _mm512_setzero_ps();
1937
2319
 
1938
- const int superblock_size = 8;
2320
+ const int superblock_size = 16;
2321
+
1939
2322
  const int superblock_count = nb / superblock_size;
1940
2323
 
1941
2324
  for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1942
2325
  int i = superblock_ix * superblock_size;
1943
2326
 
1944
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
1945
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
1946
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
1947
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
1948
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
1949
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
1950
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
1951
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
2327
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
1952
2335
  }
1953
2336
 
1954
2337
  // Remainders
1955
- for (int i = superblock_count * superblock_size; i < nb; ++i) {
1956
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
2338
+ for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
1957
2340
  }
1958
2341
 
1959
2342
  // Horizontal sum of all lanes of the accumulator
@@ -1962,7 +2345,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1962
2345
  // Initialize accumulator with zeros
1963
2346
  __m256 acc = _mm256_setzero_ps();
1964
2347
 
1965
- /* Prepare the constants we will need during execution */
2348
+ /* Prepare the constants we will need during execution */
1966
2349
  const __m256i lowMask = _mm256_set1_epi8( 0xF );
1967
2350
  const __m256i offset_8 = _mm256_set1_epi16( 8 );
1968
2351
 
@@ -1972,61 +2355,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2355
 
1973
2356
  // Main loop
1974
2357
  for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1975
-
1976
- // This loop will be unrolled by the compiler
2358
+ // This loop will be unrolled by the compiler
1977
2359
  for (int u=0;u<UNROLL_COUNT;u++) {
1978
- /* Compute combined scale for the block */
1979
- const __m256 scale = _mm256_mul_ps(
1980
- _mm256_broadcast_ss( &x[i+u].d ),
1981
- _mm256_broadcast_ss( &y[i+u].d ) );
1982
-
1983
- /* get input from x
1984
- Input: 32 Nibbles (16 bytes) at *x[i+u]
1985
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1986
-
1987
- /* Load 16 bytes from memory */
1988
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1989
- /* Expand bytes into uint16_t values */
1990
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2360
+ /* Compute combined scale for the block */
2361
+ const __m256 scale = _mm256_mul_ps(
2362
+ _mm256_broadcast_ss( &x[i+u].d ),
2363
+ _mm256_broadcast_ss( &y[i+u].d ) );
2364
+
2365
+ /* get input from x
2366
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
2367
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2368
+
2369
+ /* Load 16 bytes from memory */
2370
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2371
+ /* Expand bytes into uint16_t values */
2372
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1991
2373
  /* Unpack values into individual bytes */
1992
2374
  __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1993
2375
  const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1994
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2376
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1995
2377
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1996
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1997
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2378
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2379
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1998
2380
 
1999
- /* get input from y
2000
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2001
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2381
+ /* get input from y
2382
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2383
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2002
2384
 
2003
- /* Load 16 bytes from memory */
2004
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2005
- /* Expand bytes into uint16_t values */
2006
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2385
+ /* Load 16 bytes from memory */
2386
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2387
+ /* Expand bytes into uint16_t values */
2388
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2007
2389
  /* Unpack values into individual bytes */
2008
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2009
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2010
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2390
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2391
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2392
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2011
2393
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2012
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2013
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2394
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2395
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2014
2396
 
2015
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2016
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2017
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2397
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2398
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2399
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2018
2400
 
2019
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2020
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2401
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2402
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2021
2403
 
2022
- /* Convert to vectore of 8 int32_t to 8 floats */
2023
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2404
+ /* Convert to vectore of 8 int32_t to 8 floats */
2405
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2024
2406
 
2025
- /* Multiply q with scale and accumulate */
2026
- acc = _mm256_fmadd_ps( scale, q, acc );
2407
+ /* Multiply q with scale and accumulate */
2408
+ acc = _mm256_fmadd_ps( scale, q, acc );
2027
2409
  }
2028
-
2029
- }
2410
+ }
2030
2411
 
2031
2412
  // Return horizontal sum of the acc vector
2032
2413
  __m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2087,18 +2468,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2087
2468
  float sum1 = 0.0f;
2088
2469
 
2089
2470
  for (int i = 0; i < nb; i += 2) {
2090
- const block_q4_0 * restrict x0 = &px[i + 0];
2091
- const block_q4_0 * restrict y0 = &py[i + 0];
2092
- const block_q4_0 * restrict x1 = &px[i + 1];
2093
- const block_q4_0 * restrict y1 = &py[i + 1];
2471
+ const block_q4_0 * restrict x0 = &x[i + 0];
2472
+ const block_q4_0 * restrict y0 = &y[i + 0];
2473
+ const block_q4_0 * restrict x1 = &x[i + 1];
2474
+ const block_q4_0 * restrict y1 = &y[i + 1];
2094
2475
 
2095
2476
  const v128_t m4b = wasm_u8x16_splat(0xf);
2096
2477
  const v128_t s8b = wasm_i8x16_splat(0x8);
2097
2478
 
2098
- const v128_t v0_0 = wasm_v128_load(x0.qs);
2099
- const v128_t v0_1 = wasm_v128_load(y0.qs);
2100
- const v128_t v1_0 = wasm_v128_load(x1.qs);
2101
- const v128_t v1_1 = wasm_v128_load(y1.qs);
2479
+ const v128_t v0_0 = wasm_v128_load(x0->qs);
2480
+ const v128_t v0_1 = wasm_v128_load(y0->qs);
2481
+ const v128_t v1_0 = wasm_v128_load(x1->qs);
2482
+ const v128_t v1_1 = wasm_v128_load(y1->qs);
2102
2483
 
2103
2484
  // 4-bit -> 8-bit
2104
2485
  const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -2170,18 +2551,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2170
2551
  const uint8_t * restrict p0 = x[i].qs;
2171
2552
  const uint8_t * restrict p1 = y[i].qs;
2172
2553
 
2173
- for (int j = 0; j < QK/2; j++) {
2554
+ int sumi = 0;
2555
+ for (int j = 0; j < QK4_0/2; j++) {
2174
2556
  const uint8_t v0 = p0[j];
2175
2557
  const uint8_t v1 = p1[j];
2176
2558
 
2177
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
2178
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
2559
+ const int i0 = (v0 & 0xf) - 8;
2560
+ const int i1 = (v0 >> 4) - 8;
2179
2561
 
2180
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
2181
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
2562
+ const int i2 = (v1 & 0xf) - 8;
2563
+ const int i3 = (v1 >> 4) - 8;
2182
2564
 
2183
- sumf += f0*f2 + f1*f3;
2565
+ sumi += i0*i2 + i1*i3;
2184
2566
  }
2567
+ sumf += d0 * d1 * sumi;
2185
2568
  }
2186
2569
  #endif
2187
2570
 
@@ -2189,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2189
2572
  }
2190
2573
 
2191
2574
  static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2192
- const int nb = n / QK;
2575
+ const int nb = n / QK4_1;
2193
2576
 
2194
2577
  const block_q4_1 * restrict x = vx;
2195
2578
  const block_q4_1 * restrict y = vy;
@@ -2266,46 +2649,81 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2266
2649
  res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2267
2650
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2268
2651
 
2269
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2652
+ sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2270
2653
  #elif defined(__ARM_NEON)
2271
2654
  float sum00 = 0.0f;
2272
2655
  float sum01 = 0.0f;
2273
2656
  float sum10 = 0.0f;
2274
2657
  float sum11 = 0.0f;
2275
2658
 
2276
- for (int i = 0; i < nb; ++i) {
2659
+ for (int i = 0; i < nb; i += 2) {
2277
2660
  const block_q4_1 * restrict x0 = &x[i + 0];
2278
2661
  const block_q4_1 * restrict y0 = &y[i + 0];
2662
+ const block_q4_1 * restrict x1 = &x[i + 1];
2663
+ const block_q4_1 * restrict y1 = &y[i + 1];
2279
2664
 
2280
2665
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2281
2666
 
2282
2667
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2283
2668
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2669
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2670
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2284
2671
 
2285
- // and with 0xf
2672
+ // 4-bit -> 8-bit
2286
2673
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2287
2674
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2288
-
2289
2675
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2290
2676
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2291
2677
 
2292
- // dot product into uint16x8_t
2678
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2679
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2680
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2681
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2682
+
2683
+ sum00 += x0->m*y0->m;
2684
+ sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
+ sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2686
+
2687
+ sum00 += x1->m*y1->m;
2688
+ sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
+ sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2690
+
2691
+ #if defined(__ARM_FEATURE_DOTPROD)
2692
+ // dot product into int32x4_t
2693
+ uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2694
+ uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2695
+
2696
+ p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2697
+ p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2698
+
2699
+ sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2700
+ sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2701
+ #else
2293
2702
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2294
2703
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2295
-
2296
2704
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2297
2705
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2298
2706
 
2299
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2300
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2707
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2708
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2709
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2710
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2301
2711
 
2302
- sum00 += x0->m*y0->m;
2303
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2304
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2305
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2712
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2713
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2714
+
2715
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2716
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2717
+
2718
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2719
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2720
+
2721
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2722
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2723
+ #endif
2306
2724
  }
2307
2725
 
2308
- sumf = QK*sum00 + sum01 + sum10 + sum11;
2726
+ sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2309
2727
  #else
2310
2728
  // scalar
2311
2729
  for (int i = 0; i < nb; i++) {
@@ -2318,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2318
2736
  const uint8_t * restrict p0 = x[i].qs;
2319
2737
  const uint8_t * restrict p1 = y[i].qs;
2320
2738
 
2321
- for (int j = 0; j < QK/2; j++) {
2739
+ for (int j = 0; j < QK4_1/2; j++) {
2322
2740
  const uint8_t v0 = p0[j];
2323
2741
  const uint8_t v1 = p1[j];
2324
2742
 
@@ -2336,21 +2754,224 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2336
2754
  *s = sumf;
2337
2755
  }
2338
2756
 
2339
- // compute GGML_VEC_DOT_UNROLL dot products at once
2340
- // xs - x row stride in bytes
2341
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2342
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2757
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
+ const int nb = n / QK8_0;
2343
2759
 
2344
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2760
+ assert(n % QK8_0 == 0);
2761
+ assert(nb % 2 == 0);
2345
2762
 
2346
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2347
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2348
- }
2763
+ const block_q4_0 * restrict x = vx;
2764
+ const block_q8_0 * restrict y = vy;
2349
2765
 
2350
- #if defined(GGML_SIMD)
2351
- const int np = (n & ~(GGML_F16_STEP - 1));
2766
+ float sumf = 0.0;
2352
2767
 
2353
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2768
+ #if defined(__ARM_NEON)
2769
+ float sum0 = 0.0f;
2770
+ float sum1 = 0.0f;
2771
+
2772
+ for (int i = 0; i < nb; i += 2) {
2773
+ const block_q4_0 * restrict x0 = &x[i + 0];
2774
+ const block_q4_0 * restrict x1 = &x[i + 1];
2775
+ const block_q8_0 * restrict y0 = &y[i + 0];
2776
+ const block_q8_0 * restrict y1 = &y[i + 1];
2777
+
2778
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2780
+
2781
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2783
+
2784
+ // 4-bit -> 8-bit
2785
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2786
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2787
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2788
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2789
+
2790
+ // sub 8
2791
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2792
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2793
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2795
+
2796
+ // load y
2797
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2801
+
2802
+ // interleave
2803
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2807
+
2808
+ #if defined(__ARM_FEATURE_DOTPROD)
2809
+ // dot product into int32x4_t
2810
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2812
+
2813
+ p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
+ p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2815
+
2816
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2818
+ #else
2819
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2823
+
2824
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
+
2829
+ const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
+ const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
+
2832
+ const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
+ const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
+
2835
+ const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
+ const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
+
2838
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2840
+ #endif
2841
+ }
2842
+
2843
+ sumf = sum0 + sum1;
2844
+ #elif defined(__AVX2__)
2845
+ // Initialize accumulator with zeros
2846
+ __m256 acc = _mm256_setzero_ps();
2847
+
2848
+ // Main loop
2849
+ for (int i = 0; i < nb; ++i) {
2850
+ /* Compute combined scale for the block */
2851
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2852
+
2853
+ __m256i bx = bytesFromNibbles(x[i].qs);
2854
+
2855
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
+ const __m256i off = _mm256_set1_epi8( 8 );
2857
+ bx = _mm256_sub_epi8( bx, off );
2858
+
2859
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
+
2861
+ // Get absolute values of x vectors
2862
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
+
2864
+ // Sign the values of the y vectors
2865
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2866
+
2867
+ // Perform multiplication and create 16-bit values
2868
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
+
2870
+ const __m256i ones = _mm256_set1_epi16(1);
2871
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
+
2873
+ /* Convert to vectore of 8 int32_t to 8 floats */
2874
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2875
+
2876
+ /* Multiply q with scale and accumulate */
2877
+ acc = _mm256_fmadd_ps( d, q, acc );
2878
+ }
2879
+
2880
+ // Return horizontal sum of the acc vector
2881
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2885
+
2886
+ sumf = _mm_cvtss_f32( res );
2887
+ #elif defined(__AVX__)
2888
+ // Initialize accumulator with zeros
2889
+ __m256 acc = _mm256_setzero_ps();
2890
+
2891
+ // Main loop
2892
+ for (int i = 0; i < nb; ++i) {
2893
+ // Compute combined scale for the block
2894
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2895
+
2896
+ __m128i i32[2];
2897
+ for (int j = 0; j < 2; ++j) {
2898
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2901
+
2902
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
+ const __m128i off = _mm_set1_epi8( 8 );
2904
+ bx = _mm_sub_epi8( bx, off );
2905
+
2906
+ // Get absolute values of x vectors
2907
+ const __m128i ax = _mm_sign_epi8(bx, bx);
2908
+
2909
+ // Sign the values of the y vectors
2910
+ const __m128i sy = _mm_sign_epi8(by, bx);
2911
+
2912
+ // Perform multiplication and create 16-bit values
2913
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
2914
+
2915
+ const __m128i ones = _mm_set1_epi16(1);
2916
+ i32[j] = _mm_madd_epi16(ones, dot);
2917
+ }
2918
+
2919
+ // Convert int32_t to float
2920
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
+ // Apply the scale, and accumulate
2922
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2923
+ }
2924
+
2925
+ // Return horizontal sum of the acc vector
2926
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
+
2931
+ sumf = _mm_cvtss_f32( res );
2932
+ #else
2933
+ // scalar
2934
+ for (int i = 0; i < nb; i++) {
2935
+ const float d0 = x[i].d;
2936
+ const float d1 = y[i].d;
2937
+
2938
+ const uint8_t * restrict p0 = x[i].qs;
2939
+ const int8_t * restrict p1 = y[i].qs;
2940
+
2941
+ int sumi = 0;
2942
+ for (int j = 0; j < QK8_0/2; j++) {
2943
+ const uint8_t v0 = p0[j];
2944
+
2945
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2947
+
2948
+ const int i2 = p1[2*j + 0];
2949
+ const int i3 = p1[2*j + 1];
2950
+
2951
+ sumi += i0*i2 + i1*i3;
2952
+ }
2953
+ sumf += d0*d1*sumi;
2954
+ }
2955
+ #endif
2956
+
2957
+ *s = sumf;
2958
+ }
2959
+
2960
+ // compute GGML_VEC_DOT_UNROLL dot products at once
2961
+ // xs - x row stride in bytes
2962
+ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2963
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2964
+
2965
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2966
+
2967
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2968
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2969
+ }
2970
+
2971
+ #if defined(GGML_SIMD)
2972
+ const int np = (n & ~(GGML_F16_STEP - 1));
2973
+
2974
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2354
2975
 
2355
2976
  GGML_F16_VEC ax[GGML_F16_ARR];
2356
2977
  GGML_F16_VEC ay[GGML_F16_ARR];
@@ -2578,29 +3199,41 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2578
3199
  //
2579
3200
 
2580
3201
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2581
- QK,
2582
- QK,
2583
- 1,
2584
- 1,
2585
- 1,
2586
- 1,
2587
- 1,
3202
+ [GGML_TYPE_F32] = 1,
3203
+ [GGML_TYPE_F16] = 1,
3204
+ [GGML_TYPE_Q4_0] = QK4_0,
3205
+ [GGML_TYPE_Q4_1] = QK4_1,
3206
+ [GGML_TYPE_Q8_0] = QK8_0,
3207
+ [GGML_TYPE_I8] = 1,
3208
+ [GGML_TYPE_I16] = 1,
3209
+ [GGML_TYPE_I32] = 1,
2588
3210
  };
2589
-
2590
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
3211
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
2591
3212
 
2592
3213
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2593
- sizeof(block_q4_0),
2594
- sizeof(block_q4_1),
2595
- sizeof(int8_t ),
2596
- sizeof(int16_t),
2597
- sizeof(int32_t),
2598
- sizeof(ggml_fp16_t),
2599
- sizeof(float ),
3214
+ [GGML_TYPE_F32] = sizeof(float),
3215
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
3216
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3217
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3218
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3219
+ [GGML_TYPE_I8] = sizeof(int8_t),
3220
+ [GGML_TYPE_I16] = sizeof(int16_t),
3221
+ [GGML_TYPE_I32] = sizeof(int32_t),
2600
3222
  };
2601
-
2602
- // don't forget to update the array above when adding new types
2603
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
3223
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
3224
+
3225
+
3226
+ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3227
+ [GGML_TYPE_F32] = "f32",
3228
+ [GGML_TYPE_F16] = "f16",
3229
+ [GGML_TYPE_Q4_0] = "q4_0",
3230
+ [GGML_TYPE_Q4_1] = "q4_1",
3231
+ [GGML_TYPE_Q8_0] = "q8_0",
3232
+ [GGML_TYPE_I8] = "i8",
3233
+ [GGML_TYPE_I16] = "i16",
3234
+ [GGML_TYPE_I32] = "i32",
3235
+ };
3236
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
2604
3237
 
2605
3238
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2606
3239
  "NONE",
@@ -2629,6 +3262,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2629
3262
 
2630
3263
  "SCALE",
2631
3264
  "CPY",
3265
+ "CONT",
2632
3266
  "RESHAPE",
2633
3267
  "VIEW",
2634
3268
  "PERMUTE",
@@ -2642,9 +3276,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2642
3276
 
2643
3277
  "FLASH_ATTN",
2644
3278
  "FLASH_FF",
3279
+
3280
+ "MAP_UNARY",
3281
+ "MAP_BINARY",
2645
3282
  };
2646
3283
 
2647
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
3284
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2648
3285
 
2649
3286
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2650
3287
  "none",
@@ -2673,6 +3310,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2673
3310
 
2674
3311
  "x*v",
2675
3312
  "x-\\>y",
3313
+ "cont(x)",
2676
3314
  "reshape(x)",
2677
3315
  "view(x)",
2678
3316
  "permute(x)",
@@ -2686,24 +3324,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2686
3324
 
2687
3325
  "flash_attn(x)",
2688
3326
  "flash_ff(x)",
2689
- };
2690
-
2691
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2692
-
2693
- //
2694
- // ggml object
2695
- //
2696
-
2697
- struct ggml_object {
2698
- size_t offs;
2699
- size_t size;
2700
3327
 
2701
- struct ggml_object * next;
2702
-
2703
- char padding[8];
3328
+ "f(x)",
3329
+ "f(x,y)",
2704
3330
  };
2705
3331
 
2706
- static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
3332
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2707
3333
 
2708
3334
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
2709
3335
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2716,7 +3342,6 @@ struct ggml_context {
2716
3342
  size_t mem_size;
2717
3343
  void * mem_buffer;
2718
3344
  bool mem_buffer_owned;
2719
- bool mem_buffer_mlocked;
2720
3345
  bool no_alloc;
2721
3346
 
2722
3347
  int n_objects;
@@ -2834,6 +3459,11 @@ float ggml_type_sizef(enum ggml_type type) {
2834
3459
  return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
2835
3460
  }
2836
3461
 
3462
+ const char * ggml_type_name(enum ggml_type type) {
3463
+ return GGML_TYPE_NAME[type];
3464
+ }
3465
+
3466
+
2837
3467
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2838
3468
  return GGML_TYPE_SIZE[tensor->type];
2839
3469
  }
@@ -2999,11 +3629,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2999
3629
  return NULL;
3000
3630
  }
3001
3631
 
3632
+ const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
3633
+
3002
3634
  *ctx = (struct ggml_context) {
3003
- /*.mem_size =*/ params.mem_size,
3004
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
3635
+ /*.mem_size =*/ mem_size,
3636
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3005
3637
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3006
- /*.mem_buffer_mlocked =*/ false,
3007
3638
  /*.no_alloc =*/ params.no_alloc,
3008
3639
  /*.n_objects =*/ 0,
3009
3640
  /*.objects_begin =*/ NULL,
@@ -3012,7 +3643,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3012
3643
  /*.scratch_save =*/ { 0, 0, NULL, },
3013
3644
  };
3014
3645
 
3015
- GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
3646
+ GGML_ASSERT(ctx->mem_buffer != NULL);
3016
3647
 
3017
3648
  ggml_assert_aligned(ctx->mem_buffer);
3018
3649
 
@@ -3036,16 +3667,8 @@ void ggml_free(struct ggml_context * ctx) {
3036
3667
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
3037
3668
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3038
3669
 
3039
- #if GGML_MLOCK_SUPPORT
3040
- if (ctx->mem_buffer_mlocked) {
3041
- if (munlock(ctx->mem_buffer, ctx->mem_size)) {
3042
- fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
3043
- }
3044
- }
3045
- #endif
3046
-
3047
3670
  if (ctx->mem_buffer_owned) {
3048
- free(ctx->mem_buffer);
3671
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
3049
3672
  }
3050
3673
 
3051
3674
  found = true;
@@ -3072,48 +3695,6 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
3072
3695
  return result;
3073
3696
  }
3074
3697
 
3075
- #ifdef __APPLE__
3076
- #define MLOCK_SUGGESTION \
3077
- "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3078
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3079
- #else
3080
- #define MLOCK_SUGGESTION \
3081
- "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3082
- #endif
3083
-
3084
- bool ggml_mlock_supported(void) {
3085
- return GGML_MLOCK_SUPPORT;
3086
- }
3087
-
3088
- bool ggml_mlock(
3089
- struct ggml_context * ctx,
3090
- const void *opt_extra_addr,
3091
- size_t opt_extra_len,
3092
- char **err_p) {
3093
- // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
3094
- #if GGML_MLOCK_SUPPORT
3095
- if (ctx->mem_buffer_mlocked) {
3096
- return true;
3097
- }
3098
- if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3099
- (opt_extra_len &&
3100
- mlock(opt_extra_addr, opt_extra_len))) {
3101
- if ((*err_p = malloc(1024))) {
3102
- snprintf(*err_p, 1024,
3103
- "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3104
- ctx->mem_size + opt_extra_len,
3105
- strerror(errno));
3106
- }
3107
- return false;
3108
- }
3109
- ctx->mem_buffer_mlocked = true;
3110
- return true;
3111
- #else // GGML_MLOCK_SUPPORT
3112
- *err_p = strdup("can't mlock because it's not supported on this system");
3113
- return false;
3114
- #endif // GGML_MLOCK_SUPPORT
3115
- }
3116
-
3117
3698
  ////////////////////////////////////////////////////////////////////////////////
3118
3699
 
3119
3700
  struct ggml_tensor * ggml_new_tensor_impl(
@@ -3325,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3325
3906
  char * const data = tensor->data;
3326
3907
 
3327
3908
  switch (tensor->type) {
3328
- case GGML_TYPE_Q4_0:
3329
- {
3330
- GGML_ASSERT(false);
3331
- } break;
3332
- case GGML_TYPE_Q4_1:
3333
- {
3334
- GGML_ASSERT(false);
3335
- } break;
3336
3909
  case GGML_TYPE_I8:
3337
3910
  {
3338
3911
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3368,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3368
3941
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3369
3942
  }
3370
3943
  } break;
3371
- case GGML_TYPE_COUNT:
3944
+ default:
3372
3945
  {
3373
3946
  GGML_ASSERT(false);
3374
3947
  } break;
@@ -3385,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3385
3958
  char * const data = tensor->data;
3386
3959
 
3387
3960
  switch (tensor->type) {
3388
- case GGML_TYPE_Q4_0:
3389
- {
3390
- GGML_ASSERT(false);
3391
- } break;
3392
- case GGML_TYPE_Q4_1:
3393
- {
3394
- GGML_ASSERT(false);
3395
- } break;
3396
3961
  case GGML_TYPE_I8:
3397
3962
  {
3398
3963
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3428,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3428
3993
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3429
3994
  }
3430
3995
  } break;
3431
- case GGML_TYPE_COUNT:
3996
+ default:
3432
3997
  {
3433
3998
  GGML_ASSERT(false);
3434
3999
  } break;
@@ -3439,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3439
4004
 
3440
4005
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3441
4006
  switch (tensor->type) {
3442
- case GGML_TYPE_Q4_0:
3443
- {
3444
- GGML_ASSERT(false);
3445
- } break;
3446
- case GGML_TYPE_Q4_1:
3447
- {
3448
- GGML_ASSERT(false);
3449
- } break;
3450
4007
  case GGML_TYPE_I8:
3451
4008
  {
3452
4009
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3472,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3472
4029
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3473
4030
  return ((float *)(tensor->data))[i];
3474
4031
  } break;
3475
- case GGML_TYPE_COUNT:
4032
+ default:
3476
4033
  {
3477
4034
  GGML_ASSERT(false);
3478
4035
  } break;
@@ -3483,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3483
4040
 
3484
4041
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3485
4042
  switch (tensor->type) {
3486
- case GGML_TYPE_Q4_0:
3487
- {
3488
- GGML_ASSERT(false);
3489
- } break;
3490
- case GGML_TYPE_Q4_1:
3491
- {
3492
- GGML_ASSERT(false);
3493
- } break;
3494
4043
  case GGML_TYPE_I8:
3495
4044
  {
3496
4045
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3516,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3516
4065
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3517
4066
  ((float *)(tensor->data))[i] = value;
3518
4067
  } break;
3519
- case GGML_TYPE_COUNT:
4068
+ default:
3520
4069
  {
3521
4070
  GGML_ASSERT(false);
3522
4071
  } break;
@@ -3525,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3525
4074
 
3526
4075
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3527
4076
  switch (tensor->type) {
3528
- case GGML_TYPE_Q4_0:
3529
- {
3530
- GGML_ASSERT(false);
3531
- } break;
3532
- case GGML_TYPE_Q4_1:
3533
- {
3534
- GGML_ASSERT(false);
3535
- } break;
3536
4077
  case GGML_TYPE_I8:
3537
4078
  {
3538
4079
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3558,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3558
4099
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3559
4100
  return ((float *)(tensor->data))[i];
3560
4101
  } break;
3561
- case GGML_TYPE_COUNT:
4102
+ default:
3562
4103
  {
3563
4104
  GGML_ASSERT(false);
3564
4105
  } break;
@@ -3569,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3569
4110
 
3570
4111
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3571
4112
  switch (tensor->type) {
3572
- case GGML_TYPE_Q4_0:
3573
- {
3574
- GGML_ASSERT(false);
3575
- } break;
3576
- case GGML_TYPE_Q4_1:
3577
- {
3578
- GGML_ASSERT(false);
3579
- } break;
3580
4113
  case GGML_TYPE_I8:
3581
4114
  {
3582
4115
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3602,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3602
4135
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3603
4136
  ((float *)(tensor->data))[i] = value;
3604
4137
  } break;
3605
- case GGML_TYPE_COUNT:
4138
+ default:
3606
4139
  {
3607
4140
  GGML_ASSERT(false);
3608
4141
  } break;
@@ -4388,26 +4921,22 @@ struct ggml_tensor * ggml_cpy_inplace(
4388
4921
  return ggml_cpy_impl(ctx, a, b, true);
4389
4922
  }
4390
4923
 
4391
- // ggml_reshape
4924
+ // ggml_cont
4392
4925
 
4393
- struct ggml_tensor * ggml_reshape(
4926
+ struct ggml_tensor * ggml_cont_impl(
4394
4927
  struct ggml_context * ctx,
4395
- struct ggml_tensor * a,
4396
- struct ggml_tensor * b) {
4397
- GGML_ASSERT(ggml_is_contiguous(a));
4398
- GGML_ASSERT(ggml_is_contiguous(b));
4399
- GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
4400
-
4928
+ struct ggml_tensor * a,
4929
+ bool inplace) {
4401
4930
  bool is_node = false;
4402
4931
 
4403
- if (a->grad || b->grad) {
4932
+ if (!inplace && a->grad) {
4404
4933
  GGML_ASSERT(false); // TODO: implement backward
4405
4934
  is_node = true;
4406
4935
  }
4407
4936
 
4408
- struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
4937
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4409
4938
 
4410
- result->op = GGML_OP_RESHAPE;
4939
+ result->op = GGML_OP_CONT;
4411
4940
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4412
4941
  result->src0 = a;
4413
4942
  result->src1 = NULL;
@@ -4415,12 +4944,51 @@ struct ggml_tensor * ggml_reshape(
4415
4944
  return result;
4416
4945
  }
4417
4946
 
4418
- struct ggml_tensor * ggml_reshape_2d(
4947
+ struct ggml_tensor * ggml_cont(
4419
4948
  struct ggml_context * ctx,
4420
- struct ggml_tensor * a,
4421
- int64_t ne0,
4422
- int64_t ne1) {
4423
- GGML_ASSERT(ggml_is_contiguous(a));
4949
+ struct ggml_tensor * a) {
4950
+ return ggml_cont_impl(ctx, a, false);
4951
+ }
4952
+
4953
+ struct ggml_tensor * ggml_cont_inplace(
4954
+ struct ggml_context * ctx,
4955
+ struct ggml_tensor * a) {
4956
+ return ggml_cont_impl(ctx, a, true);
4957
+ }
4958
+
4959
+ // ggml_reshape
4960
+
4961
+ struct ggml_tensor * ggml_reshape(
4962
+ struct ggml_context * ctx,
4963
+ struct ggml_tensor * a,
4964
+ struct ggml_tensor * b) {
4965
+ GGML_ASSERT(ggml_is_contiguous(a));
4966
+ GGML_ASSERT(ggml_is_contiguous(b));
4967
+ GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
4968
+
4969
+ bool is_node = false;
4970
+
4971
+ if (a->grad || b->grad) {
4972
+ GGML_ASSERT(false); // TODO: implement backward
4973
+ is_node = true;
4974
+ }
4975
+
4976
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
4977
+
4978
+ result->op = GGML_OP_RESHAPE;
4979
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4980
+ result->src0 = a;
4981
+ result->src1 = NULL;
4982
+
4983
+ return result;
4984
+ }
4985
+
4986
+ struct ggml_tensor * ggml_reshape_2d(
4987
+ struct ggml_context * ctx,
4988
+ struct ggml_tensor * a,
4989
+ int64_t ne0,
4990
+ int64_t ne1) {
4991
+ GGML_ASSERT(ggml_is_contiguous(a));
4424
4992
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
4425
4993
 
4426
4994
  bool is_node = false;
@@ -4866,6 +5434,90 @@ struct ggml_tensor * ggml_flash_ff(
4866
5434
  return result;
4867
5435
  }
4868
5436
 
5437
+ // ggml_map_unary
5438
+
5439
+ struct ggml_tensor * ggml_map_unary_impl_f32(
5440
+ struct ggml_context * ctx,
5441
+ struct ggml_tensor * a,
5442
+ const ggml_unary_op_f32_t fun,
5443
+ bool inplace) {
5444
+ bool is_node = false;
5445
+
5446
+ if (!inplace && a->grad) {
5447
+ is_node = true;
5448
+ }
5449
+
5450
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
5451
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
5452
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5453
+
5454
+ result->op = GGML_OP_MAP_UNARY;
5455
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5456
+ result->src0 = a;
5457
+ result->opt[0] = addr_tensor;
5458
+
5459
+ return result;
5460
+ }
5461
+
5462
+ struct ggml_tensor * ggml_map_unary_f32(
5463
+ struct ggml_context * ctx,
5464
+ struct ggml_tensor * a,
5465
+ const ggml_unary_op_f32_t fun) {
5466
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
5467
+ }
5468
+
5469
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
5470
+ struct ggml_context * ctx,
5471
+ struct ggml_tensor * a,
5472
+ const ggml_unary_op_f32_t fun) {
5473
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
5474
+ }
5475
+
5476
+ // ggml_map_binary
5477
+
5478
+ struct ggml_tensor * ggml_map_binary_impl_f32(
5479
+ struct ggml_context * ctx,
5480
+ struct ggml_tensor * a,
5481
+ struct ggml_tensor * b,
5482
+ const ggml_binary_op_f32_t fun,
5483
+ bool inplace) {
5484
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5485
+
5486
+ bool is_node = false;
5487
+
5488
+ if (!inplace && (a->grad || b->grad)) {
5489
+ is_node = true;
5490
+ }
5491
+
5492
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
5493
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
5494
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5495
+
5496
+ result->op = GGML_OP_MAP_BINARY;
5497
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5498
+ result->src0 = a;
5499
+ result->src1 = b;
5500
+ result->opt[0] = addr_tensor;
5501
+
5502
+ return result;
5503
+ }
5504
+
5505
+ struct ggml_tensor * ggml_map_binary_f32(
5506
+ struct ggml_context * ctx,
5507
+ struct ggml_tensor * a,
5508
+ struct ggml_tensor * b,
5509
+ const ggml_binary_op_f32_t fun) {
5510
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
5511
+ }
5512
+
5513
+ struct ggml_tensor * ggml_map_binary_inplace_f32(
5514
+ struct ggml_context * ctx,
5515
+ struct ggml_tensor * a,
5516
+ struct ggml_tensor * b,
5517
+ const ggml_binary_op_f32_t fun) {
5518
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
5519
+ }
5520
+
4869
5521
  ////////////////////////////////////////////////////////////////////////////////
4870
5522
 
4871
5523
  void ggml_set_param(
@@ -4930,6 +5582,105 @@ static void ggml_compute_forward_dup_f16(
4930
5582
 
4931
5583
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4932
5584
 
5585
+ if (ggml_is_contiguous(dst)) {
5586
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5587
+ if (dst->type == GGML_TYPE_F16) {
5588
+ size_t id = 0;
5589
+ const size_t rs = ne00*nb00;
5590
+
5591
+ for (int i03 = 0; i03 < ne03; i03++) {
5592
+ for (int i02 = 0; i02 < ne02; i02++) {
5593
+ for (int i01 = 0; i01 < ne01; i01++) {
5594
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5595
+ char * dst_ptr = (char *) dst->data + id*rs;
5596
+
5597
+ memcpy(dst_ptr, src0_ptr, rs);
5598
+
5599
+ id++;
5600
+ }
5601
+ }
5602
+ }
5603
+ } else if (dst->type == GGML_TYPE_F32) {
5604
+ size_t id = 0;
5605
+ float * dst_ptr = (float *) dst->data;
5606
+
5607
+ for (int i03 = 0; i03 < ne03; i03++) {
5608
+ for (int i02 = 0; i02 < ne02; i02++) {
5609
+ for (int i01 = 0; i01 < ne01; i01++) {
5610
+ for (int i00 = 0; i00 < ne00; i00++) {
5611
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5612
+
5613
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5614
+ id++;
5615
+ }
5616
+ }
5617
+ }
5618
+ }
5619
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5620
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5621
+ size_t id = 0;
5622
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
+ float * src0_f32 = (float *) params->wdata;
5625
+
5626
+ for (int i03 = 0; i03 < ne03; i03++) {
5627
+ for (int i02 = 0; i02 < ne02; i02++) {
5628
+ for (int i01 = 0; i01 < ne01; i01++) {
5629
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
+ // convert to f32 and quantize
5631
+ for (int i00 = 0; i00 < ne00; i00++) {
5632
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
+ }
5634
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
+ id += dst_row_size;
5636
+ }
5637
+ }
5638
+ }
5639
+ } else {
5640
+ GGML_ASSERT(false); // TODO: implement
5641
+ }
5642
+ } else {
5643
+ //printf("%s: this is not optimal - fix me\n", __func__);
5644
+
5645
+ if (dst->type == GGML_TYPE_F32) {
5646
+ size_t id = 0;
5647
+ float * dst_ptr = (float *) dst->data;
5648
+
5649
+ for (int i03 = 0; i03 < ne03; i03++) {
5650
+ for (int i02 = 0; i02 < ne02; i02++) {
5651
+ for (int i01 = 0; i01 < ne01; i01++) {
5652
+ for (int i00 = 0; i00 < ne00; i00++) {
5653
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5654
+
5655
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5656
+ id++;
5657
+ }
5658
+ }
5659
+ }
5660
+ }
5661
+ } else if (dst->type == GGML_TYPE_F16) {
5662
+ size_t id = 0;
5663
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5664
+
5665
+ for (int i03 = 0; i03 < ne03; i03++) {
5666
+ for (int i02 = 0; i02 < ne02; i02++) {
5667
+ for (int i01 = 0; i01 < ne01; i01++) {
5668
+ for (int i00 = 0; i00 < ne00; i00++) {
5669
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5670
+
5671
+ dst_ptr[id] = *src0_ptr;
5672
+ id++;
5673
+ }
5674
+ }
5675
+ }
5676
+ }
5677
+ } else {
5678
+ GGML_ASSERT(false); // TODO: implement
5679
+ }
5680
+ }
5681
+ return;
5682
+ }
5683
+
4933
5684
  // dst counters
4934
5685
  int64_t i10 = 0;
4935
5686
  int64_t i11 = 0;
@@ -5024,6 +5775,120 @@ static void ggml_compute_forward_dup_f32(
5024
5775
  return;
5025
5776
  }
5026
5777
 
5778
+ if (src0->type == dst->type &&
5779
+ src0->ne[0] == dst->ne[0] &&
5780
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5781
+ // copy by rows
5782
+ const size_t rs = ne00*nb00;
5783
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5784
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5785
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5786
+ memcpy(
5787
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5788
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5789
+ rs);
5790
+ }
5791
+ }
5792
+ }
5793
+ return;
5794
+ }
5795
+
5796
+ if (ggml_is_contiguous(dst)) {
5797
+ // TODO: simplify
5798
+ if (src0->nb[0] == sizeof(float)) {
5799
+ if (dst->type == GGML_TYPE_F32) {
5800
+ size_t id = 0;
5801
+ const size_t rs = ne00*nb00;
5802
+
5803
+ for (int i03 = 0; i03 < ne03; i03++) {
5804
+ for (int i02 = 0; i02 < ne02; i02++) {
5805
+ for (int i01 = 0; i01 < ne01; i01++) {
5806
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5807
+ char * dst_ptr = (char *) dst->data + id*rs;
5808
+
5809
+ memcpy(dst_ptr, src0_ptr, rs);
5810
+
5811
+ id++;
5812
+ }
5813
+ }
5814
+ }
5815
+ } else if (dst->type == GGML_TYPE_F16) {
5816
+ size_t id = 0;
5817
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5818
+
5819
+ for (int i03 = 0; i03 < ne03; i03++) {
5820
+ for (int i02 = 0; i02 < ne02; i02++) {
5821
+ for (int i01 = 0; i01 < ne01; i01++) {
5822
+ for (int i00 = 0; i00 < ne00; i00++) {
5823
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5824
+
5825
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5826
+ id++;
5827
+ }
5828
+ }
5829
+ }
5830
+ }
5831
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5832
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5833
+ size_t id = 0;
5834
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5836
+
5837
+ for (int i03 = 0; i03 < ne03; i03++) {
5838
+ for (int i02 = 0; i02 < ne02; i02++) {
5839
+ for (int i01 = 0; i01 < ne01; i01++) {
5840
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
+ id += dst_row_size;
5843
+ }
5844
+ }
5845
+ }
5846
+ } else {
5847
+ GGML_ASSERT(false); // TODO: implement
5848
+ }
5849
+ } else {
5850
+ //printf("%s: this is not optimal - fix me\n", __func__);
5851
+
5852
+ if (dst->type == GGML_TYPE_F32) {
5853
+ size_t id = 0;
5854
+ float * dst_ptr = (float *) dst->data;
5855
+
5856
+ for (int i03 = 0; i03 < ne03; i03++) {
5857
+ for (int i02 = 0; i02 < ne02; i02++) {
5858
+ for (int i01 = 0; i01 < ne01; i01++) {
5859
+ for (int i00 = 0; i00 < ne00; i00++) {
5860
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5861
+
5862
+ dst_ptr[id] = *src0_ptr;
5863
+ id++;
5864
+ }
5865
+ }
5866
+ }
5867
+ }
5868
+ } else if (dst->type == GGML_TYPE_F16) {
5869
+ size_t id = 0;
5870
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5871
+
5872
+ for (int i03 = 0; i03 < ne03; i03++) {
5873
+ for (int i02 = 0; i02 < ne02; i02++) {
5874
+ for (int i01 = 0; i01 < ne01; i01++) {
5875
+ for (int i00 = 0; i00 < ne00; i00++) {
5876
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5877
+
5878
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5879
+ id++;
5880
+ }
5881
+ }
5882
+ }
5883
+ }
5884
+ } else {
5885
+ GGML_ASSERT(false); // TODO: implement
5886
+ }
5887
+ }
5888
+
5889
+ return;
5890
+ }
5891
+
5027
5892
  // dst counters
5028
5893
  int64_t i10 = 0;
5029
5894
  int64_t i11 = 0;
@@ -5100,12 +5965,7 @@ static void ggml_compute_forward_dup(
5100
5965
  {
5101
5966
  ggml_compute_forward_dup_f32(params, src0, dst);
5102
5967
  } break;
5103
- case GGML_TYPE_Q4_0:
5104
- case GGML_TYPE_Q4_1:
5105
- case GGML_TYPE_I8:
5106
- case GGML_TYPE_I16:
5107
- case GGML_TYPE_I32:
5108
- case GGML_TYPE_COUNT:
5968
+ default:
5109
5969
  {
5110
5970
  GGML_ASSERT(false);
5111
5971
  } break;
@@ -5144,14 +6004,18 @@ static void ggml_compute_forward_add_f32(
5144
6004
  GGML_ASSERT(nb00 == sizeof(float));
5145
6005
 
5146
6006
  if (nb10 == sizeof(float)) {
5147
- const int j0 = (n/nth)*ith;
5148
- const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
5149
-
5150
- for (int j = j0; j < j1; j++) {
6007
+ for (int j = ith; j < n; j += nth) {
6008
+ #ifdef GGML_USE_ACCELERATE
6009
+ vDSP_vadd(
6010
+ (float *) ((char *) src0->data + j*nb01), 1,
6011
+ (float *) ((char *) src1->data + j*nb11), 1,
6012
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
6013
+ #else
5151
6014
  ggml_vec_add_f32(nc,
5152
6015
  (float *) ((char *) dst->data + j*nb1),
5153
6016
  (float *) ((char *) src0->data + j*nb01),
5154
6017
  (float *) ((char *) src1->data + j*nb11));
6018
+ #endif
5155
6019
  }
5156
6020
  } else {
5157
6021
  // src1 is not contiguous
@@ -5167,6 +6031,212 @@ static void ggml_compute_forward_add_f32(
5167
6031
  }
5168
6032
  }
5169
6033
 
6034
+ static void ggml_compute_forward_add_f16_f32(
6035
+ const struct ggml_compute_params * params,
6036
+ const struct ggml_tensor * src0,
6037
+ const struct ggml_tensor * src1,
6038
+ struct ggml_tensor * dst) {
6039
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6040
+
6041
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6042
+ return;
6043
+ }
6044
+
6045
+ const int ith = params->ith;
6046
+ const int nth = params->nth;
6047
+
6048
+ const int n = ggml_nrows(src0);
6049
+ const int nc = src0->ne[0];
6050
+
6051
+ const size_t nb00 = src0->nb[0];
6052
+ const size_t nb01 = src0->nb[1];
6053
+
6054
+ const size_t nb10 = src1->nb[0];
6055
+ const size_t nb11 = src1->nb[1];
6056
+
6057
+ const size_t nb0 = dst->nb[0];
6058
+ const size_t nb1 = dst->nb[1];
6059
+
6060
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6061
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6062
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6063
+
6064
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6065
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6066
+
6067
+ if (nb10 == sizeof(float)) {
6068
+ for (int j = ith; j < n; j += nth) {
6069
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6070
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6071
+ for (int i = 0; i < nc; i++) {
6072
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6073
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
6074
+ }
6075
+ }
6076
+ }
6077
+ else {
6078
+ // src1 is not contiguous
6079
+ GGML_ASSERT(false);
6080
+ }
6081
+ }
6082
+
6083
+ static void ggml_compute_forward_add_f16_f16(
6084
+ const struct ggml_compute_params * params,
6085
+ const struct ggml_tensor * src0,
6086
+ const struct ggml_tensor * src1,
6087
+ struct ggml_tensor * dst) {
6088
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6089
+
6090
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6091
+ return;
6092
+ }
6093
+
6094
+ const int ith = params->ith;
6095
+ const int nth = params->nth;
6096
+
6097
+ const int n = ggml_nrows(src0);
6098
+ const int nc = src0->ne[0];
6099
+
6100
+ const size_t nb00 = src0->nb[0];
6101
+ const size_t nb01 = src0->nb[1];
6102
+
6103
+ const size_t nb10 = src1->nb[0];
6104
+ const size_t nb11 = src1->nb[1];
6105
+
6106
+ const size_t nb0 = dst->nb[0];
6107
+ const size_t nb1 = dst->nb[1];
6108
+
6109
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6110
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
6111
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6112
+
6113
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6114
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6115
+
6116
+ if (nb10 == sizeof(ggml_fp16_t)) {
6117
+ for (int j = ith; j < n; j += nth) {
6118
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6119
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6120
+ for (int i = 0; i < nc; i++) {
6121
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
6122
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
6123
+ }
6124
+ }
6125
+ }
6126
+ else {
6127
+ // src1 is not contiguous
6128
+ GGML_ASSERT(false);
6129
+ }
6130
+ }
6131
+
6132
+ static void ggml_compute_forward_add_q_f32(
6133
+ const struct ggml_compute_params * params,
6134
+ const struct ggml_tensor * src0,
6135
+ const struct ggml_tensor * src1,
6136
+ struct ggml_tensor * dst) {
6137
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6138
+
6139
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6140
+ return;
6141
+ }
6142
+
6143
+ const int64_t ne00 = src0->ne[0];
6144
+ const int64_t ne01 = src0->ne[1];
6145
+ const int64_t ne02 = src0->ne[2];
6146
+ const int64_t ne03 = src0->ne[3];
6147
+
6148
+ //const int64_t ne10 = src1->ne[0];
6149
+ //const int64_t ne11 = src1->ne[1];
6150
+ const int64_t ne12 = src1->ne[2];
6151
+ const int64_t ne13 = src1->ne[3];
6152
+
6153
+ //const int64_t ne0 = dst->ne[0];
6154
+ //const int64_t ne1 = dst->ne[1];
6155
+ const int64_t ne2 = dst->ne[2];
6156
+ const int64_t ne3 = dst->ne[3];
6157
+
6158
+ const int nb00 = src0->nb[0];
6159
+ const int nb01 = src0->nb[1];
6160
+ const int nb02 = src0->nb[2];
6161
+ const int nb03 = src0->nb[3];
6162
+
6163
+ const int nb10 = src1->nb[0];
6164
+ const int nb11 = src1->nb[1];
6165
+ const int nb12 = src1->nb[2];
6166
+ const int nb13 = src1->nb[3];
6167
+
6168
+ const int nb0 = dst->nb[0];
6169
+ const int nb1 = dst->nb[1];
6170
+ const int nb2 = dst->nb[2];
6171
+ const int nb3 = dst->nb[3];
6172
+
6173
+ const int ith = params->ith;
6174
+ const int nth = params->nth;
6175
+
6176
+ GGML_ASSERT(ne02 == ne12);
6177
+ GGML_ASSERT(ne03 == ne13);
6178
+ GGML_ASSERT(ne2 == ne12);
6179
+ GGML_ASSERT(ne3 == ne13);
6180
+
6181
+ const enum ggml_type type = src0->type;
6182
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6183
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6184
+
6185
+ // we don't support permuted src0 or src1
6186
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
6187
+ GGML_ASSERT(nb10 == sizeof(float));
6188
+
6189
+ // dst cannot be transposed or permuted
6190
+ GGML_ASSERT(nb0 <= nb1);
6191
+ GGML_ASSERT(nb1 <= nb2);
6192
+ GGML_ASSERT(nb2 <= nb3);
6193
+
6194
+ GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
6195
+ GGML_ASSERT(dst->type == src0->type);
6196
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
+
6198
+ // total rows in src0
6199
+ const int nr = ne01*ne02*ne03;
6200
+
6201
+ // rows per thread
6202
+ const int dr = (nr + nth - 1)/nth;
6203
+
6204
+ // row range for this thread
6205
+ const int ir0 = dr*ith;
6206
+ const int ir1 = MIN(ir0 + dr, nr);
6207
+
6208
+ float * wdata = (float*) params->wdata + ne00 * ith;
6209
+
6210
+ for (int ir = ir0; ir < ir1; ++ir) {
6211
+ // src0 indices
6212
+ const int i03 = ir/(ne02*ne01);
6213
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
6214
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6215
+
6216
+ // src1 and dst are same shape as src0 => same indices
6217
+ const int i13 = i03;
6218
+ const int i12 = i02;
6219
+ const int i11 = i01;
6220
+
6221
+ const int i3 = i03;
6222
+ const int i2 = i02;
6223
+ const int i1 = i01;
6224
+
6225
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
6226
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
6227
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
6228
+
6229
+ assert(ne00 % 32 == 0);
6230
+
6231
+ // unquantize row from src0 to temp buffer
6232
+ dequantize_row_q(src0_row, wdata, ne00);
6233
+ // add src1
6234
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
6235
+ // quantize row to dst
6236
+ quantize_row_q(wdata, dst_row, ne00);
6237
+ }
6238
+ }
6239
+
5170
6240
  static void ggml_compute_forward_add(
5171
6241
  const struct ggml_compute_params * params,
5172
6242
  const struct ggml_tensor * src0,
@@ -5177,13 +6247,24 @@ static void ggml_compute_forward_add(
5177
6247
  {
5178
6248
  ggml_compute_forward_add_f32(params, src0, src1, dst);
5179
6249
  } break;
6250
+ case GGML_TYPE_F16:
6251
+ {
6252
+ if (src1->type == GGML_TYPE_F16) {
6253
+ ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6254
+ }
6255
+ else if (src1->type == GGML_TYPE_F32) {
6256
+ ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6257
+ }
6258
+ else {
6259
+ GGML_ASSERT(false);
6260
+ }
6261
+ } break;
5180
6262
  case GGML_TYPE_Q4_0:
5181
6263
  case GGML_TYPE_Q4_1:
5182
- case GGML_TYPE_I8:
5183
- case GGML_TYPE_I16:
5184
- case GGML_TYPE_I32:
5185
- case GGML_TYPE_F16:
5186
- case GGML_TYPE_COUNT:
6264
+ {
6265
+ ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
+ } break;
6267
+ default:
5187
6268
  {
5188
6269
  GGML_ASSERT(false);
5189
6270
  } break;
@@ -5229,13 +6310,7 @@ static void ggml_compute_forward_sub(
5229
6310
  {
5230
6311
  ggml_compute_forward_sub_f32(params, src0, src1, dst);
5231
6312
  } break;
5232
- case GGML_TYPE_Q4_0:
5233
- case GGML_TYPE_Q4_1:
5234
- case GGML_TYPE_I8:
5235
- case GGML_TYPE_I16:
5236
- case GGML_TYPE_I32:
5237
- case GGML_TYPE_F16:
5238
- case GGML_TYPE_COUNT:
6313
+ default:
5239
6314
  {
5240
6315
  GGML_ASSERT(false);
5241
6316
  } break;
@@ -5281,13 +6356,7 @@ static void ggml_compute_forward_mul(
5281
6356
  {
5282
6357
  ggml_compute_forward_mul_f32(params, src0, src1, dst);
5283
6358
  } break;
5284
- case GGML_TYPE_Q4_0:
5285
- case GGML_TYPE_Q4_1:
5286
- case GGML_TYPE_I8:
5287
- case GGML_TYPE_I16:
5288
- case GGML_TYPE_I32:
5289
- case GGML_TYPE_F16:
5290
- case GGML_TYPE_COUNT:
6359
+ default:
5291
6360
  {
5292
6361
  GGML_ASSERT(false);
5293
6362
  } break;
@@ -5333,13 +6402,7 @@ static void ggml_compute_forward_div(
5333
6402
  {
5334
6403
  ggml_compute_forward_div_f32(params, src0, src1, dst);
5335
6404
  } break;
5336
- case GGML_TYPE_Q4_0:
5337
- case GGML_TYPE_Q4_1:
5338
- case GGML_TYPE_I8:
5339
- case GGML_TYPE_I16:
5340
- case GGML_TYPE_I32:
5341
- case GGML_TYPE_F16:
5342
- case GGML_TYPE_COUNT:
6405
+ default:
5343
6406
  {
5344
6407
  GGML_ASSERT(false);
5345
6408
  } break;
@@ -5381,13 +6444,7 @@ static void ggml_compute_forward_sqr(
5381
6444
  {
5382
6445
  ggml_compute_forward_sqr_f32(params, src0, dst);
5383
6446
  } break;
5384
- case GGML_TYPE_Q4_0:
5385
- case GGML_TYPE_Q4_1:
5386
- case GGML_TYPE_I8:
5387
- case GGML_TYPE_I16:
5388
- case GGML_TYPE_I32:
5389
- case GGML_TYPE_F16:
5390
- case GGML_TYPE_COUNT:
6447
+ default:
5391
6448
  {
5392
6449
  GGML_ASSERT(false);
5393
6450
  } break;
@@ -5429,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
5429
6486
  {
5430
6487
  ggml_compute_forward_sqrt_f32(params, src0, dst);
5431
6488
  } break;
5432
- case GGML_TYPE_Q4_0:
5433
- case GGML_TYPE_Q4_1:
5434
- case GGML_TYPE_I8:
5435
- case GGML_TYPE_I16:
5436
- case GGML_TYPE_I32:
5437
- case GGML_TYPE_F16:
5438
- case GGML_TYPE_COUNT:
6489
+ default:
5439
6490
  {
5440
6491
  GGML_ASSERT(false);
5441
6492
  } break;
@@ -5485,16 +6536,10 @@ static void ggml_compute_forward_sum(
5485
6536
  switch (src0->type) {
5486
6537
  case GGML_TYPE_F32:
5487
6538
  {
5488
- ggml_compute_forward_sum_f32(params, src0, dst);
5489
- } break;
5490
- case GGML_TYPE_Q4_0:
5491
- case GGML_TYPE_Q4_1:
5492
- case GGML_TYPE_I8:
5493
- case GGML_TYPE_I16:
5494
- case GGML_TYPE_I32:
5495
- case GGML_TYPE_F16:
5496
- case GGML_TYPE_COUNT:
5497
- {
6539
+ ggml_compute_forward_sum_f32(params, src0, dst);
6540
+ } break;
6541
+ default:
6542
+ {
5498
6543
  GGML_ASSERT(false);
5499
6544
  } break;
5500
6545
  }
@@ -5564,13 +6609,7 @@ static void ggml_compute_forward_mean(
5564
6609
  {
5565
6610
  ggml_compute_forward_mean_f32(params, src0, dst);
5566
6611
  } break;
5567
- case GGML_TYPE_Q4_0:
5568
- case GGML_TYPE_Q4_1:
5569
- case GGML_TYPE_I8:
5570
- case GGML_TYPE_I16:
5571
- case GGML_TYPE_I32:
5572
- case GGML_TYPE_F16:
5573
- case GGML_TYPE_COUNT:
6612
+ default:
5574
6613
  {
5575
6614
  GGML_ASSERT(false);
5576
6615
  } break;
@@ -5628,13 +6667,7 @@ static void ggml_compute_forward_repeat(
5628
6667
  {
5629
6668
  ggml_compute_forward_repeat_f32(params, src0, dst);
5630
6669
  } break;
5631
- case GGML_TYPE_Q4_0:
5632
- case GGML_TYPE_Q4_1:
5633
- case GGML_TYPE_I8:
5634
- case GGML_TYPE_I16:
5635
- case GGML_TYPE_I32:
5636
- case GGML_TYPE_F16:
5637
- case GGML_TYPE_COUNT:
6670
+ default:
5638
6671
  {
5639
6672
  GGML_ASSERT(false);
5640
6673
  } break;
@@ -5676,13 +6709,7 @@ static void ggml_compute_forward_abs(
5676
6709
  {
5677
6710
  ggml_compute_forward_abs_f32(params, src0, dst);
5678
6711
  } break;
5679
- case GGML_TYPE_Q4_0:
5680
- case GGML_TYPE_Q4_1:
5681
- case GGML_TYPE_I8:
5682
- case GGML_TYPE_I16:
5683
- case GGML_TYPE_I32:
5684
- case GGML_TYPE_F16:
5685
- case GGML_TYPE_COUNT:
6712
+ default:
5686
6713
  {
5687
6714
  GGML_ASSERT(false);
5688
6715
  } break;
@@ -5724,13 +6751,7 @@ static void ggml_compute_forward_sgn(
5724
6751
  {
5725
6752
  ggml_compute_forward_sgn_f32(params, src0, dst);
5726
6753
  } break;
5727
- case GGML_TYPE_Q4_0:
5728
- case GGML_TYPE_Q4_1:
5729
- case GGML_TYPE_I8:
5730
- case GGML_TYPE_I16:
5731
- case GGML_TYPE_I32:
5732
- case GGML_TYPE_F16:
5733
- case GGML_TYPE_COUNT:
6754
+ default:
5734
6755
  {
5735
6756
  GGML_ASSERT(false);
5736
6757
  } break;
@@ -5772,13 +6793,7 @@ static void ggml_compute_forward_neg(
5772
6793
  {
5773
6794
  ggml_compute_forward_neg_f32(params, src0, dst);
5774
6795
  } break;
5775
- case GGML_TYPE_Q4_0:
5776
- case GGML_TYPE_Q4_1:
5777
- case GGML_TYPE_I8:
5778
- case GGML_TYPE_I16:
5779
- case GGML_TYPE_I32:
5780
- case GGML_TYPE_F16:
5781
- case GGML_TYPE_COUNT:
6796
+ default:
5782
6797
  {
5783
6798
  GGML_ASSERT(false);
5784
6799
  } break;
@@ -5820,13 +6835,7 @@ static void ggml_compute_forward_step(
5820
6835
  {
5821
6836
  ggml_compute_forward_step_f32(params, src0, dst);
5822
6837
  } break;
5823
- case GGML_TYPE_Q4_0:
5824
- case GGML_TYPE_Q4_1:
5825
- case GGML_TYPE_I8:
5826
- case GGML_TYPE_I16:
5827
- case GGML_TYPE_I32:
5828
- case GGML_TYPE_F16:
5829
- case GGML_TYPE_COUNT:
6838
+ default:
5830
6839
  {
5831
6840
  GGML_ASSERT(false);
5832
6841
  } break;
@@ -5868,13 +6877,7 @@ static void ggml_compute_forward_relu(
5868
6877
  {
5869
6878
  ggml_compute_forward_relu_f32(params, src0, dst);
5870
6879
  } break;
5871
- case GGML_TYPE_Q4_0:
5872
- case GGML_TYPE_Q4_1:
5873
- case GGML_TYPE_I8:
5874
- case GGML_TYPE_I16:
5875
- case GGML_TYPE_I32:
5876
- case GGML_TYPE_F16:
5877
- case GGML_TYPE_COUNT:
6880
+ default:
5878
6881
  {
5879
6882
  GGML_ASSERT(false);
5880
6883
  } break;
@@ -5933,13 +6936,7 @@ static void ggml_compute_forward_gelu(
5933
6936
  {
5934
6937
  ggml_compute_forward_gelu_f32(params, src0, dst);
5935
6938
  } break;
5936
- case GGML_TYPE_Q4_0:
5937
- case GGML_TYPE_Q4_1:
5938
- case GGML_TYPE_I8:
5939
- case GGML_TYPE_I16:
5940
- case GGML_TYPE_I32:
5941
- case GGML_TYPE_F16:
5942
- case GGML_TYPE_COUNT:
6939
+ default:
5943
6940
  {
5944
6941
  GGML_ASSERT(false);
5945
6942
  } break;
@@ -6000,13 +6997,7 @@ static void ggml_compute_forward_silu(
6000
6997
  {
6001
6998
  ggml_compute_forward_silu_f32(params, src0, dst);
6002
6999
  } break;
6003
- case GGML_TYPE_Q4_0:
6004
- case GGML_TYPE_Q4_1:
6005
- case GGML_TYPE_I8:
6006
- case GGML_TYPE_I16:
6007
- case GGML_TYPE_I32:
6008
- case GGML_TYPE_F16:
6009
- case GGML_TYPE_COUNT:
7000
+ default:
6010
7001
  {
6011
7002
  GGML_ASSERT(false);
6012
7003
  } break;
@@ -6086,13 +7077,7 @@ static void ggml_compute_forward_norm(
6086
7077
  {
6087
7078
  ggml_compute_forward_norm_f32(params, src0, dst);
6088
7079
  } break;
6089
- case GGML_TYPE_Q4_0:
6090
- case GGML_TYPE_Q4_1:
6091
- case GGML_TYPE_I8:
6092
- case GGML_TYPE_I16:
6093
- case GGML_TYPE_I32:
6094
- case GGML_TYPE_F16:
6095
- case GGML_TYPE_COUNT:
7080
+ default:
6096
7081
  {
6097
7082
  GGML_ASSERT(false);
6098
7083
  } break;
@@ -6166,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
6166
7151
  {
6167
7152
  ggml_compute_forward_rms_norm_f32(params, src0, dst);
6168
7153
  } break;
6169
- case GGML_TYPE_Q4_0:
6170
- case GGML_TYPE_Q4_1:
6171
- case GGML_TYPE_I8:
6172
- case GGML_TYPE_I16:
6173
- case GGML_TYPE_I32:
6174
- case GGML_TYPE_F16:
6175
- case GGML_TYPE_COUNT:
7154
+ default:
6176
7155
  {
6177
7156
  GGML_ASSERT(false);
6178
7157
  } break;
@@ -6304,7 +7283,7 @@ static void ggml_compute_forward_mul_mat_f32(
6304
7283
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6305
7284
  ne11, ne01, ne10,
6306
7285
  1.0f, y, ne10,
6307
- x, ne10,
7286
+ x, ne00,
6308
7287
  0.0f, d, ne01);
6309
7288
  }
6310
7289
  }
@@ -6476,7 +7455,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6476
7455
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6477
7456
  ne11, ne01, ne10,
6478
7457
  1.0f, y, ne10,
6479
- x, ne10,
7458
+ x, ne00,
6480
7459
  0.0f, d, ne01);
6481
7460
  }
6482
7461
  }
@@ -6564,29 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6564
7543
  //}
6565
7544
  }
6566
7545
 
6567
- typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
6568
- typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
6569
- typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
6570
-
6571
- typedef struct {
6572
- dequantize_row_q_t dequantize_row_q;
6573
- quantize_row_q_t quantize_row_q;
6574
- vec_dot_q_t vec_dot_q;
6575
- } quantize_fns_t;
6576
-
6577
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6578
- [GGML_TYPE_Q4_0] = {
6579
- .dequantize_row_q = dequantize_row_q4_0,
6580
- .quantize_row_q = quantize_row_q4_0,
6581
- .vec_dot_q = ggml_vec_dot_q4_0,
6582
- },
6583
- [GGML_TYPE_Q4_1] = {
6584
- .dequantize_row_q = dequantize_row_q4_1,
6585
- .quantize_row_q = quantize_row_q4_1,
6586
- .vec_dot_q = ggml_vec_dot_q4_1,
6587
- },
6588
- };
6589
-
6590
7546
  static void ggml_compute_forward_mul_mat_q_f32(
6591
7547
  const struct ggml_compute_params * params,
6592
7548
  const struct ggml_tensor * src0,
@@ -6634,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
6634
7590
  GGML_ASSERT(ne3 == ne13);
6635
7591
 
6636
7592
  const enum ggml_type type = src0->type;
6637
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6638
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
7593
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7594
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
6639
7595
 
6640
7596
  // we don't support permuted src0 or src1
6641
7597
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -6691,7 +7647,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6691
7647
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6692
7648
  ne11, ne01, ne10,
6693
7649
  1.0f, y, ne10,
6694
- x, ne10,
7650
+ x, ne00,
6695
7651
  0.0f, d, ne01);
6696
7652
  }
6697
7653
  }
@@ -6704,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
6704
7660
 
6705
7661
  if (params->type == GGML_TASK_INIT) {
6706
7662
  char * wdata = params->wdata;
6707
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7663
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
6708
7664
 
6709
7665
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
6710
7666
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
6711
7667
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
6712
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7668
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
6713
7669
  wdata += row_size;
6714
7670
  }
6715
7671
  }
@@ -6735,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6735
7691
  const int ir1 = MIN(ir0 + dr, nr);
6736
7692
 
6737
7693
  void * wdata = params->wdata;
6738
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7694
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
6739
7695
 
6740
7696
  for (int ir = ir0; ir < ir1; ++ir) {
6741
7697
  // src0 indices
@@ -6783,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
6783
7739
  switch (src0->type) {
6784
7740
  case GGML_TYPE_Q4_0:
6785
7741
  case GGML_TYPE_Q4_1:
7742
+ case GGML_TYPE_Q8_0:
6786
7743
  {
6787
7744
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
6788
7745
  } break;
@@ -6794,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
6794
7751
  {
6795
7752
  ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
6796
7753
  } break;
6797
- case GGML_TYPE_I8:
6798
- case GGML_TYPE_I16:
6799
- case GGML_TYPE_I32:
6800
- case GGML_TYPE_COUNT:
7754
+ default:
6801
7755
  {
6802
7756
  GGML_ASSERT(false);
6803
7757
  } break;
@@ -6879,13 +7833,7 @@ static void ggml_compute_forward_scale(
6879
7833
  {
6880
7834
  ggml_compute_forward_scale_f32(params, src0, src1, dst);
6881
7835
  } break;
6882
- case GGML_TYPE_Q4_0:
6883
- case GGML_TYPE_Q4_1:
6884
- case GGML_TYPE_I8:
6885
- case GGML_TYPE_I16:
6886
- case GGML_TYPE_I32:
6887
- case GGML_TYPE_F16:
6888
- case GGML_TYPE_COUNT:
7836
+ default:
6889
7837
  {
6890
7838
  GGML_ASSERT(false);
6891
7839
  } break;
@@ -6901,6 +7849,15 @@ static void ggml_compute_forward_cpy(
6901
7849
  ggml_compute_forward_dup(params, src0, dst);
6902
7850
  }
6903
7851
 
7852
+ // ggml_compute_forward_cont
7853
+
7854
+ static void ggml_compute_forward_cont(
7855
+ const struct ggml_compute_params * params,
7856
+ const struct ggml_tensor * src0,
7857
+ struct ggml_tensor * dst) {
7858
+ ggml_compute_forward_dup(params, src0, dst);
7859
+ }
7860
+
6904
7861
  // ggml_compute_forward_reshape
6905
7862
 
6906
7863
  static void ggml_compute_forward_reshape(
@@ -7037,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
7037
7994
  switch (src0->type) {
7038
7995
  case GGML_TYPE_Q4_0:
7039
7996
  case GGML_TYPE_Q4_1:
7997
+ case GGML_TYPE_Q8_0:
7040
7998
  {
7041
7999
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
7042
8000
  } break;
@@ -7048,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
7048
8006
  {
7049
8007
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
7050
8008
  } break;
7051
- case GGML_TYPE_I8:
7052
- case GGML_TYPE_I16:
7053
- case GGML_TYPE_I32:
7054
- case GGML_TYPE_COUNT:
8009
+ default:
7055
8010
  {
7056
8011
  GGML_ASSERT(false);
7057
8012
  } break;
@@ -7124,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
7124
8079
  {
7125
8080
  ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
7126
8081
  } break;
7127
- case GGML_TYPE_Q4_0:
7128
- case GGML_TYPE_Q4_1:
7129
- case GGML_TYPE_I8:
7130
- case GGML_TYPE_I16:
7131
- case GGML_TYPE_I32:
7132
- case GGML_TYPE_F16:
7133
- case GGML_TYPE_COUNT:
8082
+ default:
7134
8083
  {
7135
8084
  GGML_ASSERT(false);
7136
8085
  } break;
@@ -7218,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
7218
8167
  {
7219
8168
  ggml_compute_forward_soft_max_f32(params, src0, dst);
7220
8169
  } break;
7221
- case GGML_TYPE_Q4_0:
7222
- case GGML_TYPE_Q4_1:
7223
- case GGML_TYPE_I8:
7224
- case GGML_TYPE_I16:
7225
- case GGML_TYPE_I32:
7226
- case GGML_TYPE_F16:
7227
- case GGML_TYPE_COUNT:
8170
+ default:
7228
8171
  {
7229
8172
  GGML_ASSERT(false);
7230
8173
  } break;
@@ -7279,6 +8222,8 @@ static void ggml_compute_forward_rope_f32(
7279
8222
  // row index used to determine which thread to use
7280
8223
  int ir = 0;
7281
8224
 
8225
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
8226
+
7282
8227
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7283
8228
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7284
8229
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7286,11 +8231,13 @@ static void ggml_compute_forward_rope_f32(
7286
8231
  if (ir++ < ir0) continue;
7287
8232
  if (ir > ir1) break;
7288
8233
 
8234
+ float theta = (float)p;
8235
+
7289
8236
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7290
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
8237
+ const float cos_theta = cosf(theta);
8238
+ const float sin_theta = sinf(theta);
7291
8239
 
7292
- const float cos_theta = cosf(p*theta);
7293
- const float sin_theta = sinf(p*theta);
8240
+ theta *= theta_scale;
7294
8241
 
7295
8242
  const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7296
8243
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7352,6 +8299,8 @@ static void ggml_compute_forward_rope_f16(
7352
8299
  // row index used to determine which thread to use
7353
8300
  int ir = 0;
7354
8301
 
8302
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
8303
+
7355
8304
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7356
8305
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7357
8306
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7359,20 +8308,22 @@ static void ggml_compute_forward_rope_f16(
7359
8308
  if (ir++ < ir0) continue;
7360
8309
  if (ir > ir1) break;
7361
8310
 
8311
+ float theta = (float)p;
8312
+
7362
8313
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7363
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
8314
+ const float cos_theta = cosf(theta);
8315
+ const float sin_theta = sinf(theta);
7364
8316
 
7365
- const float cos_theta = cosf(p*theta);
7366
- const float sin_theta = sinf(p*theta);
8317
+ theta *= theta_scale;
7367
8318
 
7368
8319
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7369
8320
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7370
8321
 
7371
- const float x0 = ggml_fp16_to_fp32(src[0]);
7372
- const float x1 = ggml_fp16_to_fp32(src[1]);
8322
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
7373
8324
 
7374
- dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
7375
- dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
8325
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
7376
8327
  }
7377
8328
  }
7378
8329
  }
@@ -7393,12 +8344,7 @@ static void ggml_compute_forward_rope(
7393
8344
  {
7394
8345
  ggml_compute_forward_rope_f32(params, src0, src1, dst);
7395
8346
  } break;
7396
- case GGML_TYPE_Q4_0:
7397
- case GGML_TYPE_Q4_1:
7398
- case GGML_TYPE_I8:
7399
- case GGML_TYPE_I16:
7400
- case GGML_TYPE_I32:
7401
- case GGML_TYPE_COUNT:
8347
+ default:
7402
8348
  {
7403
8349
  GGML_ASSERT(false);
7404
8350
  } break;
@@ -7661,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
7661
8607
  {
7662
8608
  ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
7663
8609
  } break;
7664
- case GGML_TYPE_Q4_0:
7665
- case GGML_TYPE_Q4_1:
7666
- case GGML_TYPE_I8:
7667
- case GGML_TYPE_I16:
7668
- case GGML_TYPE_I32:
7669
- case GGML_TYPE_COUNT:
8610
+ default:
7670
8611
  {
7671
8612
  GGML_ASSERT(false);
7672
8613
  } break;
@@ -7929,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
7929
8870
  {
7930
8871
  ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
7931
8872
  } break;
7932
- case GGML_TYPE_Q4_0:
7933
- case GGML_TYPE_Q4_1:
7934
- case GGML_TYPE_I8:
7935
- case GGML_TYPE_I16:
7936
- case GGML_TYPE_I32:
7937
- case GGML_TYPE_COUNT:
8873
+ default:
7938
8874
  {
7939
8875
  GGML_ASSERT(false);
7940
8876
  } break;
@@ -8414,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
8414
9350
  {
8415
9351
  ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
8416
9352
  } break;
8417
- case GGML_TYPE_Q4_0:
8418
- case GGML_TYPE_Q4_1:
8419
- case GGML_TYPE_I8:
8420
- case GGML_TYPE_I16:
8421
- case GGML_TYPE_I32:
8422
- case GGML_TYPE_COUNT:
9353
+ default:
8423
9354
  {
8424
9355
  GGML_ASSERT(false);
8425
9356
  } break;
@@ -8625,12 +9556,100 @@ static void ggml_compute_forward_flash_ff(
8625
9556
  {
8626
9557
  GGML_ASSERT(false); // TODO
8627
9558
  } break;
8628
- case GGML_TYPE_Q4_0:
8629
- case GGML_TYPE_Q4_1:
8630
- case GGML_TYPE_I8:
8631
- case GGML_TYPE_I16:
8632
- case GGML_TYPE_I32:
8633
- case GGML_TYPE_COUNT:
9559
+ default:
9560
+ {
9561
+ GGML_ASSERT(false);
9562
+ } break;
9563
+ }
9564
+ }
9565
+
9566
+ // ggml_compute_forward_map_unary
9567
+
9568
+ static void ggml_compute_forward_map_unary_f32(
9569
+ const struct ggml_compute_params * params,
9570
+ const struct ggml_tensor * src0,
9571
+ struct ggml_tensor * dst,
9572
+ const ggml_unary_op_f32_t fun) {
9573
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9574
+
9575
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9576
+ return;
9577
+ }
9578
+
9579
+ const int n = ggml_nrows(src0);
9580
+ const int nc = src0->ne[0];
9581
+
9582
+ assert( dst->nb[0] == sizeof(float));
9583
+ assert(src0->nb[0] == sizeof(float));
9584
+
9585
+ for (int i = 0; i < n; i++) {
9586
+ fun(nc,
9587
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9588
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
9589
+ }
9590
+ }
9591
+
9592
+
9593
+ static void ggml_compute_forward_map_unary(
9594
+ const struct ggml_compute_params * params,
9595
+ const struct ggml_tensor * src0,
9596
+ struct ggml_tensor * dst,
9597
+ const ggml_unary_op_f32_t fun) {
9598
+ switch (src0->type) {
9599
+ case GGML_TYPE_F32:
9600
+ {
9601
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9602
+ } break;
9603
+ default:
9604
+ {
9605
+ GGML_ASSERT(false);
9606
+ } break;
9607
+ }
9608
+ }
9609
+
9610
+ // ggml_compute_forward_map_binary
9611
+
9612
+ static void ggml_compute_forward_map_binary_f32(
9613
+ const struct ggml_compute_params * params,
9614
+ const struct ggml_tensor * src0,
9615
+ const struct ggml_tensor * src1,
9616
+ struct ggml_tensor * dst,
9617
+ const ggml_binary_op_f32_t fun) {
9618
+ assert(params->ith == 0);
9619
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9620
+
9621
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9622
+ return;
9623
+ }
9624
+
9625
+ const int n = ggml_nrows(src0);
9626
+ const int nc = src0->ne[0];
9627
+
9628
+ assert( dst->nb[0] == sizeof(float));
9629
+ assert(src0->nb[0] == sizeof(float));
9630
+ assert(src1->nb[0] == sizeof(float));
9631
+
9632
+ for (int i = 0; i < n; i++) {
9633
+ fun(nc,
9634
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9635
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
9636
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
9637
+ }
9638
+ }
9639
+
9640
+
9641
+ static void ggml_compute_forward_map_binary(
9642
+ const struct ggml_compute_params * params,
9643
+ const struct ggml_tensor * src0,
9644
+ const struct ggml_tensor * src1,
9645
+ struct ggml_tensor * dst,
9646
+ const ggml_binary_op_f32_t fun) {
9647
+ switch (src0->type) {
9648
+ case GGML_TYPE_F32:
9649
+ {
9650
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9651
+ } break;
9652
+ default:
8634
9653
  {
8635
9654
  GGML_ASSERT(false);
8636
9655
  } break;
@@ -8731,6 +9750,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8731
9750
  {
8732
9751
  ggml_compute_forward_cpy(params, tensor->src0, tensor);
8733
9752
  } break;
9753
+ case GGML_OP_CONT:
9754
+ {
9755
+ ggml_compute_forward_cont(params, tensor->src0, tensor);
9756
+ } break;
8734
9757
  case GGML_OP_RESHAPE:
8735
9758
  {
8736
9759
  ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8782,6 +9805,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8782
9805
  {
8783
9806
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
8784
9807
  } break;
9808
+ case GGML_OP_MAP_UNARY:
9809
+ {
9810
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
9811
+ ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9812
+ }
9813
+ break;
9814
+ case GGML_OP_MAP_BINARY:
9815
+ {
9816
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
9817
+ ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9818
+ }
9819
+ break;
8785
9820
  case GGML_OP_NONE:
8786
9821
  {
8787
9822
  // nop
@@ -8975,8 +10010,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8975
10010
  src1->grad =
8976
10011
  ggml_add_impl(ctx,
8977
10012
  src1->grad,
8978
- // TODO: fix transpose, the node will break the graph connections
8979
- ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
10013
+ ggml_mul_mat(ctx,
10014
+ ggml_cont(ctx, ggml_transpose(ctx, src0)),
10015
+ tensor->grad),
8980
10016
  inplace);
8981
10017
  }
8982
10018
  } break;
@@ -8988,6 +10024,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8988
10024
  {
8989
10025
  GGML_ASSERT(false); // TODO: not implemented
8990
10026
  } break;
10027
+ case GGML_OP_CONT:
10028
+ {
10029
+ GGML_ASSERT(false); // TODO: not implemented
10030
+ } break;
8991
10031
  case GGML_OP_RESHAPE:
8992
10032
  {
8993
10033
  GGML_ASSERT(false); // TODO: not implemented
@@ -9036,6 +10076,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
9036
10076
  {
9037
10077
  GGML_ASSERT(false); // not supported
9038
10078
  } break;
10079
+ case GGML_OP_MAP_UNARY:
10080
+ case GGML_OP_MAP_BINARY:
10081
+ {
10082
+ GGML_ASSERT(false); // not supported
10083
+ } break;
9039
10084
  case GGML_OP_NONE:
9040
10085
  {
9041
10086
  // nop
@@ -9126,7 +10171,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
9126
10171
  struct ggml_cgraph result = {
9127
10172
  /*.n_nodes =*/ 0,
9128
10173
  /*.n_leafs =*/ 0,
9129
- /*.n_threads =*/ 0,
10174
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
9130
10175
  /*.work_size =*/ 0,
9131
10176
  /*.work =*/ NULL,
9132
10177
  /*.nodes =*/ { NULL },
@@ -9354,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9354
10399
  struct ggml_tensor * node = cgraph->nodes[i];
9355
10400
 
9356
10401
  switch (node->op) {
10402
+ case GGML_OP_CPY:
9357
10403
  case GGML_OP_DUP:
9358
10404
  {
9359
10405
  node->n_tasks = 1;
10406
+
10407
+ size_t cur = 0;
10408
+ if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
10410
+ }
10411
+
10412
+ work_size = MAX(work_size, cur);
9360
10413
  } break;
9361
10414
  case GGML_OP_ADD:
9362
10415
  {
9363
10416
  node->n_tasks = n_threads;
10417
+
10418
+ size_t cur = 0;
10419
+
10420
+ if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
10421
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
+ }
10423
+
10424
+ work_size = MAX(work_size, cur);
9364
10425
  } break;
9365
10426
  case GGML_OP_SUB:
9366
10427
  case GGML_OP_MUL:
@@ -9429,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9429
10490
  } else
9430
10491
  #endif
9431
10492
  {
9432
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
10493
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
9433
10494
  }
9434
10495
  } else {
9435
10496
  GGML_ASSERT(false);
@@ -9441,7 +10502,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9441
10502
  {
9442
10503
  node->n_tasks = n_threads;
9443
10504
  } break;
9444
- case GGML_OP_CPY:
10505
+ case GGML_OP_CONT:
9445
10506
  case GGML_OP_RESHAPE:
9446
10507
  case GGML_OP_VIEW:
9447
10508
  case GGML_OP_PERMUTE:
@@ -9527,6 +10588,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9527
10588
 
9528
10589
  work_size = MAX(work_size, cur);
9529
10590
  } break;
10591
+ case GGML_OP_MAP_UNARY:
10592
+ case GGML_OP_MAP_BINARY:
10593
+ {
10594
+ node->n_tasks = 1;
10595
+ } break;
9530
10596
  case GGML_OP_NONE:
9531
10597
  {
9532
10598
  node->n_tasks = 1;
@@ -9745,8 +10811,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9745
10811
 
9746
10812
  GGML_PRINT("=== GRAPH ===\n");
9747
10813
 
9748
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9749
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
10814
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
10815
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
9750
10816
 
9751
10817
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9752
10818
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10598,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
10598
11664
  ////////////////////////////////////////////////////////////////////////////////
10599
11665
 
10600
11666
  size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
10601
- assert(k % QK == 0);
10602
- const int nb = k / QK;
11667
+ assert(k % QK4_0 == 0);
11668
+ const int nb = k / QK4_0;
10603
11669
 
10604
11670
  for (int j = 0; j < n; j += k) {
10605
- block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
11671
+ block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
10606
11672
 
10607
11673
  quantize_row_q4_0_reference(src + j, y, k);
10608
11674
 
10609
11675
  for (int i = 0; i < nb; i++) {
10610
- for (int l = 0; l < QK; l += 2) {
11676
+ for (int l = 0; l < QK4_0; l += 2) {
10611
11677
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
10612
11678
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
10613
11679
 
@@ -10617,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
10617
11683
  }
10618
11684
  }
10619
11685
 
10620
- return (n/QK*sizeof(block_q4_0));
11686
+ return (n/QK4_0*sizeof(block_q4_0));
10621
11687
  }
10622
11688
 
10623
11689
  size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
10624
- assert(k % QK == 0);
10625
- const int nb = k / QK;
11690
+ assert(k % QK4_1 == 0);
11691
+ const int nb = k / QK4_1;
10626
11692
 
10627
11693
  for (int j = 0; j < n; j += k) {
10628
- block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
11694
+ block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
10629
11695
 
10630
11696
  quantize_row_q4_1_reference(src + j, y, k);
10631
11697
 
10632
11698
  for (int i = 0; i < nb; i++) {
10633
- for (int l = 0; l < QK; l += 2) {
11699
+ for (int l = 0; l < QK4_1; l += 2) {
10634
11700
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
10635
11701
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
10636
11702
 
@@ -10640,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
10640
11706
  }
10641
11707
  }
10642
11708
 
10643
- return (n/QK*sizeof(block_q4_1));
11709
+ return (n/QK4_1*sizeof(block_q4_1));
10644
11710
  }
10645
11711
 
10646
11712
  ////////////////////////////////////////////////////////////////////////////////
@@ -10669,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
10669
11735
  #endif
10670
11736
  }
10671
11737
 
11738
+ int ggml_cpu_has_avx512_vbmi(void) {
11739
+ #if defined(__AVX512VBMI__)
11740
+ return 1;
11741
+ #else
11742
+ return 0;
11743
+ #endif
11744
+ }
11745
+
11746
+ int ggml_cpu_has_avx512_vnni(void) {
11747
+ #if defined(__AVX512VNNI__)
11748
+ return 1;
11749
+ #else
11750
+ return 0;
11751
+ #endif
11752
+ }
11753
+
10672
11754
  int ggml_cpu_has_fma(void) {
10673
11755
  #if defined(__FMA__)
10674
11756
  return 1;