llama_cpp 0.15.2 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
406
  int i = 0;
407
407
  #if defined(__AVX512BF16__)
408
408
  for (; i + 32 <= n; i += 32) {
409
- _mm512_storeu_ps(
410
- (__m512 *)(y + i),
411
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
- _mm512_loadu_ps(x + i)));
409
+ _mm512_storeu_si512(
410
+ (__m512i *)(y + i),
411
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i))));
413
413
  }
414
414
  #endif
415
415
  for (; i < n; i++) {
@@ -871,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
871
871
  },
872
872
  [GGML_TYPE_IQ4_XS] = {
873
873
  .type_name = "iq4_xs",
874
- #if QK_K == 64
875
- .blck_size = QK4_NL,
876
- #else
877
874
  .blck_size = QK_K,
878
- #endif
879
875
  .type_size = sizeof(block_iq4_xs),
880
876
  .is_quantized = true,
881
877
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
882
878
  .from_float = quantize_row_iq4_xs,
883
879
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
884
880
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
885
- #if QK_K == 64
886
- .vec_dot_type = GGML_TYPE_Q8_0,
887
- #else
888
881
  .vec_dot_type = GGML_TYPE_Q8_K,
889
- #endif
890
882
  .nrows = 1,
891
883
  },
892
884
  [GGML_TYPE_Q8_K] = {
@@ -1523,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1523
1515
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1524
1516
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1525
1517
 
1518
+ #elif defined(__loongarch_asx)
1519
+
1520
+ #define GGML_SIMD
1521
+
1522
+ // F32 LASX
1523
+ #define GGML_F32_STEP 32
1524
+ #define GGML_F32_EPR 8
1525
+
1526
+ #define GGML_F32x8 __m256
1527
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1528
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1529
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1530
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1531
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1532
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1533
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1534
+ #define GGML_F32x8_REDUCE(res, x) \
1535
+ do { \
1536
+ int offset = GGML_F32_ARR >> 1; \
1537
+ for (int i = 0; i < offset; ++i) { \
1538
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1539
+ } \
1540
+ offset >>= 1; \
1541
+ for (int i = 0; i < offset; ++i) { \
1542
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1543
+ } \
1544
+ offset >>= 1; \
1545
+ for (int i = 0; i < offset; ++i) { \
1546
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1547
+ } \
1548
+ float *tmp_p = (float *)&x[0]; \
1549
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1550
+ } while (0)
1551
+ // TODO: is this optimal ?
1552
+
1553
+ #define GGML_F32_VEC GGML_F32x8
1554
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1555
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1556
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1557
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1558
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1559
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1560
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1561
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1562
+
1563
+ // F16 LASX
1564
+
1565
+ #define GGML_F16_STEP 32
1566
+ #define GGML_F16_EPR 8
1567
+
1568
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
+
1570
+ #define GGML_F32Cx8 __m256
1571
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
+
1574
+ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1575
+ float tmp[8];
1576
+
1577
+ for (int i = 0; i < 8; i++) {
1578
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1579
+ }
1580
+
1581
+ return (__m256)__lasx_xvld(tmp, 0);
1582
+ }
1583
+ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1584
+ float arr[8];
1585
+
1586
+ __lasx_xvst(y, arr, 0);
1587
+
1588
+ for (int i = 0; i < 8; i++)
1589
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1590
+ }
1591
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1593
+
1594
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1595
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1596
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1597
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1598
+
1599
+ #define GGML_F16_VEC GGML_F32Cx8
1600
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1601
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1602
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1603
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1604
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1605
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1606
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1607
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1608
+
1609
+ #elif defined(__loongarch_sx)
1610
+
1611
+ #define GGML_SIMD
1612
+
1613
+ // F32 LSX
1614
+
1615
+ #define GGML_F32_STEP 32
1616
+ #define GGML_F32_EPR 4
1617
+
1618
+ #define GGML_F32x4 __m128
1619
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1620
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1621
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1622
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1623
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1624
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1625
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1626
+ #define GGML_F32x4_REDUCE(res, x) \
1627
+ { \
1628
+ int offset = GGML_F32_ARR >> 1; \
1629
+ for (int i = 0; i < offset; ++i) { \
1630
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1631
+ } \
1632
+ offset >>= 1; \
1633
+ for (int i = 0; i < offset; ++i) { \
1634
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1635
+ } \
1636
+ offset >>= 1; \
1637
+ for (int i = 0; i < offset; ++i) { \
1638
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1639
+ } \
1640
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1641
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1642
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1643
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1644
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1645
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1646
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1647
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1648
+ }
1649
+
1650
+ #define GGML_F32_VEC GGML_F32x4
1651
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1652
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1653
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1654
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1655
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1656
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1657
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1658
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1659
+
1660
+ // F16 LSX
1661
+
1662
+ #define GGML_F16_STEP 32
1663
+ #define GGML_F16_EPR 4
1664
+
1665
+ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1666
+ float tmp[4];
1667
+
1668
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1669
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1670
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1671
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1672
+
1673
+ return __lsx_vld(tmp, 0);
1674
+ }
1675
+
1676
+ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1677
+ float arr[4];
1678
+
1679
+ __lsx_vst(y, arr, 0);
1680
+
1681
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1682
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1683
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1684
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1685
+ }
1686
+
1687
+ #define GGML_F32Cx4 __m128
1688
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1689
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1690
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1691
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1692
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1693
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1694
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1695
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1696
+
1697
+ #define GGML_F16_VEC GGML_F32Cx4
1698
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1699
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1700
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1701
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1702
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1703
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1704
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1705
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1706
+
1526
1707
  #endif
1527
1708
 
1528
1709
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1666,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1666
1847
  __m512 c1 = _mm512_setzero_ps();
1667
1848
  __m512 c2 = _mm512_setzero_ps();
1668
1849
  for (; i + 64 <= n; i += 64) {
1669
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1850
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1851
+ m512bh(_mm512_loadu_si512((y + i))));
1852
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1853
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1673
1854
  }
1674
1855
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
1856
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -2076,7 +2257,7 @@ inline static float ggml_silu_f32(float x) {
2076
2257
  return x/(1.0f + expf(-x));
2077
2258
  }
