llama_cpp 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- // Defines CLOCK_MONOTONIC and asprintf on Linux
1
+ // Defines CLOCK_MONOTONIC on Linux
2
2
  #define _GNU_SOURCE
3
3
 
4
4
  #include "ggml.h"
@@ -26,14 +26,9 @@
26
26
  #define static_assert(cond, msg) struct global_scope_noop_trick
27
27
  #endif
28
28
 
29
- #if defined _MSC_VER || defined(__MINGW32__)
29
+ #if defined(_WIN32)
30
30
 
31
- #if !defined(__MINGW32__)
32
- #include <Windows.h>
33
- #else
34
- // ref: https://github.com/ggerganov/whisper.cpp/issues/168
35
31
  #include <windows.h>
36
- #endif
37
32
 
38
33
  typedef volatile LONG atomic_int;
39
34
  typedef atomic_int atomic_bool;
@@ -55,6 +50,7 @@ typedef HANDLE pthread_t;
55
50
 
56
51
  typedef DWORD thread_ret_t;
57
52
  static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
53
+ (void) unused;
58
54
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
59
55
  if (handle == NULL)
60
56
  {
@@ -66,6 +62,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
66
62
  }
67
63
 
68
64
  static int pthread_join(pthread_t thread, void* unused) {
65
+ (void) unused;
69
66
  return (int) WaitForSingleObject(thread, INFINITE);
70
67
  }
71
68
 
@@ -97,17 +94,6 @@ typedef void* thread_ret_t;
97
94
  #define static_assert(cond, msg) _Static_assert(cond, msg)
98
95
  #endif
99
96
 
100
- #define GGML_MLOCK_SUPPORT 0
101
-
102
- #ifdef __has_include
103
- #if __has_include(<sys/mman.h>)
104
- #undef GGML_MLOCK_SUPPORT
105
- #define GGML_MLOCK_SUPPORT 1
106
- #include <sys/mman.h>
107
- #endif
108
- #endif
109
-
110
-
111
97
  /*#define GGML_PERF*/
112
98
  #define GGML_DEBUG 0
113
99
  #define GGML_GELU_FP16
@@ -128,6 +114,23 @@ typedef void* thread_ret_t;
128
114
  #define GGML_MEM_ALIGN 16
129
115
  #endif
130
116
 
117
+ #if defined(_MSC_VER) || defined(__MINGW32__)
118
+ #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
119
+ #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
120
+ #else
121
+ inline static void* ggml_aligned_malloc(size_t size) {
122
+ void* aligned_memory = NULL;
123
+ int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
124
+ if (result != 0) {
125
+ // Handle allocation failure
126
+ return NULL;
127
+ }
128
+ return aligned_memory;
129
+ }
130
+ #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
131
+ #define GGML_ALIGNED_FREE(ptr) free(ptr)
132
+ #endif
133
+
131
134
  #define UNUSED(x) (void)(x)
132
135
  #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
133
136
 
@@ -242,12 +245,12 @@ static inline float fp32_from_bits(uint32_t w) {
242
245
  }
243
246
 
244
247
  static inline uint32_t fp32_to_bits(float f) {
245
- union {
246
- float as_value;
247
- uint32_t as_bits;
248
- } fp32;
249
- fp32.as_value = f;
250
- return fp32.as_bits;
248
+ union {
249
+ float as_value;
250
+ uint32_t as_bits;
251
+ } fp32;
252
+ fp32.as_value = f;
253
+ return fp32.as_bits;
251
254
  }
252
255
 
253
256
  static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
@@ -424,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
424
427
  // quantization
425
428
  //
426
429
 
427
- #define QK 32
428
-
429
430
  // AVX routines provided by GH user Const-me
430
431
  // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
431
432
  #if __AVX2__ || __AVX512F__
@@ -497,37 +498,113 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
497
498
  }
498
499
  #endif
499
500
 
500
- // method 5
501
- // blocks of QK elements
502
- // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
501
+ #if __ARM_NEON
502
+
503
+ #if !defined(__aarch64__)
504
+
505
+ inline static uint16_t vaddvq_u8(uint8x16_t v) {
506
+ return
507
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
508
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
509
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
510
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
511
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
512
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
513
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
514
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
515
+ }
516
+
517
+ inline static int32_t vaddvq_s16(int16x8_t v) {
518
+ return
519
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
520
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
521
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
522
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
523
+ }
524
+
525
+ inline static uint32_t vaddvq_u16(uint16x8_t v) {
526
+ return
527
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
528
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
529
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
530
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
531
+ }
532
+
533
+ inline static int32_t vaddvq_s32(int32x4_t v) {
534
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
535
+ }
536
+
537
+ inline static float vaddvq_f32(float32x4_t v) {
538
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
539
+ }
540
+
541
+ float vminvq_f32(float32x4_t v) {
542
+ return
543
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
544
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
545
+ }
546
+
547
+ float vmaxvq_f32(float32x4_t v) {
548
+ return
549
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
550
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
551
+ }
552
+
553
+ int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
554
+ return vget_low_s8(vcombine_s8(a, b));
555
+ }
556
+
557
+ int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
558
+ return vget_high_s8(vcombine_s8(a, b));
559
+ }
560
+
561
+ uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
562
+ return vget_low_u8(vcombine_u8(a, b));
563
+ }
564
+
565
+ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
566
+ return vget_high_u8(vcombine_u8(a, b));
567
+ }
568
+
569
+ #endif
570
+ #endif
571
+
572
+
573
+ #define QK4_0 32
503
574
  typedef struct {
504
- float d; // delta
505
- uint8_t qs[QK / 2]; // nibbles / quants
575
+ float d; // delta
576
+ uint8_t qs[QK4_0 / 2]; // nibbles / quants
506
577
  } block_q4_0;
507
- static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
578
+ static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
508
579
 
509
- // method 4
510
- // blocks of QK elements
511
- // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
580
+ #define QK4_1 32
512
581
  typedef struct {
513
- float d;
514
- float m;
515
- uint8_t qs[QK / 2]; // nibbles / quants
582
+ float d; // delta
583
+ float m; // min
584
+ uint8_t qs[QK4_1 / 2]; // nibbles / quants
516
585
  } block_q4_1;
517
- static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
586
+ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
587
+
588
+ #define QK8_0 32
589
+ typedef struct {
590
+ float d; // delta
591
+ int8_t qs[QK8_0]; // quants
592
+ } block_q8_0;
593
+ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
594
+
518
595
 
519
596
  // reference implementation for deterministic creation of model files
520
597
  static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
521
- assert(k % QK == 0);
522
- const int nb = k / QK;
598
+ assert(k % QK4_0 == 0);
599
+ const int nb = k / QK4_0;
523
600
 
524
- uint8_t pp[QK/2];
601
+ uint8_t pp[QK4_0/2];
525
602
 
526
603
  for (int i = 0; i < nb; i++) {
527
604
  float amax = 0.0f; // absolute max
528
605
 
529
- for (int l = 0; l < QK; l++) {
530
- const float v = x[i*QK + l];
606
+ for (int l = 0; l < QK4_0; l++) {
607
+ const float v = x[i*QK4_0 + l];
531
608
  amax = MAX(amax, fabsf(v));
532
609
  }
533
610
 
@@ -536,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
536
613
 
537
614
  y[i].d = d;
538
615
 
539
- for (int l = 0; l < QK; l += 2) {
540
- const float v0 = x[i*QK + l + 0]*id;
541
- const float v1 = x[i*QK + l + 1]*id;
616
+ for (int l = 0; l < QK4_0; l += 2) {
617
+ const float v0 = x[i*QK4_0 + l + 0]*id;
618
+ const float v1 = x[i*QK4_0 + l + 1]*id;
542
619
 
543
620
  const uint8_t vi0 = (int8_t)roundf(v0) + 8;
544
621
  const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -554,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
554
631
  }
555
632
 
556
633
  static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
557
- assert(k % QK == 0);
558
- const int nb = k / QK;
634
+ assert(k % QK4_0 == 0);
635
+ const int nb = k / QK4_0;
559
636
 
560
637
  block_q4_0 * restrict y = vy;
561
638
 
@@ -610,10 +687,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
610
687
  for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
611
688
  for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
612
689
 
613
- // absolute max
614
- const float amax = MAX(
615
- MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
616
- MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
690
+ const float amax = vmaxvq_f32(amaxv[0]);
617
691
 
618
692
  const float d = amax / ((1 << 3) - 1);
619
693
  const float id = d ? 1.0f/d : 0.0f;
@@ -808,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
808
882
  }
809
883
 
810
884
  static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
811
- assert(k % QK == 0);
812
- const int nb = k / QK;
885
+ assert(k % QK4_1 == 0);
886
+ const int nb = k / QK4_1;
813
887
 
814
888
  block_q4_1 * restrict y = vy;
815
889
 
816
- uint8_t pp[QK/2];
890
+ uint8_t pp[QK4_1/2];
817
891
 
818
892
  for (int i = 0; i < nb; i++) {
819
893
  float min = FLT_MAX;
820
894
  float max = -FLT_MAX;
821
895
 
822
- for (int l = 0; l < QK; l++) {
823
- const float v = x[i*QK + l];
896
+ for (int l = 0; l < QK4_1; l++) {
897
+ const float v = x[i*QK4_1 + l];
824
898
  if (v < min) min = v;
825
899
  if (v > max) max = v;
826
900
  }
@@ -831,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
831
905
  y[i].d = d;
832
906
  y[i].m = min;
833
907
 
834
- for (int l = 0; l < QK; l += 2) {
835
- const float v0 = (x[i*QK + l + 0] - min)*id;
836
- const float v1 = (x[i*QK + l + 1] - min)*id;
908
+ for (int l = 0; l < QK4_1; l += 2) {
909
+ const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
910
+ const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
837
911
 
838
912
  const uint8_t vi0 = roundf(v0);
839
913
  const uint8_t vi1 = roundf(v1);
@@ -849,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
849
923
  }
850
924
 
851
925
  static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
852
- assert(k % QK == 0);
926
+ assert(k % QK4_1 == 0);
853
927
 
854
- const int nb = k / QK;
928
+ const int nb = k / QK4_1;
855
929
 
856
930
  block_q4_1 * restrict y = vy;
857
931
 
@@ -935,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
935
1009
  float32x4_t minv[8];
936
1010
  float32x4_t maxv[8];
937
1011
 
938
- for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1012
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
939
1013
 
940
1014
  for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
941
1015
  for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -958,7 +1032,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
958
1032
 
959
1033
  for (int l = 0; l < 8; l++) {
960
1034
  const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
961
- const int32x4_t vi = vcvtq_s32_f32(v);
1035
+ const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
1036
+ const int32x4_t vi = vcvtq_s32_f32(vf);
962
1037
 
963
1038
  y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
964
1039
  y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
@@ -970,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
970
1045
  #endif
971
1046
  }
972
1047
 
1048
+ // reference implementation for deterministic creation of model files
1049
+ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
1050
+ assert(k % QK8_0 == 0);
1051
+ const int nb = k / QK8_0;
1052
+
1053
+ for (int i = 0; i < nb; i++) {
1054
+ float amax = 0.0f; // absolute max
1055
+
1056
+ for (int l = 0; l < QK8_0; l++) {
1057
+ const float v = x[i*QK8_0 + l];
1058
+ amax = MAX(amax, fabsf(v));
1059
+ }
1060
+
1061
+ const float d = amax / ((1 << 7) - 1);
1062
+ const float id = d ? 1.0f/d : 0.0f;
1063
+
1064
+ y[i].d = d;
1065
+
1066
+ for (int l = 0; l < QK8_0; ++l) {
1067
+ const float v = x[i*QK8_0 + l]*id;
1068
+ y[i].qs[l] = roundf(v);
1069
+ }
1070
+ }
1071
+ }
1072
+
1073
+ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
1074
+ assert(k % QK8_0 == 0);
1075
+ const int nb = k / QK8_0;
1076
+
1077
+ block_q8_0 * restrict y = vy;
1078
+
1079
+ #if defined(__ARM_NEON)
1080
+ for (int i = 0; i < nb; i++) {
1081
+ float32x4_t srcv [8];
1082
+ float32x4_t asrcv[8];
1083
+ float32x4_t amaxv[8];
1084
+
1085
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1086
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1087
+
1088
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1089
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1090
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1091
+
1092
+ const float amax = vmaxvq_f32(amaxv[0]);
1093
+
1094
+ const float d = amax / ((1 << 7) - 1);
1095
+ const float id = d ? 1.0f/d : 0.0f;
1096
+
1097
+ y[i].d = d;
1098
+
1099
+ for (int l = 0; l < 8; l++) {
1100
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
1101
+ const int32x4_t vi = vcvtnq_s32_f32(v);
1102
+
1103
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1104
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1105
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1106
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1107
+ }
1108
+ }
1109
+ #elif defined(__AVX2__) || defined(__AVX__)
1110
+ for (int i = 0; i < nb; i++) {
1111
+ // Load elements into 4 AVX vectors
1112
+ __m256 v0 = _mm256_loadu_ps( x );
1113
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
1114
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
1115
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
1116
+ x += 32;
1117
+
1118
+ // Compute max(abs(e)) for the block
1119
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
1120
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1121
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1122
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1123
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1124
+
1125
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1126
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1127
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1128
+ const float maxScalar = _mm_cvtss_f32( max4 );
1129
+
1130
+ // Quantize these floats
1131
+ const float d = maxScalar / 127.f;
1132
+ y[i].d = d;
1133
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1134
+ const __m256 mul = _mm256_set1_ps( id );
1135
+
1136
+ // Apply the multiplier
1137
+ v0 = _mm256_mul_ps( v0, mul );
1138
+ v1 = _mm256_mul_ps( v1, mul );
1139
+ v2 = _mm256_mul_ps( v2, mul );
1140
+ v3 = _mm256_mul_ps( v3, mul );
1141
+
1142
+ // Round to nearest integer
1143
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1144
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1145
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1146
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1147
+
1148
+ // Convert floats to integers
1149
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
1150
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
1151
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
1152
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
1153
+
1154
+ #if defined(__AVX2__)
1155
+ // Convert int32 to int16
1156
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1157
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1158
+ // Convert int16 to int8
1159
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1160
+
1161
+ // We got our precious signed bytes, but the order is now wrong
1162
+ // These AVX2 pack instructions process 16-byte pieces independently
1163
+ // The following instruction is fixing the order
1164
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1165
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
1166
+
1167
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
1168
+ #else
1169
+ // Since we don't have in AVX some necessary functions,
1170
+ // we split the registers in half and call AVX2 analogs from SSE
1171
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
1172
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
1173
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
1174
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
1175
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
1176
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
1177
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
1178
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
1179
+
1180
+ // Convert int32 to int16
1181
+ ni0 = _mm_packs_epi32( ni0, ni1 );
1182
+ ni2 = _mm_packs_epi32( ni2, ni3 );
1183
+ ni4 = _mm_packs_epi32( ni4, ni5 );
1184
+ ni6 = _mm_packs_epi32( ni6, ni7 );
1185
+ // Convert int16 to int8
1186
+ ni0 = _mm_packs_epi16( ni0, ni2 );
1187
+ ni4 = _mm_packs_epi16( ni4, ni6 );
1188
+
1189
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1190
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1191
+ #endif
1192
+ }
1193
+ #else
1194
+ // scalar
1195
+ quantize_row_q8_0_reference(x, y, k);
1196
+ #endif
1197
+ }
1198
+
973
1199
  static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
974
- assert(k % QK == 0);
975
- const int nb = k / QK;
1200
+ assert(k % QK4_0 == 0);
1201
+ const int nb = k / QK4_0;
976
1202
 
977
1203
  const block_q4_0 * restrict x = vx;
978
1204
 
@@ -983,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
983
1209
 
984
1210
  const uint8_t * restrict pp = x[i].qs;
985
1211
 
986
- for (int l = 0; l < QK; l += 32) {
1212
+ for (int l = 0; l < QK4_0; l += 32) {
987
1213
  // Load 32x4-bit integers into 32x8-bit integers
988
1214
  __m256i vx8 = bytesFromNibbles(pp+l/2);
989
1215
 
@@ -1005,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1005
1231
  // Scale and store
1006
1232
  for (int j = 0; j < 4; j++) {
1007
1233
  const __m256 result = _mm256_mul_ps(vf[j], d_v);
1008
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1234
+ _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
1009
1235
  }
1010
1236
  }
1011
1237
  }
@@ -1015,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1015
1241
 
1016
1242
  const uint8_t * restrict pp = x[i].qs;
1017
1243
 
1018
- for (int l = 0; l < QK; l += 16) {
1244
+ for (int l = 0; l < QK4_0; l += 16) {
1019
1245
  // Load 16x4-bit integers into 8x8-bit integers
1020
1246
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1021
1247
 
@@ -1054,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1054
1280
  const float32x4_t r3 = vmulq_f32(vf_3, vd);
1055
1281
 
1056
1282
  // Store
1057
- vst1q_f32(y + i*QK + l + 0, r0);
1058
- vst1q_f32(y + i*QK + l + 4, r1);
1059
- vst1q_f32(y + i*QK + l + 8, r2);
1060
- vst1q_f32(y + i*QK + l + 12, r3);
1283
+ vst1q_f32(y + i*QK4_0 + l + 0, r0);
1284
+ vst1q_f32(y + i*QK4_0 + l + 4, r1);
1285
+ vst1q_f32(y + i*QK4_0 + l + 8, r2);
1286
+ vst1q_f32(y + i*QK4_0 + l + 12, r3);
1061
1287
  }
1062
1288
  }
1063
1289
  #else
@@ -1067,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1067
1293
 
1068
1294
  const uint8_t * restrict pp = x[i].qs;
1069
1295
 
1070
- for (int l = 0; l < QK; l += 2) {
1296
+ for (int l = 0; l < QK4_0; l += 2) {
1071
1297
  const uint8_t vi = pp[l/2];
1072
1298
 
1073
1299
  const int8_t vi0 = vi & 0xf;
@@ -1078,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1078
1304
 
1079
1305
  //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
1080
1306
 
1081
- y[i*QK + l + 0] = v0;
1082
- y[i*QK + l + 1] = v1;
1307
+ y[i*QK4_0 + l + 0] = v0;
1308
+ y[i*QK4_0 + l + 1] = v1;
1083
1309
 
1084
- assert(!isnan(y[i*QK + l + 0]));
1085
- assert(!isnan(y[i*QK + l + 1]));
1310
+ assert(!isnan(y[i*QK4_0 + l + 0]));
1311
+ assert(!isnan(y[i*QK4_0 + l + 1]));
1086
1312
  }
1087
1313
  }
1088
1314
  #endif
1089
1315
  }
1090
1316
 
1091
1317
  static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
1092
- assert(k % QK == 0);
1093
- const int nb = k / QK;
1318
+ assert(k % QK4_1 == 0);
1319
+ const int nb = k / QK4_1;
1094
1320
 
1095
1321
  const block_q4_1 * restrict x = vx;
1096
1322
 
@@ -1101,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1101
1327
 
1102
1328
  const uint8_t * restrict pp = x[i].qs;
1103
1329
 
1104
- for (int l = 0; l < QK; l += 32) {
1330
+ for (int l = 0; l < QK4_1; l += 32) {
1105
1331
  // Load 32x4-bit integers into 32x8-bit integers
1106
1332
  __m256i vx8 = bytesFromNibbles(pp+l/2);
1107
1333
 
@@ -1120,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1120
1346
  // Scale, add m and store
1121
1347
  for (int j = 0; j < 4; j++) {
1122
1348
  const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
1123
- _mm256_storeu_ps(y + i * QK + l + j*8, result);
1349
+ _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
1124
1350
  }
1125
1351
  }
1126
1352
  }
@@ -1131,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1131
1357
 
1132
1358
  const uint8_t * restrict pp = x[i].qs;
1133
1359
 
1134
- for (int l = 0; l < QK; l += 16) {
1360
+ for (int l = 0; l < QK4_1; l += 16) {
1135
1361
  // Load 16x4-bit integers into 8x8-bit integers
1136
1362
  const uint8x8_t v8 = vld1_u8(pp + l/2);
1137
1363
 
@@ -1162,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1162
1388
  const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
1163
1389
 
1164
1390
  // Store
1165
- vst1q_f32(y + i*QK + l + 0, r0);
1166
- vst1q_f32(y + i*QK + l + 4, r1);
1167
- vst1q_f32(y + i*QK + l + 8, r2);
1168
- vst1q_f32(y + i*QK + l + 12, r3);
1391
+ vst1q_f32(y + i*QK4_1 + l + 0, r0);
1392
+ vst1q_f32(y + i*QK4_1 + l + 4, r1);
1393
+ vst1q_f32(y + i*QK4_1 + l + 8, r2);
1394
+ vst1q_f32(y + i*QK4_1 + l + 12, r3);
1169
1395
  }
1170
1396
  }
1171
1397
  #else
@@ -1175,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1175
1401
 
1176
1402
  const uint8_t * restrict pp = x[i].qs;
1177
1403
 
1178
- for (int l = 0; l < QK; l += 2) {
1404
+ for (int l = 0; l < QK4_1; l += 2) {
1179
1405
  const uint8_t vi = pp[l/2];
1180
1406
 
1181
1407
  const int8_t vi0 = vi & 0xf;
@@ -1184,16 +1410,44 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1184
1410
  const float v0 = vi0*d + m;
1185
1411
  const float v1 = vi1*d + m;
1186
1412
 
1187
- y[i*QK + l + 0] = v0;
1188
- y[i*QK + l + 1] = v1;
1413
+ y[i*QK4_1 + l + 0] = v0;
1414
+ y[i*QK4_1 + l + 1] = v1;
1189
1415
 
1190
- assert(!isnan(y[i*QK + l + 0]));
1191
- assert(!isnan(y[i*QK + l + 1]));
1416
+ assert(!isnan(y[i*QK4_1 + l + 0]));
1417
+ assert(!isnan(y[i*QK4_1 + l + 1]));
1192
1418
  }
1193
1419
  }
1194
1420
  #endif
1195
1421
  }
1196
1422
 
1423
+ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1424
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1425
+
1426
+ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
1427
+ [GGML_TYPE_Q4_0] = {
1428
+ .dequantize_row_q = dequantize_row_q4_0,
1429
+ .quantize_row_q = quantize_row_q4_0,
1430
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
1431
+ .quantize_row_q_dot = quantize_row_q8_0,
1432
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
1433
+ },
1434
+ [GGML_TYPE_Q4_1] = {
1435
+ .dequantize_row_q = dequantize_row_q4_1,
1436
+ .quantize_row_q = quantize_row_q4_1,
1437
+ .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
1438
+ .quantize_row_q_dot = quantize_row_q4_1,
1439
+ .vec_dot_q = ggml_vec_dot_q4_1,
1440
+ },
1441
+ // TODO: GGML_TYPE_Q8_0
1442
+ };
1443
+
1444
+ // For internal test use
1445
+ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
1446
+ GGML_ASSERT(i < GGML_TYPE_COUNT);
1447
+ return quantize_fns[i];
1448
+ }
1449
+
1450
+
1197
1451
  //
