llama_cpp 0.15.2 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -60,6 +60,9 @@
60
60
 
61
61
  typedef volatile LONG atomic_int;
62
62
  typedef atomic_int atomic_bool;
63
+ typedef atomic_int atomic_flag;
64
+
65
+ #define ATOMIC_FLAG_INIT 0
63
66
 
64
67
  static void atomic_store(atomic_int * ptr, LONG val) {
65
68
  InterlockedExchange(ptr, val);
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
73
76
  static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
74
77
  return atomic_fetch_add(ptr, -(dec));
75
78
  }
79
+ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
80
+ return InterlockedExchange(ptr, 1);
81
+ }
82
+ static void atomic_flag_clear(atomic_flag * ptr) {
83
+ InterlockedExchange(ptr, 0);
84
+ }
76
85
 
77
86
  typedef HANDLE pthread_t;
78
87
 
@@ -406,10 +415,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
415
  int i = 0;
407
416
  #if defined(__AVX512BF16__)
408
417
  for (; i + 32 <= n; i += 32) {
409
- _mm512_storeu_ps(
410
- (__m512 *)(y + i),
411
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
- _mm512_loadu_ps(x + i)));
418
+ _mm512_storeu_si512(
419
+ (__m512i *)(y + i),
420
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
421
+ _mm512_loadu_ps(x + i))));
413
422
  }
414
423
  #endif
415
424
  for (; i < n; i++) {
@@ -871,22 +880,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
871
880
  },
872
881
  [GGML_TYPE_IQ4_XS] = {
873
882
  .type_name = "iq4_xs",
874
- #if QK_K == 64
875
- .blck_size = QK4_NL,
876
- #else
877
883
  .blck_size = QK_K,
878
- #endif
879
884
  .type_size = sizeof(block_iq4_xs),
880
885
  .is_quantized = true,
881
886
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
882
887
  .from_float = quantize_row_iq4_xs,
883
888
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
884
889
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
885
- #if QK_K == 64
886
- .vec_dot_type = GGML_TYPE_Q8_0,
887
- #else
888
890
  .vec_dot_type = GGML_TYPE_Q8_K,
889
- #endif
890
891
  .nrows = 1,
891
892
  },
892
893
  [GGML_TYPE_Q8_K] = {
@@ -1523,6 +1524,196 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1523
1524
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1524
1525
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1525
1526
 
1527
+ #elif defined(__loongarch_asx)
1528
+
1529
+ #define GGML_SIMD
1530
+
1531
+ // F32 LASX
1532
+ #define GGML_F32_STEP 32
1533
+ #define GGML_F32_EPR 8
1534
+
1535
+ #define GGML_F32x8 __m256
1536
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1537
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1538
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1539
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1540
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1541
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1542
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1543
+ #define GGML_F32x8_REDUCE(res, x) \
1544
+ do { \
1545
+ int offset = GGML_F32_ARR >> 1; \
1546
+ for (int i = 0; i < offset; ++i) { \
1547
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1548
+ } \
1549
+ offset >>= 1; \
1550
+ for (int i = 0; i < offset; ++i) { \
1551
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1552
+ } \
1553
+ offset >>= 1; \
1554
+ for (int i = 0; i < offset; ++i) { \
1555
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1556
+ } \
1557
+ float *tmp_p = (float *)&x[0]; \
1558
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1559
+ } while (0)
1560
+ // TODO: is this optimal ?
1561
+
1562
+ #define GGML_F32_VEC GGML_F32x8
1563
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1564
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1565
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1566
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1567
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1568
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1569
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1570
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1571
+
1572
+ // F16 LASX
1573
+
1574
+ #define GGML_F16_STEP 32
1575
+ #define GGML_F16_EPR 8
1576
+
1577
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1578
+
1579
+ #define GGML_F32Cx8 __m256
1580
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1581
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1582
+
1583
+ static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
1584
+ float tmp[8];
1585
+
1586
+ for (int i = 0; i < 8; i++) {
1587
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1588
+ }
1589
+
1590
+ return (__m256)__lasx_xvld(tmp, 0);
1591
+ }
1592
+ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1593
+ float arr[8];
1594
+
1595
+ __lasx_xvst(y, arr, 0);
1596
+
1597
+ for (int i = 0; i < 8; i++) {
1598
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1599
+ }
1600
+ }
1601
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1602
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1603
+
1604
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1605
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1606
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1607
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1608
+
1609
+ #define GGML_F16_VEC GGML_F32Cx8
1610
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1611
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1612
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1613
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1614
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1615
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1616
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1617
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1618
+
1619
+ #elif defined(__loongarch_sx)
1620
+
1621
+ #define GGML_SIMD
1622
+
1623
+ // F32 LSX
1624
+
1625
+ #define GGML_F32_STEP 32
1626
+ #define GGML_F32_EPR 4
1627
+
1628
+ #define GGML_F32x4 __m128
1629
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1630
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1631
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1632
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1633
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1634
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1635
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1636
+ #define GGML_F32x4_REDUCE(res, x) \
1637
+ { \
1638
+ int offset = GGML_F32_ARR >> 1; \
1639
+ for (int i = 0; i < offset; ++i) { \
1640
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1641
+ } \
1642
+ offset >>= 1; \
1643
+ for (int i = 0; i < offset; ++i) { \
1644
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1645
+ } \
1646
+ offset >>= 1; \
1647
+ for (int i = 0; i < offset; ++i) { \
1648
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1649
+ } \
1650
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1651
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1652
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1653
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1654
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1655
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1656
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1657
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1658
+ }
1659
+
1660
+ #define GGML_F32_VEC GGML_F32x4
1661
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1662
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1663
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1664
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1665
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1666
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1667
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1668
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1669
+
1670
+ // F16 LSX
1671
+
1672
+ #define GGML_F16_STEP 32
1673
+ #define GGML_F16_EPR 4
1674
+
1675
+ static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
1676
+ float tmp[4];
1677
+
1678
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1679
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1680
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1681
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1682
+
1683
+ return __lsx_vld(tmp, 0);
1684
+ }
1685
+
1686
+ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
1687
+ float arr[4];
1688
+
1689
+ __lsx_vst(y, arr, 0);
1690
+
1691
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1692
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1693
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1694
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1695
+ }
1696
+
1697
+ #define GGML_F32Cx4 __m128
1698
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1699
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1700
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1701
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1702
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1703
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1704
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1705
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1706
+
1707
+ #define GGML_F16_VEC GGML_F32Cx4
1708
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1709
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1710
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1711
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1712
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1713
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1714
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1715
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1716
+
1526
1717
  #endif
1527
1718
 
1528
1719
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1666,10 +1857,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1666
1857
  __m512 c1 = _mm512_setzero_ps();
1667
1858
  __m512 c2 = _mm512_setzero_ps();
1668
1859
  for (; i + 64 <= n; i += 64) {
1669
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1860
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1861
+ m512bh(_mm512_loadu_si512((y + i))));
1862
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1863
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1673
1864
  }
1674
1865
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
1866
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -2076,7 +2267,7 @@ inline static float ggml_silu_f32(float x) {
2076
2267
  return x/(1.0f + expf(-x));
2077
2268
  }
2078
2269
 
2079
- #if defined(__ARM_NEON)
2270
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2080
2271
 
2081
2272
  // adapted from arm limited optimized routine
2082
2273
  // the maximum error is 1.45358 plus 0.5 ulps
@@ -2125,32 +2316,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
2125
2316
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
2126
2317
  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2127
2318
  const __m512 n = _mm512_sub_ps(z, r);