2078
2259
 
2079
- #if defined(__ARM_NEON)
2260
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2080
2261
 
2081
2262
  // adapted from arm limited optimized routine
2082
2263
  // the maximum error is 1.45358 plus 0.5 ulps
@@ -2288,7 +2469,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2288
2469
  for (; i + 3 < n; i += 4) {
2289
2470
  _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
2471
  }
2291
- #elif defined(__ARM_NEON)
2472
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2292
2473
  for (; i + 3 < n; i += 4) {
2293
2474
  vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
2475
  }
@@ -2335,7 +2516,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
2335
2516
  #endif
2336
2517
  sum += (ggml_float)_mm_cvtss_f32(val);
2337
2518
  }
2338
- #elif defined(__ARM_NEON)
2519
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2339
2520
  for (; i + 3 < n; i += 4) {
2340
2521
  float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
2522
  vdupq_n_f32(max)));
@@ -2489,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2489
2670
  "ARGSORT",
2490
2671
  "LEAKY_RELU",
2491
2672
 
2492
- "FLASH_ATTN",
2493
2673
  "FLASH_ATTN_EXT",
2494
- "FLASH_FF",
2495
2674
  "FLASH_ATTN_BACK",
2496
2675
  "SSM_CONV",
2497
2676
  "SSM_SCAN",
@@ -2517,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2517
2696
  "CROSS_ENTROPY_LOSS_BACK",
2518
2697
  };
2519
2698
 
2520
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2699
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2521
2700
 
2522
2701
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2523
2702
  "none",
@@ -2579,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2579
2758
  "argsort(x)",
2580
2759
  "leaky_relu(x)",
2581
2760
 
2582
- "flash_attn(x)",
2583
2761
  "flash_attn_ext(x)",
2584
- "flash_ff(x)",
2585
2762
  "flash_attn_back(x)",
2586
2763
  "ssm_conv(x)",
2587
2764
  "ssm_scan(x)",
@@ -2607,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2607
2784
  "cross_entropy_loss_back(x,y)",
2608
2785
  };
2609
2786
 
2610
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2787
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2611
2788
 
2612
2789
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2613
2790
 