1198
1452
  // simd mappings
1199
1453
  //
@@ -1226,15 +1480,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1226
1480
  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
1227
1481
  #define GGML_F32x4_ADD vaddq_f32
1228
1482
  #define GGML_F32x4_MUL vmulq_f32
1229
- #if defined(__ARM_FEATURE_QRDMX)
1230
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1231
- #else
1232
- #define GGML_F32x4_REDUCE_ONE(x) \
1233
- (vgetq_lane_f32(x, 0) + \
1234
- vgetq_lane_f32(x, 1) + \
1235
- vgetq_lane_f32(x, 2) + \
1236
- vgetq_lane_f32(x, 3))
1237
- #endif
1483
+ #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1238
1484
  #define GGML_F32x4_REDUCE(res, x) \
1239
1485
  { \
1240
1486
  for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1758,34 +2004,188 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1758
2004
  *s = sumf;
1759
2005
  }
1760
2006
 
1761
- #if __AVX512F__ && QK == 32
1762
- static inline __m512 dot_q4_0_oneblock_avx512(
2007
+ #if __AVX512F__ && QK4_0 == 32
2008
+ static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
2009
+ // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
2010
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2011
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2012
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2013
+ // | :. =_ () [] <> () Zz Yy|
2014
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2015
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2016
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2017
+ // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
2018
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
2019
+ //
2020
+ // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
2021
+ // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
2022
+ // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
2023
+ // Bytes 40..63 are masked when loading the data, so they are zeroed out.
2024
+ #ifdef __AVX512VBMI__
2025
+ const __m512i byte_perm = _mm512_set_epi8(
2026
+ 39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
2027
+ 31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
2028
+ 19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
2029
+ 11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
2030
+ );
2031
+ const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
2032
+ // After applying VPERMB, `permuted` looks like this:
2033
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2034
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2035
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2036
+ // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
2037
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2038
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2039
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2040
+ // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
2041
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2042
+ #else
2043
+ const __m512i word_perm = _mm512_set_epi16(
2044
+ 19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
2045
+ 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
2046
+ );
2047
+ const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
2048
+ // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
2049
+ // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
2050
+ // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
2051
+ #endif
2052
+
2053
+ // Shift every odd-numbered 16-bit group to the right by 4 bits.
2054
+ const __mmask32 shift_mask = 0xaaaaaaaa;
2055
+ const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
2056
+ // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
2057
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2058
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
2059
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2060
+ // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
2061
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2062
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2063
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2064
+ // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
2065
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2066
+
2067
+ // Now we just need to zero out the higher nibble in each byte, and we're done.
2068
+ const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
2069
+ return _mm512_and_si512( low_nibble_mask, shifted );
2070
+ // The final result looks like this:
2071
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2072
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
2073
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2074
+ // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
2075
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2076
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
2077
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2078
+ // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
2079
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
2080
+ }
2081
+
2082
+ static inline __m512 dot_q4_0_twoblocks_avx512(
1763
2083
  __m512 acc,
1764
2084
  const block_q4_0 * restrict x,
1765
2085
  const block_q4_0 * restrict y,
1766
2086
  int i
1767
2087
  ) {
1768
- // Compute combined scale for the block
1769
- __m512 d = _mm512_set1_ps( x[i].d * y[i].d );
1770
-
1771
- __m256i bx = bytesFromNibbles( x[i].qs );
1772
- __m256i by = bytesFromNibbles( y[i].qs );
1773
-
1774
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1775
- const __m256i off = _mm256_set1_epi8( 8 );
1776
- bx = _mm256_sub_epi8( bx, off );
1777
- by = _mm256_sub_epi8( by, off );
1778
-
1779
- // Sign-extend 16 signed bytes into int16_t
1780
- __m512i x32 = _mm512_cvtepi8_epi16( bx );
1781
- __m512i y32 = _mm512_cvtepi8_epi16( by );
1782
- // Compute products of int16_t integers, add pairwise
1783
- __m512i i64 = _mm512_madd_epi16( x32, y32 );
2088
+ // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
2089
+ // can potentially be unaddressable, so we make sure to mask them out before the load, even though
2090
+ // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
2091
+ // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
2092
+ const __mmask8 load_mask = 0x1f;
2093
+ const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
2094
+ const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
2095
+
2096
+ // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
2097
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2098
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2099
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2100
+ // blocks_0_float
2101
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2102
+ // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
2103
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2104
+ // blocks_1_float
2105
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2106
+ // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
2107
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2108
+ const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
2109
+ const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
2110
+ // We absolutely shouldn't touch the floats marked with `xx`: they contain some
2111
+ // random data, which might very well underflow. At least on Intel, this leads
2112
+ // to a huge penalty that can't be ignored (easily 100x or more) unless you
2113
+ // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
2114
+ // (and ggml can't assume that you do)...
2115
+ const __mmask16 scale_mul_mask = 0x21;
2116
+ #ifdef __clang__
2117
+ // ...however, clang decides to optimize the multiplication mask away:
2118
+ // https://godbolt.org/z/P8PqdsfvW
2119
+ // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
2120
+ __m512i scales;
2121
+ __asm__(
2122
+ "vmulps %1, %2, %0%{%3%}"
2123
+ : "=v" ( scales )
2124
+ : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
2125
+ );
2126
+ #else
2127
+ const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
2128
+ #endif
2129
+ const __m512i scale_perm = _mm512_set_epi32(
2130
+ 5, 5, 5, 5, 5, 5, 5, 5,
2131
+ 0, 0, 0, 0, 0, 0, 0, 0
2132
+ );
2133
+ const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
2134
+ // After VMULPS and VPERMPS, `permuted_scales` looks like this:
2135
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2136
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
2137
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2138
+ // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
2139
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
2140
+
2141
+ const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
2142
+ const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
2143
+
2144
+ // Now we want to compute dot products of 4-element byte vectors and store them in
2145
+ // 32-bit integers. That is (only one 4-element vector is shown for clarity):
2146
+ // +----+----+----+----+
2147
+ // ... | 03 | 02 | 01 | 00 |
2148
+ // +----+----+----+----+
2149
+ // bytes_0
2150
+ // +----+----+----+----+
2151
+ // ... | D | C | B | A |
2152
+ // +----+----+----+----+
2153
+ // bytes_1
2154
+ // +----+----+----+----+
2155
+ // ... | H | G | F | E |
2156
+ // +----+----+----+----+
2157
+ // final_res_int
2158
+ // +----+----+----+----+
2159
+ // ... | A*E+B*F+C*G+D*H |
2160
+ // +----+----+----+----+
2161
+ const __m512i plus_8 = _mm512_set1_epi8( 8 );
2162
+ const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
2163
+
2164
+ #ifdef __AVX512VNNI__
2165
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2166
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2167
+ // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
2168
+ // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
2169
+ // which means we only need 2 instructions.
2170
+ const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
2171
+ const __m512i minus_8 = _mm512_set1_epi8( -8 );
2172
+ const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
2173
+ const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
2174
+ #else
2175
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2176
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2177
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2178
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2179
+ const __m512i one = _mm512_set1_epi16( 1 );
2180
+ const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
2181
+ const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
2182
+ const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
2183
+ const __m512i final_res_int = _mm512_madd_epi16( diff, one );
2184
+ #endif
1784
2185
 
1785
- // Convert int32_t to float
1786
- __m512 p = _mm512_cvtepi32_ps( i64 );
1787
- // Apply the scale, and accumulate
1788
- return _mm512_fmadd_ps( d, p, acc );
2186
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2187
+ const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
2188
+ return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
1789
2189
  }
1790
2190
  #endif
1791
2191
 
@@ -1826,9 +2226,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
1826
2226
  }
1827
2227
 
1828
2228
  static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
1829
- const int nb = n / QK;
2229
+ const int nb = n / QK4_0;
1830
2230
 
