@fugood/llama.node 1.1.5 → 1.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. package/package.json +14 -14
  2. package/scripts/llama.cpp.patch +17 -13
  3. package/src/LlamaCompletionWorker.cpp +2 -0
  4. package/src/llama.cpp/common/arg.cpp +28 -11
  5. package/src/llama.cpp/common/chat.cpp +46 -2
  6. package/src/llama.cpp/common/chat.h +7 -2
  7. package/src/llama.cpp/common/common.h +3 -2
  8. package/src/llama.cpp/ggml/CMakeLists.txt +3 -2
  9. package/src/llama.cpp/ggml/include/ggml.h +37 -1
  10. package/src/llama.cpp/ggml/src/CMakeLists.txt +12 -1
  11. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +61 -0
  12. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +96 -8
  13. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +6 -0
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14 -1
  15. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +207 -9
  16. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -7
  17. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +35 -0
  18. package/src/llama.cpp/ggml/src/ggml-cpu/quants.h +8 -0
  19. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +19 -4
  20. package/src/llama.cpp/include/llama.h +1 -0
  21. package/src/llama.cpp/src/llama-arch.cpp +65 -0
  22. package/src/llama.cpp/src/llama-arch.h +10 -0
  23. package/src/llama.cpp/src/llama-chat.cpp +13 -0
  24. package/src/llama.cpp/src/llama-chat.h +1 -0
  25. package/src/llama.cpp/src/llama-context.cpp +8 -8
  26. package/src/llama.cpp/src/llama-graph.cpp +118 -9
  27. package/src/llama.cpp/src/llama-graph.h +38 -0
  28. package/src/llama.cpp/src/llama-hparams.h +5 -3
  29. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +4 -0
  30. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  31. package/src/llama.cpp/src/llama-model-loader.h +3 -2
  32. package/src/llama.cpp/src/llama-model.cpp +499 -4
  33. package/src/llama.cpp/src/llama-model.h +24 -4
  34. package/src/llama.cpp/src/llama-quant.cpp +37 -1
  35. package/src/llama.cpp/src/llama-vocab.cpp +42 -0
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
66
66
  }
67
67
 
68
68
  #if defined(__AVX2__) || defined(__AVX512F__)
69
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
70
+ const __m256i ax = _mm256_sign_epi8(x, x);
71
+ const __m256i sy = _mm256_sign_epi8(y, x);
72
+ return _mm256_maddubs_epi16(ax, sy);
73
+ }
74
+
69
75
  // spread 32 bits to 32 bytes { 0x00, 0xFF }
70
76
  static inline __m256i bytes_from_bits_32(const uint8_t * x) {
71
77
  uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
261
267
  return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
262
268
  _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
263
269
  }
270
+
271
+ static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
272
+ return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
273
+ _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
274
+ }
264
275
  #endif
265
276
  #elif defined(__SSSE3__)
266
277
  // horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
746
757
  #endif
747
758
  }
748
759
 
760
+ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
761
+ assert(nrc == 1);
762
+ UNUSED(nrc);
763
+ UNUSED(bx);
764
+ UNUSED(by);
765
+ UNUSED(bs);
766
+ assert(n % QK_MXFP4 == 0);
767
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
768
+
769
+ const block_mxfp4 * GGML_RESTRICT x = vx;
770
+ const block_q8_0 * GGML_RESTRICT y = vy;
771
+
772
+ const int nb = n / QK_MXFP4;
773
+
774
+ int ib = 0;
775
+ float sumf = 0;
776
+
777
+ #if defined __AVX2__
778
+
779
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
780
+ const __m128i m4b = _mm_set1_epi8(0x0f);
781
+ const __m256i mone = _mm256_set1_epi16(1);
782
+
783
+ __m256 accum1 = _mm256_setzero_ps();
784
+ __m256 accum2 = _mm256_setzero_ps();
785
+ for (; ib + 1 < nb; ib += 2) {
786
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
787
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
788
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
789
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
790
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
791
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
792
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
793
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
794
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
795
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
796
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
797
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
798
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
799
+ _mm256_cvtepi32_ps(p_1), accum1);
800
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
801
+ _mm256_cvtepi32_ps(p_2), accum2);
802
+ }
803
+
804
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
805
+
806
+ #elif defined __AVX__
807
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
808
+ const __m128i m4b = _mm_set1_epi8(0x0f);
809
+
810
+ __m256 accum = _mm256_setzero_ps();
811
+ for (; ib + 1 < nb; ib += 2) {
812
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
813
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
814
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
815
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
816
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
817
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
818
+
819
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
820
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
821
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
822
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
823
+
824
+ const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
825
+ const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
826
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
827
+ }
828
+
829
+ sumf = hsum_float_8(accum);
830
+
831
+ #endif
832
+ for (; ib < nb; ++ib) {
833
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
834
+ int sumi1 = 0;
835
+ int sumi2 = 0;
836
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
837
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
838
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
839
+ }
840
+ sumf += d * (sumi1 + sumi2);
841
+ }
842
+ *s = sumf;
843
+ }
844
+
749
845
  void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