@@ -6042,6 +6219,7 @@ static struct ggml_tensor * ggml_rope_impl(
6042
6219
  struct ggml_context * ctx,
6043
6220
  struct ggml_tensor * a,
6044
6221
  struct ggml_tensor * b,
6222
+ struct ggml_tensor * c,
6045
6223
  int n_dims,
6046
6224
  int mode,
6047
6225
  int n_ctx,
@@ -6055,10 +6233,17 @@ static struct ggml_tensor * ggml_rope_impl(
6055
6233
  float xpos_base,
6056
6234
  bool xpos_down,
6057
6235
  bool inplace) {
6236
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6237
+
6058
6238
  GGML_ASSERT(ggml_is_vector(b));
6059
6239
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6060
6240
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6061
6241
 
6242
+ if (c) {
6243
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6244
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6245
+ }
6246
+
6062
6247
  bool is_node = false;
6063
6248
 
6064
6249
  if (a->grad) {
@@ -6082,6 +6267,7 @@ static struct ggml_tensor * ggml_rope_impl(
6082
6267
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6083
6268
  result->src[0] = a;
6084
6269
  result->src[1] = b;
6270
+ result->src[2] = c;
6085
6271
 
6086
6272
  return result;
6087
6273
  }
@@ -6094,7 +6280,7 @@ struct ggml_tensor * ggml_rope(
6094
6280
  int mode,
6095
6281
  int n_ctx) {
6096
6282
  return ggml_rope_impl(
6097
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6283
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6098
6284
  );
6099
6285
  }
6100
6286
 
@@ -6106,14 +6292,15 @@ struct ggml_tensor * ggml_rope_inplace(
6106
6292
  int mode,
6107
6293
  int n_ctx) {
6108
6294
  return ggml_rope_impl(
6109
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6295
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6110
6296
  );
6111
6297
  }
6112
6298
 
6113
- struct ggml_tensor * ggml_rope_custom(
6299
+ struct ggml_tensor * ggml_rope_ext(
6114
6300
  struct ggml_context * ctx,
6115
6301
  struct ggml_tensor * a,
6116
6302
  struct ggml_tensor * b,
6303
+ struct ggml_tensor * c,
6117
6304
  int n_dims,
6118
6305
  int mode,
6119
6306
  int n_ctx,
@@ -6125,15 +6312,16 @@ struct ggml_tensor * ggml_rope_custom(
6125
6312
  float beta_fast,
6126
6313
  float beta_slow) {
6127
6314
  return ggml_rope_impl(
6128
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6315
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6129
6316
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6130
6317
  );
6131
6318
  }
6132
6319
 
6133
- struct ggml_tensor * ggml_rope_custom_inplace(
6320
+ struct ggml_tensor * ggml_rope_ext_inplace(
6134
6321
  struct ggml_context * ctx,
6135
6322
  struct ggml_tensor * a,
6136
6323
  struct ggml_tensor * b,
6324
+ struct ggml_tensor * c,
6137
6325
  int n_dims,
6138
6326
  int mode,
6139
6327
  int n_ctx,
@@ -6145,19 +6333,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6145
6333
  float beta_fast,
6146
6334
  float beta_slow) {
6147
6335
  return ggml_rope_impl(
6148
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6336
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6149
6337
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6150
6338
  );
6151
6339
  }
6152
6340
 
6153
- struct ggml_tensor * ggml_rope_xpos_inplace(
6341
+ struct ggml_tensor * ggml_rope_custom(
6342
+ struct ggml_context * ctx,
6343
+ struct ggml_tensor * a,
6344
+ struct ggml_tensor * b,
6345
+ int n_dims,
6346
+ int mode,
6347
+ int n_ctx,
6348
+ int n_orig_ctx,
6349
+ float freq_base,
6350
+ float freq_scale,
6351
+ float ext_factor,
6352
+ float attn_factor,
6353
+ float beta_fast,
6354
+ float beta_slow) {
6355
+ return ggml_rope_impl(
6356
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6357
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6358
+ );
6359
+ }
6360
+
6361
+ struct ggml_tensor * ggml_rope_custom_inplace(
6154
6362
  struct ggml_context * ctx,
6155
6363
  struct ggml_tensor * a,
6156
6364
  struct ggml_tensor * b,
6157
6365
  int n_dims,
6158
- float base,
6159
- bool down) {
6160
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6366
+ int mode,
6367
+ int n_ctx,
6368
+ int n_orig_ctx,
6369
+ float freq_base,
6370
+ float freq_scale,
6371
+ float ext_factor,
6372
+ float attn_factor,
6373
+ float beta_fast,
6374
+ float beta_slow) {
6375
+ return ggml_rope_impl(
6376
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6377
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6378
+ );
6161
6379
  }
6162
6380
 
6163
6381
  // ggml_rope_back
@@ -6166,6 +6384,7 @@ struct ggml_tensor * ggml_rope_back(
6166
6384
  struct ggml_context * ctx,
6167
6385
  struct ggml_tensor * a,
6168
6386
  struct ggml_tensor * b,
6387
+ struct ggml_tensor * c,
6169
6388
  int n_dims,
6170
6389
  int mode,
6171
6390
  int n_ctx,
@@ -6181,6 +6400,7 @@ struct ggml_tensor * ggml_rope_back(
6181
6400
  GGML_ASSERT(ggml_is_vector(b));
6182
6401
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6183
6402
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6403
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6184
6404
 
6185
6405
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6186
6406
 
@@ -6724,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
6724
6944
  return result;
6725
6945
  }
6726
6946
 
6727
- // ggml_flash_attn
6728
-
6729
- struct ggml_tensor * ggml_flash_attn(
6730
- struct ggml_context * ctx,
6731
- struct ggml_tensor * q,
6732
- struct ggml_tensor * k,
6733
- struct ggml_tensor * v,
6734
- bool masked) {
6735
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6736
- // TODO: check if vT can be multiplied by (k*qT)
6737
-
6738
- bool is_node = false;
6739
-
6740
- if (q->grad || k->grad || v->grad) {
6741
- is_node = true;
6742
- }
6743
-
6744
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6745
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6746
-
6747
- int32_t t = masked ? 1 : 0;
6748
- ggml_set_op_params(result, &t, sizeof(t));
6749
-
6750
- result->op = GGML_OP_FLASH_ATTN;
6751
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6752
- result->src[0] = q;
6753
- result->src[1] = k;
6754
- result->src[2] = v;
6755
-
6756
- return result;
6757
- }
6758
-
6759
6947
  // ggml_flash_attn_ext
6760
6948
 
6761
6949
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6815,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
6815
7003
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6816
7004
  }
6817
7005
 
6818
- // ggml_flash_ff
6819
-
6820
- struct ggml_tensor * ggml_flash_ff(
6821
- struct ggml_context * ctx,
6822
- struct ggml_tensor * a,
6823
- struct ggml_tensor * b0,
6824
- struct ggml_tensor * b1,
6825
- struct ggml_tensor * c0,
6826
- struct ggml_tensor * c1) {
6827
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6828
- // TODO: more checks
6829
-
6830
- bool is_node = false;
6831
-
6832
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6833
- is_node = true;
6834
- }
6835
-
6836
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6837
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6838
-
6839
- result->op = GGML_OP_FLASH_FF;
6840
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6841
- result->src[0] = a;
6842
- result->src[1] = b0;
6843
- result->src[2] = b1;
6844
- result->src[3] = c0;
6845
- result->src[4] = c1;
6846
-
6847
- return result;
6848
- }
6849
-
6850
7006
  // ggml_flash_attn_back
6851
7007
 
6852
7008
  struct ggml_tensor * ggml_flash_attn_back(
@@ -6856,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6856
7012
  struct ggml_tensor * v,
6857
7013
  struct ggml_tensor * d,
6858
7014
  bool masked) {
7015
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7016
+
6859
7017
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6860
7018
  // TODO: check if vT can be multiplied by (k*qT)
6861
7019
 
@@ -14115,6 +14273,7 @@ static void ggml_compute_forward_rope_f32(
14115
14273
 
14116
14274
  const struct ggml_tensor * src0 = dst->src[0];
14117
14275
  const struct ggml_tensor * src1 = dst->src[1];
14276
+ const struct ggml_tensor * src2 = dst->src[2];
14118
14277
 
14119
14278
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14120
14279
  return;
@@ -14174,6 +14333,17 @@ static void ggml_compute_forward_rope_f32(
14174
14333
  const bool is_neox = mode & 2;
14175
14334
  const bool is_glm = mode & 4;
14176
14335
 
14336
+ const float * freq_factors = NULL;
14337
+ if (is_neox) {
14338
+ if (src2 != NULL) {
14339
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14340
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14341
+ freq_factors = (const float *) src2->data;
14342
+ }
14343
+ } else {
14344
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14345
+ }
14346
+
14177
14347
  // backward process uses inverse rotation by cos and sin.
14178
14348
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14179
14349
  // this essentially just switches the sign of sin.
@@ -14250,10 +14420,11 @@ static void ggml_compute_forward_rope_f32(
14250
14420
 
14251
14421
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14252
14422
  float cur_rot = inv_ndims * ic - ib;
14423
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14253
14424
 
14254
14425
  float cos_theta, sin_theta;
14255
14426
  rope_yarn(
14256
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14427
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14257
14428
  &cos_theta, &sin_theta
14258
14429
  );
14259
14430
  sin_theta *= sin_sign;
@@ -14286,6 +14457,7 @@ static void ggml_compute_forward_rope_f32(
14286
14457
  }
14287
14458
  }
14288
14459
 
14460
+ // TODO: deduplicate f16/f32 code
14289
14461
  static void ggml_compute_forward_rope_f16(
14290
14462
  const struct ggml_compute_params * params,
14291
14463
  struct ggml_tensor * dst,
@@ -14293,6 +14465,7 @@ static void ggml_compute_forward_rope_f16(
14293
14465
 
14294
14466
  const struct ggml_tensor * src0 = dst->src[0];
14295
14467
  const struct ggml_tensor * src1 = dst->src[1];
14468
+ const struct ggml_tensor * src2 = dst->src[2];
14296
14469
 
14297
14470
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14298
14471
  return;
@@ -14345,6 +14518,17 @@ static void ggml_compute_forward_rope_f16(
14345
14518
  const bool is_neox = mode & 2;
14346
14519
  const bool is_glm = mode & 4;
14347
14520
 
14521
+ const float * freq_factors = NULL;
14522
+ if (is_neox) {
14523
+ if (src2 != NULL) {
14524
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14525
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14526
+ freq_factors = (const float *) src2->data;
14527
+ }
14528
+ } else {
14529
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14530
+ }
14531
+
14348
14532
  // backward process uses inverse rotation by cos and sin.
14349
14533
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14350
14534
  // this essentially just switches the sign of sin.
@@ -14417,10 +14601,11 @@ static void ggml_compute_forward_rope_f16(
14417
14601
 
14418
14602
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14419
14603
  float cur_rot = inv_ndims * ic - ib;
14604
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14420
14605
 
14421
14606
  float cos_theta, sin_theta;
14422
14607
  rope_yarn(
14423
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14608
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14424
14609
  &cos_theta, &sin_theta
14425
14610
  );
14426
14611
  sin_theta *= sin_sign;
@@ -15458,17 +15643,15 @@ static void ggml_compute_forward_argsort(
15458
15643
  }
15459
15644
  }
15460
15645
 
15461
- // ggml_compute_forward_flash_attn
15646
+ // ggml_compute_forward_flash_attn_ext
15462
15647
 
15463
- static void ggml_compute_forward_flash_attn_f32(
15648
+ static void ggml_compute_forward_flash_attn_ext_f16(
15464
15649
  const struct ggml_compute_params * params,
15465
- const bool masked,
15650
+ const struct ggml_tensor * q,
15651
+ const struct ggml_tensor * k,
15652
+ const struct ggml_tensor * v,
15653
+ const struct ggml_tensor * mask,
15466
15654
  struct ggml_tensor * dst) {
15467
-
15468
- const struct ggml_tensor * q = dst->src[0];
15469
- const struct ggml_tensor * k = dst->src[1];
15470
- const struct ggml_tensor * v = dst->src[2];
15471
-
15472
15655
  int64_t t0 = ggml_perf_time_us();
15473
15656
  UNUSED(t0);
15474
15657
 
@@ -15486,409 +15669,18 @@ static void ggml_compute_forward_flash_attn_f32(
15486
15669
 
15487
15670
  const int64_t D = neq0;
15488
15671
  const int64_t N = neq1;
15489
- const int64_t P = nek1 - N;
15490
- const int64_t M = P + N;
15491
-
15492
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15493
15672
 
15494
15673
  GGML_ASSERT(ne0 == D);
15495
- GGML_ASSERT(ne1 == N);
15496
- GGML_ASSERT(P >= 0);
15674
+ GGML_ASSERT(ne2 == N);
15497
15675
 
15498
- GGML_ASSERT(nbq0 == sizeof(float));
15499
- GGML_ASSERT(nbk0 == sizeof(float));
15500
- GGML_ASSERT(nbv0 == sizeof(float));
15676
+ // input tensor rows must be contiguous
15677
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15678
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15679
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15501
15680
 
15502
15681
  GGML_ASSERT(neq0 == D);
15503
15682
  GGML_ASSERT(nek0 == D);
15504
- GGML_ASSERT(nev1 == D);
15505
-
15506
- GGML_ASSERT(neq1 == N);
15507
- GGML_ASSERT(nek1 == N + P);
15508
- GGML_ASSERT(nev1 == D);
15509
-
15510
- // dst cannot be transposed or permuted
15511
- GGML_ASSERT(nb0 == sizeof(float));
15512
- GGML_ASSERT(nb0 <= nb1);
15513
- GGML_ASSERT(nb1 <= nb2);
15514
- GGML_ASSERT(nb2 <= nb3);
15515
-
15516
- if (params->type == GGML_TASK_TYPE_INIT) {
15517
- return;
15518
- }
15519
-
15520
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15521
- return;
15522
- }
15523
-
15524
- // parallelize by q rows using ggml_vec_dot_f32
15525
-
15526
- // total rows in q
15527
- const int nr = neq1*neq2*neq3;
15528
-
15529
- // rows per thread
15530
- const int dr = (nr + nth - 1)/nth;
15531
-
15532
- // row range for this thread
15533
- const int ir0 = dr*ith;
15534
- const int ir1 = MIN(ir0 + dr, nr);
15535
-
15536
- const float scale = 1.0f/sqrtf(D);
15537
-
15538
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15539
-
15540
- for (int ir = ir0; ir < ir1; ++ir) {
15541
- // q indices
15542
- const int iq3 = ir/(neq2*neq1);
15543
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15544
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15545
-
15546
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15547
-
15548
- for (int i = M; i < Mup; ++i) {
15549
- S[i] = -INFINITY;
15550
- }
15551
-
15552
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15553
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15554
- // k indices
15555
- const int ik3 = iq3;
15556
- const int ik2 = iq2 % nek2;
15557
- const int ik1 = ic;
15558
-
15559
- // S indices
15560
- const int i1 = ik1;
15561
-
15562
- ggml_vec_dot_f32(neq0,
15563
- S + i1, 0,
15564
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15565
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15566
- }
15567
-
15568
- // scale
15569
- ggml_vec_scale_f32(masked_begin, S, scale);
15570
-
15571
- for (int64_t i = masked_begin; i < M; i++) {
15572
- S[i] = -INFINITY;
15573
- }
15574
-
15575
- // softmax
15576
- // exclude known -INF S[..] values from max and loop
15577
- // dont forget to set their SW values to zero
15578
- {
15579
- float max = -INFINITY;
15580
- ggml_vec_max_f32(masked_begin, &max, S);
15581
-
15582
- ggml_float sum = 0.0;
15583
- {
15584
- #ifdef GGML_SOFT_MAX_ACCELERATE
15585
- max = -max;
15586
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15587
- vvexpf(S, S, &Mup);
15588
- ggml_vec_sum_f32(Mup, &sum, S);
15589
- #else
15590
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15591
- #endif
15592
- }
15593
-
15594
- assert(sum > 0.0);
15595
-
15596
- sum = 1.0/sum;
15597
- ggml_vec_scale_f32(masked_begin, S, sum);
15598
-
15599
- #ifndef NDEBUG
15600
- for (int i = 0; i < masked_begin; ++i) {
15601
- assert(!isnan(S[i]));
15602
- assert(!isinf(S[i]));
15603
- }
15604
- #endif
15605
- }
15606
-
15607
- for (int64_t ic = 0; ic < nev1; ++ic) {
15608
- // dst indices
15609
- const int i1 = iq1;
15610
- const int i2 = iq2;
15611
- const int i3 = iq3;
15612
-
15613
- // v indices
15614
- const int iv2 = iq2 % nev2;
15615
- const int iv3 = iq3;
15616
-
15617
- ggml_vec_dot_f32(masked_begin,
15618
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15619
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15620
- S, 0, 1);
15621
- }
15622
- }
15623
- }
15624
-
15625
- static void ggml_compute_forward_flash_attn_f16(
15626
- const struct ggml_compute_params * params,
15627
- const bool masked,
15628
- struct ggml_tensor * dst) {
15629
-
15630
- const struct ggml_tensor * q = dst->src[0];
15631
- const struct ggml_tensor * k = dst->src[1];
15632
- const struct ggml_tensor * v = dst->src[2];
15633
-
15634
- int64_t t0 = ggml_perf_time_us();
15635
- UNUSED(t0);
15636
-
15637
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15638
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15639
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15640
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15641
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15642
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15643
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15644
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15645
-
15646
- const int ith = params->ith;
15647
- const int nth = params->nth;
15648
-
15649
- const int64_t D = neq0;
15650
- const int64_t N = neq1;
15651
- const int64_t P = nek1 - N;
15652
- const int64_t M = P + N;
15653
-
15654
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15655
-
15656
- GGML_ASSERT(ne0 == D);
15657
- GGML_ASSERT(ne1 == N);
15658
- GGML_ASSERT(P >= 0);
15659
-
15660
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15661
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15662
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15663
-
15664
- GGML_ASSERT(neq0 == D);
15665
- GGML_ASSERT(nek0 == D);
15666
- GGML_ASSERT(nev1 == D);
15667
-
15668
- GGML_ASSERT(neq1 == N);
15669
- GGML_ASSERT(nek1 == N + P);
15670
- GGML_ASSERT(nev1 == D);
15671
-
15672
- // dst cannot be transposed or permuted
15673
- GGML_ASSERT(nb0 == sizeof(float));
15674
- GGML_ASSERT(nb0 <= nb1);
15675
- GGML_ASSERT(nb1 <= nb2);
15676
- GGML_ASSERT(nb2 <= nb3);
15677
-
15678
- if (params->type == GGML_TASK_TYPE_INIT) {
15679
- return;
15680
- }
15681
-
15682
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15683
- return;
15684
- }
15685
-
15686
- // parallelize by q rows using ggml_vec_dot_f32
15687
-
15688
- // total rows in q
15689
- const int nr = neq1*neq2*neq3;
15690
-
15691
- // rows per thread
15692
- const int dr = (nr + nth - 1)/nth;
15693
-
15694
- // row range for this thread
15695
- const int ir0 = dr*ith;
15696
- const int ir1 = MIN(ir0 + dr, nr);
15697
-
15698
- const float scale = 1.0f/sqrtf(D);
15699
-
15700
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15701
-
15702
- for (int ir = ir0; ir < ir1; ++ir) {
15703
- // q indices
15704
- const int iq3 = ir/(neq2*neq1);
15705
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15706
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15707
-
15708
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15709
-
15710
- for (int i = M; i < Mup; ++i) {
15711
- S[i] = -INFINITY;
15712
- }
15713
-
15714
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15715
- for (int64_t ic = 0; ic < nek1; ++ic) {
15716
- // k indices
15717
- const int ik3 = iq3;
15718
- const int ik2 = iq2 % nek2;
15719
- const int ik1 = ic;
15720
-
15721
- // S indices
15722
- const int i1 = ik1;
15723
-
15724
- ggml_vec_dot_f16(neq0,
15725
- S + i1, 0,
15726
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15727
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15728
- }
15729
- } else {
15730
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15731
- // k indices
15732
- const int ik3 = iq3;
15733
- const int ik2 = iq2 % nek2;
15734
- const int ik1 = ic;
15735
-
15736
- // S indices
15737
- const int i1 = ik1;
15738
-
15739
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15740
- S + i1,
15741
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15742
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15743
- }
15744
- }
15745
-
15746
- // scale
15747
- ggml_vec_scale_f32(nek1, S, scale);
15748
-
15749
- if (masked) {
15750
- for (int64_t i = P; i < M; i++) {
15751
- if (i > P + iq1) {
15752
- S[i] = -INFINITY;
15753
- }
15754
- }
15755
- }
15756
-
15757
- // softmax
15758
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15759
- // dont forget to set their S values to zero
15760
- {
15761
- float max = -INFINITY;
15762
- ggml_vec_max_f32(M, &max, S);
15763
-
15764
- ggml_float sum = 0.0;
15765
- {
15766
- #ifdef GGML_SOFT_MAX_ACCELERATE
15767
- max = -max;
15768
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15769
- vvexpf(S, S, &Mup);
15770
- ggml_vec_sum_f32(Mup, &sum, S);
15771
- #else
15772
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15773
- #endif
15774
- }
15775
-
15776
- assert(sum > 0.0);
15777
-
15778
- sum = 1.0/sum;
15779
- ggml_vec_scale_f32(M, S, sum);
15780
-
15781
- #ifndef NDEBUG
15782
- for (int i = 0; i < M; ++i) {
15783
- assert(!isnan(S[i]));
15784
- assert(!isinf(S[i]));
15785
- }
15786
- #endif
15787
- }
15788
-
15789
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15790
-
15791
- for (int64_t i = 0; i < M; i++) {
15792
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15793
- }
15794
-
15795
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15796
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15797
- for (int64_t ic = 0; ic < nev1; ++ic) {
15798
- // dst indices
15799
- const int i1 = iq1;
15800
- const int i2 = iq2;
15801
- const int i3 = iq3;
15802
-
15803
- // v indices
15804
- const int iv2 = iq2 % nev2;
15805
- const int iv3 = iq3;
15806
-
15807
- ggml_vec_dot_f16(nev0,
15808
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15809
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15810
- S16, 0, 1);
15811
- }
15812
- } else {
15813
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15814
- // dst indices
15815
- const int i1 = iq1;
15816
- const int i2 = iq2;
15817
- const int i3 = iq3;
15818
-
15819
- // v indices
15820
- const int iv2 = iq2 % nev2;
15821
- const int iv3 = iq3;
15822
-
15823
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15824
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15825
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15826
- S16);
15827
- }
15828
- }
15829
- }
15830
- }
15831
-
15832
- static void ggml_compute_forward_flash_attn(
15833
- const struct ggml_compute_params * params,
15834
- const bool masked,
15835
- struct ggml_tensor * dst) {
15836
-
15837
- const struct ggml_tensor * q = dst->src[0];
15838
-
15839
- switch (q->type) {
15840
- case GGML_TYPE_F16:
15841
- {
15842
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15843
- } break;
15844
- case GGML_TYPE_F32:
15845
- {
15846
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15847
- } break;
15848
- default:
15849
- {
15850
- GGML_ASSERT(false);
15851
- } break;
15852
- }
15853
- }
15854
-
15855
- // ggml_compute_forward_flash_attn_ext
15856
-
15857
- static void ggml_compute_forward_flash_attn_ext_f16(
15858
- const struct ggml_compute_params * params,
15859
- const struct ggml_tensor * q,
15860
- const struct ggml_tensor * k,
15861
- const struct ggml_tensor * v,
15862
- const struct ggml_tensor * mask,
15863
- struct ggml_tensor * dst) {
15864
- int64_t t0 = ggml_perf_time_us();
15865
- UNUSED(t0);
15866
-
15867
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15868
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15869
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15870
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15871
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15872
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15873
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15874
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15875
-
15876
- const int ith = params->ith;
15877
- const int nth = params->nth;
15878
-
15879
- const int64_t D = neq0;
15880
- const int64_t N = neq1;
15881
-
15882
- GGML_ASSERT(ne0 == D);
15883
- GGML_ASSERT(ne2 == N);
15884
-
15885
- GGML_ASSERT(nbq0 == sizeof(float));
15886
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15888
-
15889
- GGML_ASSERT(neq0 == D);
15890
- GGML_ASSERT(nek0 == D);
15891
- GGML_ASSERT(nev0 == D);
15683
+ GGML_ASSERT(nev0 == D);
15892
15684
 
15893
15685
  GGML_ASSERT(neq1 == N);
15894
15686
  GGML_ASSERT(nev0 == D);
@@ -15938,6 +15730,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15938
15730
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
15731
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15940
15732
 
15733
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15734
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15735
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15736
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15737
+
15941
15738
  // loop over n_batch and n_head
15942
15739
  for (int ir = ir0; ir < ir1; ++ir) {
15943
15740
  // q indices
@@ -15945,17 +15742,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15945
15742
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15946
15743
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15947
15744
 
15948
- const uint32_t h = iq2; // head
15745
+ const uint32_t h = iq2; // head index
15949
15746
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
15747
 
15951
- float S = 0.0f;
15952
- float M = -INFINITY;
15748
+ float S = 0.0f; // sum
15749
+ float M = -INFINITY; // maximum KQ value
15953
15750
 
15954
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15751
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15752
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15753
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15754
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15957
15755
 
15958
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15756
+ if (v->type == GGML_TYPE_F16) {
15757
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15758
+ } else {
15759
+ memset(VKQ32, 0, D*sizeof(float));
15760
+ }
15959
15761
 
15960
15762
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15961
15763
 
@@ -15967,6 +15769,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15967
15769
  const int iv3 = iq3 / rv3;
15968
15770
  const int iv2 = iq2 / rv2;
15969
15771
 
15772
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15773
+ q_to_vec_dot(pq, Q_q, D);
15774
+
15970
15775
  // online softmax / attention
15971
15776
  // loop over n_kv and n_head_kv
15972
15777
  // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +15781,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15976
15781
  continue;
15977
15782
  }
15978
15783
 
15979
- float s;
15784
+ float s; // KQ value
15980
15785
 
15981
- // convert Q to F16 in V32
15982
- {
15983
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15786
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15787
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15984
15788
 
15985
- for (int64_t d = 0; d < D; ++d) {
15986
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987
- }
15988
- }
15789
+ s = s*scale + mv; // scale KQ value and apply mask
15989
15790
 
15990
- ggml_vec_dot_f16(D,
15991
- &s, 0,
15992
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993
- Q16, 0, 1);
15791
+ const float Mold = M;
15994
15792
 
15995
- s = s*scale + mv;
15793
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15794
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15996
15795
 
15997
- const float Mold = M;
15796
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15998
15797
 
15999
- float ms = 1.0f;
16000
- float vs = 1.0f;
15798
+ if (v->type== GGML_TYPE_F16) {
15799
+ if (s > M) {
15800
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15801
+ M = s;
15802
+ ms = expf(Mold - M);
16001
15803
 
16002
- if (s > M) {
16003
- M = s;
16004
- ms = expf(Mold - M);
15804
+ // V = V*expf(Mold - M)
15805
+ ggml_vec_scale_f16(D, VKQ16, ms);
15806
+ } else {
15807
+ // no new maximum, ms == 1.0f, vs != 1.0f
15808
+ vs = expf(s - M);
15809
+ }
16005
15810
 
16006
- // V = V*expf(Mold - M)
16007
- ggml_vec_scale_f16(D, V16, ms);
15811
+ // V += v*expf(s - M)
15812
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
16008
15813
  } else {
16009
- vs = expf(s - M);
16010
- }
15814
+ if (s > M) {
15815
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15816
+ M = s;
15817
+ ms = expf(Mold - M);
16011
15818
 
16012
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15819
+ // V = V*expf(Mold - M)
15820
+ ggml_vec_scale_f32(D, VKQ32, ms);
15821
+ } else {
15822
+ // no new maximum, ms == 1.0f, vs != 1.0f
15823
+ vs = expf(s - M);
15824
+ }
16013
15825
 
16014
- // V += v*expf(s - M)
16015
- ggml_vec_mad_f16(D, V16, v16, vs);
15826
+ v_to_float(v_data, V32, D);
16016
15827
 
16017
- S = S*ms + vs;
15828
+ // V += v*expf(s - M)
15829
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15830
+ }
15831
+
15832
+ S = S*ms + vs; // scale and increment sum with partial sum
16018
15833
  }
16019
15834
 
16020
- // V /= S
16021
- for (int64_t d = 0; d < D; ++d) {
16022
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15835
+ if (v->type == GGML_TYPE_F16) {
15836
+ for (int64_t d = 0; d < D; ++d) {
15837
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15838
+ }
16023
15839
  }
16024
15840
 
15841
+ // V /= S
15842
+ const float S_inv = 1.0f/S;
15843
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15844
+
16025
15845
  // dst indices
16026
15846
  const int i1 = iq1;
16027
15847
  const int i2 = iq2;
@@ -16031,7 +15851,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
16031
15851
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16032
15852
 
16033
15853
  // permute(0, 2, 1, 3)
16034
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15854
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
16035
15855
  }
16036
15856
  }
16037
15857
 
@@ -16056,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
16056
15876
  }
16057
15877
  }
16058
15878
 
16059
- // ggml_compute_forward_flash_ff
16060
-
16061
- static void ggml_compute_forward_flash_ff_f16(
16062
- const struct ggml_compute_params * params,
16063
- struct ggml_tensor * dst) {
16064
-
16065
- const struct ggml_tensor * a = dst->src[0]; // F16
16066
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
16067
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
16068
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
16069
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
16070
-
16071
- int64_t t0 = ggml_perf_time_us();
16072
- UNUSED(t0);
16073
-
16074
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
16075
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
16076
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
16077
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
16078
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
16079
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
16080
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
16081
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
16082
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
16083
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
16084
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
16085
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
16086
-
16087
- const int ith = params->ith;
16088
- const int nth = params->nth;
16089
-
16090
- const int64_t D = nea0;
16091
- //const int64_t N = nea1;
16092
- const int64_t M = neb01;
16093
-
16094
- GGML_ASSERT(ne0 == nea0);
16095
- GGML_ASSERT(ne1 == nea1);
16096
- GGML_ASSERT(ne2 == nea2);
16097
-
16098
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
16099
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
16100
- GGML_ASSERT(nbb10 == sizeof(float));
16101
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
16102
- GGML_ASSERT(nbc10 == sizeof(float));
16103
-
16104
- GGML_ASSERT(neb00 == D);
16105
- GGML_ASSERT(neb01 == M);
16106
- GGML_ASSERT(neb10 == M);
16107
- GGML_ASSERT(neb11 == 1);
16108
-
16109
- GGML_ASSERT(nec00 == M);
16110
- GGML_ASSERT(nec01 == D);
16111
- GGML_ASSERT(nec10 == D);
16112
- GGML_ASSERT(nec11 == 1);
16113
-
16114
- // dst cannot be transposed or permuted
16115
- GGML_ASSERT(nb0 == sizeof(float));
16116
- GGML_ASSERT(nb0 <= nb1);
16117
- GGML_ASSERT(nb1 <= nb2);
16118
- GGML_ASSERT(nb2 <= nb3);
16119
-
16120
- if (params->type == GGML_TASK_TYPE_INIT) {
16121
- return;
16122
- }
16123
-
16124
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
16125
- return;
16126
- }
16127
-
16128
- // parallelize by a rows using ggml_vec_dot_f32
16129
-
16130
- // total rows in a
16131
- const int nr = nea1*nea2*nea3;
16132
-
16133
- // rows per thread
16134
- const int dr = (nr + nth - 1)/nth;
16135
-
16136
- // row range for this thread
16137
- const int ir0 = dr*ith;
16138
- const int ir1 = MIN(ir0 + dr, nr);
16139
-
16140
- for (int ir = ir0; ir < ir1; ++ir) {
16141
- // a indices
16142
- const int ia3 = ir/(nea2*nea1);
16143
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
16144
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
16145
-
16146
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
16147
-
16148
- for (int64_t ic = 0; ic < neb01; ++ic) {
16149
- // b0 indices
16150
- const int ib03 = ia3;
16151
- const int ib02 = ia2;
16152
- const int ib01 = ic;
16153
-
16154
- // S indices
16155
- const int i1 = ib01;
16156
-
16157
- ggml_vec_dot_f16(nea0,
16158
- S + i1, 0,
16159
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
16160
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
16161
- }
16162
-
16163
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
16164
- //ggml_vec_gelu_f32(neb01, S, S);
16165
-
16166
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
16167
-
16168
- for (int64_t i = 0; i < M; i++) {
16169
- S16[i] = GGML_FP32_TO_FP16(S[i]);
16170
- }
16171
-
16172
- ggml_vec_gelu_f16(neb01, S16, S16);
16173
-
16174
- {
16175
- // dst indices
16176
- const int i1 = ia1;
16177
- const int i2 = ia2;
16178
- const int i3 = ia3;
16179
-
16180
- for (int64_t ic = 0; ic < nec01; ++ic) {
16181
-
16182
- ggml_vec_dot_f16(neb01,
16183
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16184
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16185
- S16, 0, 1);
16186
- }
16187
-
16188
- ggml_vec_add_f32(nec01,
16189
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16190
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16191
- (float *) c1->data);
16192
- }
16193
- }
16194
- }
16195
-
16196
- static void ggml_compute_forward_flash_ff(
16197
- const struct ggml_compute_params * params,
16198
- struct ggml_tensor * dst) {
16199
-
16200
- const struct ggml_tensor * b0 = dst->src[1];
16201
-
16202
- switch (b0->type) {
16203
- case GGML_TYPE_F16:
16204
- {
16205
- ggml_compute_forward_flash_ff_f16(params, dst);
16206
- } break;
16207
- case GGML_TYPE_F32:
16208
- {
16209
- GGML_ASSERT(false); // TODO
16210
- } break;
16211
- default:
16212
- {
16213
- GGML_ASSERT(false);
16214
- } break;
16215
- }
16216
- }
16217
-
16218
15879
  // ggml_compute_forward_flash_attn_back
16219
15880
 
16220
15881
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -17785,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17785
17446
  {
17786
17447
  ggml_compute_forward_leaky_relu(params, tensor);
17787
17448
  } break;
17788
- case GGML_OP_FLASH_ATTN:
17789
- {
17790
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17791
- GGML_ASSERT(t == 0 || t == 1);
17792
- const bool masked = t != 0;
17793
- ggml_compute_forward_flash_attn(params, masked, tensor);
17794
- } break;
17795
17449
  case GGML_OP_FLASH_ATTN_EXT:
17796
17450
  {
17797
17451
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17798
17452
  } break;
17799
- case GGML_OP_FLASH_FF:
17800
- {
17801
- ggml_compute_forward_flash_ff(params, tensor);
17802
- } break;
17803
17453
  case GGML_OP_FLASH_ATTN_BACK:
17804
17454
  {
17805
17455
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -18169,6 +17819,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18169
17819
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18170
17820
  struct ggml_tensor * src0 = tensor->src[0];
18171
17821
  struct ggml_tensor * src1 = tensor->src[1];
17822
+ struct ggml_tensor * src2 = tensor->src[2];
18172
17823
 
18173
17824
  switch (tensor->op) {
18174
17825
  case GGML_OP_DUP:
@@ -18700,6 +18351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18700
18351
  ggml_rope_back(ctx,
18701
18352
  tensor->grad,
18702
18353
  src1,
18354
+ src2,
18703
18355
  n_dims,
18704
18356
  mode,
18705
18357
  n_ctx,
@@ -18739,6 +18391,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18739
18391
  ggml_rope_impl(ctx,
18740
18392
  tensor->grad,
18741
18393
  src1,
18394
+ src2,
18742
18395
  n_dims,
18743
18396
  mode,
18744
18397
  n_ctx,
@@ -18803,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18803
18456
  {
18804
18457
  GGML_ASSERT(false); // TODO: not implemented
18805
18458
  } break;
18806
- case GGML_OP_FLASH_ATTN:
18807
18459
  case GGML_OP_FLASH_ATTN_EXT:
18808
18460
  {
18809
18461
  struct ggml_tensor * flash_grad = NULL;
@@ -18820,7 +18472,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18820
18472
  masked);
18821
18473
  }
18822
18474
 
18823
- struct ggml_tensor * src2 = tensor->src[2];
18824
18475
  const int64_t elem_q = ggml_nelements(src0);
18825
18476
  const int64_t elem_k = ggml_nelements(src1);
18826
18477
  const int64_t elem_v = ggml_nelements(src2);
@@ -18858,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18858
18509
  zero_table);
18859
18510
  }
18860
18511
  } break;
18861
- case GGML_OP_FLASH_FF:
18862
- {
18863
- GGML_ASSERT(false); // not supported
18864
- } break;
18865
18512
  case GGML_OP_FLASH_ATTN_BACK:
18866
18513
  {
18867
18514
  GGML_ASSERT(false); // not supported
@@ -19548,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19548
19195
  {
19549
19196
  n_tasks = n_threads;
19550
19197
  } break;
19551
- case GGML_OP_FLASH_ATTN:
19552
19198
  case GGML_OP_FLASH_ATTN_EXT:
19553
19199
  {
19554
19200
  n_tasks = n_threads;
19555
19201
  } break;
19556
- case GGML_OP_FLASH_FF:
19557
- {
19558
- n_tasks = n_threads;
19559
- } break;
19560
19202
  case GGML_OP_FLASH_ATTN_BACK:
19561
19203
  {
19562
19204
  n_tasks = n_threads;
@@ -19953,39 +19595,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19953
19595
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19954
19596
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19955
19597
  } break;
19956
- case GGML_OP_FLASH_ATTN:
19957
- {
19958
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19959
-
19960
- if (node->src[1]->type == GGML_TYPE_F32) {
19961
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19962
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19963
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19964
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19965
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19966
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19967
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19968
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19969
- }
19970
- } break;
19971
19598
  case GGML_OP_FLASH_ATTN_EXT:
19972
19599
  {
19973
19600
  const int64_t ne00 = node->src[0]->ne[0]; // D
19974
19601
 
19975
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19976
- } break;
19977
- case GGML_OP_FLASH_FF:
19978
- {
19979
- if (node->src[1]->type == GGML_TYPE_F32) {
19980
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19981
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19982
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19983
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19984
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19985
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19986
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19987
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19988
- }
19602
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19989
19603
  } break;
19990
19604
  case GGML_OP_FLASH_ATTN_BACK:
19991
19605
  {
@@ -21827,11 +21441,7 @@ size_t ggml_quantize_chunk(
21827
21441
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21828
21442
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21829
21443
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21830
- #if QK_K == 64
21831
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21832
- #else
21833
21444
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21834
- #endif
21835
21445
  case GGML_TYPE_F16:
21836
21446
  {
21837
21447
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -23108,6 +22718,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23108
22718
  #endif
23109
22719
  }
23110
22720
 
22721
+ int ggml_cpu_has_avx512_bf16(void) {
22722
+ #if defined(__AVX512BF16__)
22723
+ return 1;
22724
+ #else
22725
+ return 0;
22726
+ #endif
22727
+ }
22728
+
23111
22729
  int ggml_cpu_has_fma(void) {
23112
22730
  #if defined(__FMA__)
23113
22731
  return 1;