llama_cpp 0.15.2 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -60,6 +60,9 @@
60
60
 
61
61
  typedef volatile LONG atomic_int;
62
62
  typedef atomic_int atomic_bool;
63
+ typedef atomic_int atomic_flag;
64
+
65
+ #define ATOMIC_FLAG_INIT 0
63
66
 
64
67
  static void atomic_store(atomic_int * ptr, LONG val) {
65
68
  InterlockedExchange(ptr, val);
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
73
76
  static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
74
77
  return atomic_fetch_add(ptr, -(dec));
75
78
  }
79
+ static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
80
+ return InterlockedExchange(ptr, 1);
81
+ }
82
+ static void atomic_flag_clear(atomic_flag * ptr) {
83
+ InterlockedExchange(ptr, 0);
84
+ }
76
85
 
77
86
  typedef HANDLE pthread_t;
78
87
 
@@ -406,10 +415,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406
415
  int i = 0;
407
416
  #if defined(__AVX512BF16__)
408
417
  for (; i + 32 <= n; i += 32) {
409
- _mm512_storeu_ps(
410
- (__m512 *)(y + i),
411
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
- _mm512_loadu_ps(x + i)));
418
+ _mm512_storeu_si512(
419
+ (__m512i *)(y + i),
420
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
421
+ _mm512_loadu_ps(x + i))));
413
422
  }
414
423
  #endif
415
424
  for (; i < n; i++) {
@@ -871,22 +880,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
871
880
  },
872
881
  [GGML_TYPE_IQ4_XS] = {
873
882
  .type_name = "iq4_xs",
874
- #if QK_K == 64
875
- .blck_size = QK4_NL,
876
- #else
877
883
  .blck_size = QK_K,
878
- #endif
879
884
  .type_size = sizeof(block_iq4_xs),
880
885
  .is_quantized = true,
881
886
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
882
887
  .from_float = quantize_row_iq4_xs,
883
888
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
884
889
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
885
- #if QK_K == 64
886
- .vec_dot_type = GGML_TYPE_Q8_0,
887
- #else
888
890
  .vec_dot_type = GGML_TYPE_Q8_K,
889
- #endif
890
891
  .nrows = 1,
891
892
  },
892
893
  [GGML_TYPE_Q8_K] = {
@@ -1523,6 +1524,196 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1523
1524
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1524
1525
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1525
1526
 
1527
+ #elif defined(__loongarch_asx)
1528
+
1529
+ #define GGML_SIMD
1530
+
1531
+ // F32 LASX
1532
+ #define GGML_F32_STEP 32
1533
+ #define GGML_F32_EPR 8
1534
+
1535
+ #define GGML_F32x8 __m256
1536
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1537
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1538
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1539
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1540
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1541
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1542
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1543
+ #define GGML_F32x8_REDUCE(res, x) \
1544
+ do { \
1545
+ int offset = GGML_F32_ARR >> 1; \
1546
+ for (int i = 0; i < offset; ++i) { \
1547
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1548
+ } \
1549
+ offset >>= 1; \
1550
+ for (int i = 0; i < offset; ++i) { \
1551
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1552
+ } \
1553
+ offset >>= 1; \
1554
+ for (int i = 0; i < offset; ++i) { \
1555
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1556
+ } \
1557
+ float *tmp_p = (float *)&x[0]; \
1558
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1559
+ } while (0)
1560
+ // TODO: is this optimal ?
1561
+
1562
+ #define GGML_F32_VEC GGML_F32x8
1563
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1564
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1565
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1566
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1567
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1568
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1569
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1570
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1571
+
1572
+ // F16 LASX
1573
+
1574
+ #define GGML_F16_STEP 32
1575
+ #define GGML_F16_EPR 8
1576
+
1577
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1578
+
1579
+ #define GGML_F32Cx8 __m256
1580
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1581
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1582
+
1583
+ static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
1584
+ float tmp[8];
1585
+
1586
+ for (int i = 0; i < 8; i++) {
1587
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1588
+ }
1589
+
1590
+ return (__m256)__lasx_xvld(tmp, 0);
1591
+ }
1592
+ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1593
+ float arr[8];
1594
+
1595
+ __lasx_xvst(y, arr, 0);
1596
+
1597
+ for (int i = 0; i < 8; i++) {
1598
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1599
+ }
1600
+ }
1601
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1602
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1603
+
1604
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1605
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1606
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1607
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1608
+
1609
+ #define GGML_F16_VEC GGML_F32Cx8
1610
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1611
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1612
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1613
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1614
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1615
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1616
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1617
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1618
+
1619
+ #elif defined(__loongarch_sx)
1620
+
1621
+ #define GGML_SIMD
1622
+
1623
+ // F32 LSX
1624
+
1625
+ #define GGML_F32_STEP 32
1626
+ #define GGML_F32_EPR 4
1627
+
1628
+ #define GGML_F32x4 __m128
1629
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1630
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1631
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1632
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1633
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1634
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1635
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1636
+ #define GGML_F32x4_REDUCE(res, x) \
1637
+ { \
1638
+ int offset = GGML_F32_ARR >> 1; \
1639
+ for (int i = 0; i < offset; ++i) { \
1640
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1641
+ } \
1642
+ offset >>= 1; \
1643
+ for (int i = 0; i < offset; ++i) { \
1644
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1645
+ } \
1646
+ offset >>= 1; \
1647
+ for (int i = 0; i < offset; ++i) { \
1648
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1649
+ } \
1650
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1651
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1652
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1653
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1654
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1655
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1656
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1657
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1658
+ }
1659
+
1660
+ #define GGML_F32_VEC GGML_F32x4
1661
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1662
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1663
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1664
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1665
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1666
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1667
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1668
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1669
+
1670
+ // F16 LSX
1671
+
1672
+ #define GGML_F16_STEP 32
1673
+ #define GGML_F16_EPR 4
1674
+
1675
+ static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
1676
+ float tmp[4];
1677
+
1678
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1679
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1680
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1681
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1682
+
1683
+ return __lsx_vld(tmp, 0);
1684
+ }
1685
+
1686
+ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
1687
+ float arr[4];
1688
+
1689
+ __lsx_vst(y, arr, 0);
1690
+
1691
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1692
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1693
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1694
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1695
+ }
1696
+
1697
+ #define GGML_F32Cx4 __m128
1698
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1699
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1700
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1701
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1702
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1703
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1704
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1705
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1706
+
1707
+ #define GGML_F16_VEC GGML_F32Cx4
1708
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1709
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1710
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1711
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1712
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1713
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1714
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1715
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1716
+
1526
1717
  #endif
1527
1718
 
1528
1719
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1666,10 +1857,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1666
1857
  __m512 c1 = _mm512_setzero_ps();
1667
1858
  __m512 c2 = _mm512_setzero_ps();
1668
1859
  for (; i + 64 <= n; i += 64) {
1669
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1670
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1671
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1672
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1860
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1861
+ m512bh(_mm512_loadu_si512((y + i))));
1862
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1863
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1673
1864
  }
1674
1865
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1675
1866
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -2076,7 +2267,7 @@ inline static float ggml_silu_f32(float x) {
2076
2267
  return x/(1.0f + expf(-x));
2077
2268
  }
2078
2269
 
2079
- #if defined(__ARM_NEON)
2270
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2080
2271
 
2081
2272
  // adapted from arm limited optimized routine
2082
2273
  // the maximum error is 1.45358 plus 0.5 ulps
@@ -2125,32 +2316,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
2125
2316
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
2126
2317
  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2127
2318
  const __m512 n = _mm512_sub_ps(z, r);