750
846
  const int qk = QK8_0;
751
847
  const int nb = n / qk;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
3206
3302
  #endif
3207
3303
  }
3208
3304
 
3209
- #if defined(__AVX2__)
3210
- static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
3211
- const __m256i ax = _mm256_sign_epi8(x, x);
3212
- const __m256i sy = _mm256_sign_epi8(y, x);
3213
- return _mm256_maddubs_epi16(ax, sy);
3214
- }
3215
- #endif
3216
-
3217
3305
  void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
3218
3306
  assert(n % QK_K == 0);
3219
3307
  assert(nrc == 1);
@@ -13,6 +13,7 @@
13
13
  #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
14
14
  #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
15
15
  #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
16
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
16
17
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
17
18
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
18
19
  #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -68,6 +69,7 @@
68
69
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
69
70
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
70
71
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
72
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
71
73
  // repack.cpp
72
74
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
73
75
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -90,6 +92,7 @@
90
92
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
91
93
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
92
94
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
95
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
93
96
  // repack.cpp
94
97
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
95
98
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -120,6 +123,7 @@
120
123
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
121
124
  #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
122
125
  #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
126
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
123
127
  // repack.cpp
124
128
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
125
129
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -149,6 +153,7 @@
149
153
  #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
150
154
  #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
151
155
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
156
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
152
157
  // repack.cpp
153
158
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
154
159
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -179,6 +184,7 @@
179
184
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
180
185
  #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
181
186
  #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
187
+ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
182
188
  // repack.cpp
183
189
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
184
190
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
253
253
  .vec_dot_type = GGML_TYPE_Q8_1,
254
254
  .nrows = 1,
255
255
  },