2128
- const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2129
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2130
- const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2131
- const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2132
- const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2133
- const __m512 u = _mm512_mul_ps(b, b);
2134
- const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2135
- _mm512_set1_ps(0x1.573e2ep-5f)), u,
2136
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2137
- _mm512_set1_ps(0x1.fffdb6p-2f))),
2138
- u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2139
- if (_mm512_kortestz(c, c))
2140
- return _mm512_fmadd_ps(j, k, k);
2141
- const __m512i g = _mm512_and_si512(
2142
- _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2143
- _mm512_set1_epi32(0x82000000u));
2144
- const __m512 s1 =
2145
- _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2146
- const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2319
+ const __m512 b =
2320
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2321
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2147
2322
  const __mmask16 d =
2148
2323
  _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2149
- return _mm512_mask_blend_ps(
2150
- d, _mm512_mask_blend_ps(
2151
- c, _mm512_fmadd_ps(k, j, k),
2152
- _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2153
- _mm512_mul_ps(s1, s1));
2324
+ const __m512 u = _mm512_mul_ps(b, b);
2325
+ const __m512 j = _mm512_fmadd_ps(
2326
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2327
+ _mm512_set1_ps(0x1.573e2ep-5f)),
2328
+ u,
2329
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2330
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2331
+ u,
2332
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
2333
+ const __m512 res = _mm512_scalef_ps(j, n);
2334
+ if (_mm512_kortestz(d, d))
2335
+ return res;
2336
+ const __m512 zero = _mm512_setzero_ps();
2337
+ const __m512 alt = _mm512_mask_blend_ps(
2338
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
2339
+ return _mm512_mask_blend_ps(d, res, alt);
2154
2340
  }
2155
2341
 
2156
2342
  // computes silu x/(1+exp(-x)) in single precision vector
@@ -2288,7 +2474,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2288
2474
  for (; i + 3 < n; i += 4) {
2289
2475
  _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
2476
  }
2291
- #elif defined(__ARM_NEON)
2477
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2292
2478
  for (; i + 3 < n; i += 4) {
2293
2479
  vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
2480
  }
@@ -2335,7 +2521,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
2335
2521
  #endif
2336
2522
  sum += (ggml_float)_mm_cvtss_f32(val);
2337
2523
  }
2338
- #elif defined(__ARM_NEON)
2524
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2339
2525
  for (; i + 3 < n; i += 4) {
2340
2526
  float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
2527
  vdupq_n_f32(max)));
@@ -2489,9 +2675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2489
2675
  "ARGSORT",
2490
2676
  "LEAKY_RELU",
2491
2677
 
2492
- "FLASH_ATTN",
2493
2678
  "FLASH_ATTN_EXT",
2494
- "FLASH_FF",
2495
2679
  "FLASH_ATTN_BACK",
2496
2680
  "SSM_CONV",
2497
2681
  "SSM_SCAN",
@@ -2517,7 +2701,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2517
2701
  "CROSS_ENTROPY_LOSS_BACK",
2518
2702
  };
2519
2703
 
2520
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2704
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2521
2705
 
2522
2706
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2523
2707
  "none",
@@ -2579,9 +2763,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2579
2763
  "argsort(x)",
2580
2764
  "leaky_relu(x)",
2581
2765
 
2582
- "flash_attn(x)",
2583
2766
  "flash_attn_ext(x)",
2584
- "flash_ff(x)",
2585
2767
  "flash_attn_back(x)",
2586
2768
  "ssm_conv(x)",
2587
2769
  "ssm_scan(x)",
@@ -2607,7 +2789,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2607
2789
  "cross_entropy_loss_back(x,y)",
2608
2790
  };
2609
2791
 
2610
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2792
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2611
2793
 
2612
2794
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2613
2795
 
@@ -2706,24 +2888,20 @@ struct ggml_state {
2706
2888
 
2707
2889
  // global state
2708
2890
  static struct ggml_state g_state;
2709
- static atomic_int g_state_barrier = 0;
2891
+ static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
2710
2892
 
2711
2893
  // barrier via spin lock
2712
2894
  inline static void ggml_critical_section_start(void) {
2713
- int processing = atomic_fetch_add(&g_state_barrier, 1);
2714
-
2715
- while (processing > 0) {
2716
- // wait for other threads to finish
2717
- atomic_fetch_sub(&g_state_barrier, 1);
2718
- sched_yield(); // TODO: reconsider this
2719
- processing = atomic_fetch_add(&g_state_barrier, 1);
2895
+ while (atomic_flag_test_and_set(&g_state_critical)) {
2896
+ // spin
2897
+ sched_yield();
2720
2898
  }
2721
2899
  }
2722
2900
 
2723
2901
  // TODO: make this somehow automatically executed
2724
2902
  // some sort of "sentry" mechanism
2725
2903
  inline static void ggml_critical_section_end(void) {
2726
- atomic_fetch_sub(&g_state_barrier, 1);
2904
+ atomic_flag_clear(&g_state_critical);
2727
2905
  }
2728
2906
 
2729
2907
  #if defined(__gnu_linux__)
@@ -3039,7 +3217,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3039
3217
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3040
3218
  }
3041
3219
 
3042
- static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
3220
+ GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3221
+ return ggml_is_contiguous(tensor);
3222
+ }
3223
+
3224
+ GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
3043
3225
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3044
3226
 
3045
3227
  return
@@ -3048,6 +3230,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
3048
3230
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3049
3231
  }
3050
3232
 
3233
+ GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3234
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3235
+
3236
+ return
3237
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
3238
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3239
+ }
3240
+
3051
3241
  GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3052
3242
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3053
3243
 