1831
- assert(n % QK == 0);
2231
+ assert(n % QK4_0 == 0);
1832
2232
  assert(nb % 2 == 0);
1833
2233
 
1834
2234
  const block_q4_0 * restrict x = vx;
@@ -1857,55 +2257,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1857
2257
  // 4-bit -> 8-bit
1858
2258
  const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
1859
2259
  const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1860
-
1861
2260
  const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
1862
2261
  const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
1863
2262
 
1864
2263
  const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
1865
2264
  const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1866
-
1867
2265
  const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
1868
2266
  const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
1869
2267
 
1870
2268
  // sub 8
1871
2269
  const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
1872
2270
  const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1873
-
1874
2271
  const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
1875
2272
  const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
1876
2273
 
1877
2274
  const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
1878
2275
  const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1879
-
1880
2276
  const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
1881
2277
  const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
1882
2278
 
1883
2279
  #if defined(__ARM_FEATURE_DOTPROD)
1884
- // dot product into int16x8_t
2280
+ // dot product into int32x4_t
1885
2281
  int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
1886
2282
  int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
1887
2283
 
1888
2284
  p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
1889
2285
  p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
1890
2286
 
1891
- // scalar
1892
- #if defined(__ARM_FEATURE_QRDMX)
1893
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1894
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
2287
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2288
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
1895
2289
  #else
1896
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1897
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1898
- #endif
1899
- #else
1900
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2290
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1901
2291
  const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1902
-
1903
2292
  const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1904
2293
  const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1905
2294
 
1906
2295
  const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1907
2296
  const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1908
-
1909
2297
  const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1910
2298
  const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1911
2299
 
@@ -1918,14 +2306,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1918
2306
  const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1919
2307
  const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1920
2308
 
1921
- // scalar
1922
- #if defined(__ARM_FEATURE_QRDMX)
1923
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1924
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1925
- #else
1926
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1927
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1928
- #endif
2309
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2310
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
1929
2311
  #endif
1930
2312
  }
1931
2313
 
@@ -1935,25 +2317,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1935
2317
  __m512 acc0 = _mm512_setzero_ps();
1936
2318
  __m512 acc1 = _mm512_setzero_ps();
1937
2319
 
1938
- const int superblock_size = 8;
2320
+ const int superblock_size = 16;
2321
+
1939
2322
  const int superblock_count = nb / superblock_size;
1940
2323
 
1941
2324
  for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1942
2325
  int i = superblock_ix * superblock_size;
1943
2326
 
1944
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
1945
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
1946
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
1947
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
1948
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
1949
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
1950
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
1951
- acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
2327
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
2328
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
2329
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
2330
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
2331
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
2332
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
2333
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
2334
+ acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
1952
2335
  }
1953
2336
 
1954
2337
  // Remainders
1955
- for (int i = superblock_count * superblock_size; i < nb; ++i) {
1956
- acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
2338
+ for (int i = superblock_count * superblock_size; i < nb; i += 2) {
2339
+ acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
1957
2340
  }
1958
2341
 
1959
2342
  // Horizontal sum of all lanes of the accumulator
@@ -1962,7 +2345,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1962
2345
  // Initialize accumulator with zeros
1963
2346
  __m256 acc = _mm256_setzero_ps();
1964
2347
 
1965
- /* Prepare the constants we will need during execution */
2348
+ /* Prepare the constants we will need during execution */
1966
2349
  const __m256i lowMask = _mm256_set1_epi8( 0xF );
1967
2350
  const __m256i offset_8 = _mm256_set1_epi16( 8 );
1968
2351
 
@@ -1972,61 +2355,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2355
 
1973
2356
  // Main loop
1974
2357
  for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1975
-
1976
- // This loop will be unrolled by the compiler
2358
+ // This loop will be unrolled by the compiler
1977
2359
  for (int u=0;u<UNROLL_COUNT;u++) {
1978
- /* Compute combined scale for the block */
1979
- const __m256 scale = _mm256_mul_ps(
1980
- _mm256_broadcast_ss( &x[i+u].d ),
1981
- _mm256_broadcast_ss( &y[i+u].d ) );
1982
-
1983
- /* get input from x
1984
- Input: 32 Nibbles (16 bytes) at *x[i+u]
1985
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1986
-
1987
- /* Load 16 bytes from memory */
1988
- const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1989
- /* Expand bytes into uint16_t values */
1990
- const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
2360
+ /* Compute combined scale for the block */
2361
+ const __m256 scale = _mm256_mul_ps(
2362
+ _mm256_broadcast_ss( &x[i+u].d ),
2363
+ _mm256_broadcast_ss( &y[i+u].d ) );
2364
+
2365
+ /* get input from x
2366
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
2367
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
2368
+
2369
+ /* Load 16 bytes from memory */
2370
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
2371
+ /* Expand bytes into uint16_t values */
2372
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1991
2373
  /* Unpack values into individual bytes */
1992
2374
  __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1993
2375
  const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1994
- __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
2376
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1995
2377
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1996
- x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1997
- x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
2378
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
2379
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1998
2380
 
1999
- /* get input from y
2000
- Input: 32 Nibbles (16 bytes) at *y[i+u]
2001
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2381
+ /* get input from y
2382
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2383
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2002
2384
 
2003
- /* Load 16 bytes from memory */
2004
- const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2005
- /* Expand bytes into uint16_t values */
2006
- const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2385
+ /* Load 16 bytes from memory */
2386
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
2387
+ /* Expand bytes into uint16_t values */
2388
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2007
2389
  /* Unpack values into individual bytes */
2008
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2009
- __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2010
- __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2390
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2391
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2392
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2011
2393
  /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2012
- y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2013
- y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2394
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2395
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2014
2396
 
2015
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2016
- __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2017
- __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2397
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2398
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2399
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2018
2400
 
2019
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2020
- __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2401
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2402
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2021
2403
 
2022
- /* Convert to vectore of 8 int32_t to 8 floats */
2023
- __m256 q = _mm256_cvtepi32_ps( xy_q );
2404
+ /* Convert to vectore of 8 int32_t to 8 floats */
2405
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2024
2406
 
2025
- /* Multiply q with scale and accumulate */
2026
- acc = _mm256_fmadd_ps( scale, q, acc );
2407
+ /* Multiply q with scale and accumulate */
2408
+ acc = _mm256_fmadd_ps( scale, q, acc );
2027
2409
  }
2028
-
2029
- }
2410
+ }
2030
2411
 
2031
2412
  // Return horizontal sum of the acc vector
2032
2413
  __m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2087,18 +2468,18 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2087
2468
  float sum1 = 0.0f;
2088
2469
 
2089
2470
  for (int i = 0; i < nb; i += 2) {
2090
- const block_q4_0 * restrict x0 = &px[i + 0];
2091
- const block_q4_0 * restrict y0 = &py[i + 0];
2092
- const block_q4_0 * restrict x1 = &px[i + 1];
2093
- const block_q4_0 * restrict y1 = &py[i + 1];
2471
+ const block_q4_0 * restrict x0 = &x[i + 0];
2472
+ const block_q4_0 * restrict y0 = &y[i + 0];
2473
+ const block_q4_0 * restrict x1 = &x[i + 1];
2474
+ const block_q4_0 * restrict y1 = &y[i + 1];
2094
2475
 
2095
2476
  const v128_t m4b = wasm_u8x16_splat(0xf);
2096
2477
  const v128_t s8b = wasm_i8x16_splat(0x8);
2097
2478
 
2098
- const v128_t v0_0 = wasm_v128_load(x0.qs);
2099
- const v128_t v0_1 = wasm_v128_load(y0.qs);
2100
- const v128_t v1_0 = wasm_v128_load(x1.qs);
2101
- const v128_t v1_1 = wasm_v128_load(y1.qs);
2479
+ const v128_t v0_0 = wasm_v128_load(x0->qs);
2480
+ const v128_t v0_1 = wasm_v128_load(y0->qs);
2481
+ const v128_t v1_0 = wasm_v128_load(x1->qs);
2482
+ const v128_t v1_1 = wasm_v128_load(y1->qs);
2102
2483
 
2103
2484
  // 4-bit -> 8-bit
2104
2485
  const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
@@ -2170,18 +2551,20 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2170
2551
  const uint8_t * restrict p0 = x[i].qs;
2171
2552
  const uint8_t * restrict p1 = y[i].qs;
2172
2553
 
2173
- for (int j = 0; j < QK/2; j++) {
2554
+ int sumi = 0;
2555
+ for (int j = 0; j < QK4_0/2; j++) {
2174
2556
  const uint8_t v0 = p0[j];
2175
2557
  const uint8_t v1 = p1[j];
2176
2558
 
2177
- const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
2178
- const float f1 = d0*((int8_t) (v0 >> 4) - 8);
2559
+ const int i0 = (v0 & 0xf) - 8;
2560
+ const int i1 = (v0 >> 4) - 8;
2179
2561
 
2180
- const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
2181
- const float f3 = d1*((int8_t) (v1 >> 4) - 8);
2562
+ const int i2 = (v1 & 0xf) - 8;
2563
+ const int i3 = (v1 >> 4) - 8;
2182
2564
 
2183
- sumf += f0*f2 + f1*f3;
2565
+ sumi += i0*i2 + i1*i3;
2184
2566
  }
2567
+ sumf += d0 * d1 * sumi;
2185
2568
  }
2186
2569
  #endif
2187
2570
 
@@ -2189,7 +2572,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2189
2572
  }
2190
2573
 
2191
2574
  static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2192
- const int nb = n / QK;
2575
+ const int nb = n / QK4_1;
2193
2576
 
2194
2577
  const block_q4_1 * restrict x = vx;
2195
2578
  const block_q4_1 * restrict y = vy;
@@ -2266,46 +2649,81 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2266
2649
  res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2267
2650
  res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2268
2651
 
2269
- sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
2652
+ sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
2270
2653
  #elif defined(__ARM_NEON)
2271
2654
  float sum00 = 0.0f;
2272
2655
  float sum01 = 0.0f;
2273
2656
  float sum10 = 0.0f;
2274
2657
  float sum11 = 0.0f;
2275
2658
 
2276
- for (int i = 0; i < nb; ++i) {
2659
+ for (int i = 0; i < nb; i += 2) {
2277
2660
  const block_q4_1 * restrict x0 = &x[i + 0];
2278
2661
  const block_q4_1 * restrict y0 = &y[i + 0];
2662
+ const block_q4_1 * restrict x1 = &x[i + 1];
2663
+ const block_q4_1 * restrict y1 = &y[i + 1];
2279
2664
 
2280
2665
  const uint8x16_t m4b = vdupq_n_u8(0xf);
2281
2666
 
2282
2667
  const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2283
2668
  const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2669
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2670
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
2284
2671
 
2285
- // and with 0xf
2672
+ // 4-bit -> 8-bit
2286
2673
  const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
2287
2674
  const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2288
-
2289
2675
  const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
2290
2676
  const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
2291
2677
 
2292
- // dot product into uint16x8_t
2678
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2679
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2680
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2681
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2682
+
2683
+ sum00 += x0->m*y0->m;
2684
+ sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
2685
+ sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
2686
+
2687
+ sum00 += x1->m*y1->m;
2688
+ sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
2689
+ sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
2690
+
2691
+ #if defined(__ARM_FEATURE_DOTPROD)
2692
+ // dot product into int32x4_t
2693
+ uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
2694
+ uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
2695
+
2696
+ p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
2697
+ p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
2698
+
2699
+ sum11 += x0->d*y0->d*vaddvq_u32(p_0);
2700
+ sum11 += x1->d*y1->d*vaddvq_u32(p_1);
2701
+ #else
2293
2702
  const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
2294
2703
  const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2295
-
2296
2704
  const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
2297
2705
  const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
2298
2706
 
2299
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2300
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2707
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
2708
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
2709
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
2710
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
2301
2711
 
2302
- sum00 += x0->m*y0->m;
2303
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2304
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2305
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2712
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2713
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2714
+
2715
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2716
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2717
+
2718
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2719
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2720
+
2721
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2722
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2723
+ #endif
2306
2724
  }
2307
2725
 
2308
- sumf = QK*sum00 + sum01 + sum10 + sum11;
2726
+ sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
2309
2727
  #else
2310
2728
  // scalar
2311
2729
  for (int i = 0; i < nb; i++) {
@@ -2318,7 +2736,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2318
2736
  const uint8_t * restrict p0 = x[i].qs;
2319
2737
  const uint8_t * restrict p1 = y[i].qs;
2320
2738
 
2321
- for (int j = 0; j < QK/2; j++) {
2739
+ for (int j = 0; j < QK4_1/2; j++) {
2322
2740
  const uint8_t v0 = p0[j];
2323
2741
  const uint8_t v1 = p1[j];
2324
2742
 
@@ -2336,21 +2754,224 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2336
2754
  *s = sumf;
2337
2755
  }
2338
2756
 
2339
- // compute GGML_VEC_DOT_UNROLL dot products at once
2340
- // xs - x row stride in bytes
2341
- inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2342
- ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2757
+ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2758
+ const int nb = n / QK8_0;
2343
2759
 
2344
- ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2760
+ assert(n % QK8_0 == 0);
2761
+ assert(nb % 2 == 0);
2345
2762
 
2346
- for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2347
- x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2348
- }
2763
+ const block_q4_0 * restrict x = vx;
2764
+ const block_q8_0 * restrict y = vy;
2349
2765
 
2350
- #if defined(GGML_SIMD)
2351
- const int np = (n & ~(GGML_F16_STEP - 1));
2766
+ float sumf = 0.0;
2352
2767
 
2353
- GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2768
+ #if defined(__ARM_NEON)
2769
+ float sum0 = 0.0f;
2770
+ float sum1 = 0.0f;
2771
+
2772
+ for (int i = 0; i < nb; i += 2) {
2773
+ const block_q4_0 * restrict x0 = &x[i + 0];
2774
+ const block_q4_0 * restrict x1 = &x[i + 1];
2775
+ const block_q8_0 * restrict y0 = &y[i + 0];
2776
+ const block_q8_0 * restrict y1 = &y[i + 1];
2777
+
2778
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
2779
+ const int8x16_t s8b = vdupq_n_s8(0x8);
2780
+
2781
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2782
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2783
+
2784
+ // 4-bit -> 8-bit
2785
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2786
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2787
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2788
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
2789
+
2790
+ // sub 8
2791
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2792
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2793
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2794
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2795
+
2796
+ // load y
2797
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
2798
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2799
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
2800
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2801
+
2802
+ // interleave
2803
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2804
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2805
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2806
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
2807
+
2808
+ #if defined(__ARM_FEATURE_DOTPROD)
2809
+ // dot product into int32x4_t
2810
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
2811
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
2812
+
2813
+ p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
2814
+ p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
2815
+
2816
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
2817
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
2818
+ #else
2819
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2820
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2821
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2822
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2823
+
2824
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2825
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2826
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2827
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2828
+
2829
+ const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
2830
+ const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
2831
+
2832
+ const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
2833
+ const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
2834
+
2835
+ const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
2836
+ const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
2837
+
2838
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
2839
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
2840
+ #endif
2841
+ }
2842
+
2843
+ sumf = sum0 + sum1;
2844
+ #elif defined(__AVX2__)
2845
+ // Initialize accumulator with zeros
2846
+ __m256 acc = _mm256_setzero_ps();
2847
+
2848
+ // Main loop
2849
+ for (int i = 0; i < nb; ++i) {
2850
+ /* Compute combined scale for the block */
2851
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2852
+
2853
+ __m256i bx = bytesFromNibbles(x[i].qs);
2854
+
2855
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2856
+ const __m256i off = _mm256_set1_epi8( 8 );
2857
+ bx = _mm256_sub_epi8( bx, off );
2858
+
2859
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2860
+
2861
+ // Get absolute values of x vectors
2862
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
2863
+
2864
+ // Sign the values of the y vectors
2865
+ const __m256i sy = _mm256_sign_epi8(by, bx);
2866
+
2867
+ // Perform multiplication and create 16-bit values
2868
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2869
+
2870
+ const __m256i ones = _mm256_set1_epi16(1);
2871
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
2872
+
2873
+ /* Convert to vectore of 8 int32_t to 8 floats */
2874
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
2875
+
2876
+ /* Multiply q with scale and accumulate */
2877
+ acc = _mm256_fmadd_ps( d, q, acc );
2878
+ }
2879
+
2880
+ // Return horizontal sum of the acc vector
2881
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2882
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2883
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2884
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2885
+
2886
+ sumf = _mm_cvtss_f32( res );
2887
+ #elif defined(__AVX__)
2888
+ // Initialize accumulator with zeros
2889
+ __m256 acc = _mm256_setzero_ps();
2890
+
2891
+ // Main loop
2892
+ for (int i = 0; i < nb; ++i) {
2893
+ // Compute combined scale for the block
2894
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2895
+
2896
+ __m128i i32[2];
2897
+ for (int j = 0; j < 2; ++j) {
2898
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2899
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2900
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2901
+
2902
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2903
+ const __m128i off = _mm_set1_epi8( 8 );
2904
+ bx = _mm_sub_epi8( bx, off );
2905
+
2906
+ // Get absolute values of x vectors
2907
+ const __m128i ax = _mm_sign_epi8(bx, bx);
2908
+
2909
+ // Sign the values of the y vectors
2910
+ const __m128i sy = _mm_sign_epi8(by, bx);
2911
+
2912
+ // Perform multiplication and create 16-bit values
2913
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
2914
+
2915
+ const __m128i ones = _mm_set1_epi16(1);
2916
+ i32[j] = _mm_madd_epi16(ones, dot);
2917
+ }
2918
+
2919
+ // Convert int32_t to float
2920
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2921
+ // Apply the scale, and accumulate
2922
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2923
+ }
2924
+
2925
+ // Return horizontal sum of the acc vector
2926
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
2927
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2928
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2929
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2930
+
2931
+ sumf = _mm_cvtss_f32( res );
2932
+ #else
2933
+ // scalar
2934
+ for (int i = 0; i < nb; i++) {
2935
+ const float d0 = x[i].d;
2936
+ const float d1 = y[i].d;
2937
+
2938
+ const uint8_t * restrict p0 = x[i].qs;
2939
+ const int8_t * restrict p1 = y[i].qs;
2940
+
2941
+ int sumi = 0;
2942
+ for (int j = 0; j < QK8_0/2; j++) {
2943
+ const uint8_t v0 = p0[j];
2944
+
2945
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
2946
+ const int i1 = (int8_t) (v0 >> 4) - 8;
2947
+
2948
+ const int i2 = p1[2*j + 0];
2949
+ const int i3 = p1[2*j + 1];
2950
+
2951
+ sumi += i0*i2 + i1*i3;
2952
+ }
2953
+ sumf += d0*d1*sumi;
2954
+ }
2955
+ #endif
2956
+
2957
+ *s = sumf;
2958
+ }
2959
+
2960
+ // compute GGML_VEC_DOT_UNROLL dot products at once
2961
+ // xs - x row stride in bytes
2962
+ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
2963
+ ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
2964
+
2965
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
2966
+
2967
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
2968
+ x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
2969
+ }
2970
+
2971
+ #if defined(GGML_SIMD)
2972
+ const int np = (n & ~(GGML_F16_STEP - 1));
2973
+
2974
+ GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
2354
2975
 
