llama_cpp 0.15.2 → 0.15.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
406
  int i = 0;
407
407
  #if defined(__AVX512BF16__)
408
408
  for (; i + 32 <= n; i += 32) {
409
- _mm512_storeu_ps(
410
- (__m512 *)(y + i),
411
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
- _mm512_loadu_ps(x + i)));
409
+ _mm512_storeu_si512(
410
+ (__m512i *)(y + i),
411
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i))));
413
413
  }
414
414
  #endif
415
415
  for (; i < n; i++) {
@@ -871,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
871
871
  },
872
872
  [GGML_TYPE_IQ4_XS] = {
873
873
  .type_name = "iq4_xs",
874
- #if QK_K == 64
875
- .blck_size = QK4_NL,
876
- #else
877
874
  .blck_size = QK_K,
878
- #endif
879
875
  .type_size = sizeof(block_iq4_xs),
880
876
  .is_quantized = true,
881
877
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
882
878
  .from_float = quantize_row_iq4_xs,
883
879
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
884
880
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
885
- #if QK_K == 64
886
- .vec_dot_type = GGML_TYPE_Q8_0,
887
- #else
888
881
  .vec_dot_type = GGML_TYPE_Q8_K,
889
- #endif
890
882
  .nrows = 1,
891
883
  },
892
884
  [GGML_TYPE_Q8_K] = {
@@ -1523,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1523
1515
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1524
1516
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1525
1517
 
1518
+ #elif defined(__loongarch_asx)
1519
+
1520
+ #define GGML_SIMD
1521
+
1522
+ // F32 LASX
1523
+ #define GGML_F32_STEP 32
1524
+ #define GGML_F32_EPR 8
1525
+
1526
+ #define GGML_F32x8 __m256
1527
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1528
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1529
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1530
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1531
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1532
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1533
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1534
+ #define GGML_F32x8_REDUCE(res, x) \
1535
+ do { \
1536
+ int offset = GGML_F32_ARR >> 1; \
1537
+ for (int i = 0; i < offset; ++i) { \
1538
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1539
+ } \
1540
+ offset >>= 1; \
1541
+ for (int i = 0; i < offset; ++i) { \
1542
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1543
+ } \
1544
+ offset >>= 1; \
1545
+ for (int i = 0; i < offset; ++i) { \
1546
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1547
+ } \
1548
+ float *tmp_p = (float *)&x[0]; \
1549
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1550
+ } while (0)
1551
+ // TODO: is this optimal ?
1552
+
1553
+ #define GGML_F32_VEC GGML_F32x8
1554
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1555
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1556
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1557
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1558
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1559
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1560
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1561
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1562
+
1563
+ // F16 LASX
1564
+
1565
+ #define GGML_F16_STEP 32
1566
+ #define GGML_F16_EPR 8
1567
+
1568
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
+
1570
+ #define GGML_F32Cx8 __m256
1571
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
+
1574
+ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1575
+ float tmp[8];
1576
+
1577
+ for (int i = 0; i < 8; i++) {
1578
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1579
+ }
1580
+
1581
+ return (__m256)__lasx_xvld(tmp, 0);
1582
+ }
1583
+ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1584
+ float arr[8];
1585
+
1586
+ __lasx_xvst(y, arr, 0);
1587
+
1588
+ for (int i = 0; i < 8; i++)
1589
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1590
+ }
1591
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1593
+
1594
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1595
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1596
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1597
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1598
+
1599
+ #define GGML_F16_VEC GGML_F32Cx8
1600
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1601
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1602
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1603
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1604
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1605
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1606
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1607
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1608
+
1609
+ #elif defined(__loongarch_sx)
1610
+
1611
+ #define GGML_SIMD
1612
+
1613
+ // F32 LSX
1614
+
1615
+ #define GGML_F32_STEP 32
1616
+ #define GGML_F32_EPR 4
1617
+
1618
+ #define GGML_F32x4 __m128
1619
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1620
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1621
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1622
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1623
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1624
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1625
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1626
+ #define GGML_F32x4_REDUCE(res, x) \
1627
+ { \
1628
+ int offset = GGML_F32_ARR >> 1; \
1629
+ for (int i = 0; i < offset; ++i) { \
1630
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1631
+ } \
1632
+ offset >>= 1; \
1633
+ for (int i = 0; i < offset; ++i) { \
1634
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1635
+ } \
1636
+ offset >>= 1; \
1637
+ for (int i = 0; i < offset; ++i) { \
1638
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1639
+ } \
1640
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1641
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1642
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1643
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1644
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1645
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1646
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1647
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1648
+ }
1649
+
1650
+ #define GGML_F32_VEC GGML_F32x4
1651
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1652
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1653
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1654
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1655
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1656
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1657
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1658
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1659
+
1660
+ // F16 LSX
1661
+
1662
+ #define GGML_F16_STEP 32
1663
+ #define GGML_F16_EPR 4
1664
+
1665
+ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1666
+ float tmp[4];
1667
+
1668
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1669
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1670
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1671
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1672
+
1673
+ return __lsx_vld(tmp, 0);
1674
+ }
1675
+
1676
+ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1677
+ float arr[4];
1678
+
1679
+ __lsx_vst(y, arr, 0);
1680
+
1681
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1682
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1683
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1684
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1685
+ }
1686
+
1687
+ #define GGML_F32Cx4 __m128
1688
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1689
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1690
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1691
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1692
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1693
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1694
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1695
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1696
+
1697
+ #define GGML_F16_VEC GGML_F32Cx4
1698
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1699
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1700
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1701
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1702
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1703
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1704
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1705
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1706
+
1526
1707
  #endif
1527
1708
 
1528
1709
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1666,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1666
1847
  __m512 c1 = _mm512_setzero_ps();
1667
1848
  __m512 c2 = _mm512_setzero_ps();
1668
1849
  for (; i + 64 <= n; i += 64) {
1669
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1850
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1851
+ m512bh(_mm512_loadu_si512((y + i))));
1852
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1853
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1673
1854
  }
1674
1855
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
1856
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -2076,7 +2257,7 @@ inline static float ggml_silu_f32(float x) {
2076
2257
  return x/(1.0f + expf(-x));
2077
2258
  }