@@ -4705,10 +4895,21 @@ struct ggml_tensor * ggml_repeat_back(
4705
4895
  // ggml_concat
4706
4896
 
4707
4897
  struct ggml_tensor * ggml_concat(
4708
- struct ggml_context* ctx,
4709
- struct ggml_tensor* a,
4710
- struct ggml_tensor* b) {
4711
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4898
+ struct ggml_context * ctx,
4899
+ struct ggml_tensor * a,
4900
+ struct ggml_tensor * b,
4901
+ int dim) {
4902
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4903
+
4904
+ int64_t ne[GGML_MAX_DIMS];
4905
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4906
+ if (d == dim) {
4907
+ ne[d] = a->ne[d] + b->ne[d];
4908
+ continue;
4909
+ }
4910
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4911
+ ne[d] = a->ne[d];
4912
+ }
4712
4913
 
4713
4914
  bool is_node = false;
4714
4915
 
@@ -4716,7 +4917,9 @@ struct ggml_tensor * ggml_concat(
4716
4917
  is_node = true;
4717
4918
  }
4718
4919
 
4719
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4920
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4921
+
4922
+ ggml_set_op_params_i32(result, 0, dim);
4720
4923
 
4721
4924
  result->op = GGML_OP_CONCAT;
4722
4925
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4836,6 +5039,7 @@ struct ggml_tensor * ggml_leaky_relu(
4836
5039
  }
4837
5040
 
4838
5041
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5042
+
4839
5043
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
4840
5044
 
4841
5045
  result->op = GGML_OP_LEAKY_RELU;
@@ -6042,6 +6246,7 @@ static struct ggml_tensor * ggml_rope_impl(
6042
6246
  struct ggml_context * ctx,
6043
6247
  struct ggml_tensor * a,
6044
6248
  struct ggml_tensor * b,
6249
+ struct ggml_tensor * c,
6045
6250
  int n_dims,
6046
6251
  int mode,
6047
6252
  int n_ctx,
@@ -6055,10 +6260,17 @@ static struct ggml_tensor * ggml_rope_impl(
6055
6260
  float xpos_base,
6056
6261
  bool xpos_down,
6057
6262
  bool inplace) {
6263
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6264
+
6058
6265
  GGML_ASSERT(ggml_is_vector(b));
6059
6266
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6060
6267
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6061
6268
 
6269
+ if (c) {
6270
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6271
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6272
+ }
6273
+
6062
6274
  bool is_node = false;
6063
6275
 
6064
6276
  if (a->grad) {
@@ -6082,6 +6294,7 @@ static struct ggml_tensor * ggml_rope_impl(
6082
6294
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6083
6295
  result->src[0] = a;
6084
6296
  result->src[1] = b;
6297
+ result->src[2] = c;
6085
6298
 
6086
6299
  return result;
6087
6300
  }
@@ -6094,7 +6307,7 @@ struct ggml_tensor * ggml_rope(
6094
6307
  int mode,
6095
6308
  int n_ctx) {
6096
6309
  return ggml_rope_impl(
6097
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6310
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6098
6311
  );
6099
6312
  }
6100
6313
 
@@ -6106,7 +6319,49 @@ struct ggml_tensor * ggml_rope_inplace(
6106
6319
  int mode,
6107
6320
  int n_ctx) {
6108
6321
  return ggml_rope_impl(
6109
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6322
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6323
+ );
6324
+ }
6325
+
6326
+ struct ggml_tensor * ggml_rope_ext(
6327
+ struct ggml_context * ctx,
6328
+ struct ggml_tensor * a,
6329
+ struct ggml_tensor * b,
6330
+ struct ggml_tensor * c,
6331
+ int n_dims,
6332
+ int mode,
6333
+ int n_ctx,
6334
+ int n_orig_ctx,
6335
+ float freq_base,
6336
+ float freq_scale,
6337
+ float ext_factor,
6338
+ float attn_factor,
6339
+ float beta_fast,
6340
+ float beta_slow) {
6341
+ return ggml_rope_impl(
6342
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6343
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6344
+ );
6345
+ }
6346
+
6347
+ struct ggml_tensor * ggml_rope_ext_inplace(
6348
+ struct ggml_context * ctx,
6349
+ struct ggml_tensor * a,
6350
+ struct ggml_tensor * b,
6351
+ struct ggml_tensor * c,
6352
+ int n_dims,
6353
+ int mode,
6354
+ int n_ctx,
6355
+ int n_orig_ctx,
6356
+ float freq_base,
6357
+ float freq_scale,
6358
+ float ext_factor,
6359
+ float attn_factor,
6360
+ float beta_fast,
6361
+ float beta_slow) {
6362
+ return ggml_rope_impl(
6363
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6364
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6110
6365
  );
6111
6366
  }
6112
6367
 
@@ -6125,7 +6380,7 @@ struct ggml_tensor * ggml_rope_custom(
6125
6380
  float beta_fast,
6126
6381
  float beta_slow) {
6127
6382
  return ggml_rope_impl(
6128
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6383
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6129
6384
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6130
6385
  );
6131
6386
  }
@@ -6145,7 +6400,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6145
6400
  float beta_fast,
6146
6401
  float beta_slow) {
6147
6402
  return ggml_rope_impl(
6148
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6403
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6149
6404
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6150
6405
  );
6151
6406
  }
@@ -6157,7 +6412,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
6157
6412
  int n_dims,
6158
6413
  float base,
6159
6414
  bool down) {
6160
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6415
+ return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6161
6416
  }
6162
6417
 
6163
6418
  // ggml_rope_back
@@ -6166,6 +6421,7 @@ struct ggml_tensor * ggml_rope_back(
6166
6421
  struct ggml_context * ctx,
6167
6422
  struct ggml_tensor * a,
6168
6423
  struct ggml_tensor * b,
6424
+ struct ggml_tensor * c,
6169
6425
  int n_dims,
6170
6426
  int mode,
6171
6427
  int n_ctx,
@@ -6181,6 +6437,7 @@ struct ggml_tensor * ggml_rope_back(
6181
6437
  GGML_ASSERT(ggml_is_vector(b));
6182
6438
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6183
6439
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6440
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6184
6441
 
6185
6442
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6186
6443
 
@@ -6724,38 +6981,6 @@ struct ggml_tensor * ggml_top_k(
6724
6981
  return result;
6725
6982
  }
6726
6983
 
6727
- // ggml_flash_attn
6728
-
6729
- struct ggml_tensor * ggml_flash_attn(
6730
- struct ggml_context * ctx,
6731
- struct ggml_tensor * q,
6732
- struct ggml_tensor * k,
6733
- struct ggml_tensor * v,
6734
- bool masked) {
6735
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6736
- // TODO: check if vT can be multiplied by (k*qT)
6737
-
6738
- bool is_node = false;
6739
-
6740
- if (q->grad || k->grad || v->grad) {
6741
- is_node = true;
6742
- }
6743
-
6744
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6745
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6746
-
6747
- int32_t t = masked ? 1 : 0;
6748
- ggml_set_op_params(result, &t, sizeof(t));
6749
-
6750
- result->op = GGML_OP_FLASH_ATTN;
6751
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6752
- result->src[0] = q;
6753
- result->src[1] = k;
6754
- result->src[2] = v;
6755
-
6756
- return result;
6757
- }
6758
-
6759
6984
  // ggml_flash_attn_ext
6760
6985
 
6761
6986
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6815,38 +7040,6 @@ void ggml_flash_attn_ext_set_prec(
6815
7040
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6816
7041
  }
6817
7042
 
6818
- // ggml_flash_ff
6819
-
6820
- struct ggml_tensor * ggml_flash_ff(
6821
- struct ggml_context * ctx,
6822
- struct ggml_tensor * a,
6823
- struct ggml_tensor * b0,
6824
- struct ggml_tensor * b1,
6825
- struct ggml_tensor * c0,
6826
- struct ggml_tensor * c1) {
6827
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6828
- // TODO: more checks
6829
-
6830
- bool is_node = false;
6831
-
6832
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6833
- is_node = true;
6834
- }
6835
-
6836
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6837
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6838
-
6839
- result->op = GGML_OP_FLASH_FF;
6840
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6841
- result->src[0] = a;
6842
- result->src[1] = b0;
6843
- result->src[2] = b1;
6844
- result->src[3] = c0;
6845
- result->src[4] = c1;
6846
-
6847
- return result;
6848
- }
6849
-
6850
7043
  // ggml_flash_attn_back
6851
7044
 
6852
7045
  struct ggml_tensor * ggml_flash_attn_back(
@@ -6856,6 +7049,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6856
7049
  struct ggml_tensor * v,
6857
7050
  struct ggml_tensor * d,
6858
7051
  bool masked) {
7052
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7053
+
6859
7054
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6860
7055
  // TODO: check if vT can be multiplied by (k*qT)
6861
7056
 
@@ -10809,26 +11004,29 @@ static void ggml_compute_forward_concat_f32(
10809
11004
  GGML_ASSERT(nb00 == sizeof(float));
10810
11005
  GGML_ASSERT(nb10 == sizeof(float));
10811
11006
 
11007
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
11008
+
11009
+ GGML_ASSERT(dim >= 0 && dim < 4);
11010
+
11011
+ int64_t o[4] = {0, 0, 0, 0};
11012
+ o[dim] = src0->ne[dim];
11013
+
11014
+ const float * x;
11015
+
11016
+ // TODO: smarter multi-theading
10812
11017
  for (int i3 = 0; i3 < ne3; i3++) {
10813
11018
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10814
- if (i2 < ne02) { // src0
10815
- for (int i1 = 0; i1 < ne1; i1++) {
10816
- for (int i0 = 0; i0 < ne0; i0++) {
10817
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10818
-
10819
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10820
- *y = *x;
10821
- }
10822
- }
10823
- } // src1
10824
- else {
10825
- for (int i1 = 0; i1 < ne1; i1++) {
10826
- for (int i0 = 0; i0 < ne0; i0++) {
10827
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10828
-
10829
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10830
- *y = *x;
11019
+ for (int i1 = 0; i1 < ne1; i1++) {
11020
+ for (int i0 = 0; i0 < ne0; i0++) {
11021
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
11022
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
11023
+ } else {
11024
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
10831
11025
  }
11026
+
11027
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
11028
+
11029
+ *y = *x;
10832
11030
  }
10833
11031
  }
10834
11032
  }
@@ -10836,8 +11034,8 @@ static void ggml_compute_forward_concat_f32(
10836
11034
  }
10837
11035
 
10838
11036
  static void ggml_compute_forward_concat(
10839
- const struct ggml_compute_params* params,
10840
- struct ggml_tensor* dst) {
11037
+ const struct ggml_compute_params * params,
11038
+ struct ggml_tensor * dst) {
10841
11039
 
10842
11040
  const struct ggml_tensor * src0 = dst->src[0];
10843
11041
 
@@ -11230,8 +11428,8 @@ static void ggml_compute_forward_gelu_f32(
11230
11428
 
11231
11429
  const struct ggml_tensor * src0 = dst->src[0];
11232
11430
 
11233
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11234
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11431
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11432
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11235
11433
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11236
11434
 
11237
11435
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11293,8 +11491,8 @@ static void ggml_compute_forward_gelu_quick_f32(
11293
11491
 
11294
11492
  const struct ggml_tensor * src0 = dst->src[0];
11295
11493
 
11296
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11297
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11494
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11495
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11298
11496
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11299
11497
 
11300
11498
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11356,8 +11554,8 @@ static void ggml_compute_forward_silu_f32(
11356
11554
 
11357
11555
  const struct ggml_tensor * src0 = dst->src[0];
11358
11556
 
11359
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11360
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11557
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11558
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11361
11559
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11362
11560
 
11363
11561
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11468,9 +11666,9 @@ static void ggml_compute_forward_silu_back_f32(
11468
11666
  const struct ggml_tensor * src0 = dst->src[0];
11469
11667
  const struct ggml_tensor * grad = dst->src[1];
11470
11668
 
11471
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
11472
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11473
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11669
+ GGML_ASSERT(ggml_is_contiguous_1(grad));
11670
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11671
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11474
11672
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11475
11673
  GGML_ASSERT(ggml_are_same_shape(src0, grad));
11476
11674
 
@@ -14115,6 +14313,7 @@ static void ggml_compute_forward_rope_f32(
14115
14313
 
14116
14314
  const struct ggml_tensor * src0 = dst->src[0];
14117
14315
  const struct ggml_tensor * src1 = dst->src[1];
14316
+ const struct ggml_tensor * src2 = dst->src[2];
14118
14317
 
14119
14318
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14120
14319
  return;
@@ -14167,13 +14366,24 @@ static void ggml_compute_forward_rope_f32(
14167
14366
  int ir = 0;
14168
14367
 
14169
14368
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14170
- const float inv_ndims = -1.f/n_dims;
14369
+
14171
14370
  float corr_dims[2];
14172
14371
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14173
14372
 
14174
14373
  const bool is_neox = mode & 2;
14175
14374
  const bool is_glm = mode & 4;
14176
14375
 
14376
+ const float * freq_factors = NULL;
14377
+ if (is_neox) {
14378
+ if (src2 != NULL) {
14379
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14380
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14381
+ freq_factors = (const float *) src2->data;
14382
+ }
14383
+ } else {
14384
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14385
+ }
14386
+
14177
14387
  // backward process uses inverse rotation by cos and sin.
14178
14388
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14179
14389
  // this essentially just switches the sign of sin.
@@ -14205,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
14205
14415
  const float cos_block_theta = cosf(block_theta);
14206
14416
  const float sin_block_theta = sinf(block_theta) * sin_sign;
14207
14417
 
14208
- theta_base *= theta_scale;
14418
+ theta_base *= theta_scale;
14209
14419
  block_theta *= theta_scale;
14210
14420
 
14211
14421
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14240,28 +14450,22 @@ static void ggml_compute_forward_rope_f32(
14240
14450
  dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
14241
14451
  }
14242
14452
  } else {
14243
- // TODO: this might be wrong for ne0 != n_dims - need double check
14244
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14245
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14246
- theta_base *= freq_scale;
14453
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14247
14454
  for (int64_t ic = 0; ic < ne0; ic += 2) {
14248
14455
  if (ic < n_dims) {
14249
- const int64_t ib = 0;
14456
+ const int64_t i0 = ic/2;
14250
14457
 
14251
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14252
- float cur_rot = inv_ndims * ic - ib;
14458
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14253
14459
 
14254
14460
  float cos_theta, sin_theta;
14255
14461
  rope_yarn(
14256
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14462
+ theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
14257
14463
  &cos_theta, &sin_theta
14258
14464
  );
14259
- sin_theta *= sin_sign;
14260
14465
 
14466
+ sin_theta *= sin_sign;
14261
14467
  theta_base *= theta_scale;
14262
14468
 
14263
- const int64_t i0 = ib*n_dims + ic/2;
14264
-
14265
14469
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14266
14470
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14267
14471
 
@@ -14286,6 +14490,7 @@ static void ggml_compute_forward_rope_f32(
14286
14490
  }
14287
14491
  }
14288
14492
 
14493
+ // TODO: deduplicate f16/f32 code
14289
14494
  static void ggml_compute_forward_rope_f16(
14290
14495
  const struct ggml_compute_params * params,
14291
14496
  struct ggml_tensor * dst,
@@ -14293,6 +14498,7 @@ static void ggml_compute_forward_rope_f16(
14293
14498
 
14294
14499
  const struct ggml_tensor * src0 = dst->src[0];
14295
14500
  const struct ggml_tensor * src1 = dst->src[1];
14501
+ const struct ggml_tensor * src2 = dst->src[2];
14296
14502
 
14297
14503
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14298
14504
  return;
@@ -14338,13 +14544,24 @@ static void ggml_compute_forward_rope_f16(
14338
14544
  int ir = 0;
14339
14545
 
14340
14546
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14341
- const float inv_ndims = -1.f/n_dims;
14547
+
14342
14548
  float corr_dims[2];
14343
14549
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14344
14550
 
14345
14551
  const bool is_neox = mode & 2;
14346
14552
  const bool is_glm = mode & 4;
14347
14553
 
14554
+ const float * freq_factors = NULL;
14555
+ if (is_neox) {
14556
+ if (src2 != NULL) {
14557
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14558
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14559
+ freq_factors = (const float *) src2->data;
14560
+ }
14561
+ } else {
14562
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14563
+ }
14564
+
14348
14565
  // backward process uses inverse rotation by cos and sin.
14349
14566
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14350
14567
  // this essentially just switches the sign of sin.
@@ -14376,7 +14593,7 @@ static void ggml_compute_forward_rope_f16(
14376
14593
  const float cos_block_theta = cosf(block_theta);
14377
14594
  const float sin_block_theta = sinf(block_theta) * sin_sign;
14378
14595
 
14379
- theta_base *= theta_scale;
14596
+ theta_base *= theta_scale;
14380
14597
  block_theta *= theta_scale;
14381
14598
 
14382
14599
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14407,28 +14624,22 @@ static void ggml_compute_forward_rope_f16(
14407
14624
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14408
14625
  }
14409
14626
  } else {
14410
- // TODO: this might be wrong for ne0 != n_dims - need double check
14411
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14412
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14413
- theta_base *= freq_scale;
14627
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14414
14628
  for (int64_t ic = 0; ic < ne0; ic += 2) {
14415
14629
  if (ic < n_dims) {
14416
- const int64_t ib = 0;
14630
+ const int64_t i0 = ic/2;
14417
14631
 
14418
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14419
- float cur_rot = inv_ndims * ic - ib;
14632
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14420
14633
 
14421
14634
  float cos_theta, sin_theta;
14422
14635
  rope_yarn(
14423
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14636
+ theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
14424
14637
  &cos_theta, &sin_theta
14425
14638
  );
14426
- sin_theta *= sin_sign;
14427
14639
 
14640
+ sin_theta *= sin_sign;
14428
14641
  theta_base *= theta_scale;
14429
14642
 
14430
- const int64_t i0 = ib*n_dims + ic/2;
14431
-
14432
14643
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14433
14644
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14434
14645
 
@@ -15458,400 +15669,6 @@ static void ggml_compute_forward_argsort(
15458
15669
  }
15459
15670
  }
15460
15671
 
15461
- // ggml_compute_forward_flash_attn
15462
-
15463
- static void ggml_compute_forward_flash_attn_f32(
15464
- const struct ggml_compute_params * params,
15465
- const bool masked,
15466
- struct ggml_tensor * dst) {
15467
-
15468
- const struct ggml_tensor * q = dst->src[0];
15469
- const struct ggml_tensor * k = dst->src[1];
15470
- const struct ggml_tensor * v = dst->src[2];
15471
-
15472
- int64_t t0 = ggml_perf_time_us();
15473
- UNUSED(t0);
15474
-
15475
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15476
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15477
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15478
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15479
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15480
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15481
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15482
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15483
-
15484
- const int ith = params->ith;
15485
- const int nth = params->nth;
15486
-
15487
- const int64_t D = neq0;
15488
- const int64_t N = neq1;
15489
- const int64_t P = nek1 - N;
15490
- const int64_t M = P + N;
15491
-
15492
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15493
-
15494
- GGML_ASSERT(ne0 == D);
15495
- GGML_ASSERT(ne1 == N);
15496
- GGML_ASSERT(P >= 0);
15497
-
15498
- GGML_ASSERT(nbq0 == sizeof(float));
15499
- GGML_ASSERT(nbk0 == sizeof(float));
15500
- GGML_ASSERT(nbv0 == sizeof(float));
15501
-
15502
- GGML_ASSERT(neq0 == D);
15503
- GGML_ASSERT(nek0 == D);
15504
- GGML_ASSERT(nev1 == D);
15505
-
15506
- GGML_ASSERT(neq1 == N);
15507
- GGML_ASSERT(nek1 == N + P);
15508
- GGML_ASSERT(nev1 == D);
15509
-
15510
- // dst cannot be transposed or permuted
15511
- GGML_ASSERT(nb0 == sizeof(float));
15512
- GGML_ASSERT(nb0 <= nb1);
15513
- GGML_ASSERT(nb1 <= nb2);
15514
- GGML_ASSERT(nb2 <= nb3);
15515
-
15516
- if (params->type == GGML_TASK_TYPE_INIT) {
15517
- return;
15518
- }
15519
-
15520
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15521
- return;
15522
- }
15523
-
15524
- // parallelize by q rows using ggml_vec_dot_f32
15525
-
15526
- // total rows in q
15527
- const int nr = neq1*neq2*neq3;
15528
-
15529
- // rows per thread
15530
- const int dr = (nr + nth - 1)/nth;
15531
-
15532
- // row range for this thread
15533
- const int ir0 = dr*ith;
15534
- const int ir1 = MIN(ir0 + dr, nr);
15535
-
15536
- const float scale = 1.0f/sqrtf(D);
15537
-
15538
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15539
-
15540
- for (int ir = ir0; ir < ir1; ++ir) {
15541
- // q indices
15542
- const int iq3 = ir/(neq2*neq1);
15543
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15544
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15545
-
15546
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15547
-
15548
- for (int i = M; i < Mup; ++i) {
15549
- S[i] = -INFINITY;
15550
- }
15551
-
15552
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15553
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15554
- // k indices
15555
- const int ik3 = iq3;
15556
- const int ik2 = iq2 % nek2;
15557
- const int ik1 = ic;
15558
-
15559
- // S indices
15560
- const int i1 = ik1;
15561
-
15562
- ggml_vec_dot_f32(neq0,
15563
- S + i1, 0,
15564
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15565
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15566
- }
15567
-
15568
- // scale
15569
- ggml_vec_scale_f32(masked_begin, S, scale);
15570
-
15571
- for (int64_t i = masked_begin; i < M; i++) {
15572
- S[i] = -INFINITY;
15573
- }
15574
-
15575
- // softmax
15576
- // exclude known -INF S[..] values from max and loop
15577
- // dont forget to set their SW values to zero
15578
- {
15579
- float max = -INFINITY;
15580
- ggml_vec_max_f32(masked_begin, &max, S);
15581
-
15582
- ggml_float sum = 0.0;
15583
- {
15584
- #ifdef GGML_SOFT_MAX_ACCELERATE
15585
- max = -max;
15586
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15587
- vvexpf(S, S, &Mup);
15588
- ggml_vec_sum_f32(Mup, &sum, S);
15589
- #else
15590
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15591
- #endif
15592
- }
15593
-
15594
- assert(sum > 0.0);
15595
-
15596
- sum = 1.0/sum;
15597
- ggml_vec_scale_f32(masked_begin, S, sum);
15598
-
15599
- #ifndef NDEBUG
15600
- for (int i = 0; i < masked_begin; ++i) {
15601
- assert(!isnan(S[i]));
15602
- assert(!isinf(S[i]));
15603
- }
15604
- #endif
15605
- }
15606
-
15607
- for (int64_t ic = 0; ic < nev1; ++ic) {
15608
- // dst indices
15609
- const int i1 = iq1;
15610
- const int i2 = iq2;
15611
- const int i3 = iq3;
15612
-
15613
- // v indices
15614
- const int iv2 = iq2 % nev2;
15615
- const int iv3 = iq3;
15616
-
15617
- ggml_vec_dot_f32(masked_begin,
15618
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15619
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15620
- S, 0, 1);
15621
- }
15622
- }
15623
- }
15624
-
15625
- static void ggml_compute_forward_flash_attn_f16(
15626
- const struct ggml_compute_params * params,
15627
- const bool masked,
15628
- struct ggml_tensor * dst) {
15629
-
15630
- const struct ggml_tensor * q = dst->src[0];
15631
- const struct ggml_tensor * k = dst->src[1];
15632
- const struct ggml_tensor * v = dst->src[2];
15633
-
15634
- int64_t t0 = ggml_perf_time_us();
15635
- UNUSED(t0);
15636
-
15637
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15638
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15639
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15640
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15641
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15642
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15643
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15644
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15645
-
15646
- const int ith = params->ith;
15647
- const int nth = params->nth;
15648
-
15649
- const int64_t D = neq0;
15650
- const int64_t N = neq1;
15651
- const int64_t P = nek1 - N;
15652
- const int64_t M = P + N;
15653
-
15654
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15655
-
15656
- GGML_ASSERT(ne0 == D);
15657
- GGML_ASSERT(ne1 == N);
15658
- GGML_ASSERT(P >= 0);
15659
-
15660
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15661
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15662
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15663
-
15664
- GGML_ASSERT(neq0 == D);
15665
- GGML_ASSERT(nek0 == D);
15666
- GGML_ASSERT(nev1 == D);
15667
-
15668
- GGML_ASSERT(neq1 == N);
15669
- GGML_ASSERT(nek1 == N + P);
15670
- GGML_ASSERT(nev1 == D);
15671
-
15672
- // dst cannot be transposed or permuted
15673
- GGML_ASSERT(nb0 == sizeof(float));
15674
- GGML_ASSERT(nb0 <= nb1);
15675
- GGML_ASSERT(nb1 <= nb2);
15676
- GGML_ASSERT(nb2 <= nb3);
15677
-
15678
- if (params->type == GGML_TASK_TYPE_INIT) {
15679
- return;
15680
- }
15681
-
15682
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15683
- return;
15684
- }
15685
-
15686
- // parallelize by q rows using ggml_vec_dot_f32
15687
-
15688
- // total rows in q
15689
- const int nr = neq1*neq2*neq3;
15690
-
15691
- // rows per thread
15692
- const int dr = (nr + nth - 1)/nth;
15693
-
15694
- // row range for this thread
15695
- const int ir0 = dr*ith;
15696
- const int ir1 = MIN(ir0 + dr, nr);
15697
-
15698
- const float scale = 1.0f/sqrtf(D);
15699
-
15700
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15701
-
15702
- for (int ir = ir0; ir < ir1; ++ir) {
15703
- // q indices
15704
- const int iq3 = ir/(neq2*neq1);
15705
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15706
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15707
-
15708
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15709
-
15710
- for (int i = M; i < Mup; ++i) {
15711
- S[i] = -INFINITY;
15712
- }
15713
-
15714
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15715
- for (int64_t ic = 0; ic < nek1; ++ic) {
15716
- // k indices
15717
- const int ik3 = iq3;
15718
- const int ik2 = iq2 % nek2;
15719
- const int ik1 = ic;
15720
-
15721
- // S indices
15722
- const int i1 = ik1;
15723
-
15724
- ggml_vec_dot_f16(neq0,
15725
- S + i1, 0,
15726
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15727
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15728
- }
15729
- } else {
15730
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15731
- // k indices
15732
- const int ik3 = iq3;
15733
- const int ik2 = iq2 % nek2;
15734
- const int ik1 = ic;
15735
-
15736
- // S indices
15737
- const int i1 = ik1;
15738
-
15739
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15740
- S + i1,
15741
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15742
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15743
- }
15744
- }
15745
-
15746
- // scale
15747
- ggml_vec_scale_f32(nek1, S, scale);
15748
-
15749
- if (masked) {
15750
- for (int64_t i = P; i < M; i++) {
15751
- if (i > P + iq1) {
15752
- S[i] = -INFINITY;
15753
- }
15754
- }
15755
- }
15756
-
15757
- // softmax
15758
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15759
- // dont forget to set their S values to zero
15760
- {
15761
- float max = -INFINITY;
15762
- ggml_vec_max_f32(M, &max, S);
15763
-
15764
- ggml_float sum = 0.0;
15765
- {
15766
- #ifdef GGML_SOFT_MAX_ACCELERATE
15767
- max = -max;
15768
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15769
- vvexpf(S, S, &Mup);
15770
- ggml_vec_sum_f32(Mup, &sum, S);
15771
- #else
15772
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15773
- #endif
15774
- }
15775
-
15776
- assert(sum > 0.0);
15777
-
15778
- sum = 1.0/sum;
15779
- ggml_vec_scale_f32(M, S, sum);
15780
-
15781
- #ifndef NDEBUG
15782
- for (int i = 0; i < M; ++i) {
15783
- assert(!isnan(S[i]));
15784
- assert(!isinf(S[i]));
15785
- }
15786
- #endif
15787
- }
15788
-
15789
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15790
-
15791
- for (int64_t i = 0; i < M; i++) {
15792
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15793
- }
15794
-
15795
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15796
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15797
- for (int64_t ic = 0; ic < nev1; ++ic) {
15798
- // dst indices
15799
- const int i1 = iq1;
15800
- const int i2 = iq2;
15801
- const int i3 = iq3;
15802
-
15803
- // v indices
15804
- const int iv2 = iq2 % nev2;
15805
- const int iv3 = iq3;
15806
-
15807
- ggml_vec_dot_f16(nev0,
15808
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15809
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15810
- S16, 0, 1);
15811
- }
15812
- } else {
15813
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15814
- // dst indices
15815
- const int i1 = iq1;
15816
- const int i2 = iq2;
15817
- const int i3 = iq3;
15818
-
15819
- // v indices
15820
- const int iv2 = iq2 % nev2;
15821
- const int iv3 = iq3;
15822
-
15823
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15824
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15825
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15826
- S16);
15827
- }
15828
- }
15829
- }
15830
- }
15831
-
15832
- static void ggml_compute_forward_flash_attn(
15833
- const struct ggml_compute_params * params,
15834
- const bool masked,
15835
- struct ggml_tensor * dst) {
15836
-
15837
- const struct ggml_tensor * q = dst->src[0];
15838
-
15839
- switch (q->type) {
15840
- case GGML_TYPE_F16:
15841
- {
15842
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15843
- } break;
15844
- case GGML_TYPE_F32:
15845
- {
15846
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15847
- } break;
15848
- default:
15849
- {
15850
- GGML_ASSERT(false);
15851
- } break;
15852
- }
15853
- }
15854
-
15855
15672
  // ggml_compute_forward_flash_attn_ext
15856
15673
 
15857
15674
  static void ggml_compute_forward_flash_attn_ext_f16(
@@ -15882,9 +15699,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15882
15699
  GGML_ASSERT(ne0 == D);
15883
15700
  GGML_ASSERT(ne2 == N);
15884
15701
 
15885
- GGML_ASSERT(nbq0 == sizeof(float));
15886
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15702
+ // input tensor rows must be contiguous
15703
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15704
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15705
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15888
15706
 
15889
15707
  GGML_ASSERT(neq0 == D);
15890
15708
  GGML_ASSERT(nek0 == D);
@@ -15938,6 +15756,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15938
15756
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
15757
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15940
15758
 
15759
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15760
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15761
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15762
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15763
+
15941
15764
  // loop over n_batch and n_head
15942
15765
  for (int ir = ir0; ir < ir1; ++ir) {
15943
15766
  // q indices
@@ -15945,17 +15768,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15945
15768
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15946
15769
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15947
15770
 
15948
- const uint32_t h = iq2; // head
15771
+ const uint32_t h = iq2; // head index
15949
15772
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
15773
 
15951
- float S = 0.0f;
15952
- float M = -INFINITY;
15774
+ float S = 0.0f; // sum
15775
+ float M = -INFINITY; // maximum KQ value
15953
15776
 
15954
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15777
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15778
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15779
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15780
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15957
15781
 
15958
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15782
+ if (v->type == GGML_TYPE_F16) {
15783
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15784
+ } else {
15785
+ memset(VKQ32, 0, D*sizeof(float));
15786
+ }
15959
15787
 
15960
15788
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15961
15789
 
@@ -15967,6 +15795,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15967
15795
  const int iv3 = iq3 / rv3;
15968
15796
  const int iv2 = iq2 / rv2;
15969
15797
 
15798
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15799
+ q_to_vec_dot(pq, Q_q, D);
15800
+
15970
15801
  // online softmax / attention
15971
15802
  // loop over n_kv and n_head_kv
15972
15803
  // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +15807,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15976
15807
  continue;
15977
15808
  }
15978
15809
 
15979
- float s;
15810
+ float s; // KQ value
15980
15811
 
15981
- // convert Q to F16 in V32
15982
- {
15983
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15812
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15813
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15984
15814
 
15985
- for (int64_t d = 0; d < D; ++d) {
15986
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987
- }
15988
- }
15815
+ s = s*scale + mv; // scale KQ value and apply mask
15989
15816
 
15990
- ggml_vec_dot_f16(D,
15991
- &s, 0,
15992
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993
- Q16, 0, 1);
15817
+ const float Mold = M;
15994
15818
 
15995
- s = s*scale + mv;
15819
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15820
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15996
15821
 
15997
- const float Mold = M;
15822
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15998
15823
 
15999
- float ms = 1.0f;
16000
- float vs = 1.0f;
15824
+ if (v->type== GGML_TYPE_F16) {
15825
+ if (s > M) {
15826
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15827
+ M = s;
15828
+ ms = expf(Mold - M);
16001
15829
 
16002
- if (s > M) {
16003
- M = s;
16004
- ms = expf(Mold - M);
15830
+ // V = V*expf(Mold - M)
15831
+ ggml_vec_scale_f16(D, VKQ16, ms);
15832
+ } else {
15833
+ // no new maximum, ms == 1.0f, vs != 1.0f
15834
+ vs = expf(s - M);
15835
+ }
16005
15836
 
16006
- // V = V*expf(Mold - M)
16007
- ggml_vec_scale_f16(D, V16, ms);
15837
+ // V += v*expf(s - M)
15838
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
16008
15839
  } else {
16009
- vs = expf(s - M);
16010
- }
15840
+ if (s > M) {
15841
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15842
+ M = s;
15843
+ ms = expf(Mold - M);
16011
15844
 
16012
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15845
+ // V = V*expf(Mold - M)
15846
+ ggml_vec_scale_f32(D, VKQ32, ms);
15847
+ } else {
15848
+ // no new maximum, ms == 1.0f, vs != 1.0f
15849
+ vs = expf(s - M);
15850
+ }
16013
15851
 
16014
- // V += v*expf(s - M)
16015
- ggml_vec_mad_f16(D, V16, v16, vs);
15852
+ v_to_float(v_data, V32, D);
15853
+
15854
+ // V += v*expf(s - M)
15855
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15856
+ }
16016
15857
 
16017
- S = S*ms + vs;
15858
+ S = S*ms + vs; // scale and increment sum with partial sum
16018
15859
  }
16019
15860
 
16020
- // V /= S
16021
- for (int64_t d = 0; d < D; ++d) {
16022
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15861
+ if (v->type == GGML_TYPE_F16) {
15862
+ for (int64_t d = 0; d < D; ++d) {
15863
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15864
+ }
16023
15865
  }
16024
15866
 
15867
+ // V /= S
15868
+ const float S_inv = 1.0f/S;
15869
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15870
+
16025
15871
  // dst indices
16026
15872
  const int i1 = iq1;
16027
15873
  const int i2 = iq2;
@@ -16031,7 +15877,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
16031
15877
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16032
15878
 
16033
15879
  // permute(0, 2, 1, 3)
16034
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15880
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
16035
15881
  }
16036
15882
  }
16037
15883
 
@@ -16056,165 +15902,6 @@ static void ggml_compute_forward_flash_attn_ext(
16056
15902
  }
16057
15903
  }
16058
15904
 
16059
- // ggml_compute_forward_flash_ff
16060
-
16061
- static void ggml_compute_forward_flash_ff_f16(
16062
- const struct ggml_compute_params * params,
16063
- struct ggml_tensor * dst) {
16064
-
16065
- const struct ggml_tensor * a = dst->src[0]; // F16
16066
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
16067
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
16068
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
16069
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
16070
-
16071
- int64_t t0 = ggml_perf_time_us();
16072
- UNUSED(t0);
16073
-
16074
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
16075
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
16076
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
16077
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
16078
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
16079
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
16080
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
16081
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
16082
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
16083
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
16084
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
16085
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
16086
-
16087
- const int ith = params->ith;
16088
- const int nth = params->nth;
16089
-
16090
- const int64_t D = nea0;
16091
- //const int64_t N = nea1;
16092
- const int64_t M = neb01;
16093
-
16094
- GGML_ASSERT(ne0 == nea0);
16095
- GGML_ASSERT(ne1 == nea1);
16096
- GGML_ASSERT(ne2 == nea2);
16097
-
16098
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
16099
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
16100
- GGML_ASSERT(nbb10 == sizeof(float));
16101
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
16102
- GGML_ASSERT(nbc10 == sizeof(float));
16103
-
16104
- GGML_ASSERT(neb00 == D);
16105
- GGML_ASSERT(neb01 == M);
16106
- GGML_ASSERT(neb10 == M);
16107
- GGML_ASSERT(neb11 == 1);
16108
-
16109
- GGML_ASSERT(nec00 == M);
16110
- GGML_ASSERT(nec01 == D);
16111
- GGML_ASSERT(nec10 == D);
16112
- GGML_ASSERT(nec11 == 1);
16113
-
16114
- // dst cannot be transposed or permuted
16115
- GGML_ASSERT(nb0 == sizeof(float));
16116
- GGML_ASSERT(nb0 <= nb1);
16117
- GGML_ASSERT(nb1 <= nb2);
16118
- GGML_ASSERT(nb2 <= nb3);
16119
-
16120
- if (params->type == GGML_TASK_TYPE_INIT) {
16121
- return;
16122
- }
16123
-
16124
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
16125
- return;
16126
- }
16127
-
16128
- // parallelize by a rows using ggml_vec_dot_f32
16129
-
16130
- // total rows in a
16131
- const int nr = nea1*nea2*nea3;
16132
-
16133
- // rows per thread
16134
- const int dr = (nr + nth - 1)/nth;
16135
-
16136
- // row range for this thread
16137
- const int ir0 = dr*ith;
16138
- const int ir1 = MIN(ir0 + dr, nr);
16139
-
16140
- for (int ir = ir0; ir < ir1; ++ir) {
16141
- // a indices
16142
- const int ia3 = ir/(nea2*nea1);
16143
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
16144
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
16145
-
16146
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
16147
-
16148
- for (int64_t ic = 0; ic < neb01; ++ic) {
16149
- // b0 indices
16150
- const int ib03 = ia3;
16151
- const int ib02 = ia2;
16152
- const int ib01 = ic;
16153
-
16154
- // S indices
16155
- const int i1 = ib01;
16156
-
16157
- ggml_vec_dot_f16(nea0,
16158
- S + i1, 0,
16159
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
16160
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
16161
- }
16162
-
16163
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
16164
- //ggml_vec_gelu_f32(neb01, S, S);
16165
-
16166
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
16167
-
16168
- for (int64_t i = 0; i < M; i++) {
16169
- S16[i] = GGML_FP32_TO_FP16(S[i]);
16170
- }
16171
-
16172
- ggml_vec_gelu_f16(neb01, S16, S16);
16173
-
16174
- {
16175
- // dst indices
16176
- const int i1 = ia1;
16177
- const int i2 = ia2;
16178
- const int i3 = ia3;
16179
-
16180
- for (int64_t ic = 0; ic < nec01; ++ic) {
16181
-
16182
- ggml_vec_dot_f16(neb01,
16183
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16184
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16185
- S16, 0, 1);
16186
- }
16187
-
16188
- ggml_vec_add_f32(nec01,
16189
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16190
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16191
- (float *) c1->data);
16192
- }
16193
- }
16194
- }
16195
-
16196
- static void ggml_compute_forward_flash_ff(
16197
- const struct ggml_compute_params * params,
16198
- struct ggml_tensor * dst) {
16199
-
16200
- const struct ggml_tensor * b0 = dst->src[1];
16201
-
16202
- switch (b0->type) {
16203
- case GGML_TYPE_F16:
16204
- {
16205
- ggml_compute_forward_flash_ff_f16(params, dst);
16206
- } break;
16207
- case GGML_TYPE_F32:
16208
- {
16209
- GGML_ASSERT(false); // TODO
16210
- } break;
16211
- default:
16212
- {
16213
- GGML_ASSERT(false);
16214
- } break;
16215
- }
16216
- }
16217
-
16218
15905
  // ggml_compute_forward_flash_attn_back
16219
15906
 
16220
15907
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -17785,21 +17472,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17785
17472
  {
17786
17473
  ggml_compute_forward_leaky_relu(params, tensor);
17787
17474
  } break;
17788
- case GGML_OP_FLASH_ATTN:
17789
- {
17790
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17791
- GGML_ASSERT(t == 0 || t == 1);
17792
- const bool masked = t != 0;
17793
- ggml_compute_forward_flash_attn(params, masked, tensor);
17794
- } break;
17795
17475
  case GGML_OP_FLASH_ATTN_EXT:
17796
17476
  {
17797
17477
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17798
17478
  } break;
17799
- case GGML_OP_FLASH_FF:
17800
- {
17801
- ggml_compute_forward_flash_ff(params, tensor);
17802
- } break;
17803
17479
  case GGML_OP_FLASH_ATTN_BACK:
17804
17480
  {
17805
17481
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -18169,6 +17845,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18169
17845
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18170
17846
  struct ggml_tensor * src0 = tensor->src[0];
18171
17847
  struct ggml_tensor * src1 = tensor->src[1];
17848
+ struct ggml_tensor * src2 = tensor->src[2];
18172
17849
 
18173
17850
  switch (tensor->op) {
18174
17851
  case GGML_OP_DUP:
@@ -18700,6 +18377,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18700
18377
  ggml_rope_back(ctx,
18701
18378
  tensor->grad,
18702
18379
  src1,
18380
+ src2,
18703
18381
  n_dims,
18704
18382
  mode,
18705
18383
  n_ctx,
@@ -18739,6 +18417,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18739
18417
  ggml_rope_impl(ctx,
18740
18418
  tensor->grad,
18741
18419
  src1,
18420
+ src2,
18742
18421
  n_dims,
18743
18422
  mode,
18744
18423
  n_ctx,
@@ -18803,7 +18482,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18803
18482
  {
18804
18483
  GGML_ASSERT(false); // TODO: not implemented
18805
18484
  } break;
18806
- case GGML_OP_FLASH_ATTN:
18807
18485
  case GGML_OP_FLASH_ATTN_EXT:
18808
18486
  {
18809
18487
  struct ggml_tensor * flash_grad = NULL;
@@ -18820,7 +18498,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18820
18498
  masked);
18821
18499
  }
18822
18500
 
18823
- struct ggml_tensor * src2 = tensor->src[2];
18824
18501
  const int64_t elem_q = ggml_nelements(src0);
18825
18502
  const int64_t elem_k = ggml_nelements(src1);
18826
18503
  const int64_t elem_v = ggml_nelements(src2);
@@ -18858,10 +18535,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18858
18535
  zero_table);
18859
18536
  }
18860
18537
  } break;
18861
- case GGML_OP_FLASH_FF:
18862
- {
18863
- GGML_ASSERT(false); // not supported
18864
- } break;
18865
18538
  case GGML_OP_FLASH_ATTN_BACK:
18866
18539
  {
18867
18540
  GGML_ASSERT(false); // not supported
@@ -19548,15 +19221,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19548
19221
  {
19549
19222
  n_tasks = n_threads;
19550
19223
  } break;
19551
- case GGML_OP_FLASH_ATTN:
19552
19224
  case GGML_OP_FLASH_ATTN_EXT:
19553
19225
  {
19554
19226
  n_tasks = n_threads;
19555
19227
  } break;
19556
- case GGML_OP_FLASH_FF:
19557
- {
19558
- n_tasks = n_threads;
19559
- } break;
19560
19228
  case GGML_OP_FLASH_ATTN_BACK:
19561
19229
  {
19562
19230
  n_tasks = n_threads;
@@ -19953,39 +19621,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19953
19621
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19954
19622
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19955
19623
  } break;
19956
- case GGML_OP_FLASH_ATTN:
19957
- {
19958
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19959
-
19960
- if (node->src[1]->type == GGML_TYPE_F32) {
19961
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19962
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19963
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19964
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19965
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19966
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19967
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19968
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19969
- }
19970
- } break;
19971
19624
  case GGML_OP_FLASH_ATTN_EXT:
19972
19625
  {
19973
19626
  const int64_t ne00 = node->src[0]->ne[0]; // D
19974
19627
 
19975
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19976
- } break;
19977
- case GGML_OP_FLASH_FF:
19978
- {
19979
- if (node->src[1]->type == GGML_TYPE_F32) {
19980
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19981
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19982
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19983
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19984
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19985
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19986
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19987
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19988
- }
19628
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19989
19629
  } break;
19990
19630
  case GGML_OP_FLASH_ATTN_BACK:
19991
19631
  {
@@ -21827,11 +21467,7 @@ size_t ggml_quantize_chunk(
21827
21467
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21828
21468
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21829
21469
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21830
- #if QK_K == 64
21831
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21832
- #else
21833
21470
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21834
- #endif
21835
21471
  case GGML_TYPE_F16:
21836
21472
  {
21837
21473
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -23108,6 +22744,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23108
22744
  #endif
23109
22745
  }
23110
22746
 
22747
+ int ggml_cpu_has_avx512_bf16(void) {
22748
+ #if defined(__AVX512BF16__)
22749
+ return 1;
22750
+ #else
22751
+ return 0;
22752
+ #endif
22753
+ }
22754
+
23111
22755
  int ggml_cpu_has_fma(void) {
23112
22756
  #if defined(__FMA__)
23113
22757
  return 1;
@@ -23124,6 +22768,16 @@ int ggml_cpu_has_neon(void) {
23124
22768
  #endif
23125
22769
  }
23126
22770
 
22771
+ int ggml_cpu_has_sve(void) {
22772
+ #if defined(__ARM_FEATURE_SVE)
22773
+ // TODO: Currently, SVE 256 bit is only supported.
22774
+ GGML_ASSERT(svcntb() == QK8_0);
22775
+ return 1;
22776
+ #else
22777
+ return 0;
22778
+ #endif
22779
+ }
22780
+
23127
22781
  int ggml_cpu_has_arm_fma(void) {
23128
22782
  #if defined(__ARM_FEATURE_FMA)
23129
22783
  return 1;
@@ -23212,6 +22866,14 @@ int ggml_cpu_has_sycl(void) {
23212
22866
  #endif
23213
22867
  }
23214
22868
 
22869
+ int ggml_cpu_has_rpc(void) {
22870
+ #if defined(GGML_USE_RPC)
22871
+ return 1;
22872
+ #else
22873
+ return 0;
22874
+ #endif
22875
+ }
22876
+
23215
22877
  int ggml_cpu_has_gpublas(void) {
23216
22878
  return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
23217
22879
  ggml_cpu_has_sycl();