256
+ [GGML_TYPE_MXFP4] = {
257
+ .from_float = quantize_row_mxfp4,
258
+ .vec_dot = ggml_vec_dot_mxfp4_q8_0,
259
+ .vec_dot_type = GGML_TYPE_Q8_0,
260
+ .nrows = 1,
261
+ },
256
262
  [GGML_TYPE_Q2_K] = {
257
263
  .from_float = quantize_row_q2_K,
258
264
  .vec_dot = ggml_vec_dot_q2_K_q8_K,
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1670
1676
  {
1671
1677
  ggml_compute_forward_add(params, tensor);
1672
1678
  } break;
1679
+ case GGML_OP_ADD_ID:
1680
+ {
1681
+ ggml_compute_forward_add_id(params, tensor);
1682
+ } break;
1673
1683
  case GGML_OP_ADD1:
1674
1684
  {
1675
1685
  ggml_compute_forward_add1(params, tensor);
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1924
1934
  } break;
1925
1935
  case GGML_OP_FLASH_ATTN_EXT:
1926
1936
  {
1927
- ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
1937
+ ggml_compute_forward_flash_attn_ext(params, tensor);
1928
1938
  } break;
1929
1939
  case GGML_OP_FLASH_ATTN_BACK:
1930
1940
  {
@@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2111
2121
  case GGML_OP_DUP:
2112
2122
  case GGML_OP_CONT:
2113
2123
  case GGML_OP_ADD:
2124
+ case GGML_OP_ADD_ID:
2114
2125
  case GGML_OP_ADD1:
2115
2126
  case GGML_OP_ACC:
2116
2127
  {
@@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2172
2183
  case GGML_GLU_OP_REGLU:
2173
2184
  case GGML_GLU_OP_GEGLU:
2174
2185
  case GGML_GLU_OP_SWIGLU:
2186
+ case GGML_GLU_OP_SWIGLU_OAI:
2175
2187
  case GGML_GLU_OP_GEGLU_ERF:
2176
2188
  case GGML_GLU_OP_GEGLU_QUICK:
2177
2189
  {
@@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
2673
2685
  }
2674
2686
  } break;
2675
2687
  case GGML_OP_ADD:
2688
+ case GGML_OP_ADD_ID:
2676
2689
  case GGML_OP_ADD1:
2677
2690
  {
2678
2691
  if (ggml_is_quantized(node->src[0]->type)) {
@@ -8,6 +8,7 @@
8
8
  #include "vec.h"
9
9
 
10
10
  #include <float.h>
11
+ #include <algorithm>
11
12
 
12
13
  // ggml_compute_forward_dup
13
14
 
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
1283
1284
  case GGML_TYPE_Q5_0:
1284
1285
  case GGML_TYPE_Q5_1:
1285
1286
  case GGML_TYPE_Q8_0:
1287
+ case GGML_TYPE_MXFP4:
1286
1288
  case GGML_TYPE_Q2_K:
1287
1289
  case GGML_TYPE_Q3_K:
1288
1290
  case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
1309
1311
  }
1310
1312
  }
1311
1313
 
1314
+ // ggml_compute_forward_add_id
1315
+
1316
+ static void ggml_compute_forward_add_id_f32(
1317
+ const ggml_compute_params * params,
1318
+ ggml_tensor * dst) {
1319
+
1320
+ const ggml_tensor * src0 = dst->src[0];
1321
+ const ggml_tensor * src1 = dst->src[1];
1322
+ const ggml_tensor * src2 = dst->src[2];
1323
+
1324
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
1325
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
1326
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
1327
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
1328
+
1329
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
1330
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
1331
+
1332
+ const int ith = params->ith;
1333
+ const int nth = params->nth;
1334
+
1335
+ const int nr = ggml_nrows(src0);
1336
+
1337
+ GGML_TENSOR_TERNARY_OP_LOCALS
1338
+
1339
+ GGML_ASSERT( nb0 == sizeof(float));
1340
+ GGML_ASSERT(nb10 == sizeof(float));
1341
+
1342
+ // rows per thread
1343
+ const int dr = (nr + nth - 1)/nth;
1344
+
1345
+ // row range for this thread
1346
+ const int ir0 = dr*ith;
1347
+ const int ir1 = MIN(ir0 + dr, nr);
1348
+
1349
+ for (int ir = ir0; ir < ir1; ++ir) {
1350
+ // src0 indices
1351
+ const int i3 = ir/(ne2*ne1);
1352
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
1353
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1354
+
1355
+ // src1 indices
1356
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
1357
+
1358
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
1359
+
1360
+ ggml_vec_add_f32(ne0,
1361
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
1362
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
1363
+ (float *) ((char *) src1->data + i11*nb11));
1364
+ }
1365
+ }
1366
+
1367
+ void ggml_compute_forward_add_id(
1368
+ const ggml_compute_params * params,
1369
+ ggml_tensor * dst) {
1370
+
1371
+ const ggml_tensor * src0 = dst->src[0];
1372
+
1373
+ switch (src0->type) {
1374
+ case GGML_TYPE_F32:
1375
+ {
1376
+ ggml_compute_forward_add_id_f32(params, dst);
1377
+ } break;
1378
+ default:
1379
+ {
1380
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
1381
+ }
1382
+ }
1383
+ }
1384
+
1312
1385
  // ggml_compute_forward_add1
1313
1386
 
1314
1387
  static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
1660
1733
  case GGML_TYPE_Q5_1:
1661
1734
  case GGML_TYPE_Q8_0:
1662
1735
  case GGML_TYPE_Q8_1:
1736
+ case GGML_TYPE_MXFP4:
1663
1737
  case GGML_TYPE_Q2_K:
1664
1738
  case GGML_TYPE_Q3_K:
1665
1739
  case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
1787
1861
  case GGML_TYPE_Q5_1:
1788
1862
  case GGML_TYPE_Q8_0:
1789
1863
  case GGML_TYPE_Q8_1:
1864
+ case GGML_TYPE_MXFP4:
1790
1865
  case GGML_TYPE_Q2_K:
1791
1866
  case GGML_TYPE_Q3_K:
1792
1867
  case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
3614
3689
  }
3615
3690
  }