2355
2976
  GGML_F16_VEC ax[GGML_F16_ARR];
2356
2977
  GGML_F16_VEC ay[GGML_F16_ARR];
@@ -2578,29 +3199,41 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
2578
3199
  //
2579
3200
 
2580
3201
  static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
2581
- QK,
2582
- QK,
2583
- 1,
2584
- 1,
2585
- 1,
2586
- 1,
2587
- 1,
3202
+ [GGML_TYPE_F32] = 1,
3203
+ [GGML_TYPE_F16] = 1,
3204
+ [GGML_TYPE_Q4_0] = QK4_0,
3205
+ [GGML_TYPE_Q4_1] = QK4_1,
3206
+ [GGML_TYPE_Q8_0] = QK8_0,
3207
+ [GGML_TYPE_I8] = 1,
3208
+ [GGML_TYPE_I16] = 1,
3209
+ [GGML_TYPE_I32] = 1,
2588
3210
  };
2589
-
2590
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
3211
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
2591
3212
 
2592
3213
  static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
2593
- sizeof(block_q4_0),
2594
- sizeof(block_q4_1),
2595
- sizeof(int8_t ),
2596
- sizeof(int16_t),
2597
- sizeof(int32_t),
2598
- sizeof(ggml_fp16_t),
2599
- sizeof(float ),
3214
+ [GGML_TYPE_F32] = sizeof(float),
3215
+ [GGML_TYPE_F16] = sizeof(ggml_fp16_t),
3216
+ [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
3217
+ [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
3218
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
3219
+ [GGML_TYPE_I8] = sizeof(int8_t),
3220
+ [GGML_TYPE_I16] = sizeof(int16_t),
3221
+ [GGML_TYPE_I32] = sizeof(int32_t),
2600
3222
  };
2601
-
2602
- // don't forget to update the array above when adding new types
2603
- static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
3223
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
3224
+
3225
+
3226
+ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
3227
+ [GGML_TYPE_F32] = "f32",
3228
+ [GGML_TYPE_F16] = "f16",
3229
+ [GGML_TYPE_Q4_0] = "q4_0",
3230
+ [GGML_TYPE_Q4_1] = "q4_1",
3231
+ [GGML_TYPE_Q8_0] = "q8_0",
3232
+ [GGML_TYPE_I8] = "i8",
3233
+ [GGML_TYPE_I16] = "i16",
3234
+ [GGML_TYPE_I32] = "i32",
3235
+ };
3236
+ static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
2604
3237
 
2605
3238
  static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2606
3239
  "NONE",
@@ -2629,6 +3262,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2629
3262
 
2630
3263
  "SCALE",
2631
3264
  "CPY",
3265
+ "CONT",
2632
3266
  "RESHAPE",
2633
3267
  "VIEW",
2634
3268
  "PERMUTE",
@@ -2642,9 +3276,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2642
3276
 
2643
3277
  "FLASH_ATTN",
2644
3278
  "FLASH_FF",
3279
+
3280
+ "MAP_UNARY",
3281
+ "MAP_BINARY",
2645
3282
  };
2646
3283
 
2647
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
3284
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2648
3285
 
2649
3286
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2650
3287
  "none",
@@ -2673,6 +3310,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2673
3310
 
2674
3311
  "x*v",
2675
3312
  "x-\\>y",
3313
+ "cont(x)",
2676
3314
  "reshape(x)",
2677
3315
  "view(x)",
2678
3316
  "permute(x)",
@@ -2686,24 +3324,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2686
3324
 
2687
3325
  "flash_attn(x)",
2688
3326
  "flash_ff(x)",
2689
- };
2690
-
2691
- static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2692
-
2693
- //
2694
- // ggml object
2695
- //
2696
-
2697
- struct ggml_object {
2698
- size_t offs;
2699
- size_t size;
2700
3327
 
2701
- struct ggml_object * next;
2702
-
2703
- char padding[8];
3328
+ "f(x)",
3329
+ "f(x,y)",
2704
3330
  };
2705
3331
 
2706
- static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
3332
+ static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
2707
3333
 
2708
3334
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
2709
3335
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -2716,7 +3342,6 @@ struct ggml_context {
2716
3342
  size_t mem_size;
2717
3343
  void * mem_buffer;
2718
3344
  bool mem_buffer_owned;
2719
- bool mem_buffer_mlocked;
2720
3345
  bool no_alloc;
2721
3346
 
2722
3347
  int n_objects;
@@ -2834,6 +3459,11 @@ float ggml_type_sizef(enum ggml_type type) {
2834
3459
  return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
2835
3460
  }
2836
3461
 
3462
+ const char * ggml_type_name(enum ggml_type type) {
3463
+ return GGML_TYPE_NAME[type];
3464
+ }
3465
+
3466
+
2837
3467
  size_t ggml_element_size(const struct ggml_tensor * tensor) {
2838
3468
  return GGML_TYPE_SIZE[tensor->type];
2839
3469
  }
@@ -2999,11 +3629,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2999
3629
  return NULL;
3000
3630
  }
3001
3631
 
3632
+ const size_t mem_size = (params.mem_size + GGML_MEM_ALIGN - 1) & ~(GGML_MEM_ALIGN - 1);
3633
+
3002
3634
  *ctx = (struct ggml_context) {
3003
- /*.mem_size =*/ params.mem_size,
3004
- /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : malloc(params.mem_size),
3635
+ /*.mem_size =*/ mem_size,
3636
+ /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
3005
3637
  /*.mem_buffer_owned =*/ params.mem_buffer ? false : true,
3006
- /*.mem_buffer_mlocked =*/ false,
3007
3638
  /*.no_alloc =*/ params.no_alloc,
3008
3639
  /*.n_objects =*/ 0,
3009
3640
  /*.objects_begin =*/ NULL,
@@ -3012,7 +3643,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3012
3643
  /*.scratch_save =*/ { 0, 0, NULL, },
3013
3644
  };
3014
3645
 
3015
- GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
3646
+ GGML_ASSERT(ctx->mem_buffer != NULL);
3016
3647
 
3017
3648
  ggml_assert_aligned(ctx->mem_buffer);
3018
3649
 
@@ -3036,16 +3667,8 @@ void ggml_free(struct ggml_context * ctx) {
3036
3667
  GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
3037
3668
  __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
3038
3669
 
3039
- #if GGML_MLOCK_SUPPORT
3040
- if (ctx->mem_buffer_mlocked) {
3041
- if (munlock(ctx->mem_buffer, ctx->mem_size)) {
3042
- fprintf(stderr, "%s: failed to munlock buffer: %s\n", __func__, strerror(errno));
3043
- }
3044
- }
3045
- #endif
3046
-
3047
3670
  if (ctx->mem_buffer_owned) {
3048
- free(ctx->mem_buffer);
3671
+ GGML_ALIGNED_FREE(ctx->mem_buffer);
3049
3672
  }
3050
3673
 
3051
3674
  found = true;
@@ -3072,48 +3695,6 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
3072
3695
  return result;
3073
3696
  }
3074
3697
 
3075
- #ifdef __APPLE__
3076
- #define MLOCK_SUGGESTION \
3077
- "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
3078
- "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n"
3079
- #else
3080
- #define MLOCK_SUGGESTION \
3081
- "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n"
3082
- #endif
3083
-
3084
- bool ggml_mlock_supported(void) {
3085
- return GGML_MLOCK_SUPPORT;
3086
- }
3087
-
3088
- bool ggml_mlock(
3089
- struct ggml_context * ctx,
3090
- const void *opt_extra_addr,
3091
- size_t opt_extra_len,
3092
- char **err_p) {
3093
- // TODO: Use SetProcessWorkingSetSize() + VirtualLock() on WIN32
3094
- #if GGML_MLOCK_SUPPORT
3095
- if (ctx->mem_buffer_mlocked) {
3096
- return true;
3097
- }
3098
- if (mlock(ctx->mem_buffer, ctx->mem_size) ||
3099
- (opt_extra_len &&
3100
- mlock(opt_extra_addr, opt_extra_len))) {
3101
- if ((*err_p = malloc(1024))) {
3102
- snprintf(*err_p, 1024,
3103
- "failed to mlock %zu-byte buffer: %s\n" MLOCK_SUGGESTION,
3104
- ctx->mem_size + opt_extra_len,
3105
- strerror(errno));
3106
- }
3107
- return false;
3108
- }
3109
- ctx->mem_buffer_mlocked = true;
3110
- return true;
3111
- #else // GGML_MLOCK_SUPPORT
3112
- *err_p = strdup("can't mlock because it's not supported on this system");
3113
- return false;
3114
- #endif // GGML_MLOCK_SUPPORT
3115
- }
3116
-
3117
3698
  ////////////////////////////////////////////////////////////////////////////////
3118
3699
 