2078
2259
 
2079
- #if defined(__ARM_NEON)
2260
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2080
2261
 
2081
2262
  // adapted from arm limited optimized routine
2082
2263
  // the maximum error is 1.45358 plus 0.5 ulps
@@ -2288,7 +2469,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2288
2469
  for (; i + 3 < n; i += 4) {
2289
2470
  _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
2471
  }
2291
- #elif defined(__ARM_NEON)
2472
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2292
2473
  for (; i + 3 < n; i += 4) {
2293
2474
  vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
2475
  }
@@ -2335,7 +2516,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
2335
2516
  #endif
2336
2517
  sum += (ggml_float)_mm_cvtss_f32(val);
2337
2518
  }
2338
- #elif defined(__ARM_NEON)
2519
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2339
2520
  for (; i + 3 < n; i += 4) {
2340
2521
  float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
2522
  vdupq_n_f32(max)));
@@ -2489,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2489
2670
  "ARGSORT",
2490
2671
  "LEAKY_RELU",
2491
2672
 
2492
- "FLASH_ATTN",
2493
2673
  "FLASH_ATTN_EXT",
2494
- "FLASH_FF",
2495
2674
  "FLASH_ATTN_BACK",
2496
2675
  "SSM_CONV",
2497
2676
  "SSM_SCAN",
@@ -2517,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2517
2696
  "CROSS_ENTROPY_LOSS_BACK",
2518
2697
  };
2519
2698
 
2520
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2699
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2521
2700
 
2522
2701
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2523
2702
  "none",
@@ -2579,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2579
2758
  "argsort(x)",
2580
2759
  "leaky_relu(x)",
2581
2760
 
2582
- "flash_attn(x)",
2583
2761
  "flash_attn_ext(x)",
2584
- "flash_ff(x)",
2585
2762
  "flash_attn_back(x)",
2586
2763
  "ssm_conv(x)",
2587
2764
  "ssm_scan(x)",
@@ -2607,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2607
2784
  "cross_entropy_loss_back(x,y)",
2608
2785
  };
2609
2786
 
2610
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2787
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2611
2788
 
2612
2789
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2613
2790
 