3616
3691
 
3692
+ // ggml_compute_forward_swiglu_oai
3693
+
3694
+ static void ggml_compute_forward_swiglu_oai_f32(
3695
+ const ggml_compute_params * params,
3696
+ ggml_tensor * dst) {
3697
+
3698
+ const ggml_tensor * src0 = dst->src[0];
3699
+ const ggml_tensor * src1 = dst->src[1];
3700
+ char * src0_d = (char *) src0->data;
3701
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
3702
+ const size_t src0_o = src0->nb[1];
3703
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3704
+
3705
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
3706
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
3707
+
3708
+ if (src1) {
3709
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
3710
+ GGML_ASSERT(src0->type == src1->type);
3711
+ }
3712
+
3713
+ const int ith = params->ith;
3714
+ const int nth = params->nth;
3715
+
3716
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3717
+ const int nr = ggml_nrows(src0);
3718
+
3719
+ GGML_ASSERT(dst->ne[0] == nc);
3720
+ GGML_ASSERT(ggml_nrows(dst) == nr);
3721
+
3722
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3723
+ const float alpha = ggml_get_op_params_f32(dst, 2);
3724
+ const float limit = ggml_get_op_params_f32(dst, 3);
3725
+
3726
+ // rows per thread
3727
+ const int dr = (nr + nth - 1)/nth;
3728
+
3729
+ // row range for this thread
3730
+ const int ir0 = dr*ith;
3731
+ const int ir1 = MIN(ir0 + dr, nr);
3732
+
3733
+ for (int i1 = ir0; i1 < ir1; i1++) {
3734
+ float * src0_p = (float *) (src0_d + i1*src0_o);
3735
+ float * src1_p = (float *) (src1_d + i1*src1_o);
3736
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3737
+
3738
+ if (!src1) {
3739
+ src0_p += swapped ? nc : 0;
3740
+ src1_p += swapped ? 0 : nc;
3741
+ }
3742
+
3743
+ for (int k = 0; k < nc; k++) {
3744
+ const float x = std::min(src0_p[k], limit);
3745
+ const float y = std::clamp(src1_p[k], -limit, limit);
3746
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
3747
+ dst_p[k] = out_glu * (y + 1.f);
3748
+ }
3749
+
3750
+ #ifndef NDEBUG
3751
+ for (int k = 0; k < nc; k++) {
3752
+ const float x = dst_p[k];
3753
+ GGML_UNUSED(x);
3754
+ assert(!isnan(x));
3755
+ assert(!isinf(x));
3756
+ }
3757
+ #endif
3758
+ }
3759
+ }
3760
+
3761
+ static void ggml_compute_forward_swiglu_oai(
3762
+ const ggml_compute_params * params,
3763
+ ggml_tensor * dst) {
3764
+
3765
+ const ggml_tensor * src0 = dst->src[0];
3766
+
3767
+ switch (src0->type) {
3768
+ case GGML_TYPE_F32:
3769
+ {
3770
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
3771
+ } break;
3772
+ default:
3773
+ {
3774
+ GGML_ABORT("fatal error");
3775
+ }
3776
+ }
3777
+ }
3778
+
3617
3779
  // ggml_compute_forward_geglu_erf
3618
3780
 