3119
3700
  struct ggml_tensor * ggml_new_tensor_impl(
@@ -3325,14 +3906,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3325
3906
  char * const data = tensor->data;
3326
3907
 
3327
3908
  switch (tensor->type) {
3328
- case GGML_TYPE_Q4_0:
3329
- {
3330
- GGML_ASSERT(false);
3331
- } break;
3332
- case GGML_TYPE_Q4_1:
3333
- {
3334
- GGML_ASSERT(false);
3335
- } break;
3336
3909
  case GGML_TYPE_I8:
3337
3910
  {
3338
3911
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3368,7 +3941,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
3368
3941
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3369
3942
  }
3370
3943
  } break;
3371
- case GGML_TYPE_COUNT:
3944
+ default:
3372
3945
  {
3373
3946
  GGML_ASSERT(false);
3374
3947
  } break;
@@ -3385,14 +3958,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3385
3958
  char * const data = tensor->data;
3386
3959
 
3387
3960
  switch (tensor->type) {
3388
- case GGML_TYPE_Q4_0:
3389
- {
3390
- GGML_ASSERT(false);
3391
- } break;
3392
- case GGML_TYPE_Q4_1:
3393
- {
3394
- GGML_ASSERT(false);
3395
- } break;
3396
3961
  case GGML_TYPE_I8:
3397
3962
  {
3398
3963
  assert(tensor->nb[0] == sizeof(int8_t));
@@ -3428,7 +3993,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3428
3993
  ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
3429
3994
  }
3430
3995
  } break;
3431
- case GGML_TYPE_COUNT:
3996
+ default:
3432
3997
  {
3433
3998
  GGML_ASSERT(false);
3434
3999
  } break;
@@ -3439,14 +4004,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
3439
4004
 
3440
4005
  int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3441
4006
  switch (tensor->type) {
3442
- case GGML_TYPE_Q4_0:
3443
- {
3444
- GGML_ASSERT(false);
3445
- } break;
3446
- case GGML_TYPE_Q4_1:
3447
- {
3448
- GGML_ASSERT(false);
3449
- } break;
3450
4007
  case GGML_TYPE_I8:
3451
4008
  {
3452
4009
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3472,7 +4029,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3472
4029
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3473
4030
  return ((float *)(tensor->data))[i];
3474
4031
  } break;
3475
- case GGML_TYPE_COUNT:
4032
+ default:
3476
4033
  {
3477
4034
  GGML_ASSERT(false);
3478
4035
  } break;
@@ -3483,14 +4040,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
3483
4040
 
3484
4041
  void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3485
4042
  switch (tensor->type) {
3486
- case GGML_TYPE_Q4_0:
3487
- {
3488
- GGML_ASSERT(false);
3489
- } break;
3490
- case GGML_TYPE_Q4_1:
3491
- {
3492
- GGML_ASSERT(false);
3493
- } break;
3494
4043
  case GGML_TYPE_I8:
3495
4044
  {
3496
4045
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3516,7 +4065,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3516
4065
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3517
4066
  ((float *)(tensor->data))[i] = value;
3518
4067
  } break;
3519
- case GGML_TYPE_COUNT:
4068
+ default:
3520
4069
  {
3521
4070
  GGML_ASSERT(false);
3522
4071
  } break;
@@ -3525,14 +4074,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
3525
4074
 
3526
4075
  float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3527
4076
  switch (tensor->type) {
3528
- case GGML_TYPE_Q4_0:
3529
- {
3530
- GGML_ASSERT(false);
3531
- } break;
3532
- case GGML_TYPE_Q4_1:
3533
- {
3534
- GGML_ASSERT(false);
3535
- } break;
3536
4077
  case GGML_TYPE_I8:
3537
4078
  {
3538
4079
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3558,7 +4099,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3558
4099
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3559
4100
  return ((float *)(tensor->data))[i];
3560
4101
  } break;
3561
- case GGML_TYPE_COUNT:
4102
+ default:
3562
4103
  {
3563
4104
  GGML_ASSERT(false);
3564
4105
  } break;
@@ -3569,14 +4110,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
3569
4110
 
3570
4111
  void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3571
4112
  switch (tensor->type) {
3572
- case GGML_TYPE_Q4_0:
3573
- {
3574
- GGML_ASSERT(false);
3575
- } break;
3576
- case GGML_TYPE_Q4_1:
3577
- {
3578
- GGML_ASSERT(false);
3579
- } break;
3580
4113
  case GGML_TYPE_I8:
3581
4114
  {
3582
4115
  GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3602,7 +4135,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
3602
4135
  GGML_ASSERT(tensor->nb[0] == sizeof(float));
3603
4136
  ((float *)(tensor->data))[i] = value;
3604
4137
  } break;
3605
- case GGML_TYPE_COUNT:
4138
+ default:
3606
4139
  {
3607
4140
  GGML_ASSERT(false);
3608
4141
  } break;
@@ -4388,26 +4921,22 @@ struct ggml_tensor * ggml_cpy_inplace(
4388
4921
  return ggml_cpy_impl(ctx, a, b, true);
4389
4922
  }
4390
4923
 
4391
- // ggml_reshape
4924
+ // ggml_cont
4392
4925
 
4393
- struct ggml_tensor * ggml_reshape(
4926
+ struct ggml_tensor * ggml_cont_impl(
4394
4927
  struct ggml_context * ctx,
4395
- struct ggml_tensor * a,
4396
- struct ggml_tensor * b) {
4397
- GGML_ASSERT(ggml_is_contiguous(a));
4398
- GGML_ASSERT(ggml_is_contiguous(b));
4399
- GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
4400
-
4928
+ struct ggml_tensor * a,
4929
+ bool inplace) {
4401
4930
  bool is_node = false;
4402
4931
 
4403
- if (a->grad || b->grad) {
4932
+ if (!inplace && a->grad) {
4404
4933
  GGML_ASSERT(false); // TODO: implement backward
4405
4934
  is_node = true;
4406
4935
  }
4407
4936
 
4408
- struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
4937
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
4409
4938
 
4410
- result->op = GGML_OP_RESHAPE;
4939
+ result->op = GGML_OP_CONT;
4411
4940
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4412
4941
  result->src0 = a;
4413
4942
  result->src1 = NULL;
@@ -4415,12 +4944,51 @@ struct ggml_tensor * ggml_reshape(
4415
4944
  return result;
4416
4945
  }
4417
4946
 
4418
- struct ggml_tensor * ggml_reshape_2d(
4947
+ struct ggml_tensor * ggml_cont(
4419
4948
  struct ggml_context * ctx,
4420
- struct ggml_tensor * a,
4421
- int64_t ne0,
4422
- int64_t ne1) {
4423
- GGML_ASSERT(ggml_is_contiguous(a));
4949
+ struct ggml_tensor * a) {
4950
+ return ggml_cont_impl(ctx, a, false);
4951
+ }
4952
+
4953
+ struct ggml_tensor * ggml_cont_inplace(
4954
+ struct ggml_context * ctx,
4955
+ struct ggml_tensor * a) {
4956
+ return ggml_cont_impl(ctx, a, true);
4957
+ }
4958
+
4959
+ // ggml_reshape
4960
+
4961
+ struct ggml_tensor * ggml_reshape(
4962
+ struct ggml_context * ctx,
4963
+ struct ggml_tensor * a,
4964
+ struct ggml_tensor * b) {
4965
+ GGML_ASSERT(ggml_is_contiguous(a));
4966
+ GGML_ASSERT(ggml_is_contiguous(b));
4967
+ GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
4968
+
4969
+ bool is_node = false;
4970
+
4971
+ if (a->grad || b->grad) {
4972
+ GGML_ASSERT(false); // TODO: implement backward
4973
+ is_node = true;
4974
+ }
4975
+
4976
+ struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
4977
+
4978
+ result->op = GGML_OP_RESHAPE;
4979
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4980
+ result->src0 = a;
4981
+ result->src1 = NULL;
4982
+
4983
+ return result;
4984
+ }
4985
+
4986
+ struct ggml_tensor * ggml_reshape_2d(
4987
+ struct ggml_context * ctx,
4988
+ struct ggml_tensor * a,
4989
+ int64_t ne0,
4990
+ int64_t ne1) {
4991
+ GGML_ASSERT(ggml_is_contiguous(a));
4424
4992
  GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
4425
4993
 
4426
4994
  bool is_node = false;
@@ -4866,6 +5434,90 @@ struct ggml_tensor * ggml_flash_ff(
4866
5434
  return result;
4867
5435
  }
4868
5436
 
5437
+ // ggml_map_unary
5438
+
5439
+ struct ggml_tensor * ggml_map_unary_impl_f32(
5440
+ struct ggml_context * ctx,
5441
+ struct ggml_tensor * a,
5442
+ const ggml_unary_op_f32_t fun,
5443
+ bool inplace) {
5444
+ bool is_node = false;
5445
+
5446
+ if (!inplace && a->grad) {
5447
+ is_node = true;
5448
+ }
5449
+
5450
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
5451
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
5452
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5453
+
5454
+ result->op = GGML_OP_MAP_UNARY;
5455
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5456
+ result->src0 = a;
5457
+ result->opt[0] = addr_tensor;
5458
+
5459
+ return result;
5460
+ }
5461
+
5462
+ struct ggml_tensor * ggml_map_unary_f32(
5463
+ struct ggml_context * ctx,
5464
+ struct ggml_tensor * a,
5465
+ const ggml_unary_op_f32_t fun) {
5466
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
5467
+ }
5468
+
5469
+ struct ggml_tensor * ggml_map_unary_inplace_f32(
5470
+ struct ggml_context * ctx,
5471
+ struct ggml_tensor * a,
5472
+ const ggml_unary_op_f32_t fun) {
5473
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
5474
+ }
5475
+
5476
+ // ggml_map_binary
5477
+
5478
+ struct ggml_tensor * ggml_map_binary_impl_f32(
5479
+ struct ggml_context * ctx,
5480
+ struct ggml_tensor * a,
5481
+ struct ggml_tensor * b,
5482
+ const ggml_binary_op_f32_t fun,
5483
+ bool inplace) {
5484
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5485
+
5486
+ bool is_node = false;
5487
+
5488
+ if (!inplace && (a->grad || b->grad)) {
5489
+ is_node = true;
5490
+ }
5491
+
5492
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
5493
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
5494
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5495
+
5496
+ result->op = GGML_OP_MAP_BINARY;
5497
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5498
+ result->src0 = a;
5499
+ result->src1 = b;
5500
+ result->opt[0] = addr_tensor;
5501
+
5502
+ return result;
5503
+ }
5504
+
5505
+ struct ggml_tensor * ggml_map_binary_f32(
5506
+ struct ggml_context * ctx,
5507
+ struct ggml_tensor * a,
5508
+ struct ggml_tensor * b,
5509
+ const ggml_binary_op_f32_t fun) {
5510
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
5511
+ }
5512
+
5513
+ struct ggml_tensor * ggml_map_binary_inplace_f32(
5514
+ struct ggml_context * ctx,
5515
+ struct ggml_tensor * a,
5516
+ struct ggml_tensor * b,
5517
+ const ggml_binary_op_f32_t fun) {
5518
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
5519
+ }
5520
+
4869
5521
  ////////////////////////////////////////////////////////////////////////////////
4870
5522
 
4871
5523
  void ggml_set_param(
@@ -4930,6 +5582,105 @@ static void ggml_compute_forward_dup_f16(
4930
5582
 
4931
5583
  // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
4932
5584
 
5585
+ if (ggml_is_contiguous(dst)) {
5586
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
5587
+ if (dst->type == GGML_TYPE_F16) {
5588
+ size_t id = 0;
5589
+ const size_t rs = ne00*nb00;
5590
+
5591
+ for (int i03 = 0; i03 < ne03; i03++) {
5592
+ for (int i02 = 0; i02 < ne02; i02++) {
5593
+ for (int i01 = 0; i01 < ne01; i01++) {
5594
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5595
+ char * dst_ptr = (char *) dst->data + id*rs;
5596
+
5597
+ memcpy(dst_ptr, src0_ptr, rs);
5598
+
5599
+ id++;
5600
+ }
5601
+ }
5602
+ }
5603
+ } else if (dst->type == GGML_TYPE_F32) {
5604
+ size_t id = 0;
5605
+ float * dst_ptr = (float *) dst->data;
5606
+
5607
+ for (int i03 = 0; i03 < ne03; i03++) {
5608
+ for (int i02 = 0; i02 < ne02; i02++) {
5609
+ for (int i01 = 0; i01 < ne01; i01++) {
5610
+ for (int i00 = 0; i00 < ne00; i00++) {
5611
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5612
+
5613
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5614
+ id++;
5615
+ }
5616
+ }
5617
+ }
5618
+ }
5619
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5620
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5621
+ size_t id = 0;
5622
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5623
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5624
+ float * src0_f32 = (float *) params->wdata;
5625
+
5626
+ for (int i03 = 0; i03 < ne03; i03++) {
5627
+ for (int i02 = 0; i02 < ne02; i02++) {
5628
+ for (int i01 = 0; i01 < ne01; i01++) {
5629
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5630
+ // convert to f32 and quantize
5631
+ for (int i00 = 0; i00 < ne00; i00++) {
5632
+ src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
5633
+ }
5634
+ quantize_row_q(src0_f32, dst_ptr + id, ne00);
5635
+ id += dst_row_size;
5636
+ }
5637
+ }
5638
+ }
5639
+ } else {
5640
+ GGML_ASSERT(false); // TODO: implement
5641
+ }
5642
+ } else {
5643
+ //printf("%s: this is not optimal - fix me\n", __func__);
5644
+
5645
+ if (dst->type == GGML_TYPE_F32) {
5646
+ size_t id = 0;
5647
+ float * dst_ptr = (float *) dst->data;
5648
+
5649
+ for (int i03 = 0; i03 < ne03; i03++) {
5650
+ for (int i02 = 0; i02 < ne02; i02++) {
5651
+ for (int i01 = 0; i01 < ne01; i01++) {
5652
+ for (int i00 = 0; i00 < ne00; i00++) {
5653
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5654
+
5655
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
5656
+ id++;
5657
+ }
5658
+ }
5659
+ }
5660
+ }
5661
+ } else if (dst->type == GGML_TYPE_F16) {
5662
+ size_t id = 0;
5663
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5664
+
5665
+ for (int i03 = 0; i03 < ne03; i03++) {
5666
+ for (int i02 = 0; i02 < ne02; i02++) {
5667
+ for (int i01 = 0; i01 < ne01; i01++) {
5668
+ for (int i00 = 0; i00 < ne00; i00++) {
5669
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5670
+
5671
+ dst_ptr[id] = *src0_ptr;
5672
+ id++;
5673
+ }
5674
+ }
5675
+ }
5676
+ }
5677
+ } else {
5678
+ GGML_ASSERT(false); // TODO: implement
5679
+ }
5680
+ }
5681
+ return;
5682
+ }
5683
+
4933
5684
  // dst counters
4934
5685
  int64_t i10 = 0;
4935
5686
  int64_t i11 = 0;
@@ -5024,6 +5775,120 @@ static void ggml_compute_forward_dup_f32(
5024
5775
  return;
5025
5776
  }
5026
5777
 
5778
+ if (src0->type == dst->type &&
5779
+ src0->ne[0] == dst->ne[0] &&
5780
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
5781
+ // copy by rows
5782
+ const size_t rs = ne00*nb00;
5783
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
5784
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
5785
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
5786
+ memcpy(
5787
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
5788
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
5789
+ rs);
5790
+ }
5791
+ }
5792
+ }
5793
+ return;
5794
+ }
5795
+
5796
+ if (ggml_is_contiguous(dst)) {
5797
+ // TODO: simplify
5798
+ if (src0->nb[0] == sizeof(float)) {
5799
+ if (dst->type == GGML_TYPE_F32) {
5800
+ size_t id = 0;
5801
+ const size_t rs = ne00*nb00;
5802
+
5803
+ for (int i03 = 0; i03 < ne03; i03++) {
5804
+ for (int i02 = 0; i02 < ne02; i02++) {
5805
+ for (int i01 = 0; i01 < ne01; i01++) {
5806
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
5807
+ char * dst_ptr = (char *) dst->data + id*rs;
5808
+
5809
+ memcpy(dst_ptr, src0_ptr, rs);
5810
+
5811
+ id++;
5812
+ }
5813
+ }
5814
+ }
5815
+ } else if (dst->type == GGML_TYPE_F16) {
5816
+ size_t id = 0;
5817
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5818
+
5819
+ for (int i03 = 0; i03 < ne03; i03++) {
5820
+ for (int i02 = 0; i02 < ne02; i02++) {
5821
+ for (int i01 = 0; i01 < ne01; i01++) {
5822
+ for (int i00 = 0; i00 < ne00; i00++) {
5823
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5824
+
5825
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5826
+ id++;
5827
+ }
5828
+ }
5829
+ }
5830
+ }
5831
+ } else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
5832
+ quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
5833
+ size_t id = 0;
5834
+ uint8_t * dst_ptr = (uint8_t *) dst->data;
5835
+ size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
5836
+
5837
+ for (int i03 = 0; i03 < ne03; i03++) {
5838
+ for (int i02 = 0; i02 < ne02; i02++) {
5839
+ for (int i01 = 0; i01 < ne01; i01++) {
5840
+ const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5841
+ quantize_row_q(src0_ptr, dst_ptr + id, ne00);
5842
+ id += dst_row_size;
5843
+ }
5844
+ }
5845
+ }
5846
+ } else {
5847
+ GGML_ASSERT(false); // TODO: implement
5848
+ }
5849
+ } else {
5850
+ //printf("%s: this is not optimal - fix me\n", __func__);
5851
+
5852
+ if (dst->type == GGML_TYPE_F32) {
5853
+ size_t id = 0;
5854
+ float * dst_ptr = (float *) dst->data;
5855
+
5856
+ for (int i03 = 0; i03 < ne03; i03++) {
5857
+ for (int i02 = 0; i02 < ne02; i02++) {
5858
+ for (int i01 = 0; i01 < ne01; i01++) {
5859
+ for (int i00 = 0; i00 < ne00; i00++) {
5860
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5861
+
5862
+ dst_ptr[id] = *src0_ptr;
5863
+ id++;
5864
+ }
5865
+ }
5866
+ }
5867
+ }
5868
+ } else if (dst->type == GGML_TYPE_F16) {
5869
+ size_t id = 0;
5870
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
5871
+
5872
+ for (int i03 = 0; i03 < ne03; i03++) {
5873
+ for (int i02 = 0; i02 < ne02; i02++) {
5874
+ for (int i01 = 0; i01 < ne01; i01++) {
5875
+ for (int i00 = 0; i00 < ne00; i00++) {
5876
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
5877
+
5878
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
5879
+ id++;
5880
+ }
5881
+ }
5882
+ }
5883
+ }
5884
+ } else {
5885
+ GGML_ASSERT(false); // TODO: implement
5886
+ }
5887
+ }
5888
+
5889
+ return;
5890
+ }
5891
+
5027
5892
  // dst counters
5028
5893
  int64_t i10 = 0;
5029
5894
  int64_t i11 = 0;
@@ -5100,12 +5965,7 @@ static void ggml_compute_forward_dup(
5100
5965
  {
5101
5966
  ggml_compute_forward_dup_f32(params, src0, dst);
5102
5967
  } break;
5103
- case GGML_TYPE_Q4_0:
5104
- case GGML_TYPE_Q4_1:
5105
- case GGML_TYPE_I8:
5106
- case GGML_TYPE_I16:
5107
- case GGML_TYPE_I32:
5108
- case GGML_TYPE_COUNT:
5968
+ default:
5109
5969
  {
5110
5970
  GGML_ASSERT(false);
5111
5971
  } break;
@@ -5144,14 +6004,18 @@ static void ggml_compute_forward_add_f32(
5144
6004
  GGML_ASSERT(nb00 == sizeof(float));
5145
6005
 
5146
6006
  if (nb10 == sizeof(float)) {
5147
- const int j0 = (n/nth)*ith;
5148
- const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
5149
-
5150
- for (int j = j0; j < j1; j++) {
6007
+ for (int j = ith; j < n; j += nth) {
6008
+ #ifdef GGML_USE_ACCELERATE
6009
+ vDSP_vadd(
6010
+ (float *) ((char *) src0->data + j*nb01), 1,
6011
+ (float *) ((char *) src1->data + j*nb11), 1,
6012
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
6013
+ #else
5151
6014
  ggml_vec_add_f32(nc,
5152
6015
  (float *) ((char *) dst->data + j*nb1),
5153
6016
  (float *) ((char *) src0->data + j*nb01),
5154
6017
  (float *) ((char *) src1->data + j*nb11));
6018
+ #endif
5155
6019
  }
5156
6020
  } else {
5157
6021
  // src1 is not contiguous
@@ -5167,6 +6031,212 @@ static void ggml_compute_forward_add_f32(
5167
6031
  }
5168
6032
  }
5169
6033
 
6034
+ static void ggml_compute_forward_add_f16_f32(
6035
+ const struct ggml_compute_params * params,
6036
+ const struct ggml_tensor * src0,
6037
+ const struct ggml_tensor * src1,
6038
+ struct ggml_tensor * dst) {
6039
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6040
+
6041
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6042
+ return;
6043
+ }
6044
+
6045
+ const int ith = params->ith;
6046
+ const int nth = params->nth;
6047
+
6048
+ const int n = ggml_nrows(src0);
6049
+ const int nc = src0->ne[0];
6050
+
6051
+ const size_t nb00 = src0->nb[0];
6052
+ const size_t nb01 = src0->nb[1];
6053
+
6054
+ const size_t nb10 = src1->nb[0];
6055
+ const size_t nb11 = src1->nb[1];
6056
+
6057
+ const size_t nb0 = dst->nb[0];
6058
+ const size_t nb1 = dst->nb[1];
6059
+
6060
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6061
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6062
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6063
+
6064
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6065
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6066
+
6067
+ if (nb10 == sizeof(float)) {
6068
+ for (int j = ith; j < n; j += nth) {
6069
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6070
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6071
+ for (int i = 0; i < nc; i++) {
6072
+ float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
6073
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
6074
+ }
6075
+ }
6076
+ }
6077
+ else {
6078
+ // src1 is not contiguous
6079
+ GGML_ASSERT(false);
6080
+ }
6081
+ }
6082
+
6083
+ static void ggml_compute_forward_add_f16_f16(
6084
+ const struct ggml_compute_params * params,
6085
+ const struct ggml_tensor * src0,
6086
+ const struct ggml_tensor * src1,
6087
+ struct ggml_tensor * dst) {
6088
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6089
+
6090
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6091
+ return;
6092
+ }
6093
+
6094
+ const int ith = params->ith;
6095
+ const int nth = params->nth;
6096
+
6097
+ const int n = ggml_nrows(src0);
6098
+ const int nc = src0->ne[0];
6099
+
6100
+ const size_t nb00 = src0->nb[0];
6101
+ const size_t nb01 = src0->nb[1];
6102
+
6103
+ const size_t nb10 = src1->nb[0];
6104
+ const size_t nb11 = src1->nb[1];
6105
+
6106
+ const size_t nb0 = dst->nb[0];
6107
+ const size_t nb1 = dst->nb[1];
6108
+
6109
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
6110
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
6111
+ GGML_ASSERT(dst->type == GGML_TYPE_F16);
6112
+
6113
+ GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6114
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6115
+
6116
+ if (nb10 == sizeof(ggml_fp16_t)) {
6117
+ for (int j = ith; j < n; j += nth) {
6118
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
6119
+ ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
6120
+ for (int i = 0; i < nc; i++) {
6121
+ ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
6122
+ dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
6123
+ }
6124
+ }
6125
+ }
6126
+ else {
6127
+ // src1 is not contiguous
6128
+ GGML_ASSERT(false);
6129
+ }
6130
+ }
6131
+
6132
+ static void ggml_compute_forward_add_q_f32(
6133
+ const struct ggml_compute_params * params,
6134
+ const struct ggml_tensor * src0,
6135
+ const struct ggml_tensor * src1,
6136
+ struct ggml_tensor * dst) {
6137
+ GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
6138
+
6139
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
6140
+ return;
6141
+ }
6142
+
6143
+ const int64_t ne00 = src0->ne[0];
6144
+ const int64_t ne01 = src0->ne[1];
6145
+ const int64_t ne02 = src0->ne[2];
6146
+ const int64_t ne03 = src0->ne[3];
6147
+
6148
+ //const int64_t ne10 = src1->ne[0];
6149
+ //const int64_t ne11 = src1->ne[1];
6150
+ const int64_t ne12 = src1->ne[2];
6151
+ const int64_t ne13 = src1->ne[3];
6152
+
6153
+ //const int64_t ne0 = dst->ne[0];
6154
+ //const int64_t ne1 = dst->ne[1];
6155
+ const int64_t ne2 = dst->ne[2];
6156
+ const int64_t ne3 = dst->ne[3];
6157
+
6158
+ const int nb00 = src0->nb[0];
6159
+ const int nb01 = src0->nb[1];
6160
+ const int nb02 = src0->nb[2];
6161
+ const int nb03 = src0->nb[3];
6162
+
6163
+ const int nb10 = src1->nb[0];
6164
+ const int nb11 = src1->nb[1];
6165
+ const int nb12 = src1->nb[2];
6166
+ const int nb13 = src1->nb[3];
6167
+
6168
+ const int nb0 = dst->nb[0];
6169
+ const int nb1 = dst->nb[1];
6170
+ const int nb2 = dst->nb[2];
6171
+ const int nb3 = dst->nb[3];
6172
+
6173
+ const int ith = params->ith;
6174
+ const int nth = params->nth;
6175
+
6176
+ GGML_ASSERT(ne02 == ne12);
6177
+ GGML_ASSERT(ne03 == ne13);
6178
+ GGML_ASSERT(ne2 == ne12);
6179
+ GGML_ASSERT(ne3 == ne13);
6180
+
6181
+ const enum ggml_type type = src0->type;
6182
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
6183
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6184
+
6185
+ // we don't support permuted src0 or src1
6186
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
6187
+ GGML_ASSERT(nb10 == sizeof(float));
6188
+
6189
+ // dst cannot be transposed or permuted
6190
+ GGML_ASSERT(nb0 <= nb1);
6191
+ GGML_ASSERT(nb1 <= nb2);
6192
+ GGML_ASSERT(nb2 <= nb3);
6193
+
6194
+ GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
6195
+ GGML_ASSERT(dst->type == src0->type);
6196
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
6197
+
6198
+ // total rows in src0
6199
+ const int nr = ne01*ne02*ne03;
6200
+
6201
+ // rows per thread
6202
+ const int dr = (nr + nth - 1)/nth;
6203
+
6204
+ // row range for this thread
6205
+ const int ir0 = dr*ith;
6206
+ const int ir1 = MIN(ir0 + dr, nr);
6207
+
6208
+ float * wdata = (float*) params->wdata + ne00 * ith;
6209
+
6210
+ for (int ir = ir0; ir < ir1; ++ir) {
6211
+ // src0 indices
6212
+ const int i03 = ir/(ne02*ne01);
6213
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
6214
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
6215
+
6216
+ // src1 and dst are same shape as src0 => same indices
6217
+ const int i13 = i03;
6218
+ const int i12 = i02;
6219
+ const int i11 = i01;
6220
+
6221
+ const int i3 = i03;
6222
+ const int i2 = i02;
6223
+ const int i1 = i01;
6224
+
6225
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
6226
+ float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
6227
+ void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
6228
+
6229
+ assert(ne00 % 32 == 0);
6230
+
6231
+ // unquantize row from src0 to temp buffer
6232
+ dequantize_row_q(src0_row, wdata, ne00);
6233
+ // add src1
6234
+ ggml_vec_acc_f32(ne00, wdata, src1_row);
6235
+ // quantize row to dst
6236
+ quantize_row_q(wdata, dst_row, ne00);
6237
+ }
6238
+ }
6239
+
5170
6240
  static void ggml_compute_forward_add(
5171
6241
  const struct ggml_compute_params * params,
5172
6242
  const struct ggml_tensor * src0,
@@ -5177,13 +6247,24 @@ static void ggml_compute_forward_add(
5177
6247
  {
5178
6248
  ggml_compute_forward_add_f32(params, src0, src1, dst);
5179
6249
  } break;
6250
+ case GGML_TYPE_F16:
6251
+ {
6252
+ if (src1->type == GGML_TYPE_F16) {
6253
+ ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
6254
+ }
6255
+ else if (src1->type == GGML_TYPE_F32) {
6256
+ ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
6257
+ }
6258
+ else {
6259
+ GGML_ASSERT(false);
6260
+ }
6261
+ } break;
5180
6262
  case GGML_TYPE_Q4_0:
5181
6263
  case GGML_TYPE_Q4_1:
5182
- case GGML_TYPE_I8:
5183
- case GGML_TYPE_I16:
5184
- case GGML_TYPE_I32:
5185
- case GGML_TYPE_F16:
5186
- case GGML_TYPE_COUNT:
6264
+ {
6265
+ ggml_compute_forward_add_q_f32(params, src0, src1, dst);
6266
+ } break;
6267
+ default:
5187
6268
  {
5188
6269
  GGML_ASSERT(false);
5189
6270
  } break;
@@ -5229,13 +6310,7 @@ static void ggml_compute_forward_sub(
5229
6310
  {
5230
6311
  ggml_compute_forward_sub_f32(params, src0, src1, dst);
5231
6312
  } break;
5232
- case GGML_TYPE_Q4_0:
5233
- case GGML_TYPE_Q4_1:
5234
- case GGML_TYPE_I8:
5235
- case GGML_TYPE_I16:
5236
- case GGML_TYPE_I32:
5237
- case GGML_TYPE_F16:
5238
- case GGML_TYPE_COUNT:
6313
+ default:
5239
6314
  {
5240
6315
  GGML_ASSERT(false);
5241
6316
  } break;
@@ -5281,13 +6356,7 @@ static void ggml_compute_forward_mul(
5281
6356
  {
5282
6357
  ggml_compute_forward_mul_f32(params, src0, src1, dst);
5283
6358
  } break;
5284
- case GGML_TYPE_Q4_0:
5285
- case GGML_TYPE_Q4_1:
5286
- case GGML_TYPE_I8:
5287
- case GGML_TYPE_I16:
5288
- case GGML_TYPE_I32:
5289
- case GGML_TYPE_F16:
5290
- case GGML_TYPE_COUNT:
6359
+ default:
5291
6360
  {
5292
6361
  GGML_ASSERT(false);
5293
6362
  } break;
@@ -5333,13 +6402,7 @@ static void ggml_compute_forward_div(
5333
6402
  {
5334
6403
  ggml_compute_forward_div_f32(params, src0, src1, dst);
5335
6404
  } break;
5336
- case GGML_TYPE_Q4_0:
5337
- case GGML_TYPE_Q4_1:
5338
- case GGML_TYPE_I8:
5339
- case GGML_TYPE_I16:
5340
- case GGML_TYPE_I32:
5341
- case GGML_TYPE_F16:
5342
- case GGML_TYPE_COUNT:
6405
+ default:
5343
6406
  {
5344
6407
  GGML_ASSERT(false);
5345
6408
  } break;
@@ -5381,13 +6444,7 @@ static void ggml_compute_forward_sqr(
5381
6444
  {
5382
6445
  ggml_compute_forward_sqr_f32(params, src0, dst);
5383
6446
  } break;
5384
- case GGML_TYPE_Q4_0:
5385
- case GGML_TYPE_Q4_1:
5386
- case GGML_TYPE_I8:
5387
- case GGML_TYPE_I16:
5388
- case GGML_TYPE_I32:
5389
- case GGML_TYPE_F16:
5390
- case GGML_TYPE_COUNT:
6447
+ default:
5391
6448
  {
5392
6449
  GGML_ASSERT(false);
5393
6450
  } break;
@@ -5429,13 +6486,7 @@ static void ggml_compute_forward_sqrt(
5429
6486
  {
5430
6487
  ggml_compute_forward_sqrt_f32(params, src0, dst);
5431
6488
  } break;
5432
- case GGML_TYPE_Q4_0:
5433
- case GGML_TYPE_Q4_1:
5434
- case GGML_TYPE_I8:
5435
- case GGML_TYPE_I16:
5436
- case GGML_TYPE_I32:
5437
- case GGML_TYPE_F16:
5438
- case GGML_TYPE_COUNT:
6489
+ default:
5439
6490
  {
5440
6491
  GGML_ASSERT(false);
5441
6492
  } break;
@@ -5485,16 +6536,10 @@ static void ggml_compute_forward_sum(
5485
6536
  switch (src0->type) {
5486
6537
  case GGML_TYPE_F32:
5487
6538
  {
5488
- ggml_compute_forward_sum_f32(params, src0, dst);
5489
- } break;
5490
- case GGML_TYPE_Q4_0:
5491
- case GGML_TYPE_Q4_1:
5492
- case GGML_TYPE_I8:
5493
- case GGML_TYPE_I16:
5494
- case GGML_TYPE_I32:
5495
- case GGML_TYPE_F16:
5496
- case GGML_TYPE_COUNT:
5497
- {
6539
+ ggml_compute_forward_sum_f32(params, src0, dst);
6540
+ } break;
6541
+ default:
6542
+ {
5498
6543
  GGML_ASSERT(false);
5499
6544
  } break;
5500
6545
  }
@@ -5564,13 +6609,7 @@ static void ggml_compute_forward_mean(
5564
6609
  {
5565
6610
  ggml_compute_forward_mean_f32(params, src0, dst);
5566
6611
  } break;
5567
- case GGML_TYPE_Q4_0:
5568
- case GGML_TYPE_Q4_1:
5569
- case GGML_TYPE_I8:
5570
- case GGML_TYPE_I16:
5571
- case GGML_TYPE_I32:
5572
- case GGML_TYPE_F16:
5573
- case GGML_TYPE_COUNT:
6612
+ default:
5574
6613
  {
5575
6614
  GGML_ASSERT(false);
5576
6615
  } break;
@@ -5628,13 +6667,7 @@ static void ggml_compute_forward_repeat(
5628
6667
  {
5629
6668
  ggml_compute_forward_repeat_f32(params, src0, dst);
5630
6669
  } break;
5631
- case GGML_TYPE_Q4_0:
5632
- case GGML_TYPE_Q4_1:
5633
- case GGML_TYPE_I8:
5634
- case GGML_TYPE_I16:
5635
- case GGML_TYPE_I32:
5636
- case GGML_TYPE_F16:
5637
- case GGML_TYPE_COUNT:
6670
+ default:
5638
6671
  {
5639
6672
  GGML_ASSERT(false);
5640
6673
  } break;
@@ -5676,13 +6709,7 @@ static void ggml_compute_forward_abs(
5676
6709
  {
5677
6710
  ggml_compute_forward_abs_f32(params, src0, dst);
5678
6711
  } break;
5679
- case GGML_TYPE_Q4_0:
5680
- case GGML_TYPE_Q4_1:
5681
- case GGML_TYPE_I8:
5682
- case GGML_TYPE_I16:
5683
- case GGML_TYPE_I32:
5684
- case GGML_TYPE_F16:
5685
- case GGML_TYPE_COUNT:
6712
+ default:
5686
6713
  {
5687
6714
  GGML_ASSERT(false);
5688
6715
  } break;
@@ -5724,13 +6751,7 @@ static void ggml_compute_forward_sgn(
5724
6751
  {
5725
6752
  ggml_compute_forward_sgn_f32(params, src0, dst);
5726
6753
  } break;
5727
- case GGML_TYPE_Q4_0:
5728
- case GGML_TYPE_Q4_1:
5729
- case GGML_TYPE_I8:
5730
- case GGML_TYPE_I16:
5731
- case GGML_TYPE_I32:
5732
- case GGML_TYPE_F16:
5733
- case GGML_TYPE_COUNT:
6754
+ default:
5734
6755
  {
5735
6756
  GGML_ASSERT(false);
5736
6757
  } break;
@@ -5772,13 +6793,7 @@ static void ggml_compute_forward_neg(
5772
6793
  {
5773
6794
  ggml_compute_forward_neg_f32(params, src0, dst);
5774
6795
  } break;
5775
- case GGML_TYPE_Q4_0:
5776
- case GGML_TYPE_Q4_1:
5777
- case GGML_TYPE_I8:
5778
- case GGML_TYPE_I16:
5779
- case GGML_TYPE_I32:
5780
- case GGML_TYPE_F16:
5781
- case GGML_TYPE_COUNT:
6796
+ default:
5782
6797
  {
5783
6798
  GGML_ASSERT(false);
5784
6799
  } break;
@@ -5820,13 +6835,7 @@ static void ggml_compute_forward_step(
5820
6835
  {
5821
6836
  ggml_compute_forward_step_f32(params, src0, dst);
5822
6837
  } break;
5823
- case GGML_TYPE_Q4_0:
5824
- case GGML_TYPE_Q4_1:
5825
- case GGML_TYPE_I8:
5826
- case GGML_TYPE_I16:
5827
- case GGML_TYPE_I32:
5828
- case GGML_TYPE_F16:
5829
- case GGML_TYPE_COUNT:
6838
+ default:
5830
6839
  {
5831
6840
  GGML_ASSERT(false);
5832
6841
  } break;
@@ -5868,13 +6877,7 @@ static void ggml_compute_forward_relu(
5868
6877
  {
5869
6878
  ggml_compute_forward_relu_f32(params, src0, dst);
5870
6879
  } break;
5871
- case GGML_TYPE_Q4_0:
5872
- case GGML_TYPE_Q4_1:
5873
- case GGML_TYPE_I8:
5874
- case GGML_TYPE_I16:
5875
- case GGML_TYPE_I32:
5876
- case GGML_TYPE_F16:
5877
- case GGML_TYPE_COUNT:
6880
+ default:
5878
6881
  {
5879
6882
  GGML_ASSERT(false);
5880
6883
  } break;
@@ -5933,13 +6936,7 @@ static void ggml_compute_forward_gelu(
5933
6936
  {
5934
6937
  ggml_compute_forward_gelu_f32(params, src0, dst);
5935
6938
  } break;
5936
- case GGML_TYPE_Q4_0:
5937
- case GGML_TYPE_Q4_1:
5938
- case GGML_TYPE_I8:
5939
- case GGML_TYPE_I16:
5940
- case GGML_TYPE_I32:
5941
- case GGML_TYPE_F16:
5942
- case GGML_TYPE_COUNT:
6939
+ default:
5943
6940
  {
5944
6941
  GGML_ASSERT(false);
5945
6942
  } break;
@@ -6000,13 +6997,7 @@ static void ggml_compute_forward_silu(
6000
6997
  {
6001
6998
  ggml_compute_forward_silu_f32(params, src0, dst);
6002
6999
  } break;
6003
- case GGML_TYPE_Q4_0:
6004
- case GGML_TYPE_Q4_1:
6005
- case GGML_TYPE_I8:
6006
- case GGML_TYPE_I16:
6007
- case GGML_TYPE_I32:
6008
- case GGML_TYPE_F16:
6009
- case GGML_TYPE_COUNT:
7000
+ default:
6010
7001
  {
6011
7002
  GGML_ASSERT(false);
6012
7003
  } break;
@@ -6086,13 +7077,7 @@ static void ggml_compute_forward_norm(
6086
7077
  {
6087
7078
  ggml_compute_forward_norm_f32(params, src0, dst);
6088
7079
  } break;
6089
- case GGML_TYPE_Q4_0:
6090
- case GGML_TYPE_Q4_1:
6091
- case GGML_TYPE_I8:
6092
- case GGML_TYPE_I16:
6093
- case GGML_TYPE_I32:
6094
- case GGML_TYPE_F16:
6095
- case GGML_TYPE_COUNT:
7080
+ default:
6096
7081
  {
6097
7082
  GGML_ASSERT(false);
6098
7083
  } break;
@@ -6166,13 +7151,7 @@ static void ggml_compute_forward_rms_norm(
6166
7151
  {
6167
7152
  ggml_compute_forward_rms_norm_f32(params, src0, dst);
6168
7153
  } break;
6169
- case GGML_TYPE_Q4_0:
6170
- case GGML_TYPE_Q4_1:
6171
- case GGML_TYPE_I8:
6172
- case GGML_TYPE_I16:
6173
- case GGML_TYPE_I32:
6174
- case GGML_TYPE_F16:
6175
- case GGML_TYPE_COUNT:
7154
+ default:
6176
7155
  {
6177
7156
  GGML_ASSERT(false);
6178
7157
  } break;
@@ -6304,7 +7283,7 @@ static void ggml_compute_forward_mul_mat_f32(
6304
7283
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6305
7284
  ne11, ne01, ne10,
6306
7285
  1.0f, y, ne10,
6307
- x, ne10,
7286
+ x, ne00,
6308
7287
  0.0f, d, ne01);
6309
7288
  }
6310
7289
  }
@@ -6476,7 +7455,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6476
7455
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6477
7456
  ne11, ne01, ne10,
6478
7457
  1.0f, y, ne10,
6479
- x, ne10,
7458
+ x, ne00,
6480
7459
  0.0f, d, ne01);
6481
7460
  }
6482
7461
  }
@@ -6564,29 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
6564
7543
  //}
6565
7544
  }
6566
7545
 
6567
- typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
6568
- typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
6569
- typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
6570
-
6571
- typedef struct {
6572
- dequantize_row_q_t dequantize_row_q;
6573
- quantize_row_q_t quantize_row_q;
6574
- vec_dot_q_t vec_dot_q;
6575
- } quantize_fns_t;
6576
-
6577
- static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6578
- [GGML_TYPE_Q4_0] = {
6579
- .dequantize_row_q = dequantize_row_q4_0,
6580
- .quantize_row_q = quantize_row_q4_0,
6581
- .vec_dot_q = ggml_vec_dot_q4_0,
6582
- },
6583
- [GGML_TYPE_Q4_1] = {
6584
- .dequantize_row_q = dequantize_row_q4_1,
6585
- .quantize_row_q = quantize_row_q4_1,
6586
- .vec_dot_q = ggml_vec_dot_q4_1,
6587
- },
6588
- };
6589
-
6590
7546
  static void ggml_compute_forward_mul_mat_q_f32(
6591
7547
  const struct ggml_compute_params * params,
6592
7548
  const struct ggml_tensor * src0,
@@ -6634,8 +7590,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
6634
7590
  GGML_ASSERT(ne3 == ne13);
6635
7591
 
6636
7592
  const enum ggml_type type = src0->type;
6637
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
6638
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
7593
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
7594
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
6639
7595
 
6640
7596
  // we don't support permuted src0 or src1
6641
7597
  GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -6691,7 +7647,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6691
7647
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
6692
7648
  ne11, ne01, ne10,
6693
7649
  1.0f, y, ne10,
6694
- x, ne10,
7650
+ x, ne00,
6695
7651
  0.0f, d, ne01);
6696
7652
  }
6697
7653
  }
@@ -6704,12 +7660,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
6704
7660
 
6705
7661
  if (params->type == GGML_TASK_INIT) {
6706
7662
  char * wdata = params->wdata;
6707
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7663
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
6708
7664
 
6709
7665
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
6710
7666
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
6711
7667
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
6712
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
7668
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
6713
7669
  wdata += row_size;
6714
7670
  }
6715
7671
  }
@@ -6735,7 +7691,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
6735
7691
  const int ir1 = MIN(ir0 + dr, nr);
6736
7692
 
6737
7693
  void * wdata = params->wdata;
6738
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
7694
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
6739
7695
 
6740
7696
  for (int ir = ir0; ir < ir1; ++ir) {
6741
7697
  // src0 indices
@@ -6783,6 +7739,7 @@ static void ggml_compute_forward_mul_mat(
6783
7739
  switch (src0->type) {
6784
7740
  case GGML_TYPE_Q4_0:
6785
7741
  case GGML_TYPE_Q4_1:
7742
+ case GGML_TYPE_Q8_0:
6786
7743
  {
6787
7744
  ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
6788
7745
  } break;
@@ -6794,10 +7751,7 @@ static void ggml_compute_forward_mul_mat(
6794
7751
  {
6795
7752
  ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
6796
7753
  } break;
6797
- case GGML_TYPE_I8:
6798
- case GGML_TYPE_I16:
6799
- case GGML_TYPE_I32:
6800
- case GGML_TYPE_COUNT:
7754
+ default:
6801
7755
  {
6802
7756
  GGML_ASSERT(false);
6803
7757
  } break;
@@ -6879,13 +7833,7 @@ static void ggml_compute_forward_scale(
6879
7833
  {
6880
7834
  ggml_compute_forward_scale_f32(params, src0, src1, dst);
6881
7835
  } break;
6882
- case GGML_TYPE_Q4_0:
6883
- case GGML_TYPE_Q4_1:
6884
- case GGML_TYPE_I8:
6885
- case GGML_TYPE_I16:
6886
- case GGML_TYPE_I32:
6887
- case GGML_TYPE_F16:
6888
- case GGML_TYPE_COUNT:
7836
+ default:
6889
7837
  {
6890
7838
  GGML_ASSERT(false);
6891
7839
  } break;
@@ -6901,6 +7849,15 @@ static void ggml_compute_forward_cpy(
6901
7849
  ggml_compute_forward_dup(params, src0, dst);
6902
7850
  }
6903
7851
 
7852
+ // ggml_compute_forward_cont
7853
+
7854
+ static void ggml_compute_forward_cont(
7855
+ const struct ggml_compute_params * params,
7856
+ const struct ggml_tensor * src0,
7857
+ struct ggml_tensor * dst) {
7858
+ ggml_compute_forward_dup(params, src0, dst);
7859
+ }
7860
+
6904
7861
  // ggml_compute_forward_reshape
6905
7862
 
6906
7863
  static void ggml_compute_forward_reshape(
@@ -7037,6 +7994,7 @@ static void ggml_compute_forward_get_rows(
7037
7994
  switch (src0->type) {
7038
7995
  case GGML_TYPE_Q4_0:
7039
7996
  case GGML_TYPE_Q4_1:
7997
+ case GGML_TYPE_Q8_0:
7040
7998
  {
7041
7999
  ggml_compute_forward_get_rows_q(params, src0, src1, dst);
7042
8000
  } break;
@@ -7048,10 +8006,7 @@ static void ggml_compute_forward_get_rows(
7048
8006
  {
7049
8007
  ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
7050
8008
  } break;
7051
- case GGML_TYPE_I8:
7052
- case GGML_TYPE_I16:
7053
- case GGML_TYPE_I32:
7054
- case GGML_TYPE_COUNT:
8009
+ default:
7055
8010
  {
7056
8011
  GGML_ASSERT(false);
7057
8012
  } break;
@@ -7124,13 +8079,7 @@ static void ggml_compute_forward_diag_mask_inf(
7124
8079
  {
7125
8080
  ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
7126
8081
  } break;
7127
- case GGML_TYPE_Q4_0:
7128
- case GGML_TYPE_Q4_1:
7129
- case GGML_TYPE_I8:
7130
- case GGML_TYPE_I16:
7131
- case GGML_TYPE_I32:
7132
- case GGML_TYPE_F16:
7133
- case GGML_TYPE_COUNT:
8082
+ default:
7134
8083
  {
7135
8084
  GGML_ASSERT(false);
7136
8085
  } break;
@@ -7218,13 +8167,7 @@ static void ggml_compute_forward_soft_max(
7218
8167
  {
7219
8168
  ggml_compute_forward_soft_max_f32(params, src0, dst);
7220
8169
  } break;
7221
- case GGML_TYPE_Q4_0:
7222
- case GGML_TYPE_Q4_1:
7223
- case GGML_TYPE_I8:
7224
- case GGML_TYPE_I16:
7225
- case GGML_TYPE_I32:
7226
- case GGML_TYPE_F16:
7227
- case GGML_TYPE_COUNT:
8170
+ default:
7228
8171
  {
7229
8172
  GGML_ASSERT(false);
7230
8173
  } break;
@@ -7279,6 +8222,8 @@ static void ggml_compute_forward_rope_f32(
7279
8222
  // row index used to determine which thread to use
7280
8223
  int ir = 0;
7281
8224
 
8225
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
8226
+
7282
8227
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7283
8228
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7284
8229
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7286,11 +8231,13 @@ static void ggml_compute_forward_rope_f32(
7286
8231
  if (ir++ < ir0) continue;
7287
8232
  if (ir > ir1) break;
7288
8233
 
8234
+ float theta = (float)p;
8235
+
7289
8236
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7290
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
8237
+ const float cos_theta = cosf(theta);
8238
+ const float sin_theta = sinf(theta);
7291
8239
 
7292
- const float cos_theta = cosf(p*theta);
7293
- const float sin_theta = sinf(p*theta);
8240
+ theta *= theta_scale;
7294
8241
 
7295
8242
  const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7296
8243
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -7352,6 +8299,8 @@ static void ggml_compute_forward_rope_f16(
7352
8299
  // row index used to determine which thread to use
7353
8300
  int ir = 0;
7354
8301
 
8302
+ const float theta_scale = powf(10000.0, -2.0f/n_dims);
8303
+
7355
8304
  for (int64_t i3 = 0; i3 < ne3; i3++) {
7356
8305
  for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
7357
8306
  const int p = (mode == 0 ? n_past + i2 : i2);
@@ -7359,20 +8308,22 @@ static void ggml_compute_forward_rope_f16(
7359
8308
  if (ir++ < ir0) continue;
7360
8309
  if (ir > ir1) break;
7361
8310
 
8311
+ float theta = (float)p;
8312
+
7362
8313
  for (int i0 = 0; i0 < n_dims; i0 += 2) {
7363
- const float theta = powf(10000.0, ((float)-i0)/n_dims);
8314
+ const float cos_theta = cosf(theta);
8315
+ const float sin_theta = sinf(theta);
7364
8316
 
7365
- const float cos_theta = cosf(p*theta);
7366
- const float sin_theta = sinf(p*theta);
8317
+ theta *= theta_scale;
7367
8318
 
7368
8319
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7369
8320
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
7370
8321
 
7371
- const float x0 = ggml_fp16_to_fp32(src[0]);
7372
- const float x1 = ggml_fp16_to_fp32(src[1]);
8322
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
8323
+ const float x1 = GGML_FP16_TO_FP32(src[1]);
7373
8324
 
7374
- dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
7375
- dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
8325
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8326
+ dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
7376
8327
  }
7377
8328
  }
7378
8329
  }
@@ -7393,12 +8344,7 @@ static void ggml_compute_forward_rope(
7393
8344
  {
7394
8345
  ggml_compute_forward_rope_f32(params, src0, src1, dst);
7395
8346
  } break;
7396
- case GGML_TYPE_Q4_0:
7397
- case GGML_TYPE_Q4_1:
7398
- case GGML_TYPE_I8:
7399
- case GGML_TYPE_I16:
7400
- case GGML_TYPE_I32:
7401
- case GGML_TYPE_COUNT:
8347
+ default:
7402
8348
  {
7403
8349
  GGML_ASSERT(false);
7404
8350
  } break;
@@ -7661,12 +8607,7 @@ static void ggml_compute_forward_conv_1d_1s(
7661
8607
  {
7662
8608
  ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
7663
8609
  } break;
7664
- case GGML_TYPE_Q4_0:
7665
- case GGML_TYPE_Q4_1:
7666
- case GGML_TYPE_I8:
7667
- case GGML_TYPE_I16:
7668
- case GGML_TYPE_I32:
7669
- case GGML_TYPE_COUNT:
8610
+ default:
7670
8611
  {
7671
8612
  GGML_ASSERT(false);
7672
8613
  } break;
@@ -7929,12 +8870,7 @@ static void ggml_compute_forward_conv_1d_2s(
7929
8870
  {
7930
8871
  ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
7931
8872
  } break;
7932
- case GGML_TYPE_Q4_0:
7933
- case GGML_TYPE_Q4_1:
7934
- case GGML_TYPE_I8:
7935
- case GGML_TYPE_I16:
7936
- case GGML_TYPE_I32:
7937
- case GGML_TYPE_COUNT:
8873
+ default:
7938
8874
  {
7939
8875
  GGML_ASSERT(false);
7940
8876
  } break;
@@ -8414,12 +9350,7 @@ static void ggml_compute_forward_flash_attn(
8414
9350
  {
8415
9351
  ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
8416
9352
  } break;
8417
- case GGML_TYPE_Q4_0:
8418
- case GGML_TYPE_Q4_1:
8419
- case GGML_TYPE_I8:
8420
- case GGML_TYPE_I16:
8421
- case GGML_TYPE_I32:
8422
- case GGML_TYPE_COUNT:
9353
+ default:
8423
9354
  {
8424
9355
  GGML_ASSERT(false);
8425
9356
  } break;
@@ -8625,12 +9556,100 @@ static void ggml_compute_forward_flash_ff(
8625
9556
  {
8626
9557
  GGML_ASSERT(false); // TODO
8627
9558
  } break;
8628
- case GGML_TYPE_Q4_0:
8629
- case GGML_TYPE_Q4_1:
8630
- case GGML_TYPE_I8:
8631
- case GGML_TYPE_I16:
8632
- case GGML_TYPE_I32:
8633
- case GGML_TYPE_COUNT:
9559
+ default:
9560
+ {
9561
+ GGML_ASSERT(false);
9562
+ } break;
9563
+ }
9564
+ }
9565
+
9566
+ // ggml_compute_forward_map_unary
9567
+
9568
+ static void ggml_compute_forward_map_unary_f32(
9569
+ const struct ggml_compute_params * params,
9570
+ const struct ggml_tensor * src0,
9571
+ struct ggml_tensor * dst,
9572
+ const ggml_unary_op_f32_t fun) {
9573
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9574
+
9575
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9576
+ return;
9577
+ }
9578
+
9579
+ const int n = ggml_nrows(src0);
9580
+ const int nc = src0->ne[0];
9581
+
9582
+ assert( dst->nb[0] == sizeof(float));
9583
+ assert(src0->nb[0] == sizeof(float));
9584
+
9585
+ for (int i = 0; i < n; i++) {
9586
+ fun(nc,
9587
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9588
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
9589
+ }
9590
+ }
9591
+
9592
+
9593
+ static void ggml_compute_forward_map_unary(
9594
+ const struct ggml_compute_params * params,
9595
+ const struct ggml_tensor * src0,
9596
+ struct ggml_tensor * dst,
9597
+ const ggml_unary_op_f32_t fun) {
9598
+ switch (src0->type) {
9599
+ case GGML_TYPE_F32:
9600
+ {
9601
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
9602
+ } break;
9603
+ default:
9604
+ {
9605
+ GGML_ASSERT(false);
9606
+ } break;
9607
+ }
9608
+ }
9609
+
9610
+ // ggml_compute_forward_map_binary
9611
+
9612
+ static void ggml_compute_forward_map_binary_f32(
9613
+ const struct ggml_compute_params * params,
9614
+ const struct ggml_tensor * src0,
9615
+ const struct ggml_tensor * src1,
9616
+ struct ggml_tensor * dst,
9617
+ const ggml_binary_op_f32_t fun) {
9618
+ assert(params->ith == 0);
9619
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
9620
+
9621
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9622
+ return;
9623
+ }
9624
+
9625
+ const int n = ggml_nrows(src0);
9626
+ const int nc = src0->ne[0];
9627
+
9628
+ assert( dst->nb[0] == sizeof(float));
9629
+ assert(src0->nb[0] == sizeof(float));
9630
+ assert(src1->nb[0] == sizeof(float));
9631
+
9632
+ for (int i = 0; i < n; i++) {
9633
+ fun(nc,
9634
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
9635
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
9636
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
9637
+ }
9638
+ }
9639
+
9640
+
9641
+ static void ggml_compute_forward_map_binary(
9642
+ const struct ggml_compute_params * params,
9643
+ const struct ggml_tensor * src0,
9644
+ const struct ggml_tensor * src1,
9645
+ struct ggml_tensor * dst,
9646
+ const ggml_binary_op_f32_t fun) {
9647
+ switch (src0->type) {
9648
+ case GGML_TYPE_F32:
9649
+ {
9650
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
9651
+ } break;
9652
+ default:
8634
9653
  {
8635
9654
  GGML_ASSERT(false);
8636
9655
  } break;
@@ -8731,6 +9750,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8731
9750
  {
8732
9751
  ggml_compute_forward_cpy(params, tensor->src0, tensor);
8733
9752
  } break;
9753
+ case GGML_OP_CONT:
9754
+ {
9755
+ ggml_compute_forward_cont(params, tensor->src0, tensor);
9756
+ } break;
8734
9757
  case GGML_OP_RESHAPE:
8735
9758
  {
8736
9759
  ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8782,6 +9805,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8782
9805
  {
8783
9806
  ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
8784
9807
  } break;
9808
+ case GGML_OP_MAP_UNARY:
9809
+ {
9810
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
9811
+ ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
9812
+ }
9813
+ break;
9814
+ case GGML_OP_MAP_BINARY:
9815
+ {
9816
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
9817
+ ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
9818
+ }
9819
+ break;
8785
9820
  case GGML_OP_NONE:
8786
9821
  {
8787
9822
  // nop
@@ -8975,8 +10010,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8975
10010
  src1->grad =
8976
10011
  ggml_add_impl(ctx,
8977
10012
  src1->grad,
8978
- // TODO: fix transpose, the node will break the graph connections
8979
- ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
10013
+ ggml_mul_mat(ctx,
10014
+ ggml_cont(ctx, ggml_transpose(ctx, src0)),
10015
+ tensor->grad),
8980
10016
  inplace);
8981
10017
  }
8982
10018
  } break;
@@ -8988,6 +10024,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
8988
10024
  {
8989
10025
  GGML_ASSERT(false); // TODO: not implemented
8990
10026
  } break;
10027
+ case GGML_OP_CONT:
10028
+ {
10029
+ GGML_ASSERT(false); // TODO: not implemented
10030
+ } break;
8991
10031
  case GGML_OP_RESHAPE:
8992
10032
  {
8993
10033
  GGML_ASSERT(false); // TODO: not implemented
@@ -9036,6 +10076,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
9036
10076
  {
9037
10077
  GGML_ASSERT(false); // not supported
9038
10078
  } break;
10079
+ case GGML_OP_MAP_UNARY:
10080
+ case GGML_OP_MAP_BINARY:
10081
+ {
10082
+ GGML_ASSERT(false); // not supported
10083
+ } break;
9039
10084
  case GGML_OP_NONE:
9040
10085
  {
9041
10086
  // nop
@@ -9126,7 +10171,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
9126
10171
  struct ggml_cgraph result = {
9127
10172
  /*.n_nodes =*/ 0,
9128
10173
  /*.n_leafs =*/ 0,
9129
- /*.n_threads =*/ 0,
10174
+ /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
9130
10175
  /*.work_size =*/ 0,
9131
10176
  /*.work =*/ NULL,
9132
10177
  /*.nodes =*/ { NULL },
@@ -9354,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9354
10399
  struct ggml_tensor * node = cgraph->nodes[i];
9355
10400
 
9356
10401
  switch (node->op) {
10402
+ case GGML_OP_CPY:
9357
10403
  case GGML_OP_DUP:
9358
10404
  {
9359
10405
  node->n_tasks = 1;
10406
+
10407
+ size_t cur = 0;
10408
+ if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
10409
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
10410
+ }
10411
+
10412
+ work_size = MAX(work_size, cur);
9360
10413
  } break;
9361
10414
  case GGML_OP_ADD:
9362
10415
  {
9363
10416
  node->n_tasks = n_threads;
10417
+
10418
+ size_t cur = 0;
10419
+
10420
+ if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
10421
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
10422
+ }
10423
+
10424
+ work_size = MAX(work_size, cur);
9364
10425
  } break;
9365
10426
  case GGML_OP_SUB:
9366
10427
  case GGML_OP_MUL:
@@ -9429,7 +10490,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9429
10490
  } else
9430
10491
  #endif
9431
10492
  {
9432
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
10493
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
9433
10494
  }
9434
10495
  } else {
9435
10496
  GGML_ASSERT(false);
@@ -9441,7 +10502,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9441
10502
  {
9442
10503
  node->n_tasks = n_threads;
9443
10504
  } break;
9444
- case GGML_OP_CPY:
10505
+ case GGML_OP_CONT:
9445
10506
  case GGML_OP_RESHAPE:
9446
10507
  case GGML_OP_VIEW:
9447
10508
  case GGML_OP_PERMUTE:
@@ -9527,6 +10588,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9527
10588
 
9528
10589
  work_size = MAX(work_size, cur);
9529
10590
  } break;
10591
+ case GGML_OP_MAP_UNARY:
10592
+ case GGML_OP_MAP_BINARY:
10593
+ {
10594
+ node->n_tasks = 1;
10595
+ } break;
9530
10596
  case GGML_OP_NONE:
9531
10597
  {
9532
10598
  node->n_tasks = 1;
@@ -9745,8 +10811,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
9745
10811
 
9746
10812
  GGML_PRINT("=== GRAPH ===\n");
9747
10813
 
9748
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
9749
- GGML_PRINT_DEBUG("total work size = %zu bytes\n",cgraph->work_size);
10814
+ GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
10815
+ GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
9750
10816
 
9751
10817
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
9752
10818
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -10598,16 +11664,16 @@ enum ggml_opt_result ggml_opt(
10598
11664
  ////////////////////////////////////////////////////////////////////////////////
10599
11665
 
10600
11666
  size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
10601
- assert(k % QK == 0);
10602
- const int nb = k / QK;
11667
+ assert(k % QK4_0 == 0);
11668
+ const int nb = k / QK4_0;
10603
11669
 
10604
11670
  for (int j = 0; j < n; j += k) {
10605
- block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
11671
+ block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
10606
11672
 
10607
11673
  quantize_row_q4_0_reference(src + j, y, k);
10608
11674
 
10609
11675
  for (int i = 0; i < nb; i++) {
10610
- for (int l = 0; l < QK; l += 2) {
11676
+ for (int l = 0; l < QK4_0; l += 2) {
10611
11677
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
10612
11678
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
10613
11679
 
@@ -10617,20 +11683,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
10617
11683
  }
10618
11684
  }
10619
11685
 
10620
- return (n/QK*sizeof(block_q4_0));
11686
+ return (n/QK4_0*sizeof(block_q4_0));
10621
11687
  }
10622
11688
 
10623
11689
  size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
10624
- assert(k % QK == 0);
10625
- const int nb = k / QK;
11690
+ assert(k % QK4_1 == 0);
11691
+ const int nb = k / QK4_1;
10626
11692
 
10627
11693
  for (int j = 0; j < n; j += k) {
10628
- block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
11694
+ block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
10629
11695
 
10630
11696
  quantize_row_q4_1_reference(src + j, y, k);
10631
11697
 
10632
11698
  for (int i = 0; i < nb; i++) {
10633
- for (int l = 0; l < QK; l += 2) {
11699
+ for (int l = 0; l < QK4_1; l += 2) {
10634
11700
  const uint8_t vi0 = y[i].qs[l/2] & 0xF;
10635
11701
  const uint8_t vi1 = y[i].qs[l/2] >> 4;
10636
11702
 
@@ -10640,7 +11706,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
10640
11706
  }
10641
11707
  }
10642
11708
 
10643
- return (n/QK*sizeof(block_q4_1));
11709
+ return (n/QK4_1*sizeof(block_q4_1));
10644
11710
  }
10645
11711
 
10646
11712
  ////////////////////////////////////////////////////////////////////////////////
@@ -10669,6 +11735,22 @@ int ggml_cpu_has_avx512(void) {
10669
11735
  #endif
10670
11736
  }
10671
11737
 
11738
+ int ggml_cpu_has_avx512_vbmi(void) {
11739
+ #if defined(__AVX512VBMI__)
11740
+ return 1;
11741
+ #else
11742
+ return 0;
11743
+ #endif
11744
+ }
11745
+
11746
+ int ggml_cpu_has_avx512_vnni(void) {
11747
+ #if defined(__AVX512VNNI__)
11748
+ return 1;
11749
+ #else
11750
+ return 0;
11751
+ #endif
11752
+ }
11753
+
10672
11754
  int ggml_cpu_has_fma(void) {
10673
11755
  #if defined(__FMA__)
10674
11756
  return 1;