2128
- const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2129
- _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2130
- const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2131
- const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2132
- const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2133
- const __m512 u = _mm512_mul_ps(b, b);
2134
- const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2135
- _mm512_set1_ps(0x1.573e2ep-5f)), u,
2136
- _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2137
- _mm512_set1_ps(0x1.fffdb6p-2f))),
2138
- u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2139
- if (_mm512_kortestz(c, c))
2140
- return _mm512_fmadd_ps(j, k, k);
2141
- const __m512i g = _mm512_and_si512(
2142
- _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2143
- _mm512_set1_epi32(0x82000000u));
2144
- const __m512 s1 =
2145
- _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2146
- const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2319
+ const __m512 b =
2320
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2321
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2147
2322
  const __mmask16 d =
2148
2323
  _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2149
- return _mm512_mask_blend_ps(
2150
- d, _mm512_mask_blend_ps(
2151
- c, _mm512_fmadd_ps(k, j, k),
2152
- _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2153
- _mm512_mul_ps(s1, s1));
2324
+ const __m512 u = _mm512_mul_ps(b, b);
2325
+ const __m512 j = _mm512_fmadd_ps(
2326
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2327
+ _mm512_set1_ps(0x1.573e2ep-5f)),
2328
+ u,
2329
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2330
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2331
+ u,
2332
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
2333
+ const __m512 res = _mm512_scalef_ps(j, n);
2334
+ if (_mm512_kortestz(d, d))
2335
+ return res;
2336
+ const __m512 zero = _mm512_setzero_ps();
2337
+ const __m512 alt = _mm512_mask_blend_ps(
2338
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
2339
+ return _mm512_mask_blend_ps(d, res, alt);
2154
2340
  }
2155
2341
 
2156
2342
  // computes silu x/(1+exp(-x)) in single precision vector
@@ -2288,7 +2474,7 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2288
2474
  for (; i + 3 < n; i += 4) {
2289
2475
  _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2290
2476
  }
2291
- #elif defined(__ARM_NEON)
2477
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2292
2478
  for (; i + 3 < n; i += 4) {
2293
2479
  vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2294
2480
  }
@@ -2335,7 +2521,7 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
2335
2521
  #endif
2336
2522
  sum += (ggml_float)_mm_cvtss_f32(val);
2337
2523
  }
2338
- #elif defined(__ARM_NEON)
2524
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2339
2525
  for (; i + 3 < n; i += 4) {
2340
2526
  float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2341
2527
  vdupq_n_f32(max)));
@@ -2489,9 +2675,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2489
2675
  "ARGSORT",
2490
2676
  "LEAKY_RELU",
2491
2677
 
2492
- "FLASH_ATTN",
2493
2678
  "FLASH_ATTN_EXT",
2494
- "FLASH_FF",
2495
2679
  "FLASH_ATTN_BACK",
2496
2680
  "SSM_CONV",
2497
2681
  "SSM_SCAN",
@@ -2517,7 +2701,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2517
2701
  "CROSS_ENTROPY_LOSS_BACK",
2518
2702
  };
2519
2703
 
2520
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2704
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2521
2705
 
2522
2706
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2523
2707
  "none",
@@ -2579,9 +2763,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2579
2763
  "argsort(x)",
2580
2764
  "leaky_relu(x)",
2581
2765
 
2582
- "flash_attn(x)",
2583
2766
  "flash_attn_ext(x)",
2584
- "flash_ff(x)",
2585
2767
  "flash_attn_back(x)",
2586
2768
  "ssm_conv(x)",
2587
2769
  "ssm_scan(x)",
@@ -2607,7 +2789,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2607
2789
  "cross_entropy_loss_back(x,y)",
2608
2790
  };
2609
2791
 
2610
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2792
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2611
2793
 
2612
2794
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2613
2795
 
@@ -2706,24 +2888,20 @@ struct ggml_state {
2706
2888
 
2707
2889
  // global state
2708
2890
  static struct ggml_state g_state;
2709
- static atomic_int g_state_barrier = 0;
2891
+ static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
2710
2892
 
2711
2893
  // barrier via spin lock
2712
2894
  inline static void ggml_critical_section_start(void) {
2713
- int processing = atomic_fetch_add(&g_state_barrier, 1);
2714
-
2715
- while (processing > 0) {
2716
- // wait for other threads to finish
2717
- atomic_fetch_sub(&g_state_barrier, 1);
2718
- sched_yield(); // TODO: reconsider this
2719
- processing = atomic_fetch_add(&g_state_barrier, 1);
2895
+ while (atomic_flag_test_and_set(&g_state_critical)) {
2896
+ // spin
2897
+ sched_yield();
2720
2898
  }
2721
2899
  }
2722
2900
 
2723
2901
  // TODO: make this somehow automatically executed
2724
2902
  // some sort of "sentry" mechanism
2725
2903
  inline static void ggml_critical_section_end(void) {
2726
- atomic_fetch_sub(&g_state_barrier, 1);
2904
+ atomic_flag_clear(&g_state_critical);
2727
2905
  }
2728
2906
 
2729
2907
  #if defined(__gnu_linux__)
@@ -3039,7 +3217,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3039
3217
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3040
3218
  }
3041
3219
 
3042
- static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
3220
+ GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3221
+ return ggml_is_contiguous(tensor);
3222
+ }
3223
+
3224
+ GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
3043
3225
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3044
3226
 
3045
3227
  return
@@ -3048,6 +3230,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
3048
3230
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3049
3231
  }
3050
3232
 
3233
+ GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3234
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3235
+
3236
+ return
3237
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
3238
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3239
+ }
3240
+
3051
3241
  GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3052
3242
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3053
3243
 