3619
3781
  static void ggml_compute_forward_geglu_erf_f32(
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
4599
4761
  case GGML_TYPE_Q5_0:
4600
4762
  case GGML_TYPE_Q5_1:
4601
4763
  case GGML_TYPE_Q8_0:
4764
+ case GGML_TYPE_MXFP4:
4602
4765
  case GGML_TYPE_Q2_K:
4603
4766
  case GGML_TYPE_Q3_K:
4604
4767
  case GGML_TYPE_Q4_K:
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
4873
5036
  case GGML_TYPE_Q5_1:
4874
5037
  case GGML_TYPE_Q8_0:
4875
5038
  case GGML_TYPE_Q8_1:
5039
+ case GGML_TYPE_MXFP4:
4876
5040
  case GGML_TYPE_Q2_K:
4877
5041
  case GGML_TYPE_Q3_K:
4878
5042
  case GGML_TYPE_Q4_K:
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
5134
5298
  case GGML_TYPE_Q5_1:
5135
5299
  case GGML_TYPE_Q8_0:
5136
5300
  case GGML_TYPE_Q8_1:
5301
+ case GGML_TYPE_MXFP4:
5137
5302
  case GGML_TYPE_Q2_K:
5138
5303
  case GGML_TYPE_Q3_K:
5139
5304
  case GGML_TYPE_Q4_K:
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
5523
5688
 
5524
5689
  const ggml_tensor * src0 = dst->src[0];
5525
5690
  const ggml_tensor * src1 = dst->src[1];
5691
+ const ggml_tensor * src2 = dst->src[2];
5526
5692
 
5527
5693
  assert(ggml_is_contiguous(dst));
5528
5694
  assert(ggml_are_same_shape(src0, dst));
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
5557
5723
 
5558
5724
  const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5559
5725
 
5726
+ // sinks
5727
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5728
+
5560
5729
  for (int64_t i03 = 0; i03 < ne03; i03++) {
5561
5730
  for (int64_t i02 = 0; i02 < ne02; i02++) {
5562
5731
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
5599
5768
  float max = -INFINITY;
5600
5769
  ggml_vec_max_f32(ne00, &max, wp);
5601
5770
 
5771
+ // if we have sinks, make a correction as if they were included in the softmax
5772
+ if (sk) {
5773
+ max = MAX(max, sk[i02]);
5774
+ }
5775
+
5602
5776
  ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5603
5777
  assert(sum > 0.0);
5604
5778
 
5779
+ if (sk) {
5780
+ sum += (ggml_float) expf(sk[i02] - max);
5781
+ }
5782
+
5605
5783
  sum = 1.0/sum;
5606
5784
  ggml_vec_scale_f32(ne00, dp, sum);
5607
5785
 
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
5836
6014
  case GGML_TYPE_Q5_1:
5837
6015
  case GGML_TYPE_Q8_0:
5838
6016
  case GGML_TYPE_Q8_1:
6017
+ case GGML_TYPE_MXFP4:
5839
6018
  case GGML_TYPE_Q2_K:
5840
6019
  case GGML_TYPE_Q3_K:
5841
6020
  case GGML_TYPE_Q4_K:
@@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
7989
8168
 
7990
8169
  static void ggml_compute_forward_flash_attn_ext_f16(
7991
8170
  const ggml_compute_params * params,
7992
- const ggml_tensor * q,
7993
- const ggml_tensor * k,
7994
- const ggml_tensor * v,
7995
- const ggml_tensor * mask,
7996
8171
  ggml_tensor * dst) {
7997
8172
 
8173
+ const ggml_tensor * q = dst->src[0];
8174
+ const ggml_tensor * k = dst->src[1];
8175
+ const ggml_tensor * v = dst->src[2];
8176
+ const ggml_tensor * mask = dst->src[3];
8177
+ const ggml_tensor * sinks = dst->src[4];
8178
+
7998
8179
  GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
7999
8180
  GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8000
8181
  GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8189
8370
  }
8190
8371
  }
8191
8372
 
8373
+ // sinks
8374
+ if (sinks) {
8375
+ const float s = ((float *)((char *) sinks->data))[h];
8376
+
8377
+ float ms = 1.0f;
8378
+ float vs = 1.0f;
8379
+
8380
+ if (s > M) {
8381
+ ms = expf(M - s);
8382
+ ggml_vec_scale_f32(DV, VKQ32, ms);
8383
+ } else {
8384
+ vs = expf(s - M);
8385
+ }
8386
+
8387
+ S = S*ms + vs;
8388
+ }
8389
+
8192
8390
  // V /= S
8193
8391
  const float S_inv = 1.0f/S;
8194
8392
  ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8208
8406
 
8209
8407
  void ggml_compute_forward_flash_attn_ext(
8210
8408
  const ggml_compute_params * params,
8211
- const ggml_tensor * q,
8212
- const ggml_tensor * k,
8213
- const ggml_tensor * v,
8214
- const ggml_tensor * mask,
8215
8409
  ggml_tensor * dst) {
8216
8410
  switch (dst->op_params[3]) {
8217
8411
  case GGML_PREC_DEFAULT:
8218
8412
  case GGML_PREC_F32:
8219
8413
  {
8220
8414
  // uses F32 accumulators
8221
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
8415
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
8222
8416
  } break;
8223
8417
  default:
8224
8418
  {
@@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
9080
9274
  {
9081
9275
  ggml_compute_forward_swiglu(params, dst);
9082
9276
  } break;
9277
+ case GGML_GLU_OP_SWIGLU_OAI:
9278
+ {
9279
+ ggml_compute_forward_swiglu_oai(params, dst);
9280
+ } break;
9083
9281
  case GGML_GLU_OP_GEGLU_ERF:
9084
9282
  {
9085
9283
  ggml_compute_forward_geglu_erf(params, dst);
@@ -29,6 +29,7 @@ extern "C" {
29
29
 
30
30
  void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
31
31
  void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
32
+ void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
32
33
  void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
33
34
  void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
34
35
  void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
82
83
  void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
83
84
  void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84
85
  void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
85
- void ggml_compute_forward_flash_attn_ext(
86
- const struct ggml_compute_params * params,
87
- const struct ggml_tensor * q,
88
- const struct ggml_tensor * k,
89
- const struct ggml_tensor * v,
90
- const struct ggml_tensor * mask,
91
- struct ggml_tensor * dst);
86
+ void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
92
87
  void ggml_compute_forward_flash_attn_back(
93
88
  const struct ggml_compute_params * params,
94
89
  const bool masked,
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
46
46
  quantize_row_q8_1_ref(x, y, k);
47
47
  }
48
48
 
49
+ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
50
+ quantize_row_mxfp4_ref(x, y, k);
51
+ }
52
+
49
53
  //
50
54
  // 2-6 bit quantization in super-blocks
51
55
  //
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
181
185
  *s = sumf;
182
186
  }
183
187
 
188
+ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
189
+ assert(nrc == 1);
190
+ UNUSED(nrc);
191
+ UNUSED(bx);
192
+ UNUSED(by);
193
+ UNUSED(bs);
194
+ assert(n % QK_MXFP4 == 0);
195
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
196
+
197
+ const block_mxfp4 * GGML_RESTRICT x = vx;
198
+ const block_q8_0 * GGML_RESTRICT y = vy;
199
+
200
+ const int nb = n / QK_MXFP4;
201
+
202
+ int ib = 0;
203
+ float sumf = 0;
204
+
205
+ for (; ib < nb; ++ib) {
206
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
207
+
208
+ int sumi1 = 0;
209
+ int sumi2 = 0;
210
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
211
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
212
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
213
+ }
214
+ sumf += d * (sumi1 + sumi2);
215
+ }
216
+ *s = sumf;
217
+ }
218
+
184
219
  void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
185
220
  const int qk = QK8_0;
186
221
  const int nb = n / qk;
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
19
19
  void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
20
20
  void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
21
21
 
22
+ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23
+
22
24
  void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23
25
  void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
24
26
  void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
39
41
  void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
40
42
  void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
41
43
 
44
+ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
45
+
42
46
  void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
43
47
  void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
44
48
  void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
67
71
  void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
68
72
  void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
69
73
  void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
74
+
75
+ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
76
+
70
77
  void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
71
78
  void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
79
+
72
80
  void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
73
81
  void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
74
82
  void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);