@@ -6042,6 +6219,7 @@ static struct ggml_tensor * ggml_rope_impl(
6042
6219
  struct ggml_context * ctx,
6043
6220
  struct ggml_tensor * a,
6044
6221
  struct ggml_tensor * b,
6222
+ struct ggml_tensor * c,
6045
6223
  int n_dims,
6046
6224
  int mode,
6047
6225
  int n_ctx,
@@ -6055,10 +6233,17 @@ static struct ggml_tensor * ggml_rope_impl(
6055
6233
  float xpos_base,
6056
6234
  bool xpos_down,
6057
6235
  bool inplace) {
6236
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6237
+
6058
6238
  GGML_ASSERT(ggml_is_vector(b));
6059
6239
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6060
6240
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6061
6241
 
6242
+ if (c) {
6243
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6244
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6245
+ }
6246
+
6062
6247
  bool is_node = false;
6063
6248
 
6064
6249
  if (a->grad) {
@@ -6082,6 +6267,7 @@ static struct ggml_tensor * ggml_rope_impl(
6082
6267
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6083
6268
  result->src[0] = a;
6084
6269
  result->src[1] = b;
6270
+ result->src[2] = c;
6085
6271
 
6086
6272
  return result;
6087
6273
  }
@@ -6094,7 +6280,7 @@ struct ggml_tensor * ggml_rope(
6094
6280
  int mode,
6095
6281
  int n_ctx) {
6096
6282
  return ggml_rope_impl(
6097
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6283
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6098
6284
  );
6099
6285
  }
6100
6286
 
@@ -6106,14 +6292,15 @@ struct ggml_tensor * ggml_rope_inplace(
6106
6292
  int mode,
6107
6293
  int n_ctx) {
6108
6294
  return ggml_rope_impl(
6109
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6295
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6110
6296
  );
6111
6297
  }
6112
6298
 
6113
- struct ggml_tensor * ggml_rope_custom(
6299
+ struct ggml_tensor * ggml_rope_ext(
6114
6300
  struct ggml_context * ctx,
6115
6301
  struct ggml_tensor * a,
6116
6302
  struct ggml_tensor * b,
6303
+ struct ggml_tensor * c,
6117
6304
  int n_dims,
6118
6305
  int mode,
6119
6306
  int n_ctx,
@@ -6125,15 +6312,16 @@ struct ggml_tensor * ggml_rope_custom(
6125
6312
  float beta_fast,
6126
6313
  float beta_slow) {
6127
6314
  return ggml_rope_impl(
6128
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6315
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6129
6316
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6130
6317
  );
6131
6318
  }
6132
6319
 
6133
- struct ggml_tensor * ggml_rope_custom_inplace(
6320
+ struct ggml_tensor * ggml_rope_ext_inplace(
6134
6321
  struct ggml_context * ctx,
6135
6322
  struct ggml_tensor * a,
6136
6323
  struct ggml_tensor * b,
6324
+ struct ggml_tensor * c,
6137
6325
  int n_dims,
6138
6326
  int mode,
6139
6327
  int n_ctx,
@@ -6145,19 +6333,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6145
6333
  float beta_fast,
6146
6334
  float beta_slow) {
6147
6335
  return ggml_rope_impl(
6148
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6336
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6149
6337
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6150
6338
  );
6151
6339
  }
6152
6340
 
6153
- struct ggml_tensor * ggml_rope_xpos_inplace(
6341
+ struct ggml_tensor * ggml_rope_custom(
6342
+ struct ggml_context * ctx,
6343
+ struct ggml_tensor * a,
6344
+ struct ggml_tensor * b,
6345
+ int n_dims,
6346
+ int mode,
6347
+ int n_ctx,
6348
+ int n_orig_ctx,
6349
+ float freq_base,
6350
+ float freq_scale,
6351
+ float ext_factor,
6352
+ float attn_factor,
6353
+ float beta_fast,
6354
+ float beta_slow) {
6355
+ return ggml_rope_impl(
6356
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6357
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6358
+ );
6359
+ }
6360
+
6361
+ struct ggml_tensor * ggml_rope_custom_inplace(
6154
6362
  struct ggml_context * ctx,
6155
6363
  struct ggml_tensor * a,
6156
6364
  struct ggml_tensor * b,
6157
6365
  int n_dims,
6158
- float base,
6159
- bool down) {
6160
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6366
+ int mode,
6367
+ int n_ctx,
6368
+ int n_orig_ctx,
6369
+ float freq_base,
6370
+ float freq_scale,
6371
+ float ext_factor,
6372
+ float attn_factor,
6373
+ float beta_fast,
6374
+ float beta_slow) {
6375
+ return ggml_rope_impl(
6376
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6377
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6378
+ );
6161
6379
  }
6162
6380
 
6163
6381
  // ggml_rope_back
@@ -6166,6 +6384,7 @@ struct ggml_tensor * ggml_rope_back(
6166
6384
  struct ggml_context * ctx,
6167
6385
  struct ggml_tensor * a,
6168
6386
  struct ggml_tensor * b,
6387
+ struct ggml_tensor * c,
6169
6388
  int n_dims,
6170
6389
  int mode,
6171
6390
  int n_ctx,
@@ -6181,6 +6400,7 @@ struct ggml_tensor * ggml_rope_back(
6181
6400
  GGML_ASSERT(ggml_is_vector(b));
6182
6401
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6183
6402
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6403
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6184
6404
 
6185
6405
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6186
6406
 
@@ -6724,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
6724
6944
  return result;
6725
6945
  }
6726
6946
 
6727
- // ggml_flash_attn
6728
-
6729
- struct ggml_tensor * ggml_flash_attn(
6730
- struct ggml_context * ctx,
6731
- struct ggml_tensor * q,
6732
- struct ggml_tensor * k,
6733
- struct ggml_tensor * v,
6734
- bool masked) {
6735
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6736
- // TODO: check if vT can be multiplied by (k*qT)
6737
-
6738
- bool is_node = false;
6739
-
6740
- if (q->grad || k->grad || v->grad) {
6741
- is_node = true;
6742
- }
6743
-
6744
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6745
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6746
-
6747
- int32_t t = masked ? 1 : 0;
6748
- ggml_set_op_params(result, &t, sizeof(t));
6749
-
6750
- result->op = GGML_OP_FLASH_ATTN;
6751
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6752
- result->src[0] = q;
6753
- result->src[1] = k;
6754
- result->src[2] = v;
6755
-
6756
- return result;
6757
- }
6758
-
6759
6947
  // ggml_flash_attn_ext
6760
6948
 
6761
6949
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6815,38 +7003,6 @@ void ggml_flash_attn_ext_set_prec(
6815
7003
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6816
7004
  }
6817
7005
 
6818
- // ggml_flash_ff
6819
-
6820
- struct ggml_tensor * ggml_flash_ff(
6821
- struct ggml_context * ctx,
6822
- struct ggml_tensor * a,
6823
- struct ggml_tensor * b0,
6824
- struct ggml_tensor * b1,
6825
- struct ggml_tensor * c0,
6826
- struct ggml_tensor * c1) {
6827
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6828
- // TODO: more checks
6829
-
6830
- bool is_node = false;
6831
-
6832
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6833
- is_node = true;
6834
- }
6835
-
6836
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6837
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6838
-
6839
- result->op = GGML_OP_FLASH_FF;
6840
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6841
- result->src[0] = a;
6842
- result->src[1] = b0;
6843
- result->src[2] = b1;
6844
- result->src[3] = c0;
6845
- result->src[4] = c1;
6846
-
6847
- return result;
6848
- }
6849
-
6850
7006
  // ggml_flash_attn_back
6851
7007
 
6852
7008
  struct ggml_tensor * ggml_flash_attn_back(
@@ -6856,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6856
7012
  struct ggml_tensor * v,
6857
7013
  struct ggml_tensor * d,
6858
7014
  bool masked) {
7015
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7016
+
6859
7017
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6860
7018
  // TODO: check if vT can be multiplied by (k*qT)
6861
7019
 
@@ -14115,6 +14273,7 @@ static void ggml_compute_forward_rope_f32(
14115
14273
 
14116
14274
  const struct ggml_tensor * src0 = dst->src[0];
14117
14275
  const struct ggml_tensor * src1 = dst->src[1];
14276
+ const struct ggml_tensor * src2 = dst->src[2];
14118
14277
 
14119
14278
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14120
14279
  return;
@@ -14174,6 +14333,17 @@ static void ggml_compute_forward_rope_f32(
14174
14333
  const bool is_neox = mode & 2;
14175
14334
  const bool is_glm = mode & 4;
14176
14335
 
14336
+ const float * freq_factors = NULL;
14337
+ if (is_neox) {
14338
+ if (src2 != NULL) {
14339
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14340
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14341
+ freq_factors = (const float *) src2->data;
14342
+ }
14343
+ } else {
14344
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14345
+ }
14346
+
14177
14347
  // backward process uses inverse rotation by cos and sin.
14178
14348
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14179
14349
  // this essentially just switches the sign of sin.
@@ -14250,10 +14420,11 @@ static void ggml_compute_forward_rope_f32(
14250
14420
 
14251
14421
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14252
14422
  float cur_rot = inv_ndims * ic - ib;
14423
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14253
14424
 
14254
14425
  float cos_theta, sin_theta;
14255
14426
  rope_yarn(
14256
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14427
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14257
14428
  &cos_theta, &sin_theta
14258
14429
  );
14259
14430
  sin_theta *= sin_sign;
@@ -14286,6 +14457,7 @@ static void ggml_compute_forward_rope_f32(
14286
14457
  }
14287
14458
  }
14288
14459
 
14460
+ // TODO: deduplicate f16/f32 code
14289
14461
  static void ggml_compute_forward_rope_f16(
14290
14462
  const struct ggml_compute_params * params,
14291
14463
  struct ggml_tensor * dst,
@@ -14293,6 +14465,7 @@ static void ggml_compute_forward_rope_f16(
14293
14465
 
14294
14466
  const struct ggml_tensor * src0 = dst->src[0];
14295
14467
  const struct ggml_tensor * src1 = dst->src[1];
14468
+ const struct ggml_tensor * src2 = dst->src[2];
14296
14469
 
14297
14470
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14298
14471
  return;
@@ -14345,6 +14518,17 @@ static void ggml_compute_forward_rope_f16(
14345
14518
  const bool is_neox = mode & 2;
14346
14519
  const bool is_glm = mode & 4;
14347
14520
 
14521
+ const float * freq_factors = NULL;
14522
+ if (is_neox) {
14523
+ if (src2 != NULL) {
14524
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14525
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14526
+ freq_factors = (const float *) src2->data;
14527
+ }
14528
+ } else {
14529
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14530
+ }
14531
+
14348
14532
  // backward process uses inverse rotation by cos and sin.
14349
14533
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14350
14534
  // this essentially just switches the sign of sin.
@@ -14417,10 +14601,11 @@ static void ggml_compute_forward_rope_f16(
14417
14601
 
14418
14602
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14419
14603
  float cur_rot = inv_ndims * ic - ib;
14604
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14420
14605
 
14421
14606
  float cos_theta, sin_theta;
14422
14607
  rope_yarn(
14423
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14608
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14424
14609
  &cos_theta, &sin_theta
14425
14610
  );
14426
14611
  sin_theta *= sin_sign;
@@ -15458,17 +15643,15 @@ static void ggml_compute_forward_argsort(
15458
15643
  }
15459
15644
  }
15460
15645
 
15461
- // ggml_compute_forward_flash_attn
15646
+ // ggml_compute_forward_flash_attn_ext
15462
15647
 
15463
- static void ggml_compute_forward_flash_attn_f32(
15648
+ static void ggml_compute_forward_flash_attn_ext_f16(
15464
15649
  const struct ggml_compute_params * params,
15465
- const bool masked,
15650
+ const struct ggml_tensor * q,
15651
+ const struct ggml_tensor * k,
15652
+ const struct ggml_tensor * v,
15653
+ const struct ggml_tensor * mask,
15466
15654
  struct ggml_tensor * dst) {
15467
-
15468
- const struct ggml_tensor * q = dst->src[0];
15469
- const struct ggml_tensor * k = dst->src[1];
15470
- const struct ggml_tensor * v = dst->src[2];
15471
-
15472
15655
  int64_t t0 = ggml_perf_time_us();
15473
15656
  UNUSED(t0);
15474
15657
 
@@ -15486,409 +15669,18 @@ static void ggml_compute_forward_flash_attn_f32(
15486
15669
 
15487
15670
  const int64_t D = neq0;
15488
15671
  const int64_t N = neq1;
15489
- const int64_t P = nek1 - N;
15490
- const int64_t M = P + N;
15491
-
15492
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15493
15672
 
15494
15673
  GGML_ASSERT(ne0 == D);
15495
- GGML_ASSERT(ne1 == N);
15496
- GGML_ASSERT(P >= 0);
15674
+ GGML_ASSERT(ne2 == N);
15497
15675
 
15498
- GGML_ASSERT(nbq0 == sizeof(float));
15499
- GGML_ASSERT(nbk0 == sizeof(float));
15500
- GGML_ASSERT(nbv0 == sizeof(float));
15676
+ // input tensor rows must be contiguous
15677
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15678
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15679
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15501
15680
 
15502
15681
  GGML_ASSERT(neq0 == D);
15503
15682
  GGML_ASSERT(nek0 == D);
15504
- GGML_ASSERT(nev1 == D);
15505
-
15506
- GGML_ASSERT(neq1 == N);
15507
- GGML_ASSERT(nek1 == N + P);
15508
- GGML_ASSERT(nev1 == D);
15509
-
15510
- // dst cannot be transposed or permuted
15511
- GGML_ASSERT(nb0 == sizeof(float));
15512
- GGML_ASSERT(nb0 <= nb1);
15513
- GGML_ASSERT(nb1 <= nb2);
15514
- GGML_ASSERT(nb2 <= nb3);
15515
-
15516
- if (params->type == GGML_TASK_TYPE_INIT) {
15517
- return;
15518
- }
15519
-
15520
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15521
- return;
15522
- }
15523
-
15524
- // parallelize by q rows using ggml_vec_dot_f32
15525
-
15526
- // total rows in q
15527
- const int nr = neq1*neq2*neq3;
15528
-
15529
- // rows per thread
15530
- const int dr = (nr + nth - 1)/nth;
15531
-
15532
- // row range for this thread
15533
- const int ir0 = dr*ith;
15534
- const int ir1 = MIN(ir0 + dr, nr);
15535
-
15536
- const float scale = 1.0f/sqrtf(D);
15537
-
15538
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15539
-
15540
- for (int ir = ir0; ir < ir1; ++ir) {
15541
- // q indices
15542
- const int iq3 = ir/(neq2*neq1);
15543
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15544
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15545
-
15546
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15547
-
15548
- for (int i = M; i < Mup; ++i) {
15549
- S[i] = -INFINITY;
15550
- }
15551
-
15552
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15553
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15554
- // k indices
15555
- const int ik3 = iq3;
15556
- const int ik2 = iq2 % nek2;
15557
- const int ik1 = ic;
15558
-
15559
- // S indices
15560
- const int i1 = ik1;
15561
-
15562
- ggml_vec_dot_f32(neq0,
15563
- S + i1, 0,
15564
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15565
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15566
- }
15567
-
15568
- // scale
15569
- ggml_vec_scale_f32(masked_begin, S, scale);
15570
-
15571
- for (int64_t i = masked_begin; i < M; i++) {
15572
- S[i] = -INFINITY;
15573
- }
15574
-
15575
- // softmax
15576
- // exclude known -INF S[..] values from max and loop
15577
- // dont forget to set their SW values to zero
15578
- {
15579
- float max = -INFINITY;
15580
- ggml_vec_max_f32(masked_begin, &max, S);
15581
-
15582
- ggml_float sum = 0.0;
15583
- {
15584
- #ifdef GGML_SOFT_MAX_ACCELERATE
15585
- max = -max;
15586
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15587
- vvexpf(S, S, &Mup);
15588
- ggml_vec_sum_f32(Mup, &sum, S);
15589
- #else
15590
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15591
- #endif
15592
- }
15593
-
15594
- assert(sum > 0.0);
15595
-
15596
- sum = 1.0/sum;
15597
- ggml_vec_scale_f32(masked_begin, S, sum);
15598
-
15599
- #ifndef NDEBUG
15600
- for (int i = 0; i < masked_begin; ++i) {
15601
- assert(!isnan(S[i]));
15602
- assert(!isinf(S[i]));
15603
- }
15604
- #endif
15605
- }
15606
-
15607
- for (int64_t ic = 0; ic < nev1; ++ic) {
15608
- // dst indices
15609
- const int i1 = iq1;
15610
- const int i2 = iq2;
15611
- const int i3 = iq3;
15612
-
15613
- // v indices
15614
- const int iv2 = iq2 % nev2;
15615
- const int iv3 = iq3;
15616
-
15617
- ggml_vec_dot_f32(masked_begin,
15618
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15619
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15620
- S, 0, 1);
15621
- }
15622
- }
15623
- }
15624
-
15625
- static void ggml_compute_forward_flash_attn_f16(
15626
- const struct ggml_compute_params * params,
15627
- const bool masked,
15628
- struct ggml_tensor * dst) {
15629
-
15630
- const struct ggml_tensor * q = dst->src[0];
15631
- const struct ggml_tensor * k = dst->src[1];
15632
- const struct ggml_tensor * v = dst->src[2];
15633
-
15634
- int64_t t0 = ggml_perf_time_us();
15635
- UNUSED(t0);
15636
-
15637
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15638
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15639
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15640
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15641
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15642
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15643
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15644
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15645
-
15646
- const int ith = params->ith;
15647
- const int nth = params->nth;
15648
-
15649
- const int64_t D = neq0;
15650
- const int64_t N = neq1;
15651
- const int64_t P = nek1 - N;
15652
- const int64_t M = P + N;
15653
-
15654
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15655
-
15656
- GGML_ASSERT(ne0 == D);
15657
- GGML_ASSERT(ne1 == N);
15658
- GGML_ASSERT(P >= 0);
15659
-
15660
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15661
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15662
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15663
-
15664
- GGML_ASSERT(neq0 == D);
15665
- GGML_ASSERT(nek0 == D);
15666
- GGML_ASSERT(nev1 == D);
15667
-
15668
- GGML_ASSERT(neq1 == N);
15669
- GGML_ASSERT(nek1 == N + P);
15670
- GGML_ASSERT(nev1 == D);
15671
-
15672
- // dst cannot be transposed or permuted
15673
- GGML_ASSERT(nb0 == sizeof(float));
15674
- GGML_ASSERT(nb0 <= nb1);
15675
- GGML_ASSERT(nb1 <= nb2);
15676
- GGML_ASSERT(nb2 <= nb3);
15677
-
15678
- if (params->type == GGML_TASK_TYPE_INIT) {
15679
- return;
15680
- }
15681
-
15682
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15683
- return;
15684
- }
15685
-
15686
- // parallelize by q rows using ggml_vec_dot_f32
15687
-
15688
- // total rows in q
15689
- const int nr = neq1*neq2*neq3;
15690
-
15691
- // rows per thread
15692
- const int dr = (nr + nth - 1)/nth;
15693
-
15694
- // row range for this thread
15695
- const int ir0 = dr*ith;
15696
- const int ir1 = MIN(ir0 + dr, nr);
15697
-
15698
- const float scale = 1.0f/sqrtf(D);
15699
-
15700
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15701
-
15702
- for (int ir = ir0; ir < ir1; ++ir) {
15703
- // q indices
15704
- const int iq3 = ir/(neq2*neq1);
15705
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15706
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15707
-
15708
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15709
-
15710
- for (int i = M; i < Mup; ++i) {
15711
- S[i] = -INFINITY;
15712
- }
15713
-
15714
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15715
- for (int64_t ic = 0; ic < nek1; ++ic) {
15716
- // k indices
15717
- const int ik3 = iq3;
15718
- const int ik2 = iq2 % nek2;
15719
- const int ik1 = ic;
15720
-
15721
- // S indices
15722
- const int i1 = ik1;
15723
-
15724
- ggml_vec_dot_f16(neq0,
15725
- S + i1, 0,
15726
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15727
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15728
- }
15729
- } else {
15730
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15731
- // k indices
15732
- const int ik3 = iq3;
15733
- const int ik2 = iq2 % nek2;
15734
- const int ik1 = ic;
15735
-
15736
- // S indices
15737
- const int i1 = ik1;
15738
-
15739
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15740
- S + i1,
15741
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15742
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15743
- }
15744
- }
15745
-
15746
- // scale
15747
- ggml_vec_scale_f32(nek1, S, scale);
15748
-
15749
- if (masked) {
15750
- for (int64_t i = P; i < M; i++) {
15751
- if (i > P + iq1) {
15752
- S[i] = -INFINITY;
15753
- }
15754
- }
15755
- }
15756
-
15757
- // softmax
15758
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15759
- // dont forget to set their S values to zero
15760
- {
15761
- float max = -INFINITY;
15762
- ggml_vec_max_f32(M, &max, S);
15763
-
15764
- ggml_float sum = 0.0;
15765
- {
15766
- #ifdef GGML_SOFT_MAX_ACCELERATE
15767
- max = -max;
15768
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15769
- vvexpf(S, S, &Mup);
15770
- ggml_vec_sum_f32(Mup, &sum, S);
15771
- #else
15772
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15773
- #endif
15774
- }
15775
-
15776
- assert(sum > 0.0);
15777
-
15778
- sum = 1.0/sum;
15779
- ggml_vec_scale_f32(M, S, sum);
15780
-
15781
- #ifndef NDEBUG
15782
- for (int i = 0; i < M; ++i) {
15783
- assert(!isnan(S[i]));
15784
- assert(!isinf(S[i]));
15785
- }
15786
- #endif
15787
- }
15788
-
15789
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15790
-
15791
- for (int64_t i = 0; i < M; i++) {
15792
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15793
- }
15794
-
15795
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15796
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15797
- for (int64_t ic = 0; ic < nev1; ++ic) {
15798
- // dst indices
15799
- const int i1 = iq1;
15800
- const int i2 = iq2;
15801
- const int i3 = iq3;
15802
-
15803
- // v indices
15804
- const int iv2 = iq2 % nev2;
15805
- const int iv3 = iq3;
15806
-
15807
- ggml_vec_dot_f16(nev0,
15808
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15809
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15810
- S16, 0, 1);
15811
- }
15812
- } else {
15813
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15814
- // dst indices
15815
- const int i1 = iq1;
15816
- const int i2 = iq2;
15817
- const int i3 = iq3;
15818
-
15819
- // v indices
15820
- const int iv2 = iq2 % nev2;
15821
- const int iv3 = iq3;
15822
-
15823
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15824
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15825
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15826
- S16);
15827
- }
15828
- }
15829
- }
15830
- }
15831
-
15832
- static void ggml_compute_forward_flash_attn(
15833
- const struct ggml_compute_params * params,
15834
- const bool masked,
15835
- struct ggml_tensor * dst) {
15836
-
15837
- const struct ggml_tensor * q = dst->src[0];
15838
-
15839
- switch (q->type) {
15840
- case GGML_TYPE_F16:
15841
- {
15842
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15843
- } break;
15844
- case GGML_TYPE_F32:
15845
- {
15846
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15847
- } break;
15848
- default:
15849
- {
15850
- GGML_ASSERT(false);
15851
- } break;
15852
- }
15853
- }
15854
-
15855
- // ggml_compute_forward_flash_attn_ext
15856
-
15857
- static void ggml_compute_forward_flash_attn_ext_f16(
15858
- const struct ggml_compute_params * params,
15859
- const struct ggml_tensor * q,
15860
- const struct ggml_tensor * k,
15861
- const struct ggml_tensor * v,
15862
- const struct ggml_tensor * mask,
15863
- struct ggml_tensor * dst) {
15864
- int64_t t0 = ggml_perf_time_us();
15865
- UNUSED(t0);
15866
-
15867
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15868
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15869
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15870
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15871
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15872
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15873
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15874
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15875
-
15876
- const int ith = params->ith;
15877
- const int nth = params->nth;
15878
-
15879
- const int64_t D = neq0;
15880
- const int64_t N = neq1;
15881
-
15882
- GGML_ASSERT(ne0 == D);
15883
- GGML_ASSERT(ne2 == N);
15884
-
15885
- GGML_ASSERT(nbq0 == sizeof(float));
15886
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15888
-
15889
- GGML_ASSERT(neq0 == D);
15890
- GGML_ASSERT(nek0 == D);
15891
- GGML_ASSERT(nev0 == D);
15683
+ GGML_ASSERT(nev0 == D);
15892
15684
 
15893
15685
  GGML_ASSERT(neq1 == N);
15894
15686
  GGML_ASSERT(nev0 == D);
@@ -15938,6 +15730,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15938
15730
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
15731
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15940
15732
 
15733
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15734
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15735
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15736
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15737
+
15941
15738
  // loop over n_batch and n_head
15942
15739
  for (int ir = ir0; ir < ir1; ++ir) {
15943
15740
  // q indices
@@ -15945,17 +15742,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15945
15742
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15946
15743
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15947
15744
 
15948
- const uint32_t h = iq2; // head
15745
+ const uint32_t h = iq2; // head index
15949
15746
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
15747
 
15951
- float S = 0.0f;
15952
- float M = -INFINITY;
15748
+ float S = 0.0f; // sum
15749
+ float M = -INFINITY; // maximum KQ value
15953
15750
 
15954
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15751
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15752
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15753
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15754
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15957
15755
 
15958
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15756
+ if (v->type == GGML_TYPE_F16) {
15757
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15758
+ } else {
15759
+ memset(VKQ32, 0, D*sizeof(float));
15760
+ }
15959
15761
 
15960
15762
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15961
15763
 
@@ -15967,6 +15769,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15967
15769
  const int iv3 = iq3 / rv3;
15968
15770
  const int iv2 = iq2 / rv2;
15969
15771
 
15772
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15773
+ q_to_vec_dot(pq, Q_q, D);
15774
+
15970
15775
  // online softmax / attention
15971
15776
  // loop over n_kv and n_head_kv
15972
15777
  // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +15781,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15976
15781
  continue;
15977
15782
  }
15978
15783
 
15979
- float s;
15784
+ float s; // KQ value
15980
15785
 
15981
- // convert Q to F16 in V32
15982
- {
15983
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15786
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15787
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15984
15788
 
15985
- for (int64_t d = 0; d < D; ++d) {
15986
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987
- }
15988
- }
15789
+ s = s*scale + mv; // scale KQ value and apply mask
15989
15790
 
15990
- ggml_vec_dot_f16(D,
15991
- &s, 0,
15992
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993
- Q16, 0, 1);
15791
+ const float Mold = M;
15994
15792
 
15995
- s = s*scale + mv;
15793
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15794
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15996
15795
 
15997
- const float Mold = M;
15796
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15998
15797
 
15999
- float ms = 1.0f;
16000
- float vs = 1.0f;
15798
+ if (v->type== GGML_TYPE_F16) {
15799
+ if (s > M) {
15800
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15801
+ M = s;
15802
+ ms = expf(Mold - M);
16001
15803
 
16002
- if (s > M) {
16003
- M = s;
16004
- ms = expf(Mold - M);
15804
+ // V = V*expf(Mold - M)
15805
+ ggml_vec_scale_f16(D, VKQ16, ms);
15806
+ } else {
15807
+ // no new maximum, ms == 1.0f, vs != 1.0f
15808
+ vs = expf(s - M);
15809
+ }
16005
15810
 
16006
- // V = V*expf(Mold - M)
16007
- ggml_vec_scale_f16(D, V16, ms);
15811
+ // V += v*expf(s - M)
15812
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
16008
15813
  } else {
16009
- vs = expf(s - M);
16010
- }
15814
+ if (s > M) {
15815
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15816
+ M = s;
15817
+ ms = expf(Mold - M);
16011
15818
 
16012
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15819
+ // V = V*expf(Mold - M)
15820
+ ggml_vec_scale_f32(D, VKQ32, ms);
15821
+ } else {
15822
+ // no new maximum, ms == 1.0f, vs != 1.0f
15823
+ vs = expf(s - M);
15824
+ }
16013
15825
 
16014
- // V += v*expf(s - M)
16015
- ggml_vec_mad_f16(D, V16, v16, vs);
15826
+ v_to_float(v_data, V32, D);
16016
15827
 
16017
- S = S*ms + vs;
15828
+ // V += v*expf(s - M)
15829
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15830
+ }
15831
+
15832
+ S = S*ms + vs; // scale and increment sum with partial sum
16018
15833
  }
16019
15834
 
16020
- // V /= S
16021
- for (int64_t d = 0; d < D; ++d) {
16022
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15835
+ if (v->type == GGML_TYPE_F16) {
15836
+ for (int64_t d = 0; d < D; ++d) {
15837
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15838
+ }
16023
15839
  }
16024
15840
 
15841
+ // V /= S
15842
+ const float S_inv = 1.0f/S;
15843
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15844
+
16025
15845
  // dst indices
16026
15846
  const int i1 = iq1;
16027
15847
  const int i2 = iq2;
@@ -16031,7 +15851,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
16031
15851
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16032
15852
 
16033
15853
  // permute(0, 2, 1, 3)
16034
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15854
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
16035
15855
  }
16036
15856
  }
16037
15857
 
@@ -16056,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
16056
15876
  }
16057
15877
  }
16058
15878
 
16059
- // ggml_compute_forward_flash_ff
16060
-
16061
- static void ggml_compute_forward_flash_ff_f16(
16062
- const struct ggml_compute_params * params,
16063
- struct ggml_tensor * dst) {
16064
-
16065
- const struct ggml_tensor * a = dst->src[0]; // F16
16066
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
16067
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
16068
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
16069
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
16070
-
16071
- int64_t t0 = ggml_perf_time_us();
16072
- UNUSED(t0);
16073
-
16074
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
16075
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
16076
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
16077
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
16078
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
16079
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
16080
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
16081
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
16082
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
16083
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
16084
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
16085
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
16086
-
16087
- const int ith = params->ith;
16088
- const int nth = params->nth;
16089
-
16090
- const int64_t D = nea0;
16091
- //const int64_t N = nea1;
16092
- const int64_t M = neb01;
16093
-
16094
- GGML_ASSERT(ne0 == nea0);
16095
- GGML_ASSERT(ne1 == nea1);
16096
- GGML_ASSERT(ne2 == nea2);
16097
-
16098
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
16099
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
16100
- GGML_ASSERT(nbb10 == sizeof(float));
16101
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
16102
- GGML_ASSERT(nbc10 == sizeof(float));
16103
-
16104
- GGML_ASSERT(neb00 == D);
16105
- GGML_ASSERT(neb01 == M);
16106
- GGML_ASSERT(neb10 == M);
16107
- GGML_ASSERT(neb11 == 1);
16108
-
16109
- GGML_ASSERT(nec00 == M);
16110
- GGML_ASSERT(nec01 == D);
16111
- GGML_ASSERT(nec10 == D);
16112
- GGML_ASSERT(nec11 == 1);
16113
-
16114
- // dst cannot be transposed or permuted
16115
- GGML_ASSERT(nb0 == sizeof(float));
16116
- GGML_ASSERT(nb0 <= nb1);
16117
- GGML_ASSERT(nb1 <= nb2);
16118
- GGML_ASSERT(nb2 <= nb3);
16119
-
16120
- if (params->type == GGML_TASK_TYPE_INIT) {
16121
- return;
16122
- }
16123
-
16124
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
16125
- return;
16126
- }
16127
-
16128
- // parallelize by a rows using ggml_vec_dot_f32
16129
-
16130
- // total rows in a
16131
- const int nr = nea1*nea2*nea3;
16132
-
16133
- // rows per thread
16134
- const int dr = (nr + nth - 1)/nth;
16135
-
16136
- // row range for this thread
16137
- const int ir0 = dr*ith;
16138
- const int ir1 = MIN(ir0 + dr, nr);
16139
-
16140
- for (int ir = ir0; ir < ir1; ++ir) {
16141
- // a indices
16142
- const int ia3 = ir/(nea2*nea1);
16143
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
16144
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
16145
-
16146
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
16147
-
16148
- for (int64_t ic = 0; ic < neb01; ++ic) {
16149
- // b0 indices
16150
- const int ib03 = ia3;
16151
- const int ib02 = ia2;
16152
- const int ib01 = ic;
16153
-
16154
- // S indices
16155
- const int i1 = ib01;
16156
-
16157
- ggml_vec_dot_f16(nea0,
16158
- S + i1, 0,
16159
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
16160
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
16161
- }
16162
-
16163
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
16164
- //ggml_vec_gelu_f32(neb01, S, S);
16165
-
16166
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
16167
-
16168
- for (int64_t i = 0; i < M; i++) {
16169
- S16[i] = GGML_FP32_TO_FP16(S[i]);
16170
- }
16171
-
16172
- ggml_vec_gelu_f16(neb01, S16, S16);
16173
-
16174
- {
16175
- // dst indices
16176
- const int i1 = ia1;
16177
- const int i2 = ia2;
16178
- const int i3 = ia3;
16179
-
16180
- for (int64_t ic = 0; ic < nec01; ++ic) {
16181
-
16182
- ggml_vec_dot_f16(neb01,
16183
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16184
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16185
- S16, 0, 1);
16186
- }
16187
-
16188
- ggml_vec_add_f32(nec01,
16189
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16190
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16191
- (float *) c1->data);
16192
- }
16193
- }
16194
- }
16195
-
16196
- static void ggml_compute_forward_flash_ff(
16197
- const struct ggml_compute_params * params,
16198
- struct ggml_tensor * dst) {
16199
-
16200
- const struct ggml_tensor * b0 = dst->src[1];
16201
-
16202
- switch (b0->type) {
16203
- case GGML_TYPE_F16:
16204
- {
16205
- ggml_compute_forward_flash_ff_f16(params, dst);
16206
- } break;
16207
- case GGML_TYPE_F32:
16208
- {
16209
- GGML_ASSERT(false); // TODO
16210
- } break;
16211
- default:
16212
- {
16213
- GGML_ASSERT(false);
16214
- } break;
16215
- }
16216
- }
16217
-
16218
15879
  // ggml_compute_forward_flash_attn_back
16219
15880
 
16220
15881
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -17785,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17785
17446
  {
17786
17447
  ggml_compute_forward_leaky_relu(params, tensor);
17787
17448
  } break;
17788
- case GGML_OP_FLASH_ATTN:
17789
- {
17790
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17791
- GGML_ASSERT(t == 0 || t == 1);
17792
- const bool masked = t != 0;
17793
- ggml_compute_forward_flash_attn(params, masked, tensor);
17794
- } break;
17795
17449
  case GGML_OP_FLASH_ATTN_EXT:
17796
17450
  {
17797
17451
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17798
17452
  } break;
17799
- case GGML_OP_FLASH_FF:
17800
- {
17801
- ggml_compute_forward_flash_ff(params, tensor);
17802
- } break;
17803
17453
  case GGML_OP_FLASH_ATTN_BACK:
17804
17454
  {
17805
17455
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -18169,6 +17819,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18169
17819
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18170
17820
  struct ggml_tensor * src0 = tensor->src[0];
18171
17821
  struct ggml_tensor * src1 = tensor->src[1];
17822
+ struct ggml_tensor * src2 = tensor->src[2];
18172
17823
 
18173
17824
  switch (tensor->op) {
18174
17825
  case GGML_OP_DUP:
@@ -18700,6 +18351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18700
18351
  ggml_rope_back(ctx,
18701
18352
  tensor->grad,
18702
18353
  src1,
18354
+ src2,
18703
18355
  n_dims,
18704
18356
  mode,
18705
18357
  n_ctx,
@@ -18739,6 +18391,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18739
18391
  ggml_rope_impl(ctx,
18740
18392
  tensor->grad,
18741
18393
  src1,
18394
+ src2,
18742
18395
  n_dims,
18743
18396
  mode,
18744
18397
  n_ctx,
@@ -18803,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18803
18456
  {
18804
18457
  GGML_ASSERT(false); // TODO: not implemented
18805
18458
  } break;
18806
- case GGML_OP_FLASH_ATTN:
18807
18459
  case GGML_OP_FLASH_ATTN_EXT:
18808
18460
  {
18809
18461
  struct ggml_tensor * flash_grad = NULL;
@@ -18820,7 +18472,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18820
18472
  masked);
18821
18473
  }
18822
18474
 
18823
- struct ggml_tensor * src2 = tensor->src[2];
18824
18475
  const int64_t elem_q = ggml_nelements(src0);
18825
18476
  const int64_t elem_k = ggml_nelements(src1);
18826
18477
  const int64_t elem_v = ggml_nelements(src2);
@@ -18858,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18858
18509
  zero_table);
18859
18510
  }
18860
18511
  } break;
18861
- case GGML_OP_FLASH_FF:
18862
- {
18863
- GGML_ASSERT(false); // not supported
18864
- } break;
18865
18512
  case GGML_OP_FLASH_ATTN_BACK:
18866
18513
  {
18867
18514
  GGML_ASSERT(false); // not supported
@@ -19548,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19548
19195
  {
19549
19196
  n_tasks = n_threads;
19550
19197
  } break;
19551
- case GGML_OP_FLASH_ATTN:
19552
19198
  case GGML_OP_FLASH_ATTN_EXT:
19553
19199
  {
19554
19200
  n_tasks = n_threads;
19555
19201
  } break;
19556
- case GGML_OP_FLASH_FF:
19557
- {
19558
- n_tasks = n_threads;
19559
- } break;
19560
19202
  case GGML_OP_FLASH_ATTN_BACK:
19561
19203
  {
19562
19204
  n_tasks = n_threads;
@@ -19953,39 +19595,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19953
19595
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19954
19596
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19955
19597
  } break;
19956
- case GGML_OP_FLASH_ATTN:
19957
- {
19958
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19959
-
19960
- if (node->src[1]->type == GGML_TYPE_F32) {
19961
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19962
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19963
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19964
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19965
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19966
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19967
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19968
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19969
- }
19970
- } break;
19971
19598
  case GGML_OP_FLASH_ATTN_EXT:
19972
19599
  {
19973
19600
  const int64_t ne00 = node->src[0]->ne[0]; // D
19974
19601
 
19975
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19976
- } break;
19977
- case GGML_OP_FLASH_FF:
19978
- {
19979
- if (node->src[1]->type == GGML_TYPE_F32) {
19980
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19981
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19982
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19983
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19984
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19985
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19986
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19987
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19988
- }
19602
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19989
19603
  } break;
19990
19604
  case GGML_OP_FLASH_ATTN_BACK:
19991
19605
  {
@@ -21827,11 +21441,7 @@ size_t ggml_quantize_chunk(
21827
21441
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21828
21442
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21829
21443
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21830
- #if QK_K == 64
21831
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21832
- #else
21833
21444
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21834
- #endif
21835
21445
  case GGML_TYPE_F16:
21836
21446
  {
21837
21447
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -23108,6 +22718,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23108
22718
  #endif
23109
22719
  }
23110
22720
 
22721
+ int ggml_cpu_has_avx512_bf16(void) {
22722
+ #if defined(__AVX512BF16__)
22723
+ return 1;
22724
+ #else
22725
+ return 0;
22726
+ #endif
22727
+ }
22728
+
23111
22729
  int ggml_cpu_has_fma(void) {
23112
22730
  #if defined(__FMA__)
23113
22731
  return 1;