@@ -4705,10 +4895,21 @@ struct ggml_tensor * ggml_repeat_back(
4705
4895
  // ggml_concat
4706
4896
 
4707
4897
  struct ggml_tensor * ggml_concat(
4708
- struct ggml_context* ctx,
4709
- struct ggml_tensor* a,
4710
- struct ggml_tensor* b) {
4711
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4898
+ struct ggml_context * ctx,
4899
+ struct ggml_tensor * a,
4900
+ struct ggml_tensor * b,
4901
+ int dim) {
4902
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4903
+
4904
+ int64_t ne[GGML_MAX_DIMS];
4905
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4906
+ if (d == dim) {
4907
+ ne[d] = a->ne[d] + b->ne[d];
4908
+ continue;
4909
+ }
4910
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4911
+ ne[d] = a->ne[d];
4912
+ }
4712
4913
 
4713
4914
  bool is_node = false;
4714
4915
 
@@ -4716,7 +4917,9 @@ struct ggml_tensor * ggml_concat(
4716
4917
  is_node = true;
4717
4918
  }
4718
4919
 
4719
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4920
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4921
+
4922
+ ggml_set_op_params_i32(result, 0, dim);
4720
4923
 
4721
4924
  result->op = GGML_OP_CONCAT;
4722
4925
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4836,6 +5039,7 @@ struct ggml_tensor * ggml_leaky_relu(
4836
5039
  }
4837
5040
 
4838
5041
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5042
+
4839
5043
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
4840
5044
 
4841
5045
  result->op = GGML_OP_LEAKY_RELU;
@@ -6042,6 +6246,7 @@ static struct ggml_tensor * ggml_rope_impl(
6042
6246
  struct ggml_context * ctx,
6043
6247
  struct ggml_tensor * a,
6044
6248
  struct ggml_tensor * b,
6249
+ struct ggml_tensor * c,
6045
6250
  int n_dims,
6046
6251
  int mode,
6047
6252
  int n_ctx,
@@ -6055,10 +6260,17 @@ static struct ggml_tensor * ggml_rope_impl(
6055
6260
  float xpos_base,
6056
6261
  bool xpos_down,
6057
6262
  bool inplace) {
6263
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6264
+
6058
6265
  GGML_ASSERT(ggml_is_vector(b));
6059
6266
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6060
6267
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6061
6268
 
6269
+ if (c) {
6270
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6271
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6272
+ }
6273
+
6062
6274
  bool is_node = false;
6063
6275
 
6064
6276
  if (a->grad) {
@@ -6082,6 +6294,7 @@ static struct ggml_tensor * ggml_rope_impl(
6082
6294
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6083
6295
  result->src[0] = a;
6084
6296
  result->src[1] = b;
6297
+ result->src[2] = c;
6085
6298
 
6086
6299
  return result;
6087
6300
  }
@@ -6094,7 +6307,7 @@ struct ggml_tensor * ggml_rope(
6094
6307
  int mode,
6095
6308
  int n_ctx) {
6096
6309
  return ggml_rope_impl(
6097
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6310
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6098
6311
  );
6099
6312
  }
6100
6313
 
@@ -6106,7 +6319,49 @@ struct ggml_tensor * ggml_rope_inplace(
6106
6319
  int mode,
6107
6320
  int n_ctx) {
6108
6321
  return ggml_rope_impl(
6109
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6322
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6323
+ );
6324
+ }
6325
+
6326
+ struct ggml_tensor * ggml_rope_ext(
6327
+ struct ggml_context * ctx,
6328
+ struct ggml_tensor * a,
6329
+ struct ggml_tensor * b,
6330
+ struct ggml_tensor * c,
6331
+ int n_dims,
6332
+ int mode,
6333
+ int n_ctx,
6334
+ int n_orig_ctx,
6335
+ float freq_base,
6336
+ float freq_scale,
6337
+ float ext_factor,
6338
+ float attn_factor,
6339
+ float beta_fast,
6340
+ float beta_slow) {
6341
+ return ggml_rope_impl(
6342
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6343
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6344
+ );
6345
+ }
6346
+
6347
+ struct ggml_tensor * ggml_rope_ext_inplace(
6348
+ struct ggml_context * ctx,
6349
+ struct ggml_tensor * a,
6350
+ struct ggml_tensor * b,
6351
+ struct ggml_tensor * c,
6352
+ int n_dims,
6353
+ int mode,
6354
+ int n_ctx,
6355
+ int n_orig_ctx,
6356
+ float freq_base,
6357
+ float freq_scale,
6358
+ float ext_factor,
6359
+ float attn_factor,
6360
+ float beta_fast,
6361
+ float beta_slow) {
6362
+ return ggml_rope_impl(
6363
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6364
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6110
6365
  );
6111
6366
  }
6112
6367
 
@@ -6125,7 +6380,7 @@ struct ggml_tensor * ggml_rope_custom(
6125
6380
  float beta_fast,
6126
6381
  float beta_slow) {
6127
6382
  return ggml_rope_impl(
6128
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6383
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6129
6384
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6130
6385
  );
6131
6386
  }
@@ -6145,7 +6400,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
6145
6400
  float beta_fast,
6146
6401
  float beta_slow) {
6147
6402
  return ggml_rope_impl(
6148
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6403
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6149
6404
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6150
6405
  );
6151
6406
  }
@@ -6157,7 +6412,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
6157
6412
  int n_dims,
6158
6413
  float base,
6159
6414
  bool down) {
6160
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6415
+ return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6161
6416
  }
6162
6417
 
6163
6418
  // ggml_rope_back
@@ -6166,6 +6421,7 @@ struct ggml_tensor * ggml_rope_back(
6166
6421
  struct ggml_context * ctx,
6167
6422
  struct ggml_tensor * a,
6168
6423
  struct ggml_tensor * b,
6424
+ struct ggml_tensor * c,
6169
6425
  int n_dims,
6170
6426
  int mode,
6171
6427
  int n_ctx,
@@ -6181,6 +6437,7 @@ struct ggml_tensor * ggml_rope_back(
6181
6437
  GGML_ASSERT(ggml_is_vector(b));
6182
6438
  GGML_ASSERT(b->type == GGML_TYPE_I32);
6183
6439
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6440
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6184
6441
 
6185
6442
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
6186
6443
 
@@ -6724,38 +6981,6 @@ struct ggml_tensor * ggml_top_k(
6724
6981
  return result;
6725
6982
  }
6726
6983
 
6727
- // ggml_flash_attn
6728
-
6729
- struct ggml_tensor * ggml_flash_attn(
6730
- struct ggml_context * ctx,
6731
- struct ggml_tensor * q,
6732
- struct ggml_tensor * k,
6733
- struct ggml_tensor * v,
6734
- bool masked) {
6735
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6736
- // TODO: check if vT can be multiplied by (k*qT)
6737
-
6738
- bool is_node = false;
6739
-
6740
- if (q->grad || k->grad || v->grad) {
6741
- is_node = true;
6742
- }
6743
-
6744
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6745
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6746
-
6747
- int32_t t = masked ? 1 : 0;
6748
- ggml_set_op_params(result, &t, sizeof(t));
6749
-
6750
- result->op = GGML_OP_FLASH_ATTN;
6751
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6752
- result->src[0] = q;
6753
- result->src[1] = k;
6754
- result->src[2] = v;
6755
-
6756
- return result;
6757
- }
6758
-
6759
6984
  // ggml_flash_attn_ext
6760
6985
 
6761
6986
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6815,38 +7040,6 @@ void ggml_flash_attn_ext_set_prec(
6815
7040
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6816
7041
  }
6817
7042
 
6818
- // ggml_flash_ff
6819
-
6820
- struct ggml_tensor * ggml_flash_ff(
6821
- struct ggml_context * ctx,
6822
- struct ggml_tensor * a,
6823
- struct ggml_tensor * b0,
6824
- struct ggml_tensor * b1,
6825
- struct ggml_tensor * c0,
6826
- struct ggml_tensor * c1) {
6827
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6828
- // TODO: more checks
6829
-
6830
- bool is_node = false;
6831
-
6832
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6833
- is_node = true;
6834
- }
6835
-
6836
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6837
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6838
-
6839
- result->op = GGML_OP_FLASH_FF;
6840
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6841
- result->src[0] = a;
6842
- result->src[1] = b0;
6843
- result->src[2] = b1;
6844
- result->src[3] = c0;
6845
- result->src[4] = c1;
6846
-
6847
- return result;
6848
- }
6849
-
6850
7043
  // ggml_flash_attn_back
6851
7044
 
6852
7045
  struct ggml_tensor * ggml_flash_attn_back(
@@ -6856,6 +7049,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6856
7049
  struct ggml_tensor * v,
6857
7050
  struct ggml_tensor * d,
6858
7051
  bool masked) {
7052
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7053
+
6859
7054
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6860
7055
  // TODO: check if vT can be multiplied by (k*qT)
6861
7056
 
@@ -10809,26 +11004,29 @@ static void ggml_compute_forward_concat_f32(
10809
11004
  GGML_ASSERT(nb00 == sizeof(float));
10810
11005
  GGML_ASSERT(nb10 == sizeof(float));
10811
11006
 
11007
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
11008
+
11009
+ GGML_ASSERT(dim >= 0 && dim < 4);
11010
+
11011
+ int64_t o[4] = {0, 0, 0, 0};
11012
+ o[dim] = src0->ne[dim];
11013
+
11014
+ const float * x;
11015
+
11016
+ // TODO: smarter multi-theading
10812
11017
  for (int i3 = 0; i3 < ne3; i3++) {
10813
11018
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10814
- if (i2 < ne02) { // src0
10815
- for (int i1 = 0; i1 < ne1; i1++) {
10816
- for (int i0 = 0; i0 < ne0; i0++) {
10817
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10818
-
10819
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10820
- *y = *x;
10821
- }
10822
- }
10823
- } // src1
10824
- else {
10825
- for (int i1 = 0; i1 < ne1; i1++) {
10826
- for (int i0 = 0; i0 < ne0; i0++) {
10827
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10828
-
10829
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10830
- *y = *x;
11019
+ for (int i1 = 0; i1 < ne1; i1++) {
11020
+ for (int i0 = 0; i0 < ne0; i0++) {
11021
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
11022
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
11023
+ } else {
11024
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
10831
11025
  }
11026
+
11027
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
11028
+
11029
+ *y = *x;
10832
11030
  }
10833
11031
  }
10834
11032
  }
@@ -10836,8 +11034,8 @@ static void ggml_compute_forward_concat_f32(
10836
11034
  }
10837
11035
 
10838
11036
  static void ggml_compute_forward_concat(
10839
- const struct ggml_compute_params* params,
10840
- struct ggml_tensor* dst) {
11037
+ const struct ggml_compute_params * params,
11038
+ struct ggml_tensor * dst) {
10841
11039
 
10842
11040
  const struct ggml_tensor * src0 = dst->src[0];
10843
11041
 
@@ -11230,8 +11428,8 @@ static void ggml_compute_forward_gelu_f32(
11230
11428
 
11231
11429
  const struct ggml_tensor * src0 = dst->src[0];
11232
11430
 
11233
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11234
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11431
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11432
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11235
11433
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11236
11434
 
11237
11435
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11293,8 +11491,8 @@ static void ggml_compute_forward_gelu_quick_f32(
11293
11491
 
11294
11492
  const struct ggml_tensor * src0 = dst->src[0];
11295
11493
 
11296
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11297
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11494
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11495
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11298
11496
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11299
11497
 
11300
11498
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11356,8 +11554,8 @@ static void ggml_compute_forward_silu_f32(
11356
11554
 
11357
11555
  const struct ggml_tensor * src0 = dst->src[0];
11358
11556
 
11359
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11360
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11557
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11558
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11361
11559
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11362
11560
 
11363
11561
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11468,9 +11666,9 @@ static void ggml_compute_forward_silu_back_f32(
11468
11666
  const struct ggml_tensor * src0 = dst->src[0];
11469
11667
  const struct ggml_tensor * grad = dst->src[1];
11470
11668
 
11471
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
11472
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
11473
- GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
11669
+ GGML_ASSERT(ggml_is_contiguous_1(grad));
11670
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
11671
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
11474
11672
  GGML_ASSERT(ggml_are_same_shape(src0, dst));
11475
11673
  GGML_ASSERT(ggml_are_same_shape(src0, grad));
11476
11674
 
@@ -14115,6 +14313,7 @@ static void ggml_compute_forward_rope_f32(
14115
14313
 
14116
14314
  const struct ggml_tensor * src0 = dst->src[0];
14117
14315
  const struct ggml_tensor * src1 = dst->src[1];
14316
+ const struct ggml_tensor * src2 = dst->src[2];
14118
14317
 
14119
14318
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14120
14319
  return;
@@ -14167,13 +14366,24 @@ static void ggml_compute_forward_rope_f32(
14167
14366
  int ir = 0;
14168
14367
 
14169
14368
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14170
- const float inv_ndims = -1.f/n_dims;
14369
+
14171
14370
  float corr_dims[2];
14172
14371
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14173
14372
 
14174
14373
  const bool is_neox = mode & 2;
14175
14374
  const bool is_glm = mode & 4;
14176
14375
 
14376
+ const float * freq_factors = NULL;
14377
+ if (is_neox) {
14378
+ if (src2 != NULL) {
14379
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14380
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14381
+ freq_factors = (const float *) src2->data;
14382
+ }
14383
+ } else {
14384
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14385
+ }
14386
+
14177
14387
  // backward process uses inverse rotation by cos and sin.
14178
14388
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14179
14389
  // this essentially just switches the sign of sin.
@@ -14205,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
14205
14415
  const float cos_block_theta = cosf(block_theta);
14206
14416
  const float sin_block_theta = sinf(block_theta) * sin_sign;
14207
14417
 
14208
- theta_base *= theta_scale;
14418
+ theta_base *= theta_scale;
14209
14419
  block_theta *= theta_scale;
14210
14420
 
14211
14421
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14240,28 +14450,22 @@ static void ggml_compute_forward_rope_f32(
14240
14450
  dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
14241
14451
  }
14242
14452
  } else {
14243
- // TODO: this might be wrong for ne0 != n_dims - need double check
14244
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14245
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14246
- theta_base *= freq_scale;
14453
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14247
14454
  for (int64_t ic = 0; ic < ne0; ic += 2) {
14248
14455
  if (ic < n_dims) {
14249
- const int64_t ib = 0;
14456
+ const int64_t i0 = ic/2;
14250
14457
 
14251
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14252
- float cur_rot = inv_ndims * ic - ib;
14458
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14253
14459
 
14254
14460
  float cos_theta, sin_theta;
14255
14461
  rope_yarn(
14256
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14462
+ theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
14257
14463
  &cos_theta, &sin_theta
14258
14464
  );
14259
- sin_theta *= sin_sign;
14260
14465
 
14466
+ sin_theta *= sin_sign;
14261
14467
  theta_base *= theta_scale;
14262
14468
 
14263
- const int64_t i0 = ib*n_dims + ic/2;
14264
-
14265
14469
  const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14266
14470
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14267
14471
 
@@ -14286,6 +14490,7 @@ static void ggml_compute_forward_rope_f32(
14286
14490
  }
14287
14491
  }
14288
14492
 
14493
+ // TODO: deduplicate f16/f32 code
14289
14494
  static void ggml_compute_forward_rope_f16(
14290
14495
  const struct ggml_compute_params * params,
14291
14496
  struct ggml_tensor * dst,
@@ -14293,6 +14498,7 @@ static void ggml_compute_forward_rope_f16(
14293
14498
 
14294
14499
  const struct ggml_tensor * src0 = dst->src[0];
14295
14500
  const struct ggml_tensor * src1 = dst->src[1];
14501
+ const struct ggml_tensor * src2 = dst->src[2];
14296
14502
 
14297
14503
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14298
14504
  return;
@@ -14338,13 +14544,24 @@ static void ggml_compute_forward_rope_f16(
14338
14544
  int ir = 0;
14339
14545
 
14340
14546
  const float theta_scale = powf(freq_base, -2.0f/n_dims);
14341
- const float inv_ndims = -1.f/n_dims;
14547
+
14342
14548
  float corr_dims[2];
14343
14549
  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14344
14550
 
14345
14551
  const bool is_neox = mode & 2;
14346
14552
  const bool is_glm = mode & 4;
14347
14553
 
14554
+ const float * freq_factors = NULL;
14555
+ if (is_neox) {
14556
+ if (src2 != NULL) {
14557
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14558
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14559
+ freq_factors = (const float *) src2->data;
14560
+ }
14561
+ } else {
14562
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14563
+ }
14564
+
14348
14565
  // backward process uses inverse rotation by cos and sin.
14349
14566
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14350
14567
  // this essentially just switches the sign of sin.
@@ -14376,7 +14593,7 @@ static void ggml_compute_forward_rope_f16(
14376
14593
  const float cos_block_theta = cosf(block_theta);
14377
14594
  const float sin_block_theta = sinf(block_theta) * sin_sign;
14378
14595
 
14379
- theta_base *= theta_scale;
14596
+ theta_base *= theta_scale;
14380
14597
  block_theta *= theta_scale;
14381
14598
 
14382
14599
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14407,28 +14624,22 @@ static void ggml_compute_forward_rope_f16(
14407
14624
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14408
14625
  }
14409
14626
  } else {
14410
- // TODO: this might be wrong for ne0 != n_dims - need double check
14411
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14412
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14413
- theta_base *= freq_scale;
14627
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14414
14628
  for (int64_t ic = 0; ic < ne0; ic += 2) {
14415
14629
  if (ic < n_dims) {
14416
- const int64_t ib = 0;
14630
+ const int64_t i0 = ic/2;
14417
14631
 
14418
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14419
- float cur_rot = inv_ndims * ic - ib;
14632
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14420
14633
 
14421
14634
  float cos_theta, sin_theta;
14422
14635
  rope_yarn(
14423
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14636
+ theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
14424
14637
  &cos_theta, &sin_theta
14425
14638
  );
14426
- sin_theta *= sin_sign;
14427
14639
 
14640
+ sin_theta *= sin_sign;
14428
14641
  theta_base *= theta_scale;
14429
14642
 
14430
- const int64_t i0 = ib*n_dims + ic/2;
14431
-
14432
14643
  const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14433
14644
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14434
14645
 
@@ -15458,400 +15669,6 @@ static void ggml_compute_forward_argsort(
15458
15669
  }
15459
15670
  }
15460
15671
 
15461
- // ggml_compute_forward_flash_attn
15462
-
15463
- static void ggml_compute_forward_flash_attn_f32(
15464
- const struct ggml_compute_params * params,
15465
- const bool masked,
15466
- struct ggml_tensor * dst) {
15467
-
15468
- const struct ggml_tensor * q = dst->src[0];
15469
- const struct ggml_tensor * k = dst->src[1];
15470
- const struct ggml_tensor * v = dst->src[2];
15471
-
15472
- int64_t t0 = ggml_perf_time_us();
15473
- UNUSED(t0);
15474
-
15475
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15476
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15477
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15478
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15479
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15480
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15481
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15482
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15483
-
15484
- const int ith = params->ith;
15485
- const int nth = params->nth;
15486
-
15487
- const int64_t D = neq0;
15488
- const int64_t N = neq1;
15489
- const int64_t P = nek1 - N;
15490
- const int64_t M = P + N;
15491
-
15492
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15493
-
15494
- GGML_ASSERT(ne0 == D);
15495
- GGML_ASSERT(ne1 == N);
15496
- GGML_ASSERT(P >= 0);
15497
-
15498
- GGML_ASSERT(nbq0 == sizeof(float));
15499
- GGML_ASSERT(nbk0 == sizeof(float));
15500
- GGML_ASSERT(nbv0 == sizeof(float));
15501
-
15502
- GGML_ASSERT(neq0 == D);
15503
- GGML_ASSERT(nek0 == D);
15504
- GGML_ASSERT(nev1 == D);
15505
-
15506
- GGML_ASSERT(neq1 == N);
15507
- GGML_ASSERT(nek1 == N + P);
15508
- GGML_ASSERT(nev1 == D);
15509
-
15510
- // dst cannot be transposed or permuted
15511
- GGML_ASSERT(nb0 == sizeof(float));
15512
- GGML_ASSERT(nb0 <= nb1);
15513
- GGML_ASSERT(nb1 <= nb2);
15514
- GGML_ASSERT(nb2 <= nb3);
15515
-
15516
- if (params->type == GGML_TASK_TYPE_INIT) {
15517
- return;
15518
- }
15519
-
15520
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15521
- return;
15522
- }
15523
-
15524
- // parallelize by q rows using ggml_vec_dot_f32
15525
-
15526
- // total rows in q
15527
- const int nr = neq1*neq2*neq3;
15528
-
15529
- // rows per thread
15530
- const int dr = (nr + nth - 1)/nth;
15531
-
15532
- // row range for this thread
15533
- const int ir0 = dr*ith;
15534
- const int ir1 = MIN(ir0 + dr, nr);
15535
-
15536
- const float scale = 1.0f/sqrtf(D);
15537
-
15538
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15539
-
15540
- for (int ir = ir0; ir < ir1; ++ir) {
15541
- // q indices
15542
- const int iq3 = ir/(neq2*neq1);
15543
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15544
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15545
-
15546
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15547
-
15548
- for (int i = M; i < Mup; ++i) {
15549
- S[i] = -INFINITY;
15550
- }
15551
-
15552
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15553
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15554
- // k indices
15555
- const int ik3 = iq3;
15556
- const int ik2 = iq2 % nek2;
15557
- const int ik1 = ic;
15558
-
15559
- // S indices
15560
- const int i1 = ik1;
15561
-
15562
- ggml_vec_dot_f32(neq0,
15563
- S + i1, 0,
15564
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15565
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15566
- }
15567
-
15568
- // scale
15569
- ggml_vec_scale_f32(masked_begin, S, scale);
15570
-
15571
- for (int64_t i = masked_begin; i < M; i++) {
15572
- S[i] = -INFINITY;
15573
- }
15574
-
15575
- // softmax
15576
- // exclude known -INF S[..] values from max and loop
15577
- // dont forget to set their SW values to zero
15578
- {
15579
- float max = -INFINITY;
15580
- ggml_vec_max_f32(masked_begin, &max, S);
15581
-
15582
- ggml_float sum = 0.0;
15583
- {
15584
- #ifdef GGML_SOFT_MAX_ACCELERATE
15585
- max = -max;
15586
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15587
- vvexpf(S, S, &Mup);
15588
- ggml_vec_sum_f32(Mup, &sum, S);
15589
- #else
15590
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15591
- #endif
15592
- }
15593
-
15594
- assert(sum > 0.0);
15595
-
15596
- sum = 1.0/sum;
15597
- ggml_vec_scale_f32(masked_begin, S, sum);
15598
-
15599
- #ifndef NDEBUG
15600
- for (int i = 0; i < masked_begin; ++i) {
15601
- assert(!isnan(S[i]));
15602
- assert(!isinf(S[i]));
15603
- }
15604
- #endif
15605
- }
15606
-
15607
- for (int64_t ic = 0; ic < nev1; ++ic) {
15608
- // dst indices
15609
- const int i1 = iq1;
15610
- const int i2 = iq2;
15611
- const int i3 = iq3;
15612
-
15613
- // v indices
15614
- const int iv2 = iq2 % nev2;
15615
- const int iv3 = iq3;
15616
-
15617
- ggml_vec_dot_f32(masked_begin,
15618
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15619
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15620
- S, 0, 1);
15621
- }
15622
- }
15623
- }
15624
-
15625
- static void ggml_compute_forward_flash_attn_f16(
15626
- const struct ggml_compute_params * params,
15627
- const bool masked,
15628
- struct ggml_tensor * dst) {
15629
-
15630
- const struct ggml_tensor * q = dst->src[0];
15631
- const struct ggml_tensor * k = dst->src[1];
15632
- const struct ggml_tensor * v = dst->src[2];
15633
-
15634
- int64_t t0 = ggml_perf_time_us();
15635
- UNUSED(t0);
15636
-
15637
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15638
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15639
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15640
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15641
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15642
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15643
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15644
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15645
-
15646
- const int ith = params->ith;
15647
- const int nth = params->nth;
15648
-
15649
- const int64_t D = neq0;
15650
- const int64_t N = neq1;
15651
- const int64_t P = nek1 - N;
15652
- const int64_t M = P + N;
15653
-
15654
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15655
-
15656
- GGML_ASSERT(ne0 == D);
15657
- GGML_ASSERT(ne1 == N);
15658
- GGML_ASSERT(P >= 0);
15659
-
15660
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15661
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15662
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15663
-
15664
- GGML_ASSERT(neq0 == D);
15665
- GGML_ASSERT(nek0 == D);
15666
- GGML_ASSERT(nev1 == D);
15667
-
15668
- GGML_ASSERT(neq1 == N);
15669
- GGML_ASSERT(nek1 == N + P);
15670
- GGML_ASSERT(nev1 == D);
15671
-
15672
- // dst cannot be transposed or permuted
15673
- GGML_ASSERT(nb0 == sizeof(float));
15674
- GGML_ASSERT(nb0 <= nb1);
15675
- GGML_ASSERT(nb1 <= nb2);
15676
- GGML_ASSERT(nb2 <= nb3);
15677
-
15678
- if (params->type == GGML_TASK_TYPE_INIT) {
15679
- return;
15680
- }
15681
-
15682
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15683
- return;
15684
- }
15685
-
15686
- // parallelize by q rows using ggml_vec_dot_f32
15687
-
15688
- // total rows in q
15689
- const int nr = neq1*neq2*neq3;
15690
-
15691
- // rows per thread
15692
- const int dr = (nr + nth - 1)/nth;
15693
-
15694
- // row range for this thread
15695
- const int ir0 = dr*ith;
15696
- const int ir1 = MIN(ir0 + dr, nr);
15697
-
15698
- const float scale = 1.0f/sqrtf(D);
15699
-
15700
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15701
-
15702
- for (int ir = ir0; ir < ir1; ++ir) {
15703
- // q indices
15704
- const int iq3 = ir/(neq2*neq1);
15705
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15706
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15707
-
15708
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15709
-
15710
- for (int i = M; i < Mup; ++i) {
15711
- S[i] = -INFINITY;
15712
- }
15713
-
15714
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15715
- for (int64_t ic = 0; ic < nek1; ++ic) {
15716
- // k indices
15717
- const int ik3 = iq3;
15718
- const int ik2 = iq2 % nek2;
15719
- const int ik1 = ic;
15720
-
15721
- // S indices
15722
- const int i1 = ik1;
15723
-
15724
- ggml_vec_dot_f16(neq0,
15725
- S + i1, 0,
15726
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15727
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15728
- }
15729
- } else {
15730
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15731
- // k indices
15732
- const int ik3 = iq3;
15733
- const int ik2 = iq2 % nek2;
15734
- const int ik1 = ic;
15735
-
15736
- // S indices
15737
- const int i1 = ik1;
15738
-
15739
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15740
- S + i1,
15741
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15742
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15743
- }
15744
- }
15745
-
15746
- // scale
15747
- ggml_vec_scale_f32(nek1, S, scale);
15748
-
15749
- if (masked) {
15750
- for (int64_t i = P; i < M; i++) {
15751
- if (i > P + iq1) {
15752
- S[i] = -INFINITY;
15753
- }
15754
- }
15755
- }
15756
-
15757
- // softmax
15758
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15759
- // dont forget to set their S values to zero
15760
- {
15761
- float max = -INFINITY;
15762
- ggml_vec_max_f32(M, &max, S);
15763
-
15764
- ggml_float sum = 0.0;
15765
- {
15766
- #ifdef GGML_SOFT_MAX_ACCELERATE
15767
- max = -max;
15768
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15769
- vvexpf(S, S, &Mup);
15770
- ggml_vec_sum_f32(Mup, &sum, S);
15771
- #else
15772
- sum = ggml_vec_soft_max_f32(Mup, S, S, max);
15773
- #endif
15774
- }
15775
-
15776
- assert(sum > 0.0);
15777
-
15778
- sum = 1.0/sum;
15779
- ggml_vec_scale_f32(M, S, sum);
15780
-
15781
- #ifndef NDEBUG
15782
- for (int i = 0; i < M; ++i) {
15783
- assert(!isnan(S[i]));
15784
- assert(!isinf(S[i]));
15785
- }
15786
- #endif
15787
- }
15788
-
15789
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15790
-
15791
- for (int64_t i = 0; i < M; i++) {
15792
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15793
- }
15794
-
15795
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15796
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15797
- for (int64_t ic = 0; ic < nev1; ++ic) {
15798
- // dst indices
15799
- const int i1 = iq1;
15800
- const int i2 = iq2;
15801
- const int i3 = iq3;
15802
-
15803
- // v indices
15804
- const int iv2 = iq2 % nev2;
15805
- const int iv3 = iq3;
15806
-
15807
- ggml_vec_dot_f16(nev0,
15808
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15809
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15810
- S16, 0, 1);
15811
- }
15812
- } else {
15813
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15814
- // dst indices
15815
- const int i1 = iq1;
15816
- const int i2 = iq2;
15817
- const int i3 = iq3;
15818
-
15819
- // v indices
15820
- const int iv2 = iq2 % nev2;
15821
- const int iv3 = iq3;
15822
-
15823
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15824
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15825
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15826
- S16);
15827
- }
15828
- }
15829
- }
15830
- }
15831
-
15832
- static void ggml_compute_forward_flash_attn(
15833
- const struct ggml_compute_params * params,
15834
- const bool masked,
15835
- struct ggml_tensor * dst) {
15836
-
15837
- const struct ggml_tensor * q = dst->src[0];
15838
-
15839
- switch (q->type) {
15840
- case GGML_TYPE_F16:
15841
- {
15842
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15843
- } break;
15844
- case GGML_TYPE_F32:
15845
- {
15846
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15847
- } break;
15848
- default:
15849
- {
15850
- GGML_ASSERT(false);
15851
- } break;
15852
- }
15853
- }
15854
-
15855
15672
  // ggml_compute_forward_flash_attn_ext
15856
15673
 
15857
15674
  static void ggml_compute_forward_flash_attn_ext_f16(
@@ -15882,9 +15699,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15882
15699
  GGML_ASSERT(ne0 == D);
15883
15700
  GGML_ASSERT(ne2 == N);
15884
15701
 
15885
- GGML_ASSERT(nbq0 == sizeof(float));
15886
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15887
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15702
+ // input tensor rows must be contiguous
15703
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15704
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15705
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15888
15706
 
15889
15707
  GGML_ASSERT(neq0 == D);
15890
15708
  GGML_ASSERT(nek0 == D);
@@ -15938,6 +15756,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15938
15756
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15939
15757
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15940
15758
 
15759
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15760
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15761
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15762
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15763
+
15941
15764
  // loop over n_batch and n_head
15942
15765
  for (int ir = ir0; ir < ir1; ++ir) {
15943
15766
  // q indices
@@ -15945,17 +15768,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15945
15768
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15946
15769
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15947
15770
 
15948
- const uint32_t h = iq2; // head
15771
+ const uint32_t h = iq2; // head index
15949
15772
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15950
15773
 
15951
- float S = 0.0f;
15952
- float M = -INFINITY;
15774
+ float S = 0.0f; // sum
15775
+ float M = -INFINITY; // maximum KQ value
15953
15776
 
15954
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15955
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15956
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15777
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15778
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15779
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15780
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15957
15781
 
15958
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15782
+ if (v->type == GGML_TYPE_F16) {
15783
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15784
+ } else {
15785
+ memset(VKQ32, 0, D*sizeof(float));
15786
+ }
15959
15787
 
15960
15788
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15961
15789
 
@@ -15967,6 +15795,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15967
15795
  const int iv3 = iq3 / rv3;
15968
15796
  const int iv2 = iq2 / rv2;
15969
15797
 
15798
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15799
+ q_to_vec_dot(pq, Q_q, D);
15800
+
15970
15801
  // online softmax / attention
15971
15802
  // loop over n_kv and n_head_kv
15972
15803
  // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +15807,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15976
15807
  continue;
15977
15808
  }
15978
15809
 
15979
- float s;
15810
+ float s; // KQ value
15980
15811
 
15981
- // convert Q to F16 in V32
15982
- {
15983
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15812
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15813
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15984
15814
 
15985
- for (int64_t d = 0; d < D; ++d) {
15986
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15987
- }
15988
- }
15815
+ s = s*scale + mv; // scale KQ value and apply mask
15989
15816
 
15990
- ggml_vec_dot_f16(D,
15991
- &s, 0,
15992
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15993
- Q16, 0, 1);
15817
+ const float Mold = M;
15994
15818
 
15995
- s = s*scale + mv;
15819
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15820
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15996
15821
 
15997
- const float Mold = M;
15822
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15998
15823
 
15999
- float ms = 1.0f;
16000
- float vs = 1.0f;
15824
+ if (v->type== GGML_TYPE_F16) {
15825
+ if (s > M) {
15826
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15827
+ M = s;
15828
+ ms = expf(Mold - M);
16001
15829
 
16002
- if (s > M) {
16003
- M = s;
16004
- ms = expf(Mold - M);
15830
+ // V = V*expf(Mold - M)
15831
+ ggml_vec_scale_f16(D, VKQ16, ms);
15832
+ } else {
15833
+ // no new maximum, ms == 1.0f, vs != 1.0f
15834
+ vs = expf(s - M);
15835
+ }
16005
15836
 
16006
- // V = V*expf(Mold - M)
16007
- ggml_vec_scale_f16(D, V16, ms);
15837
+ // V += v*expf(s - M)
15838
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
16008
15839
  } else {
16009
- vs = expf(s - M);
16010
- }
15840
+ if (s > M) {
15841
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15842
+ M = s;
15843
+ ms = expf(Mold - M);
16011
15844
 
16012
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15845
+ // V = V*expf(Mold - M)
15846
+ ggml_vec_scale_f32(D, VKQ32, ms);
15847
+ } else {
15848
+ // no new maximum, ms == 1.0f, vs != 1.0f
15849
+ vs = expf(s - M);
15850
+ }
16013
15851
 
16014
- // V += v*expf(s - M)
16015
- ggml_vec_mad_f16(D, V16, v16, vs);
15852
+ v_to_float(v_data, V32, D);
15853
+
15854
+ // V += v*expf(s - M)
15855
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15856
+ }
16016
15857
 
16017
- S = S*ms + vs;
15858
+ S = S*ms + vs; // scale and increment sum with partial sum
16018
15859
  }
16019
15860
 
16020
- // V /= S
16021
- for (int64_t d = 0; d < D; ++d) {
16022
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15861
+ if (v->type == GGML_TYPE_F16) {
15862
+ for (int64_t d = 0; d < D; ++d) {
15863
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15864
+ }
16023
15865
  }
16024
15866
 
15867
+ // V /= S
15868
+ const float S_inv = 1.0f/S;
15869
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15870
+
16025
15871
  // dst indices
16026
15872
  const int i1 = iq1;
16027
15873
  const int i2 = iq2;
@@ -16031,7 +15877,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
16031
15877
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
16032
15878
 
16033
15879
  // permute(0, 2, 1, 3)
16034
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15880
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
16035
15881
  }
16036
15882
  }
16037
15883
 
@@ -16056,165 +15902,6 @@ static void ggml_compute_forward_flash_attn_ext(
16056
15902
  }
16057
15903
  }
16058
15904
 
16059
- // ggml_compute_forward_flash_ff
16060
-
16061
- static void ggml_compute_forward_flash_ff_f16(
16062
- const struct ggml_compute_params * params,
16063
- struct ggml_tensor * dst) {
16064
-
16065
- const struct ggml_tensor * a = dst->src[0]; // F16
16066
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
16067
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
16068
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
16069
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
16070
-
16071
- int64_t t0 = ggml_perf_time_us();
16072
- UNUSED(t0);
16073
-
16074
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
16075
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
16076
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
16077
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
16078
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
16079
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
16080
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
16081
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
16082
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
16083
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
16084
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
16085
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
16086
-
16087
- const int ith = params->ith;
16088
- const int nth = params->nth;
16089
-
16090
- const int64_t D = nea0;
16091
- //const int64_t N = nea1;
16092
- const int64_t M = neb01;
16093
-
16094
- GGML_ASSERT(ne0 == nea0);
16095
- GGML_ASSERT(ne1 == nea1);
16096
- GGML_ASSERT(ne2 == nea2);
16097
-
16098
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
16099
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
16100
- GGML_ASSERT(nbb10 == sizeof(float));
16101
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
16102
- GGML_ASSERT(nbc10 == sizeof(float));
16103
-
16104
- GGML_ASSERT(neb00 == D);
16105
- GGML_ASSERT(neb01 == M);
16106
- GGML_ASSERT(neb10 == M);
16107
- GGML_ASSERT(neb11 == 1);
16108
-
16109
- GGML_ASSERT(nec00 == M);
16110
- GGML_ASSERT(nec01 == D);
16111
- GGML_ASSERT(nec10 == D);
16112
- GGML_ASSERT(nec11 == 1);
16113
-
16114
- // dst cannot be transposed or permuted
16115
- GGML_ASSERT(nb0 == sizeof(float));
16116
- GGML_ASSERT(nb0 <= nb1);
16117
- GGML_ASSERT(nb1 <= nb2);
16118
- GGML_ASSERT(nb2 <= nb3);
16119
-
16120
- if (params->type == GGML_TASK_TYPE_INIT) {
16121
- return;
16122
- }
16123
-
16124
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
16125
- return;
16126
- }
16127
-
16128
- // parallelize by a rows using ggml_vec_dot_f32
16129
-
16130
- // total rows in a
16131
- const int nr = nea1*nea2*nea3;
16132
-
16133
- // rows per thread
16134
- const int dr = (nr + nth - 1)/nth;
16135
-
16136
- // row range for this thread
16137
- const int ir0 = dr*ith;
16138
- const int ir1 = MIN(ir0 + dr, nr);
16139
-
16140
- for (int ir = ir0; ir < ir1; ++ir) {
16141
- // a indices
16142
- const int ia3 = ir/(nea2*nea1);
16143
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
16144
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
16145
-
16146
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
16147
-
16148
- for (int64_t ic = 0; ic < neb01; ++ic) {
16149
- // b0 indices
16150
- const int ib03 = ia3;
16151
- const int ib02 = ia2;
16152
- const int ib01 = ic;
16153
-
16154
- // S indices
16155
- const int i1 = ib01;
16156
-
16157
- ggml_vec_dot_f16(nea0,
16158
- S + i1, 0,
16159
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
16160
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
16161
- }
16162
-
16163
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
16164
- //ggml_vec_gelu_f32(neb01, S, S);
16165
-
16166
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
16167
-
16168
- for (int64_t i = 0; i < M; i++) {
16169
- S16[i] = GGML_FP32_TO_FP16(S[i]);
16170
- }
16171
-
16172
- ggml_vec_gelu_f16(neb01, S16, S16);
16173
-
16174
- {
16175
- // dst indices
16176
- const int i1 = ia1;
16177
- const int i2 = ia2;
16178
- const int i3 = ia3;
16179
-
16180
- for (int64_t ic = 0; ic < nec01; ++ic) {
16181
-
16182
- ggml_vec_dot_f16(neb01,
16183
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16184
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16185
- S16, 0, 1);
16186
- }
16187
-
16188
- ggml_vec_add_f32(nec01,
16189
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16190
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16191
- (float *) c1->data);
16192
- }
16193
- }
16194
- }
16195
-
16196
- static void ggml_compute_forward_flash_ff(
16197
- const struct ggml_compute_params * params,
16198
- struct ggml_tensor * dst) {
16199
-
16200
- const struct ggml_tensor * b0 = dst->src[1];
16201
-
16202
- switch (b0->type) {
16203
- case GGML_TYPE_F16:
16204
- {
16205
- ggml_compute_forward_flash_ff_f16(params, dst);
16206
- } break;
16207
- case GGML_TYPE_F32:
16208
- {
16209
- GGML_ASSERT(false); // TODO
16210
- } break;
16211
- default:
16212
- {
16213
- GGML_ASSERT(false);
16214
- } break;
16215
- }
16216
- }
16217
-
16218
15905
  // ggml_compute_forward_flash_attn_back
16219
15906
 
16220
15907
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -17785,21 +17472,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17785
17472
  {
17786
17473
  ggml_compute_forward_leaky_relu(params, tensor);
17787
17474
  } break;
17788
- case GGML_OP_FLASH_ATTN:
17789
- {
17790
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17791
- GGML_ASSERT(t == 0 || t == 1);
17792
- const bool masked = t != 0;
17793
- ggml_compute_forward_flash_attn(params, masked, tensor);
17794
- } break;
17795
17475
  case GGML_OP_FLASH_ATTN_EXT:
17796
17476
  {
17797
17477
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17798
17478
  } break;
17799
- case GGML_OP_FLASH_FF:
17800
- {
17801
- ggml_compute_forward_flash_ff(params, tensor);
17802
- } break;
17803
17479
  case GGML_OP_FLASH_ATTN_BACK:
17804
17480
  {
17805
17481
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -18169,6 +17845,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18169
17845
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18170
17846
  struct ggml_tensor * src0 = tensor->src[0];
18171
17847
  struct ggml_tensor * src1 = tensor->src[1];
17848
+ struct ggml_tensor * src2 = tensor->src[2];
18172
17849
 
18173
17850
  switch (tensor->op) {
18174
17851
  case GGML_OP_DUP:
@@ -18700,6 +18377,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18700
18377
  ggml_rope_back(ctx,
18701
18378
  tensor->grad,
18702
18379
  src1,
18380
+ src2,
18703
18381
  n_dims,
18704
18382
  mode,
18705
18383
  n_ctx,
@@ -18739,6 +18417,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18739
18417
  ggml_rope_impl(ctx,
18740
18418
  tensor->grad,
18741
18419
  src1,
18420
+ src2,
18742
18421
  n_dims,
18743
18422
  mode,
18744
18423
  n_ctx,
@@ -18803,7 +18482,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18803
18482
  {
18804
18483
  GGML_ASSERT(false); // TODO: not implemented
18805
18484
  } break;
18806
- case GGML_OP_FLASH_ATTN:
18807
18485
  case GGML_OP_FLASH_ATTN_EXT:
18808
18486
  {
18809
18487
  struct ggml_tensor * flash_grad = NULL;
@@ -18820,7 +18498,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18820
18498
  masked);
18821
18499
  }
18822
18500
 
18823
- struct ggml_tensor * src2 = tensor->src[2];
18824
18501
  const int64_t elem_q = ggml_nelements(src0);
18825
18502
  const int64_t elem_k = ggml_nelements(src1);
18826
18503
  const int64_t elem_v = ggml_nelements(src2);
@@ -18858,10 +18535,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18858
18535
  zero_table);
18859
18536
  }
18860
18537
  } break;
18861
- case GGML_OP_FLASH_FF:
18862
- {
18863
- GGML_ASSERT(false); // not supported
18864
- } break;
18865
18538
  case GGML_OP_FLASH_ATTN_BACK:
18866
18539
  {
18867
18540
  GGML_ASSERT(false); // not supported
@@ -19548,15 +19221,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19548
19221
  {
19549
19222
  n_tasks = n_threads;
19550
19223
  } break;
19551
- case GGML_OP_FLASH_ATTN:
19552
19224
  case GGML_OP_FLASH_ATTN_EXT:
19553
19225
  {
19554
19226
  n_tasks = n_threads;
19555
19227
  } break;
19556
- case GGML_OP_FLASH_FF:
19557
- {
19558
- n_tasks = n_threads;
19559
- } break;
19560
19228
  case GGML_OP_FLASH_ATTN_BACK:
19561
19229
  {
19562
19230
  n_tasks = n_threads;
@@ -19953,39 +19621,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19953
19621
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19954
19622
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19955
19623
  } break;
19956
- case GGML_OP_FLASH_ATTN:
19957
- {
19958
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19959
-
19960
- if (node->src[1]->type == GGML_TYPE_F32) {
19961
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19962
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19963
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19964
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19965
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19966
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19967
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19968
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19969
- }
19970
- } break;
19971
19624
  case GGML_OP_FLASH_ATTN_EXT:
19972
19625
  {
19973
19626
  const int64_t ne00 = node->src[0]->ne[0]; // D
19974
19627
 
19975
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19976
- } break;
19977
- case GGML_OP_FLASH_FF:
19978
- {
19979
- if (node->src[1]->type == GGML_TYPE_F32) {
19980
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19981
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19982
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19983
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19984
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19985
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19986
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19987
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19988
- }
19628
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19989
19629
  } break;
19990
19630
  case GGML_OP_FLASH_ATTN_BACK:
19991
19631
  {
@@ -21827,11 +21467,7 @@ size_t ggml_quantize_chunk(
21827
21467
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21828
21468
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21829
21469
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21830
- #if QK_K == 64
21831
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21832
- #else
21833
21470
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21834
- #endif
21835
21471
  case GGML_TYPE_F16:
21836
21472
  {
21837
21473
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -23108,6 +22744,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23108
22744
  #endif
23109
22745
  }
23110
22746
 
22747
+ int ggml_cpu_has_avx512_bf16(void) {
22748
+ #if defined(__AVX512BF16__)
22749
+ return 1;
22750
+ #else
22751
+ return 0;
22752
+ #endif
22753
+ }
22754
+
23111
22755
  int ggml_cpu_has_fma(void) {
23112
22756
  #if defined(__FMA__)
23113
22757
  return 1;
@@ -23124,6 +22768,16 @@ int ggml_cpu_has_neon(void) {
23124
22768
  #endif
23125
22769
  }
23126
22770
 
22771
+ int ggml_cpu_has_sve(void) {
22772
+ #if defined(__ARM_FEATURE_SVE)
22773
+ // TODO: Currently, SVE 256 bit is only supported.
22774
+ GGML_ASSERT(svcntb() == QK8_0);
22775
+ return 1;
22776
+ #else
22777
+ return 0;
22778
+ #endif
22779
+ }
22780
+
23127
22781
  int ggml_cpu_has_arm_fma(void) {
23128
22782
  #if defined(__ARM_FEATURE_FMA)
23129
22783
  return 1;
@@ -23212,6 +22866,14 @@ int ggml_cpu_has_sycl(void) {
23212
22866
  #endif
23213
22867
  }
23214
22868
 
22869
+ int ggml_cpu_has_rpc(void) {
22870
+ #if defined(GGML_USE_RPC)
22871
+ return 1;
22872
+ #else
22873
+ return 0;
22874
+ #endif
22875
+ }
22876
+
23215
22877
  int ggml_cpu_has_gpublas(void) {
23216
22878
  return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
23217
22879
  ggml_cpu_has_sycl();