llama_cpp 0.15.1 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,6 @@
4
4
  #include "ggml-impl.h"
5
5
  #include "ggml-quants.h"
6
6
  #include "ggml.h"
7
- #include "sgemm.h"
8
7
 
9
8
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
9
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -37,6 +36,10 @@
37
36
  #undef GGML_USE_LLAMAFILE
38
37
  #endif
39
38
 
39
+ #ifdef GGML_USE_LLAMAFILE
40
+ #include "sgemm.h"
41
+ #endif
42
+
40
43
  #if defined(_MSC_VER)
41
44
  // disable "possible loss of data" to avoid hundreds of casts
42
45
  // we should just be careful :)
@@ -109,6 +112,8 @@ typedef void * thread_ret_t;
109
112
 
110
113
  #endif
111
114
 
115
+ typedef pthread_t ggml_thread_t;
116
+
112
117
  #ifdef GGML_USE_CPU_HBM
113
118
  #include <hbwmalloc.h>
114
119
  #endif
@@ -160,9 +165,6 @@ void ggml_print_backtrace(void) {
160
165
  #define GGML_DEBUG 0
161
166
  #define GGML_GELU_FP16
162
167
  #define GGML_GELU_QUICK_FP16
163
- #define GGML_SILU_FP16
164
- // #define GGML_CROSS_ENTROPY_EXP_FP16
165
- // #define GGML_FLASH_ATTN_EXP_FP16
166
168
 
167
169
  #define GGML_SOFT_MAX_UNROLL 4
168
170
  #define GGML_VEC_DOT_UNROLL 2
@@ -313,12 +315,6 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
313
315
  // precomputed quick gelu table for f16 (128 KB)
314
316
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
315
317
 
316
- // precomputed silu table for f16 (128 KB)
317
- static ggml_fp16_t ggml_table_silu_f16[1 << 16];
318
-
319
- // precomputed exp table for f16 (128 KB)
320
- static ggml_fp16_t ggml_table_exp_f16[1 << 16];
321
-
322
318
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
319
  float ggml_table_f32_f16[1 << 16];
324
320
 
@@ -410,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
406
  int i = 0;
411
407
  #if defined(__AVX512BF16__)
412
408
  for (; i + 32 <= n; i += 32) {
413
- _mm512_storeu_ps(
414
- (__m512 *)(y + i),
415
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
- _mm512_loadu_ps(x + i)));
409
+ _mm512_storeu_si512(
410
+ (__m512i *)(y + i),
411
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i))));
417
413
  }
418
414
  #endif
419
415
  for (; i < n; i++) {
@@ -875,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
875
871
  },
876
872
  [GGML_TYPE_IQ4_XS] = {
877
873
  .type_name = "iq4_xs",
878
- #if QK_K == 64
879
- .blck_size = QK4_NL,
880
- #else
881
874
  .blck_size = QK_K,
882
- #endif
883
875
  .type_size = sizeof(block_iq4_xs),
884
876
  .is_quantized = true,
885
877
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
886
878
  .from_float = quantize_row_iq4_xs,
887
879
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
888
880
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
889
- #if QK_K == 64
890
- .vec_dot_type = GGML_TYPE_Q8_0,
891
- #else
892
881
  .vec_dot_type = GGML_TYPE_Q8_K,
893
- #endif
894
882
  .nrows = 1,
895
883
  },
896
884
  [GGML_TYPE_Q8_K] = {
@@ -1303,6 +1291,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1303
1291
  #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
1304
1292
  #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
1305
1293
  #define GGML_F16_VEC_FMA GGML_F32x4_FMA
1294
+ #define GGML_F16_VEC_ADD GGML_F32x4_ADD
1295
+ #define GGML_F16_VEC_MUL GGML_F32x4_MUL
1306
1296
  #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
1307
1297
  // Use vec_xl, not vec_ld, in case the load address is not aligned.
1308
1298
  #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
@@ -1525,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1525
1515
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1526
1516
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1527
1517
 
1518
+ #elif defined(__loongarch_asx)
1519
+
1520
+ #define GGML_SIMD
1521
+
1522
+ // F32 LASX
1523
+ #define GGML_F32_STEP 32
1524
+ #define GGML_F32_EPR 8
1525
+
1526
+ #define GGML_F32x8 __m256
1527
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1528
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1529
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1530
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1531
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1532
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1533
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1534
+ #define GGML_F32x8_REDUCE(res, x) \
1535
+ do { \
1536
+ int offset = GGML_F32_ARR >> 1; \
1537
+ for (int i = 0; i < offset; ++i) { \
1538
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1539
+ } \
1540
+ offset >>= 1; \
1541
+ for (int i = 0; i < offset; ++i) { \
1542
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1543
+ } \
1544
+ offset >>= 1; \
1545
+ for (int i = 0; i < offset; ++i) { \
1546
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1547
+ } \
1548
+ float *tmp_p = (float *)&x[0]; \
1549
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1550
+ } while (0)
1551
+ // TODO: is this optimal ?
1552
+
1553
+ #define GGML_F32_VEC GGML_F32x8
1554
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1555
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1556
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1557
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1558
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1559
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1560
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1561
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1562
+
1563
+ // F16 LASX
1564
+
1565
+ #define GGML_F16_STEP 32
1566
+ #define GGML_F16_EPR 8
1567
+
1568
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
+
1570
+ #define GGML_F32Cx8 __m256
1571
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
+
1574
+ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1575
+ float tmp[8];
1576
+
1577
+ for (int i = 0; i < 8; i++) {
1578
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1579
+ }
1580
+
1581
+ return (__m256)__lasx_xvld(tmp, 0);
1582
+ }
1583
+ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1584
+ float arr[8];
1585
+
1586
+ __lasx_xvst(y, arr, 0);
1587
+
1588
+ for (int i = 0; i < 8; i++)
1589
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1590
+ }
1591
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1593
+
1594
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1595
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1596
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1597
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1598
+
1599
+ #define GGML_F16_VEC GGML_F32Cx8
1600
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1601
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1602
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1603
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1604
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1605
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1606
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1607
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1608
+
1609
+ #elif defined(__loongarch_sx)
1610
+
1611
+ #define GGML_SIMD
1612
+
1613
+ // F32 LSX
1614
+
1615
+ #define GGML_F32_STEP 32
1616
+ #define GGML_F32_EPR 4
1617
+
1618
+ #define GGML_F32x4 __m128
1619
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1620
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1621
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1622
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1623
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1624
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1625
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1626
+ #define GGML_F32x4_REDUCE(res, x) \
1627
+ { \
1628
+ int offset = GGML_F32_ARR >> 1; \
1629
+ for (int i = 0; i < offset; ++i) { \
1630
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1631
+ } \
1632
+ offset >>= 1; \
1633
+ for (int i = 0; i < offset; ++i) { \
1634
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1635
+ } \
1636
+ offset >>= 1; \
1637
+ for (int i = 0; i < offset; ++i) { \
1638
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1639
+ } \
1640
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1641
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1642
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1643
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1644
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1645
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1646
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1647
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1648
+ }
1649
+
1650
+ #define GGML_F32_VEC GGML_F32x4
1651
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1652
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1653
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1654
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1655
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1656
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1657
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1658
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1659
+
1660
+ // F16 LSX
1661
+
1662
+ #define GGML_F16_STEP 32
1663
+ #define GGML_F16_EPR 4
1664
+
1665
+ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1666
+ float tmp[4];
1667
+
1668
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1669
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1670
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1671
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1672
+
1673
+ return __lsx_vld(tmp, 0);
1674
+ }
1675
+
1676
+ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1677
+ float arr[4];
1678
+
1679
+ __lsx_vst(y, arr, 0);
1680
+
1681
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1682
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1683
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1684
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1685
+ }
1686
+
1687
+ #define GGML_F32Cx4 __m128
1688
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1689
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1690
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1691
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1692
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1693
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1694
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1695
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1696
+
1697
+ #define GGML_F16_VEC GGML_F32Cx4
1698
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1699
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1700
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1701
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1702
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1703
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1704
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1705
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1706
+
1528
1707
  #endif
1529
1708
 
1530
1709
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1534,6 +1713,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1534
1713
  #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1535
1714
  #endif
1536
1715
 
1716
+ //
1717
+ // ggml context
1718
+ //
1719
+
1720
+ struct ggml_context {
1721
+ size_t mem_size;
1722
+ void* mem_buffer;
1723
+ bool mem_buffer_owned;
1724
+ bool no_alloc;
1725
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
1726
+
1727
+ int n_objects;
1728
+
1729
+ struct ggml_object* objects_begin;
1730
+ struct ggml_object* objects_end;
1731
+
1732
+ struct ggml_scratch scratch;
1733
+ struct ggml_scratch scratch_save;
1734
+ };
1735
+
1736
+ struct ggml_context_container {
1737
+ bool used;
1738
+
1739
+ struct ggml_context context;
1740
+ };
1741
+
1742
+ struct ggml_compute_state_shared {
1743
+ const struct ggml_cgraph* cgraph;
1744
+ const struct ggml_cplan* cplan;
1745
+
1746
+ int64_t perf_node_start_cycles;
1747
+ int64_t perf_node_start_time_us;
1748
+
1749
+ const int n_threads;
1750
+
1751
+ // synchronization primitives
1752
+ atomic_int n_active; // num active threads
1753
+ atomic_int node_n; // active graph node
1754
+ atomic_int node_task; // active graph node task phase
1755
+
1756
+ ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
1757
+ void* abort_callback_data;
1758
+
1759
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1760
+ };
1761
+
1762
+ struct ggml_compute_state {
1763
+ ggml_thread_t thrd;
1764
+ int ith;
1765
+ struct ggml_compute_state_shared* shared;
1766
+ enum ggml_status ec;
1767
+ };
1768
+
1537
1769
  //
1538
1770
  // fundamental operations
1539
1771
  //
@@ -1615,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1615
1847
  __m512 c1 = _mm512_setzero_ps();
1616
1848
  __m512 c2 = _mm512_setzero_ps();
1617
1849
  for (; i + 64 <= n; i += 64) {
1618
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1850
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1851
+ m512bh(_mm512_loadu_si512((y + i))));
1852
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1853
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1622
1854
  }
1623
1855
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
1856
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -1949,6 +2181,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
1949
2181
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1950
2182
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1951
2183
  inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
2184
+ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
1952
2185
  // TODO: optimize performance
1953
2186
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
1954
2187
  inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -2024,52 +2257,291 @@ inline static float ggml_silu_f32(float x) {
2024
2257
  return x/(1.0f + expf(-x));
2025
2258
  }
2026
2259
 
2027
- //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
2028
- // const uint16_t * i16 = (const uint16_t *) x;
2029
- // for (int i = 0; i < n; ++i) {
2030
- // y[i] = ggml_table_silu_f16[i16[i]];
2031
- // }
2032
- //}
2260
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2261
+
2262
+ // adapted from arm limited optimized routine
2263
+ // the maximum error is 1.45358 plus 0.5 ulps
2264
+ // numbers above 88.38 will flush to infinity
2265
+ // numbers beneath -103.97 will flush to zero
2266
+ inline static float32x4_t ggml_v_expf(float32x4_t x) {
2267
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
2268
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
2269
+ const float32x4_t n = vsubq_f32(z, r);
2270
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
2271
+ vdupq_n_f32(0x1.7f7d1cp-20f));
2272
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
2273
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
2274
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
2275
+ const float32x4_t u = vmulq_f32(b, b);
2276
+ const float32x4_t j = vfmaq_f32(
2277
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
2278
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
2279
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
2280
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
2281
+ return vfmaq_f32(k, j, k);
2282
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
2283
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
2284
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
2285
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
2286
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
2287
+ }
2288
+
2289
+ // computes silu x/(1+exp(-x)) in single precision vector
2290
+ inline static float32x4_t ggml_v_silu(float32x4_t x) {
2291
+ const float32x4_t one = vdupq_n_f32(1.0f);
2292
+ const float32x4_t zero = vdupq_n_f32(0.0f);
2293
+ const float32x4_t neg_x = vsubq_f32(zero, x);
2294
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
2295
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
2296
+ return vdivq_f32(x, one_plus_exp_neg_x);
2297
+ }
2298
+
2299
+ #elif defined(__AVX512F__) && defined(__AVX512DQ__)
2300
+
2301
+ // adapted from arm limited optimized routine
2302
+ // the maximum error is 1.45358 plus 0.5 ulps
2303
+ // numbers above 88.38 will flush to infinity
2304
+ // numbers beneath -103.97 will flush to zero
2305
+ inline static __m512 ggml_v_expf(__m512 x) {
2306
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
2307
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2308
+ const __m512 n = _mm512_sub_ps(z, r);
2309
+ const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2310
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2311
+ const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2312
+ const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2313
+ const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2314
+ const __m512 u = _mm512_mul_ps(b, b);
2315
+ const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2316
+ _mm512_set1_ps(0x1.573e2ep-5f)), u,
2317
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2318
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2319
+ u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2320
+ if (_mm512_kortestz(c, c))
2321
+ return _mm512_fmadd_ps(j, k, k);
2322
+ const __m512i g = _mm512_and_si512(
2323
+ _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2324
+ _mm512_set1_epi32(0x82000000u));
2325
+ const __m512 s1 =
2326
+ _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2327
+ const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2328
+ const __mmask16 d =
2329
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2330
+ return _mm512_mask_blend_ps(
2331
+ d, _mm512_mask_blend_ps(
2332
+ c, _mm512_fmadd_ps(k, j, k),
2333
+ _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2334
+ _mm512_mul_ps(s1, s1));
2335
+ }
2336
+
2337
+ // computes silu x/(1+exp(-x)) in single precision vector
2338
+ inline static __m512 ggml_v_silu(__m512 x) {
2339
+ const __m512 one = _mm512_set1_ps(1);
2340
+ const __m512 zero = _mm512_setzero_ps();
2341
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
2342
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
2343
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
2344
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
2345
+ }
2346
+
2347
+ #elif defined(__AVX2__) && defined(__FMA__)
2348
+
2349
+ // adapted from arm limited optimized routine
2350
+ // the maximum error is 1.45358 plus 0.5 ulps
2351
+ // numbers above 88.38 will flush to infinity
2352
+ // numbers beneath -103.97 will flush to zero
2353
+ inline static __m256 ggml_v_expf(__m256 x) {
2354
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
2355
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
2356
+ const __m256 n = _mm256_sub_ps(z, r);
2357
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
2358
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
2359
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
2360
+ const __m256 k = _mm256_castsi256_ps(
2361
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
2362
+ const __m256i c = _mm256_castps_si256(
2363
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2364
+ _mm256_set1_ps(126), _CMP_GT_OQ));
2365
+ const __m256 u = _mm256_mul_ps(b, b);
2366
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
2367
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
2368
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
2369
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
2370
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
2371
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
2372
+ return _mm256_fmadd_ps(j, k, k);
2373
+ const __m256i g = _mm256_and_si256(
2374
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
2375
+ _mm256_set1_epi32(0x82000000u));
2376
+ const __m256 s1 =
2377
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
2378
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
2379
+ const __m256i d = _mm256_castps_si256(
2380
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2381
+ _mm256_set1_ps(192), _CMP_GT_OQ));
2382
+ return _mm256_or_ps(
2383
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
2384
+ _mm256_andnot_ps(
2385
+ _mm256_castsi256_ps(d),
2386
+ _mm256_or_ps(
2387
+ _mm256_and_ps(_mm256_castsi256_ps(c),
2388
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
2389
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
2390
+ }
2391
+
2392
+ // computes silu x/(1+exp(-x)) in single precision vector
2393
+ inline static __m256 ggml_v_silu(__m256 x) {
2394
+ const __m256 one = _mm256_set1_ps(1);
2395
+ const __m256 zero = _mm256_setzero_ps();
2396
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
2397
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
2398
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
2399
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
2400
+ }
2401
+
2402
+ #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
2033
2403
 
2034
- #ifdef GGML_SILU_FP16
2035
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2036
- uint16_t t;
2037
- for (int i = 0; i < n; ++i) {
2038
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2039
- memcpy(&t, &fp16, sizeof(uint16_t));
2040
- y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
2041
- }
2042
- }
2404
+ #if defined(__FMA__)
2405
+ #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
2406
+ #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
2043
2407
  #else
2044
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2045
- for (int i = 0; i < n; ++i) {
2408
+ #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
2409
+ #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
2410
+ #endif
2411
+
2412
+ // adapted from arm limited optimized routine
2413
+ // the maximum error is 1.45358 plus 0.5 ulps
2414
+ // numbers above 88.38 will flush to infinity
2415
+ // numbers beneath -103.97 will flush to zero
2416
+ inline static __m128 ggml_v_expf(__m128 x) {
2417
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
2418
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
2419
+ const __m128 n = _mm_sub_ps(z, r);
2420
+ const __m128 b =
2421
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
2422
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
2423
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
2424
+ const __m128i c =
2425
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
2426
+ const __m128 u = _mm_mul_ps(b, b);
2427
+ const __m128 j =
2428
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
2429
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
2430
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
2431
+ if (!_mm_movemask_epi8(c))
2432
+ return MADD128(j, k, k);
2433
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
2434
+ _mm_set1_epi32(0x82000000u));
2435
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
2436
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
2437
+ const __m128i d =
2438
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
2439
+ return _mm_or_ps(
2440
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
2441
+ _mm_andnot_ps(_mm_castsi128_ps(d),
2442
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
2443
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
2444
+ }
2445
+
2446
+ // computes silu x/(1+exp(-x)) in single precision vector
2447
+ inline static __m128 ggml_v_silu(__m128 x) {
2448
+ const __m128 one = _mm_set1_ps(1);
2449
+ const __m128 zero = _mm_setzero_ps();
2450
+ const __m128 neg_x = _mm_sub_ps(zero, x);
2451
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
2452
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
2453
+ return _mm_div_ps(x, one_plus_exp_neg_x);
2454
+ }
2455
+
2456
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__
2457
+
2458
+ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2459
+ int i = 0;
2460
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2461
+ for (; i + 15 < n; i += 16) {
2462
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
2463
+ }
2464
+ #elif defined(__AVX2__) && defined(__FMA__)
2465
+ for (; i + 7 < n; i += 8) {
2466
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
2467
+ }
2468
+ #elif defined(__SSE2__)
2469
+ for (; i + 3 < n; i += 4) {
2470
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2471
+ }
2472
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2473
+ for (; i + 3 < n; i += 4) {
2474
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2475
+ }
2476
+ #endif
2477
+ for (; i < n; ++i) {
2046
2478
  y[i] = ggml_silu_f32(x[i]);
2047
2479
  }
2048
2480
  }
2481
+
2482
+ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
2483
+ int i = 0;
2484
+ ggml_float sum = 0;
2485
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2486
+ for (; i + 15 < n; i += 16) {
2487
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
2488
+ _mm512_set1_ps(max)));
2489
+ _mm512_storeu_ps(y + i, val);
2490
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
2491
+ }
2492
+ #elif defined(__AVX2__) && defined(__FMA__)
2493
+ for (; i + 7 < n; i += 8) {
2494
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
2495
+ _mm256_set1_ps(max)));
2496
+ _mm256_storeu_ps(y + i, val);
2497
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
2498
+ _mm256_castps256_ps128(val));
2499
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
2500
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
2501
+ sum += (ggml_float)_mm_cvtss_f32(val2);
2502
+ }
2503
+ #elif defined(__SSE2__)
2504
+ for (; i + 3 < n; i += 4) {
2505
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
2506
+ _mm_set1_ps(max)));
2507
+ _mm_storeu_ps(y + i, val);
2508
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
2509
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
2510
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
2511
+ #else
2512
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
2513
+ val = _mm_add_ps(val, tmp);
2514
+ tmp = _mm_movehl_ps(tmp, val);
2515
+ val = _mm_add_ss(val, tmp);
2516
+ #endif
2517
+ sum += (ggml_float)_mm_cvtss_f32(val);
2518
+ }
2519
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2520
+ for (; i + 3 < n; i += 4) {
2521
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2522
+ vdupq_n_f32(max)));
2523
+ vst1q_f32(y + i, val);
2524
+ sum += (ggml_float)vaddvq_f32(val);
2525
+ }
2049
2526
  #endif
2527
+ for (; i < n; ++i) {
2528
+ float val = expf(x[i] - max);
2529
+ sum += (ggml_float)val;
2530
+ y[i] = val;
2531
+ }
2532
+ return sum;
2533
+ }
2050
2534
 
2051
2535
  inline static float ggml_silu_backward_f32(float x, float dy) {
2052
2536
  const float s = 1.0f/(1.0f + expf(-x));
2053
2537
  return dy*s*(1.0f + x*(1.0f - s));
2054
2538
  }
2055
2539
 
2056
- #ifdef GGML_SILU_FP16
2057
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2058
- for (int i = 0; i < n; ++i) {
2059
- // we did not use x[i] to compute forward silu but its f16 equivalent
2060
- // take derivative at f16 of x[i]:
2061
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2062
- float usedx = GGML_FP16_TO_FP32(fp16);
2063
- dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
2064
- }
2065
- }
2066
- #else
2067
2540
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2068
2541
  for (int i = 0; i < n; ++i) {
2069
2542
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
2070
2543
  }
2071
2544
  }
2072
- #endif
2073
2545
 
2074
2546
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
2075
2547
  #ifndef GGML_USE_ACCELERATE
@@ -2185,7 +2657,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2185
2657
  "SOFT_MAX_BACK",
2186
2658
  "ROPE",
2187
2659
  "ROPE_BACK",
2188
- "ALIBI",
2189
2660
  "CLAMP",
2190
2661
  "CONV_TRANSPOSE_1D",
2191
2662
  "IM2COL",
@@ -2199,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2199
2670
  "ARGSORT",
2200
2671
  "LEAKY_RELU",
2201
2672
 
2202
- "FLASH_ATTN",
2203
2673
  "FLASH_ATTN_EXT",
2204
- "FLASH_FF",
2205
2674
  "FLASH_ATTN_BACK",
2206
2675
  "SSM_CONV",
2207
2676
  "SSM_SCAN",
@@ -2227,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2227
2696
  "CROSS_ENTROPY_LOSS_BACK",
2228
2697
  };
2229
2698
 
2230
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2699
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2231
2700
 
2232
2701
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2233
2702
  "none",
@@ -2276,7 +2745,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2276
2745
  "soft_max_back(x)",
2277
2746
  "rope(x)",
2278
2747
  "rope_back(x)",
2279
- "alibi(x)",
2280
2748
  "clamp(x)",
2281
2749
  "conv_transpose_1d(x)",
2282
2750
  "im2col(x)",
@@ -2290,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2290
2758
  "argsort(x)",
2291
2759
  "leaky_relu(x)",
2292
2760
 
2293
- "flash_attn(x)",
2294
2761
  "flash_attn_ext(x)",
2295
- "flash_ff(x)",
2296
2762
  "flash_attn_back(x)",
2297
2763
  "ssm_conv(x)",
2298
2764
  "ssm_scan(x)",
@@ -2318,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2318
2784
  "cross_entropy_loss_back(x,y)",
2319
2785
  };
2320
2786
 
2321
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2787
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2322
2788
 
2323
2789
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2324
2790
 
@@ -2331,6 +2797,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2331
2797
  "TANH",
2332
2798
  "ELU",
2333
2799
  "RELU",
2800
+ "SIGMOID",
2334
2801
  "GELU",
2335
2802
  "GELU_QUICK",
2336
2803
  "SILU",
@@ -2338,7 +2805,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2338
2805
  "HARDSIGMOID",
2339
2806
  };
2340
2807
 
2341
- static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
2808
+ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
2342
2809
 
2343
2810
 
2344
2811
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2380,32 +2847,6 @@ static void ggml_setup_op_has_task_pass(void) {
2380
2847
  }
2381
2848
  }
2382
2849
 
2383
- //
2384
- // ggml context
2385
- //
2386
-
2387
- struct ggml_context {
2388
- size_t mem_size;
2389
- void * mem_buffer;
2390
- bool mem_buffer_owned;
2391
- bool no_alloc;
2392
- bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
2393
-
2394
- int n_objects;
2395
-
2396
- struct ggml_object * objects_begin;
2397
- struct ggml_object * objects_end;
2398
-
2399
- struct ggml_scratch scratch;
2400
- struct ggml_scratch scratch_save;
2401
- };
2402
-
2403
- struct ggml_context_container {
2404
- bool used;
2405
-
2406
- struct ggml_context context;
2407
- };
2408
-
2409
2850
  //
2410
2851
  // NUMA support
2411
2852
  //
@@ -2819,8 +3260,18 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
2819
3260
  (t0->ne[3] == t1->ne[3] );
2820
3261
  }
2821
3262
 
2822
- // check if t1 can be represented as a repeatition of t0
2823
- static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3263
+ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3264
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3265
+
3266
+ return
3267
+ (t0->nb[0] == t1->nb[0] ) &&
3268
+ (t0->nb[1] == t1->nb[1] ) &&
3269
+ (t0->nb[2] == t1->nb[2] ) &&
3270
+ (t0->nb[3] == t1->nb[3] );
3271
+ }
3272
+
3273
+ // check if t1 can be represented as a repeatition of t0
3274
+ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2824
3275
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
2825
3276
 
2826
3277
  return ggml_is_empty(t0) ? ggml_is_empty(t1) :
@@ -2878,8 +3329,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2878
3329
  float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2879
3330
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2880
3331
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2881
- ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2882
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2883
3332
  }
2884
3333
 
2885
3334
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3163,6 +3612,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3163
3612
 
3164
3613
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
3165
3614
 
3615
+ #ifdef __clang__
3616
+ // temporary until ggml_tensor::backend is removed
3617
+ #pragma clang diagnostic push
3618
+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
3619
+ #endif
3620
+
3166
3621
  *result = (struct ggml_tensor) {
3167
3622
  /*.type =*/ type,
3168
3623
  /*.backend =*/ GGML_BACKEND_TYPE_CPU,
@@ -3185,6 +3640,10 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3185
3640
  /*.padding =*/ { 0 },
3186
3641
  };
3187
3642
 
3643
+ #ifdef __clang__
3644
+ #pragma clang diagnostic pop
3645
+ #endif
3646
+
3188
3647
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3189
3648
  //ggml_assert_aligned(result->data);
3190
3649
 
@@ -4563,6 +5022,20 @@ struct ggml_tensor * ggml_leaky_relu(
4563
5022
  return result;
4564
5023
  }
4565
5024
 
5025
+ // ggml_sigmoid
5026
+
5027
+ struct ggml_tensor * ggml_sigmoid(
5028
+ struct ggml_context * ctx,
5029
+ struct ggml_tensor * a) {
5030
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
5031
+ }
5032
+
5033
+ struct ggml_tensor * ggml_sigmoid_inplace(
5034
+ struct ggml_context * ctx,
5035
+ struct ggml_tensor * a) {
5036
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
5037
+ }
5038
+
4566
5039
  // ggml_gelu
4567
5040
 
4568
5041
  struct ggml_tensor * ggml_gelu(
@@ -5646,7 +6119,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5646
6119
  struct ggml_context * ctx,
5647
6120
  struct ggml_tensor * a,
5648
6121
  struct ggml_tensor * mask,
5649
- struct ggml_tensor * pos,
5650
6122
  float scale,
5651
6123
  float max_bias,
5652
6124
  bool inplace) {
@@ -5660,18 +6132,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
5660
6132
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5661
6133
  }
5662
6134
 
5663
- if (pos) {
5664
- GGML_ASSERT(ggml_is_vector(pos));
5665
- GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5666
- GGML_ASSERT(pos->ne[0] == a->ne[0]);
5667
- }
5668
-
5669
- if (pos && mask) {
5670
- GGML_ASSERT(pos->type == mask->type);
5671
- }
5672
-
5673
6135
  if (max_bias > 0.0f) {
5674
- GGML_ASSERT(pos);
6136
+ GGML_ASSERT(mask);
5675
6137
  }
5676
6138
 
5677
6139
  bool is_node = false;
@@ -5689,7 +6151,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5689
6151
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5690
6152
  result->src[0] = a;
5691
6153
  result->src[1] = mask;
5692
- result->src[2] = pos;
5693
6154
 
5694
6155
  return result;
5695
6156
  }
@@ -5697,23 +6158,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
5697
6158
  struct ggml_tensor * ggml_soft_max(
5698
6159
  struct ggml_context * ctx,
5699
6160
  struct ggml_tensor * a) {
5700
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
6161
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
5701
6162
  }
5702
6163
 
5703
6164
  struct ggml_tensor * ggml_soft_max_inplace(
5704
6165
  struct ggml_context * ctx,
5705
6166
  struct ggml_tensor * a) {
5706
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
6167
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
5707
6168
  }
5708
6169
 
5709
6170
  struct ggml_tensor * ggml_soft_max_ext(
5710
6171
  struct ggml_context * ctx,
5711
6172
  struct ggml_tensor * a,
5712
6173
  struct ggml_tensor * mask,
5713
- struct ggml_tensor * pos,
5714
6174
  float scale,
5715
6175
  float max_bias) {
5716
- return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
6176
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
5717
6177
  }
5718
6178
 
5719
6179
  // ggml_soft_max_back
@@ -5759,6 +6219,7 @@ static struct ggml_tensor * ggml_rope_impl(
5759
6219
  struct ggml_context * ctx,
5760
6220
  struct ggml_tensor * a,
5761
6221
  struct ggml_tensor * b,
6222
+ struct ggml_tensor * c,
5762
6223
  int n_dims,
5763
6224
  int mode,
5764
6225
  int n_ctx,
@@ -5772,10 +6233,17 @@ static struct ggml_tensor * ggml_rope_impl(
5772
6233
  float xpos_base,
5773
6234
  bool xpos_down,
5774
6235
  bool inplace) {
6236
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6237
+
5775
6238
  GGML_ASSERT(ggml_is_vector(b));
5776
6239
  GGML_ASSERT(b->type == GGML_TYPE_I32);
5777
6240
  GGML_ASSERT(a->ne[2] == b->ne[0]);
5778
6241
 
6242
+ if (c) {
6243
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6244
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6245
+ }
6246
+
5779
6247
  bool is_node = false;
5780
6248
 
5781
6249
  if (a->grad) {
@@ -5799,6 +6267,7 @@ static struct ggml_tensor * ggml_rope_impl(
5799
6267
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5800
6268
  result->src[0] = a;
5801
6269
  result->src[1] = b;
6270
+ result->src[2] = c;
5802
6271
 
5803
6272
  return result;
5804
6273
  }
@@ -5811,7 +6280,7 @@ struct ggml_tensor * ggml_rope(
5811
6280
  int mode,
5812
6281
  int n_ctx) {
5813
6282
  return ggml_rope_impl(
5814
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6283
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
5815
6284
  );
5816
6285
  }
5817
6286
 
@@ -5823,14 +6292,15 @@ struct ggml_tensor * ggml_rope_inplace(
5823
6292
  int mode,
5824
6293
  int n_ctx) {
5825
6294
  return ggml_rope_impl(
5826
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6295
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
5827
6296
  );
5828
6297
  }
5829
6298
 
5830
- struct ggml_tensor * ggml_rope_custom(
6299
+ struct ggml_tensor * ggml_rope_ext(
5831
6300
  struct ggml_context * ctx,
5832
6301
  struct ggml_tensor * a,
5833
6302
  struct ggml_tensor * b,
6303
+ struct ggml_tensor * c,
5834
6304
  int n_dims,
5835
6305
  int mode,
5836
6306
  int n_ctx,
@@ -5842,15 +6312,16 @@ struct ggml_tensor * ggml_rope_custom(
5842
6312
  float beta_fast,
5843
6313
  float beta_slow) {
5844
6314
  return ggml_rope_impl(
5845
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6315
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
5846
6316
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
5847
6317
  );
5848
6318
  }
5849
6319
 
5850
- struct ggml_tensor * ggml_rope_custom_inplace(
6320
+ struct ggml_tensor * ggml_rope_ext_inplace(
5851
6321
  struct ggml_context * ctx,
5852
6322
  struct ggml_tensor * a,
5853
6323
  struct ggml_tensor * b,
6324
+ struct ggml_tensor * c,
5854
6325
  int n_dims,
5855
6326
  int mode,
5856
6327
  int n_ctx,
@@ -5862,19 +6333,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
5862
6333
  float beta_fast,
5863
6334
  float beta_slow) {
5864
6335
  return ggml_rope_impl(
5865
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6336
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
5866
6337
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
5867
6338
  );
5868
6339
  }
5869
6340
 
5870
- struct ggml_tensor * ggml_rope_xpos_inplace(
6341
+ struct ggml_tensor * ggml_rope_custom(
6342
+ struct ggml_context * ctx,
6343
+ struct ggml_tensor * a,
6344
+ struct ggml_tensor * b,
6345
+ int n_dims,
6346
+ int mode,
6347
+ int n_ctx,
6348
+ int n_orig_ctx,
6349
+ float freq_base,
6350
+ float freq_scale,
6351
+ float ext_factor,
6352
+ float attn_factor,
6353
+ float beta_fast,
6354
+ float beta_slow) {
6355
+ return ggml_rope_impl(
6356
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6357
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6358
+ );
6359
+ }
6360
+
6361
+ struct ggml_tensor * ggml_rope_custom_inplace(
5871
6362
  struct ggml_context * ctx,
5872
6363
  struct ggml_tensor * a,
5873
6364
  struct ggml_tensor * b,
5874
6365
  int n_dims,
5875
- float base,
5876
- bool down) {
5877
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6366
+ int mode,
6367
+ int n_ctx,
6368
+ int n_orig_ctx,
6369
+ float freq_base,
6370
+ float freq_scale,
6371
+ float ext_factor,
6372
+ float attn_factor,
6373
+ float beta_fast,
6374
+ float beta_slow) {
6375
+ return ggml_rope_impl(
6376
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6377
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6378
+ );
5878
6379
  }
5879
6380
 
5880
6381
  // ggml_rope_back
@@ -5883,6 +6384,7 @@ struct ggml_tensor * ggml_rope_back(
5883
6384
  struct ggml_context * ctx,
5884
6385
  struct ggml_tensor * a,
5885
6386
  struct ggml_tensor * b,
6387
+ struct ggml_tensor * c,
5886
6388
  int n_dims,
5887
6389
  int mode,
5888
6390
  int n_ctx,
@@ -5898,6 +6400,7 @@ struct ggml_tensor * ggml_rope_back(
5898
6400
  GGML_ASSERT(ggml_is_vector(b));
5899
6401
  GGML_ASSERT(b->type == GGML_TYPE_I32);
5900
6402
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6403
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
5901
6404
 
5902
6405
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
5903
6406
 
@@ -5928,37 +6431,6 @@ struct ggml_tensor * ggml_rope_back(
5928
6431
  return result;
5929
6432
  }
5930
6433
 
5931
- // ggml_alibi
5932
-
5933
- struct ggml_tensor * ggml_alibi(
5934
- struct ggml_context * ctx,
5935
- struct ggml_tensor * a,
5936
- int n_past,
5937
- int n_head,
5938
- float bias_max) {
5939
- GGML_ASSERT(n_past >= 0);
5940
- bool is_node = false;
5941
-
5942
- if (a->grad) {
5943
- GGML_ASSERT(false); // TODO: implement backward
5944
- is_node = true;
5945
- }
5946
-
5947
- // TODO: when implement backward, fix this:
5948
- //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5949
- struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5950
-
5951
- int32_t op_params[3] = { n_past, n_head };
5952
- memcpy(op_params + 2, &bias_max, sizeof(float));
5953
- ggml_set_op_params(result, op_params, sizeof(op_params));
5954
-
5955
- result->op = GGML_OP_ALIBI;
5956
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5957
- result->src[0] = a;
5958
-
5959
- return result;
5960
- }
5961
-
5962
6434
  // ggml_clamp
5963
6435
 
5964
6436
  struct ggml_tensor * ggml_clamp(
@@ -6308,7 +6780,10 @@ struct ggml_tensor * ggml_pool_2d(
6308
6780
  static struct ggml_tensor * ggml_upscale_impl(
6309
6781
  struct ggml_context * ctx,
6310
6782
  struct ggml_tensor * a,
6311
- int scale_factor) {
6783
+ int ne0,
6784
+ int ne1,
6785
+ int ne2,
6786
+ int ne3) {
6312
6787
  bool is_node = false;
6313
6788
 
6314
6789
  if (a->grad) {
@@ -6316,19 +6791,45 @@ static struct ggml_tensor * ggml_upscale_impl(
6316
6791
  is_node = true;
6317
6792
  }
6318
6793
 
6794
+ GGML_ASSERT(a->ne[0] <= ne0);
6795
+ GGML_ASSERT(a->ne[1] <= ne1);
6796
+ GGML_ASSERT(a->ne[2] <= ne2);
6797
+ GGML_ASSERT(a->ne[3] <= ne3);
6798
+
6319
6799
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
6320
- a->ne[0] * scale_factor,
6321
- a->ne[1] * scale_factor,
6322
- a->ne[2], a->ne[3]);
6800
+ ne0,
6801
+ ne1,
6802
+ ne2,
6803
+ ne3
6804
+ );
6323
6805
 
6324
6806
  result->op = GGML_OP_UPSCALE;
6325
- result->op_params[0] = scale_factor;
6807
+
6326
6808
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6327
6809
  result->src[0] = a;
6328
6810
 
6329
6811
  return result;
6330
6812
  }
6331
6813
 
6814
+ struct ggml_tensor * ggml_upscale(
6815
+ struct ggml_context * ctx,
6816
+ struct ggml_tensor * a,
6817
+ int scale_factor) {
6818
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
6819
+ }
6820
+
6821
+ struct ggml_tensor * ggml_upscale_ext(
6822
+ struct ggml_context * ctx,
6823
+ struct ggml_tensor * a,
6824
+ int ne0,
6825
+ int ne1,
6826
+ int ne2,
6827
+ int ne3) {
6828
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
6829
+ }
6830
+
6831
+ // ggml_pad
6832
+
6332
6833
  struct ggml_tensor * ggml_pad(
6333
6834
  struct ggml_context * ctx,
6334
6835
  struct ggml_tensor * a,
@@ -6353,12 +6854,7 @@ struct ggml_tensor * ggml_pad(
6353
6854
  return result;
6354
6855
  }
6355
6856
 
6356
- struct ggml_tensor * ggml_upscale(
6357
- struct ggml_context * ctx,
6358
- struct ggml_tensor * a,
6359
- int scale_factor) {
6360
- return ggml_upscale_impl(ctx, a, scale_factor);
6361
- }
6857
+ // ggml_arange
6362
6858
 
6363
6859
  struct ggml_tensor * ggml_arange(
6364
6860
  struct ggml_context * ctx,
@@ -6380,6 +6876,8 @@ struct ggml_tensor * ggml_arange(
6380
6876
  return result;
6381
6877
  }
6382
6878
 
6879
+ // ggml_timestep_embedding
6880
+
6383
6881
  struct ggml_tensor * ggml_timestep_embedding(
6384
6882
  struct ggml_context * ctx,
6385
6883
  struct ggml_tensor * timesteps,
@@ -6446,38 +6944,6 @@ struct ggml_tensor * ggml_top_k(
6446
6944
  return result;
6447
6945
  }
6448
6946
 
6449
- // ggml_flash_attn
6450
-
6451
- struct ggml_tensor * ggml_flash_attn(
6452
- struct ggml_context * ctx,
6453
- struct ggml_tensor * q,
6454
- struct ggml_tensor * k,
6455
- struct ggml_tensor * v,
6456
- bool masked) {
6457
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6458
- // TODO: check if vT can be multiplied by (k*qT)
6459
-
6460
- bool is_node = false;
6461
-
6462
- if (q->grad || k->grad || v->grad) {
6463
- is_node = true;
6464
- }
6465
-
6466
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6467
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6468
-
6469
- int32_t t = masked ? 1 : 0;
6470
- ggml_set_op_params(result, &t, sizeof(t));
6471
-
6472
- result->op = GGML_OP_FLASH_ATTN;
6473
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6474
- result->src[0] = q;
6475
- result->src[1] = k;
6476
- result->src[2] = v;
6477
-
6478
- return result;
6479
- }
6480
-
6481
6947
  // ggml_flash_attn_ext
6482
6948
 
6483
6949
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6486,9 +6952,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
6486
6952
  struct ggml_tensor * k,
6487
6953
  struct ggml_tensor * v,
6488
6954
  struct ggml_tensor * mask,
6489
- float scale) {
6955
+ float scale,
6956
+ float max_bias) {
6490
6957
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6491
6958
  // TODO: check if vT can be multiplied by (k*qT)
6959
+
6492
6960
  if (mask) {
6493
6961
  GGML_ASSERT(ggml_is_contiguous(mask));
6494
6962
  GGML_ASSERT(mask->ne[2] == 1);
@@ -6498,6 +6966,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
6498
6966
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6499
6967
  }
6500
6968
 
6969
+ if (max_bias > 0.0f) {
6970
+ GGML_ASSERT(mask);
6971
+ }
6972
+
6501
6973
  bool is_node = false;
6502
6974
 
6503
6975
  if (q->grad || k->grad || v->grad) {
@@ -6508,7 +6980,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
6508
6980
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6509
6981
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6510
6982
 
6511
- float params[] = { scale };
6983
+ float params[] = { scale, max_bias };
6512
6984
  ggml_set_op_params(result, params, sizeof(params));
6513
6985
 
6514
6986
  result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6528,39 +7000,7 @@ void ggml_flash_attn_ext_set_prec(
6528
7000
 
6529
7001
  const int32_t prec_i32 = (int32_t) prec;
6530
7002
 
6531
- ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6532
- }
6533
-
6534
- // ggml_flash_ff
6535
-
6536
- struct ggml_tensor * ggml_flash_ff(
6537
- struct ggml_context * ctx,
6538
- struct ggml_tensor * a,
6539
- struct ggml_tensor * b0,
6540
- struct ggml_tensor * b1,
6541
- struct ggml_tensor * c0,
6542
- struct ggml_tensor * c1) {
6543
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6544
- // TODO: more checks
6545
-
6546
- bool is_node = false;
6547
-
6548
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6549
- is_node = true;
6550
- }
6551
-
6552
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6553
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6554
-
6555
- result->op = GGML_OP_FLASH_FF;
6556
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6557
- result->src[0] = a;
6558
- result->src[1] = b0;
6559
- result->src[2] = b1;
6560
- result->src[3] = c0;
6561
- result->src[4] = c1;
6562
-
6563
- return result;
7003
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6564
7004
  }
6565
7005
 
6566
7006
  // ggml_flash_attn_back
@@ -6572,6 +7012,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6572
7012
  struct ggml_tensor * v,
6573
7013
  struct ggml_tensor * d,
6574
7014
  bool masked) {
7015
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7016
+
6575
7017
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6576
7018
  // TODO: check if vT can be multiplied by (k*qT)
6577
7019
 
@@ -10892,6 +11334,52 @@ static void ggml_compute_forward_relu(
10892
11334
  }
10893
11335
  }
10894
11336
 
11337
+ // ggml_compute_forward_sigmoid
11338
+
11339
+ static void ggml_compute_forward_sigmoid_f32(
11340
+ const struct ggml_compute_params * params,
11341
+ struct ggml_tensor * dst) {
11342
+
11343
+ const struct ggml_tensor * src0 = dst->src[0];
11344
+
11345
+ assert(params->ith == 0);
11346
+ assert(ggml_are_same_shape(src0, dst));
11347
+
11348
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11349
+ return;
11350
+ }
11351
+
11352
+ const int n = ggml_nrows(src0);
11353
+ const int nc = src0->ne[0];
11354
+
11355
+ assert(dst->nb[0] == sizeof(float));
11356
+ assert(src0->nb[0] == sizeof(float));
11357
+
11358
+ for (int i = 0; i < n; i++) {
11359
+ ggml_vec_sigmoid_f32(nc,
11360
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
11361
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
11362
+ }
11363
+ }
11364
+
11365
+ static void ggml_compute_forward_sigmoid(
11366
+ const struct ggml_compute_params * params,
11367
+ struct ggml_tensor * dst) {
11368
+
11369
+ const struct ggml_tensor * src0 = dst->src[0];
11370
+
11371
+ switch (src0->type) {
11372
+ case GGML_TYPE_F32:
11373
+ {
11374
+ ggml_compute_forward_sigmoid_f32(params, dst);
11375
+ } break;
11376
+ default:
11377
+ {
11378
+ GGML_ASSERT(false);
11379
+ } break;
11380
+ }
11381
+ }
11382
+
10895
11383
  // ggml_compute_forward_gelu
10896
11384
 
10897
11385
  static void ggml_compute_forward_gelu_f32(
@@ -11742,80 +12230,171 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
11742
12230
  }
11743
12231
  #endif
11744
12232
 
11745
- static void ggml_compute_forward_mul_mat(
11746
- const struct ggml_compute_params * params,
11747
- struct ggml_tensor * dst) {
12233
+ static void ggml_compute_forward_mul_mat_one_chunk(
12234
+ const struct ggml_compute_params * params,
12235
+ struct ggml_tensor * dst,
12236
+ const int64_t num_rows_per_vec_dot,
12237
+ const int64_t ir0_start,
12238
+ const int64_t ir0_end,
12239
+ const int64_t ir1_start,
12240
+ const int64_t ir1_end) {
11748
12241
 
11749
12242
  const struct ggml_tensor * src0 = dst->src[0];
11750
12243
  const struct ggml_tensor * src1 = dst->src[1];
11751
12244
 
11752
- int64_t t0 = ggml_perf_time_us();
11753
- UNUSED(t0);
11754
-
11755
12245
  GGML_TENSOR_BINARY_OP_LOCALS
11756
12246
 
11757
- const int ith = params->ith;
11758
- const int nth = params->nth;
11759
-
11760
12247
  const enum ggml_type type = src0->type;
11761
12248
 
11762
12249
  const bool src1_cont = ggml_is_contiguous(src1);
11763
12250
 
11764
- ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
11765
- enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11766
- ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
11767
- int64_t const vec_dot_num_rows = type_traits[type].nrows;
11768
-
11769
- GGML_ASSERT(ne0 == ne01);
11770
- GGML_ASSERT(ne1 == ne11);
11771
- GGML_ASSERT(ne2 == ne12);
11772
- GGML_ASSERT(ne3 == ne13);
11773
-
11774
- // we don't support permuted src0 or src1
11775
- GGML_ASSERT(nb00 == ggml_type_size(type));
11776
- GGML_ASSERT(nb10 == ggml_type_size(src1->type));
11777
-
11778
- // dst cannot be transposed or permuted
11779
- GGML_ASSERT(nb0 == sizeof(float));
11780
- GGML_ASSERT(nb0 <= nb1);
11781
- GGML_ASSERT(nb1 <= nb2);
11782
- GGML_ASSERT(nb2 <= nb3);
12251
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
12252
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11783
12253
 
11784
12254
  // broadcast factors
11785
- const int64_t r2 = ne12/ne02;
11786
- const int64_t r3 = ne13/ne03;
12255
+ const int64_t r2 = ne12 / ne02;
12256
+ const int64_t r3 = ne13 / ne03;
11787
12257
 
11788
- // nb01 >= nb00 - src0 is not transposed
11789
- // compute by src0 rows
12258
+ //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
11790
12259
 
11791
- #if defined(GGML_USE_CLBLAST)
11792
- if (ggml_cl_can_mul_mat(src0, src1, dst)) {
11793
- if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) {
11794
- ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
11795
- }
12260
+ // threads with no work simply yield (not sure if it helps)
12261
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
11796
12262
  return;
11797
12263
  }
11798
- #endif
11799
12264
 
11800
- #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
11801
- if (ggml_compute_forward_mul_mat_use_blas(dst)) {
11802
- const int64_t ne_plane = ne01*ne00;
11803
- const size_t desired_wsize = ne13*ne12*ne_plane*sizeof(float);
11804
- UNUSED(desired_wsize);
12265
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12266
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11805
12267
 
11806
- if (params->type == GGML_TASK_TYPE_INIT) {
11807
- if (type != GGML_TYPE_F32) {
11808
- assert(params->wsize >= desired_wsize);
11809
- // parallelize by src0 rows
11810
- for (int64_t i13 = 0; i13 < ne13; i13++) {
11811
- for (int64_t i12 = 0; i12 < ne12; i12++) {
11812
- // broadcast src0 into src1 across 2nd,3rd dimension
11813
- const int64_t i03 = i13/r3;
11814
- const int64_t i02 = i12/r2;
12268
+ assert(ne12 % ne02 == 0);
12269
+ assert(ne13 % ne03 == 0);
11815
12270
 
11816
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
11817
- float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
11818
- ggml_to_float_t const to_float = type_traits[type].to_float;
12271
+ // block-tiling attempt
12272
+ const int64_t blck_0 = 16;
12273
+ const int64_t blck_1 = 16;
12274
+
12275
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12276
+
12277
+ // attempt to reduce false-sharing (does not seem to make a difference)
12278
+ // 16 * 2, accounting for mmla kernels
12279
+ float tmp[32];
12280
+
12281
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
12282
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
12283
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
12284
+ const int64_t i13 = (ir1 / (ne12 * ne1));
12285
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
12286
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
12287
+
12288
+ // broadcast src0 into src1
12289
+ const int64_t i03 = i13 / r3;
12290
+ const int64_t i02 = i12 / r2;
12291
+
12292
+ const int64_t i1 = i11;
12293
+ const int64_t i2 = i12;
12294
+ const int64_t i3 = i13;
12295
+
12296
+ const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
12297
+
12298
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12299
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12300
+ // the original src1 data pointer, so we should index using the indices directly
12301
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
12302
+ const char * src1_col = (const char*)wdata +
12303
+ (src1_cont || src1->type != vec_dot_type
12304
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
12305
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
12306
+ float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
12307
+
12308
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
12309
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12310
+ //}
12311
+
12312
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
12313
+ vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
12314
+ }
12315
+
12316
+ for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
12317
+ memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
12318
+ }
12319
+ }
12320
+ }
12321
+ }
12322
+ }
12323
+
12324
+ static void ggml_compute_forward_mul_mat(
12325
+ const struct ggml_compute_params * params,
12326
+ struct ggml_tensor * dst,
12327
+ struct ggml_compute_state * state) {
12328
+
12329
+ const struct ggml_tensor * src0 = dst->src[0];
12330
+ const struct ggml_tensor * src1 = dst->src[1];
12331
+
12332
+ int64_t t0 = ggml_perf_time_us();
12333
+ UNUSED(t0);
12334
+
12335
+ GGML_TENSOR_BINARY_OP_LOCALS
12336
+
12337
+ const int ith = params->ith;
12338
+ const int nth = params->nth;
12339
+
12340
+ const enum ggml_type type = src0->type;
12341
+
12342
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
12343
+ ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
12344
+ int64_t const vec_dot_num_rows = type_traits[type].nrows;
12345
+
12346
+ GGML_ASSERT(ne0 == ne01);
12347
+ GGML_ASSERT(ne1 == ne11);
12348
+ GGML_ASSERT(ne2 == ne12);
12349
+ GGML_ASSERT(ne3 == ne13);
12350
+
12351
+ // we don't support permuted src0 or src1
12352
+ GGML_ASSERT(nb00 == ggml_type_size(type));
12353
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
12354
+
12355
+ // dst cannot be transposed or permuted
12356
+ GGML_ASSERT(nb0 == sizeof(float));
12357
+ GGML_ASSERT(nb0 <= nb1);
12358
+ GGML_ASSERT(nb1 <= nb2);
12359
+ GGML_ASSERT(nb2 <= nb3);
12360
+
12361
+ // broadcast factors
12362
+ const int64_t r2 = ne12 / ne02;
12363
+ const int64_t r3 = ne13 / ne03;
12364
+ UNUSED(r2);
12365
+ UNUSED(r3);
12366
+
12367
+ // nb01 >= nb00 - src0 is not transposed
12368
+ // compute by src0 rows
12369
+
12370
+ #if defined(GGML_USE_CLBLAST)
12371
+ if (ggml_cl_can_mul_mat(src0, src1, dst)) {
12372
+ if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) {
12373
+ ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
12374
+ }
12375
+ return;
12376
+ }
12377
+ #endif
12378
+
12379
+ #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
12380
+ if (ggml_compute_forward_mul_mat_use_blas(dst)) {
12381
+ const int64_t ne_plane = ne01*ne00;
12382
+ const size_t desired_wsize = ne13*ne12*ne_plane*sizeof(float);
12383
+ UNUSED(desired_wsize);
12384
+
12385
+ if (params->type == GGML_TASK_TYPE_INIT) {
12386
+ if (type != GGML_TYPE_F32) {
12387
+ assert(params->wsize >= desired_wsize);
12388
+ // parallelize by src0 rows
12389
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
12390
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
12391
+ // broadcast src0 into src1 across 2nd,3rd dimension
12392
+ const int64_t i03 = i13/r3;
12393
+ const int64_t i02 = i12/r2;
12394
+
12395
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
12396
+ float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane;
12397
+ ggml_to_float_t const to_float = type_traits[type].to_float;
11819
12398
 
11820
12399
  for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
11821
12400
  to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00);
@@ -11865,6 +12444,8 @@ static void ggml_compute_forward_mul_mat(
11865
12444
  #endif
11866
12445
 
11867
12446
  #if GGML_USE_LLAMAFILE
12447
+ const bool src1_cont = ggml_is_contiguous(src1);
12448
+
11868
12449
  if (src1_cont) {
11869
12450
  for (int64_t i13 = 0; i13 < ne13; i13++)
11870
12451
  for (int64_t i12 = 0; i12 < ne12; i12++)
@@ -11890,6 +12471,8 @@ UseGgmlGemm1:;
11890
12471
  if (ith != 0) {
11891
12472
  return;
11892
12473
  }
12474
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12475
+ atomic_store(&state->shared->current_chunk, nth);
11893
12476
  if (src1->type != vec_dot_type) {
11894
12477
  char * wdata = params->wdata;
11895
12478
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11914,11 +12497,11 @@ UseGgmlGemm1:;
11914
12497
  return;
11915
12498
  }
11916
12499
 
11917
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11918
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11919
-
11920
12500
  #if GGML_USE_LLAMAFILE
11921
12501
  if (src1->type != vec_dot_type) {
12502
+ const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12503
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12504
+
11922
12505
  for (int64_t i13 = 0; i13 < ne13; i13++)
11923
12506
  for (int64_t i12 = 0; i12 < ne12; i12++)
11924
12507
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -11939,98 +12522,87 @@ UseGgmlGemm1:;
11939
12522
  UseGgmlGemm2:;
11940
12523
  #endif
11941
12524
 
11942
- const int64_t nr0 = ne01; // src0 rows
11943
- const int64_t nr1 = ne1*ne12*ne13; // src1 rows
11944
-
11945
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
11946
-
11947
- // distribute the thread work across the inner or outer loop based on which one is larger
11948
-
11949
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
11950
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
11951
-
11952
- const int64_t ith0 = ith % nth0;
11953
- const int64_t ith1 = ith / nth0;
11954
-
11955
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
11956
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
11957
-
11958
- const int64_t ir010 = dr0*ith0;
11959
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
11960
-
11961
- const int64_t ir110 = dr1*ith1;
11962
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
11963
-
11964
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
11965
-
11966
- // threads with no work simply yield (not sure if it helps)
11967
- if (ir010 >= ir011 || ir110 >= ir111) {
11968
- sched_yield();
11969
- return;
11970
- }
12525
+ #ifdef GGML_PERF
12526
+ int chunks_executed = 0;
12527
+ UNUSED(chunks_executed);
12528
+ #endif
11971
12529
 
11972
- assert(ne12 % ne02 == 0);
11973
- assert(ne13 % ne03 == 0);
12530
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
12531
+ const int64_t nr0 = ne0;
11974
12532
 
11975
- // block-tiling attempt
11976
- const int64_t blck_0 = 16;
11977
- const int64_t blck_1 = 16;
12533
+ // This is the size of the rest of the dimensions of the result
12534
+ const int64_t nr1 = ne1 * ne2 * ne3;
11978
12535
 
11979
12536
  // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
11980
- int64_t nrc = vec_dot_num_rows;
12537
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
11981
12538
  // TODO: currently the mmla kernels support only even numbered rows/cols.
11982
12539
  // this check can be removed once they are extended to support odd numbered rows/cols too
11983
12540
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
11984
- nrc = 1;
12541
+ num_rows_per_vec_dot = 1;
11985
12542
  }
11986
12543
 
11987
- const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12544
+ // Now select a reasonable chunk size.
12545
+ int chunk_size = 16;
11988
12546
 
11989
- // attempt to reduce false-sharing (does not seem to make a difference)
11990
- // 16 * 2, accounting for mmla kernels
11991
- float tmp[32];
12547
+ // We need to step up the size if it's small
12548
+ if (nr0 == 1 || nr1 == 1) {
12549
+ chunk_size = 64;
12550
+ }
11992
12551
 
11993
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
11994
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
11995
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
11996
- const int64_t i13 = (ir1/(ne12*ne1));
11997
- const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
11998
- const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
12552
+ // distribute the work across the inner or outer loop based on which one is larger
12553
+ // The number of chunks in the 0/1 dim.
12554
+ // CEIL(nr0/chunk_size)
12555
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
12556
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
11999
12557
 
12000
- // broadcast src0 into src1
12001
- const int64_t i03 = i13/r3;
12002
- const int64_t i02 = i12/r2;
12558
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
12559
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
12560
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
12561
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
12562
+ // distribute the thread work across the inner or outer loop based on which one is larger
12563
+ nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
12564
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
12565
+ }
12003
12566
 
12004
- const int64_t i1 = i11;
12005
- const int64_t i2 = i12;
12006
- const int64_t i3 = i13;
12567
+ // The number of elements in each chunk
12568
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
12569
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
12007
12570
 
12008
- const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
12571
+ //if (ith == 0)
12572
+ // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
12009
12573
 
12010
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12011
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12012
- // the original src1 data pointer, so we should index using the indices directly
12013
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
12014
- const char * src1_col = (const char *) wdata +
12015
- (src1_cont || src1->type != vec_dot_type
12016
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
12017
- : (i11*nb11 + i12*nb12 + i13*nb13));
12018
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
12574
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
12575
+ int current_chunk = ith;
12019
12576
 
12020
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
12021
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12022
- //}
12577
+ while (current_chunk < nchunk0 * nchunk1) {
12578
+ const int64_t ith0 = current_chunk % nchunk0;
12579
+ const int64_t ith1 = current_chunk / nchunk0;
12023
12580
 
12024
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
12025
- vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
12026
- }
12581
+ const int64_t ir0_start = dr0 * ith0;
12582
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
12027
12583
 
12028
- for (int cn = 0; cn < nrc; ++cn) {
12029
- memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
12030
- }
12031
- }
12584
+ const int64_t ir1_start = dr1 * ith1;
12585
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
12586
+
12587
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
12588
+
12589
+ #ifdef GGML_PERF
12590
+ chunks_executed++;
12591
+ #endif
12592
+
12593
+ if (nth >= nchunk0 * nchunk1) {
12594
+ break;
12032
12595
  }
12596
+
12597
+ current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
12033
12598
  }
12599
+
12600
+ #ifdef GGML_PERF
12601
+ // These numbers are useful when trying to measure how well the threading scheduling works.
12602
+ //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
12603
+ //float time = (ggml_perf_time_us() - t0);
12604
+ //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
12605
+ #endif
12034
12606
  }
12035
12607
 
12036
12608
  // ggml_compute_forward_mul_mat_id
@@ -13333,7 +13905,6 @@ static void ggml_compute_forward_soft_max_f32(
13333
13905
 
13334
13906
  const struct ggml_tensor * src0 = dst->src[0];
13335
13907
  const struct ggml_tensor * src1 = dst->src[1];
13336
- const struct ggml_tensor * src2 = dst->src[2];
13337
13908
 
13338
13909
  assert(ggml_is_contiguous(dst));
13339
13910
  assert(ggml_are_same_shape(src0, dst));
@@ -13359,8 +13930,8 @@ static void ggml_compute_forward_soft_max_f32(
13359
13930
 
13360
13931
  // TODO: is this supposed to be ceil instead of floor?
13361
13932
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
13362
- const uint32_t n_head_kv = ne02;
13363
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
13933
+ const uint32_t n_head = ne02;
13934
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
13364
13935
 
13365
13936
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
13366
13937
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -13377,13 +13948,13 @@ static void ggml_compute_forward_soft_max_f32(
13377
13948
 
13378
13949
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
13379
13950
 
13380
- // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
13381
- ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
13382
- float * pos_f32 = src2 ? (float *) src2->data : src0->data;
13383
-
13384
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
13951
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
13385
13952
 
13386
13953
  for (int i1 = ir0; i1 < ir1; i1++) {
13954
+ // ALiBi
13955
+ const uint32_t h = (i1/ne01)%ne02; // head
13956
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
13957
+
13387
13958
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
13388
13959
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
13389
13960
 
@@ -13396,27 +13967,11 @@ static void ggml_compute_forward_soft_max_f32(
13396
13967
  if (mp_f32) {
13397
13968
  if (use_f16) {
13398
13969
  for (int i = 0; i < nc; ++i) {
13399
- wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
13400
- }
13401
- } else {
13402
- for (int i = 0; i < nc; ++i) {
13403
- wp[i] += mp_f32[i];
13404
- }
13405
- }
13406
- }
13407
-
13408
- // ALiBi bias
13409
- if (max_bias > 0.0f) {
13410
- const uint32_t h = (i1/ne01)%ne02; // head
13411
- const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
13412
-
13413
- if (use_f16) {
13414
- for (int i = 0; i < nc; ++i) {
13415
- wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13970
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
13416
13971
  }
13417
13972
  } else {
13418
13973
  for (int i = 0; i < nc; ++i) {
13419
- wp[i] += slope*pos_f32[i];
13974
+ wp[i] += slope*mp_f32[i];
13420
13975
  }
13421
13976
  }
13422
13977
  }
@@ -13431,22 +13986,7 @@ static void ggml_compute_forward_soft_max_f32(
13431
13986
  float max = -INFINITY;
13432
13987
  ggml_vec_max_f32(nc, &max, wp);
13433
13988
 
13434
- ggml_float sum = 0.0;
13435
-
13436
- uint16_t scvt;
13437
- for (int i = 0; i < nc; i++) {
13438
- if (wp[i] == -INFINITY) {
13439
- dp[i] = 0.0f;
13440
- } else {
13441
- // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
13442
- ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
13443
- memcpy(&scvt, &s, sizeof(scvt));
13444
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
13445
- sum += (ggml_float)val;
13446
- dp[i] = val;
13447
- }
13448
- }
13449
-
13989
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
13450
13990
  assert(sum > 0.0);
13451
13991
 
13452
13992
  sum = 1.0/sum;
@@ -13578,68 +14118,9 @@ static void ggml_compute_forward_soft_max_back(
13578
14118
  }
13579
14119
  }
13580
14120
 
13581
- // ggml_compute_forward_alibi
13582
-
13583
- static void ggml_compute_forward_alibi_f32(
13584
- const struct ggml_compute_params * params,
13585
- struct ggml_tensor * dst) {
13586
-
13587
- const struct ggml_tensor * src0 = dst->src[0];
13588
-
13589
- assert(params->ith == 0);
13590
-
13591
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13592
- return;
13593
- }
13594
-
13595
- //const int n_past = ((int32_t *) dst->op_params)[0];
13596
- const int n_head = ((int32_t *) dst->op_params)[1];
13597
- float max_bias;
13598
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
13599
-
13600
- const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13601
- const int64_t ne1 = src0->ne[1]; // seq_len_without_past
13602
- const int64_t ne2 = src0->ne[2]; // n_head -> this is k
13603
- //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
13604
-
13605
- const int64_t n = ggml_nrows(src0);
13606
- const int64_t ne2_ne3 = n/ne1; // ne2*ne3
13607
-
13608
- const size_t nb0 = src0->nb[0];
13609
- const size_t nb1 = src0->nb[1];
13610
- const size_t nb2 = src0->nb[2];
13611
- //const int nb3 = src0->nb[3];
13612
-
13613
- GGML_ASSERT(nb0 == sizeof(float));
13614
- GGML_ASSERT(n_head == ne2);
13615
-
13616
- // add alibi to src0 (KQ_scaled)
13617
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
13618
-
13619
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
13620
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
13621
-
13622
- for (int64_t k = 0; k < ne2_ne3; k++) {
13623
- // TODO: k*nb2 or k*nb3
13624
- float m_k;
13625
-
13626
- if (k < n_heads_log2_floor) {
13627
- m_k = powf(m0, k + 1);
13628
- } else {
13629
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
13630
- }
13631
-
13632
- for (int64_t i = 0; i < ne0; i++) {
13633
- for (int64_t j = 0; j < ne1; j++) {
13634
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
13635
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
13636
- pdst[0] = i * m_k + src[0];
13637
- }
13638
- }
13639
- }
13640
- }
14121
+ // ggml_compute_forward_clamp
13641
14122
 
13642
- static void ggml_compute_forward_alibi_f16(
14123
+ static void ggml_compute_forward_clamp_f32(
13643
14124
  const struct ggml_compute_params * params,
13644
14125
  struct ggml_tensor * dst) {
13645
14126
 
@@ -13651,71 +14132,48 @@ static void ggml_compute_forward_alibi_f16(
13651
14132
  return;
13652
14133
  }
13653
14134
 
13654
- //const int n_past = ((int32_t *) dst->op_params)[0];
13655
- const int n_head = ((int32_t *) dst->op_params)[1];
13656
- float max_bias;
13657
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
14135
+ float min;
14136
+ float max;
14137
+ memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
14138
+ memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
13658
14139
 
13659
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13660
- const int ne1 = src0->ne[1]; // seq_len_without_past
13661
- const int ne2 = src0->ne[2]; // n_head -> this is k
13662
- //const int ne3 = src0->ne[3]; // 1 -> bsz
14140
+ const int ith = params->ith;
14141
+ const int nth = params->nth;
13663
14142
 
13664
14143
  const int n = ggml_nrows(src0);
13665
- const int ne2_ne3 = n/ne1; // ne2*ne3
13666
-
13667
- const int nb0 = src0->nb[0];
13668
- const int nb1 = src0->nb[1];
13669
- const int nb2 = src0->nb[2];
13670
- //const int nb3 = src0->nb[3];
13671
-
13672
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
13673
- //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
13674
- GGML_ASSERT(n_head == ne2);
13675
-
13676
- // add alibi to src0 (KQ_scaled)
13677
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
14144
+ const int nc = src0->ne[0];
13678
14145
 
13679
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
13680
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
14146
+ const size_t nb00 = src0->nb[0];
14147
+ const size_t nb01 = src0->nb[1];
13681
14148
 
13682
- for (int k = 0; k < ne2_ne3; k++) {
13683
- // TODO: k*nb2 or k*nb3
13684
- float m_k;
14149
+ const size_t nb0 = dst->nb[0];
14150
+ const size_t nb1 = dst->nb[1];
13685
14151
 
13686
- if (k < n_heads_log2_floor) {
13687
- m_k = powf(m0, k + 1);
13688
- } else {
13689
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
13690
- }
14152
+ GGML_ASSERT( nb0 == sizeof(float));
14153
+ GGML_ASSERT(nb00 == sizeof(float));
13691
14154
 
13692
- for (int i = 0; i < ne0; i++) {
13693
- for (int j = 0; j < ne1; j++) {
13694
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
13695
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
14155
+ for (int j = ith; j < n; j += nth) {
14156
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
14157
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
13696
14158
 
13697
- // we return F32
13698
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
13699
- }
14159
+ for (int i = 0; i < nc; i++) {
14160
+ dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
13700
14161
  }
13701
14162
  }
13702
14163
  }
13703
14164
 
13704
- static void ggml_compute_forward_alibi(
14165
+ static void ggml_compute_forward_clamp(
13705
14166
  const struct ggml_compute_params * params,
13706
14167
  struct ggml_tensor * dst) {
13707
14168
 
13708
14169
  const struct ggml_tensor * src0 = dst->src[0];
13709
14170
 
13710
14171
  switch (src0->type) {
13711
- case GGML_TYPE_F16:
13712
- {
13713
- ggml_compute_forward_alibi_f16(params, dst);
13714
- } break;
13715
14172
  case GGML_TYPE_F32:
13716
14173
  {
13717
- ggml_compute_forward_alibi_f32(params, dst);
14174
+ ggml_compute_forward_clamp_f32(params, dst);
13718
14175
  } break;
14176
+ case GGML_TYPE_F16:
13719
14177
  case GGML_TYPE_BF16:
13720
14178
  case GGML_TYPE_Q4_0:
13721
14179
  case GGML_TYPE_Q4_1:
@@ -13750,102 +14208,12 @@ static void ggml_compute_forward_alibi(
13750
14208
  }
13751
14209
  }
13752
14210
 
13753
- // ggml_compute_forward_clamp
13754
-
13755
- static void ggml_compute_forward_clamp_f32(
13756
- const struct ggml_compute_params * params,
13757
- struct ggml_tensor * dst) {
13758
-
13759
- const struct ggml_tensor * src0 = dst->src[0];
14211
+ // ggml_compute_forward_rope
13760
14212
 
13761
- assert(params->ith == 0);
13762
-
13763
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13764
- return;
13765
- }
13766
-
13767
- float min;
13768
- float max;
13769
- memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
13770
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
13771
-
13772
- const int ith = params->ith;
13773
- const int nth = params->nth;
13774
-
13775
- const int n = ggml_nrows(src0);
13776
- const int nc = src0->ne[0];
13777
-
13778
- const size_t nb00 = src0->nb[0];
13779
- const size_t nb01 = src0->nb[1];
13780
-
13781
- const size_t nb0 = dst->nb[0];
13782
- const size_t nb1 = dst->nb[1];
13783
-
13784
- GGML_ASSERT( nb0 == sizeof(float));
13785
- GGML_ASSERT(nb00 == sizeof(float));
13786
-
13787
- for (int j = ith; j < n; j += nth) {
13788
- float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
13789
- float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
13790
-
13791
- for (int i = 0; i < nc; i++) {
13792
- dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
13793
- }
13794
- }
13795
- }
13796
-
13797
- static void ggml_compute_forward_clamp(
13798
- const struct ggml_compute_params * params,
13799
- struct ggml_tensor * dst) {
13800
-
13801
- const struct ggml_tensor * src0 = dst->src[0];
13802
-
13803
- switch (src0->type) {
13804
- case GGML_TYPE_F32:
13805
- {
13806
- ggml_compute_forward_clamp_f32(params, dst);
13807
- } break;
13808
- case GGML_TYPE_F16:
13809
- case GGML_TYPE_BF16:
13810
- case GGML_TYPE_Q4_0:
13811
- case GGML_TYPE_Q4_1:
13812
- case GGML_TYPE_Q5_0:
13813
- case GGML_TYPE_Q5_1:
13814
- case GGML_TYPE_Q8_0:
13815
- case GGML_TYPE_Q8_1:
13816
- case GGML_TYPE_Q2_K:
13817
- case GGML_TYPE_Q3_K:
13818
- case GGML_TYPE_Q4_K:
13819
- case GGML_TYPE_Q5_K:
13820
- case GGML_TYPE_Q6_K:
13821
- case GGML_TYPE_IQ2_XXS:
13822
- case GGML_TYPE_IQ2_XS:
13823
- case GGML_TYPE_IQ3_XXS:
13824
- case GGML_TYPE_IQ1_S:
13825
- case GGML_TYPE_IQ1_M:
13826
- case GGML_TYPE_IQ4_NL:
13827
- case GGML_TYPE_IQ4_XS:
13828
- case GGML_TYPE_IQ3_S:
13829
- case GGML_TYPE_IQ2_S:
13830
- case GGML_TYPE_Q8_K:
13831
- case GGML_TYPE_I8:
13832
- case GGML_TYPE_I16:
13833
- case GGML_TYPE_I32:
13834
- case GGML_TYPE_I64:
13835
- case GGML_TYPE_F64:
13836
- case GGML_TYPE_COUNT:
13837
- {
13838
- GGML_ASSERT(false);
13839
- } break;
13840
- }
13841
- }
13842
-
13843
- // ggml_compute_forward_rope
13844
-
13845
- static float rope_yarn_ramp(const float low, const float high, const int i0) {
13846
- const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
13847
- return 1 - MIN(1, MAX(0, y));
13848
- }
14213
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
14214
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
14215
+ return 1 - MIN(1, MAX(0, y));
14216
+ }
13849
14217
 
13850
14218
  // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
13851
14219
  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
@@ -13905,6 +14273,7 @@ static void ggml_compute_forward_rope_f32(
13905
14273
 
13906
14274
  const struct ggml_tensor * src0 = dst->src[0];
13907
14275
  const struct ggml_tensor * src1 = dst->src[1];
14276
+ const struct ggml_tensor * src2 = dst->src[2];
13908
14277
 
13909
14278
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13910
14279
  return;
@@ -13964,6 +14333,17 @@ static void ggml_compute_forward_rope_f32(
13964
14333
  const bool is_neox = mode & 2;
13965
14334
  const bool is_glm = mode & 4;
13966
14335
 
14336
+ const float * freq_factors = NULL;
14337
+ if (is_neox) {
14338
+ if (src2 != NULL) {
14339
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14340
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14341
+ freq_factors = (const float *) src2->data;
14342
+ }
14343
+ } else {
14344
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14345
+ }
14346
+
13967
14347
  // backward process uses inverse rotation by cos and sin.
13968
14348
  // cos and sin build a rotation matrix, where the inverse is the transpose.
13969
14349
  // this essentially just switches the sign of sin.
@@ -14040,10 +14420,11 @@ static void ggml_compute_forward_rope_f32(
14040
14420
 
14041
14421
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14042
14422
  float cur_rot = inv_ndims * ic - ib;
14423
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14043
14424
 
14044
14425
  float cos_theta, sin_theta;
14045
14426
  rope_yarn(
14046
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14427
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14047
14428
  &cos_theta, &sin_theta
14048
14429
  );
14049
14430
  sin_theta *= sin_sign;
@@ -14076,6 +14457,7 @@ static void ggml_compute_forward_rope_f32(
14076
14457
  }
14077
14458
  }
14078
14459
 
14460
+ // TODO: deduplicate f16/f32 code
14079
14461
  static void ggml_compute_forward_rope_f16(
14080
14462
  const struct ggml_compute_params * params,
14081
14463
  struct ggml_tensor * dst,
@@ -14083,6 +14465,7 @@ static void ggml_compute_forward_rope_f16(
14083
14465
 
14084
14466
  const struct ggml_tensor * src0 = dst->src[0];
14085
14467
  const struct ggml_tensor * src1 = dst->src[1];
14468
+ const struct ggml_tensor * src2 = dst->src[2];
14086
14469
 
14087
14470
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14088
14471
  return;
@@ -14135,6 +14518,17 @@ static void ggml_compute_forward_rope_f16(
14135
14518
  const bool is_neox = mode & 2;
14136
14519
  const bool is_glm = mode & 4;
14137
14520
 
14521
+ const float * freq_factors = NULL;
14522
+ if (is_neox) {
14523
+ if (src2 != NULL) {
14524
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14525
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14526
+ freq_factors = (const float *) src2->data;
14527
+ }
14528
+ } else {
14529
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14530
+ }
14531
+
14138
14532
  // backward process uses inverse rotation by cos and sin.
14139
14533
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14140
14534
  // this essentially just switches the sign of sin.
@@ -14207,10 +14601,11 @@ static void ggml_compute_forward_rope_f16(
14207
14601
 
14208
14602
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14209
14603
  float cur_rot = inv_ndims * ic - ib;
14604
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14210
14605
 
14211
14606
  float cos_theta, sin_theta;
14212
14607
  rope_yarn(
14213
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14608
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14214
14609
  &cos_theta, &sin_theta
14215
14610
  );
14216
14611
  sin_theta *= sin_sign;
@@ -14972,25 +15367,28 @@ static void ggml_compute_forward_upscale_f32(
14972
15367
  return;
14973
15368
  }
14974
15369
 
14975
- GGML_ASSERT(src0->nb[0] == sizeof(float));
15370
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
14976
15371
 
14977
15372
  const int ith = params->ith;
14978
15373
  const int nth = params->nth;
14979
15374
 
14980
15375
  GGML_TENSOR_UNARY_OP_LOCALS
14981
15376
 
14982
- const int scale_factor = dst->op_params[0];
15377
+ const float sf0 = (float)ne0/src0->ne[0];
15378
+ const float sf1 = (float)ne1/src0->ne[1];
15379
+ const float sf2 = (float)ne2/src0->ne[2];
15380
+ const float sf3 = (float)ne3/src0->ne[3];
14983
15381
 
14984
15382
  // TODO: optimize
14985
15383
 
14986
15384
  for (int64_t i3 = 0; i3 < ne3; i3++) {
14987
- const int64_t i03 = i3;
15385
+ const int64_t i03 = i3 / sf3;
14988
15386
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
14989
- const int64_t i02 = i2;
15387
+ const int64_t i02 = i2 / sf2;
14990
15388
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14991
- const int64_t i01 = i1 / scale_factor;
15389
+ const int64_t i01 = i1 / sf1;
14992
15390
  for (int64_t i0 = 0; i0 < ne0; i0++) {
14993
- const int64_t i00 = i0 / scale_factor;
15391
+ const int64_t i00 = i0 / sf0;
14994
15392
 
14995
15393
  const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
14996
15394
  float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
@@ -15020,6 +15418,7 @@ static void ggml_compute_forward_upscale(
15020
15418
  }
15021
15419
  }
15022
15420
 
15421
+
15023
15422
  // ggml_compute_forward_pad
15024
15423
 
15025
15424
  static void ggml_compute_forward_pad_f32(
@@ -15200,487 +15599,42 @@ static void ggml_compute_forward_argsort_f32(
15200
15599
  const int ith = params->ith;
15201
15600
  const int nth = params->nth;
15202
15601
 
15203
- const int64_t nr = ggml_nrows(src0);
15204
-
15205
- enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
15206
-
15207
- for (int64_t i = ith; i < nr; i += nth) {
15208
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
15209
- const float * src_data = (float *)((char *) src0->data + i*nb01);
15210
-
15211
- for (int64_t j = 0; j < ne0; j++) {
15212
- dst_data[j] = j;
15213
- }
15214
-
15215
- // C doesn't have a functional sort, so we do a bubble sort instead
15216
- for (int64_t j = 0; j < ne0; j++) {
15217
- for (int64_t k = j + 1; k < ne0; k++) {
15218
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
15219
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
15220
- int32_t tmp = dst_data[j];
15221
- dst_data[j] = dst_data[k];
15222
- dst_data[k] = tmp;
15223
- }
15224
- }
15225
- }
15226
- }
15227
- }
15228
-
15229
- static void ggml_compute_forward_argsort(
15230
- const struct ggml_compute_params * params,
15231
- struct ggml_tensor * dst) {
15232
-
15233
- const struct ggml_tensor * src0 = dst->src[0];
15234
-
15235
- switch (src0->type) {
15236
- case GGML_TYPE_F32:
15237
- {
15238
- ggml_compute_forward_argsort_f32(params, dst);
15239
- } break;
15240
- default:
15241
- {
15242
- GGML_ASSERT(false);
15243
- } break;
15244
- }
15245
- }
15246
-
15247
- // ggml_compute_forward_flash_attn
15248
-
15249
- static void ggml_compute_forward_flash_attn_f32(
15250
- const struct ggml_compute_params * params,
15251
- const bool masked,
15252
- struct ggml_tensor * dst) {
15253
-
15254
- const struct ggml_tensor * q = dst->src[0];
15255
- const struct ggml_tensor * k = dst->src[1];
15256
- const struct ggml_tensor * v = dst->src[2];
15257
-
15258
- int64_t t0 = ggml_perf_time_us();
15259
- UNUSED(t0);
15260
-
15261
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15262
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15263
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15264
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15265
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15266
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15267
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15268
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15269
-
15270
- const int ith = params->ith;
15271
- const int nth = params->nth;
15272
-
15273
- const int64_t D = neq0;
15274
- const int64_t N = neq1;
15275
- const int64_t P = nek1 - N;
15276
- const int64_t M = P + N;
15277
-
15278
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15279
-
15280
- GGML_ASSERT(ne0 == D);
15281
- GGML_ASSERT(ne1 == N);
15282
- GGML_ASSERT(P >= 0);
15283
-
15284
- GGML_ASSERT(nbq0 == sizeof(float));
15285
- GGML_ASSERT(nbk0 == sizeof(float));
15286
- GGML_ASSERT(nbv0 == sizeof(float));
15287
-
15288
- GGML_ASSERT(neq0 == D);
15289
- GGML_ASSERT(nek0 == D);
15290
- GGML_ASSERT(nev1 == D);
15291
-
15292
- GGML_ASSERT(neq1 == N);
15293
- GGML_ASSERT(nek1 == N + P);
15294
- GGML_ASSERT(nev1 == D);
15295
-
15296
- // dst cannot be transposed or permuted
15297
- GGML_ASSERT(nb0 == sizeof(float));
15298
- GGML_ASSERT(nb0 <= nb1);
15299
- GGML_ASSERT(nb1 <= nb2);
15300
- GGML_ASSERT(nb2 <= nb3);
15301
-
15302
- if (params->type == GGML_TASK_TYPE_INIT) {
15303
- return;
15304
- }
15305
-
15306
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15307
- return;
15308
- }
15309
-
15310
- // parallelize by q rows using ggml_vec_dot_f32
15311
-
15312
- // total rows in q
15313
- const int nr = neq1*neq2*neq3;
15314
-
15315
- // rows per thread
15316
- const int dr = (nr + nth - 1)/nth;
15317
-
15318
- // row range for this thread
15319
- const int ir0 = dr*ith;
15320
- const int ir1 = MIN(ir0 + dr, nr);
15321
-
15322
- const float scale = 1.0f/sqrtf(D);
15323
-
15324
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15325
-
15326
- for (int ir = ir0; ir < ir1; ++ir) {
15327
- // q indices
15328
- const int iq3 = ir/(neq2*neq1);
15329
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15330
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15331
-
15332
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15333
-
15334
- for (int i = M; i < Mup; ++i) {
15335
- S[i] = -INFINITY;
15336
- }
15337
-
15338
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15339
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15340
- // k indices
15341
- const int ik3 = iq3;
15342
- const int ik2 = iq2 % nek2;
15343
- const int ik1 = ic;
15344
-
15345
- // S indices
15346
- const int i1 = ik1;
15347
-
15348
- ggml_vec_dot_f32(neq0,
15349
- S + i1, 0,
15350
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15351
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15352
- }
15353
-
15354
- // scale
15355
- ggml_vec_scale_f32(masked_begin, S, scale);
15356
-
15357
- for (int64_t i = masked_begin; i < M; i++) {
15358
- S[i] = -INFINITY;
15359
- }
15360
-
15361
- // softmax
15362
- // exclude known -INF S[..] values from max and loop
15363
- // dont forget to set their SW values to zero
15364
- {
15365
- float max = -INFINITY;
15366
- ggml_vec_max_f32(masked_begin, &max, S);
15367
-
15368
- ggml_float sum = 0.0;
15369
- {
15370
- #ifdef GGML_SOFT_MAX_ACCELERATE
15371
- max = -max;
15372
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15373
- vvexpf(S, S, &Mup);
15374
- ggml_vec_sum_f32(Mup, &sum, S);
15375
- #else
15376
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15377
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15378
-
15379
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15380
- if (i >= masked_begin) {
15381
- break;
15382
- }
15383
- float * SS = S + i;
15384
-
15385
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15386
- if (i + j >= masked_begin) {
15387
- break;
15388
- } else if (SS[j] == -INFINITY) {
15389
- SS[j] = 0.0f;
15390
- } else {
15391
- #ifndef GGML_FLASH_ATTN_EXP_FP16
15392
- const float val = expf(SS[j] - max);
15393
- #else
15394
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15395
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15396
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15397
- #endif
15398
- sump[j] += (ggml_float)val;
15399
- SS[j] = val;
15400
- }
15401
- }
15402
- }
15403
-
15404
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15405
- sum += sump[i];
15406
- }
15407
- #endif
15408
- }
15409
-
15410
- assert(sum > 0.0);
15411
-
15412
- sum = 1.0/sum;
15413
- ggml_vec_scale_f32(masked_begin, S, sum);
15414
-
15415
- #ifndef NDEBUG
15416
- for (int i = 0; i < masked_begin; ++i) {
15417
- assert(!isnan(S[i]));
15418
- assert(!isinf(S[i]));
15419
- }
15420
- #endif
15421
- }
15422
-
15423
- for (int64_t ic = 0; ic < nev1; ++ic) {
15424
- // dst indices
15425
- const int i1 = iq1;
15426
- const int i2 = iq2;
15427
- const int i3 = iq3;
15428
-
15429
- // v indices
15430
- const int iv2 = iq2 % nev2;
15431
- const int iv3 = iq3;
15432
-
15433
- ggml_vec_dot_f32(masked_begin,
15434
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15435
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15436
- S, 0, 1);
15437
- }
15438
- }
15439
- }
15440
-
15441
- static void ggml_compute_forward_flash_attn_f16(
15442
- const struct ggml_compute_params * params,
15443
- const bool masked,
15444
- struct ggml_tensor * dst) {
15445
-
15446
- const struct ggml_tensor * q = dst->src[0];
15447
- const struct ggml_tensor * k = dst->src[1];
15448
- const struct ggml_tensor * v = dst->src[2];
15449
-
15450
- int64_t t0 = ggml_perf_time_us();
15451
- UNUSED(t0);
15452
-
15453
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15454
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15455
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15456
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15457
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15458
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15459
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15460
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15461
-
15462
- const int ith = params->ith;
15463
- const int nth = params->nth;
15464
-
15465
- const int64_t D = neq0;
15466
- const int64_t N = neq1;
15467
- const int64_t P = nek1 - N;
15468
- const int64_t M = P + N;
15469
-
15470
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15471
-
15472
- GGML_ASSERT(ne0 == D);
15473
- GGML_ASSERT(ne1 == N);
15474
- GGML_ASSERT(P >= 0);
15475
-
15476
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15477
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15478
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15479
-
15480
- GGML_ASSERT(neq0 == D);
15481
- GGML_ASSERT(nek0 == D);
15482
- GGML_ASSERT(nev1 == D);
15483
-
15484
- GGML_ASSERT(neq1 == N);
15485
- GGML_ASSERT(nek1 == N + P);
15486
- GGML_ASSERT(nev1 == D);
15487
-
15488
- // dst cannot be transposed or permuted
15489
- GGML_ASSERT(nb0 == sizeof(float));
15490
- GGML_ASSERT(nb0 <= nb1);
15491
- GGML_ASSERT(nb1 <= nb2);
15492
- GGML_ASSERT(nb2 <= nb3);
15493
-
15494
- if (params->type == GGML_TASK_TYPE_INIT) {
15495
- return;
15496
- }
15497
-
15498
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15499
- return;
15500
- }
15501
-
15502
- // parallelize by q rows using ggml_vec_dot_f32
15503
-
15504
- // total rows in q
15505
- const int nr = neq1*neq2*neq3;
15506
-
15507
- // rows per thread
15508
- const int dr = (nr + nth - 1)/nth;
15509
-
15510
- // row range for this thread
15511
- const int ir0 = dr*ith;
15512
- const int ir1 = MIN(ir0 + dr, nr);
15513
-
15514
- const float scale = 1.0f/sqrtf(D);
15515
-
15516
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15517
-
15518
- for (int ir = ir0; ir < ir1; ++ir) {
15519
- // q indices
15520
- const int iq3 = ir/(neq2*neq1);
15521
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15522
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15523
-
15524
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15525
-
15526
- for (int i = M; i < Mup; ++i) {
15527
- S[i] = -INFINITY;
15528
- }
15529
-
15530
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15531
- for (int64_t ic = 0; ic < nek1; ++ic) {
15532
- // k indices
15533
- const int ik3 = iq3;
15534
- const int ik2 = iq2 % nek2;
15535
- const int ik1 = ic;
15536
-
15537
- // S indices
15538
- const int i1 = ik1;
15539
-
15540
- ggml_vec_dot_f16(neq0,
15541
- S + i1, 0,
15542
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15543
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15544
- }
15545
- } else {
15546
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15547
- // k indices
15548
- const int ik3 = iq3;
15549
- const int ik2 = iq2 % nek2;
15550
- const int ik1 = ic;
15551
-
15552
- // S indices
15553
- const int i1 = ik1;
15554
-
15555
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15556
- S + i1,
15557
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15558
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15559
- }
15560
- }
15561
-
15562
- // scale
15563
- ggml_vec_scale_f32(nek1, S, scale);
15564
-
15565
- if (masked) {
15566
- for (int64_t i = P; i < M; i++) {
15567
- if (i > P + iq1) {
15568
- S[i] = -INFINITY;
15569
- }
15570
- }
15571
- }
15572
-
15573
- // softmax
15574
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15575
- // dont forget to set their S values to zero
15576
- {
15577
- float max = -INFINITY;
15578
- ggml_vec_max_f32(M, &max, S);
15579
-
15580
- ggml_float sum = 0.0;
15581
- {
15582
- #ifdef GGML_SOFT_MAX_ACCELERATE
15583
- max = -max;
15584
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15585
- vvexpf(S, S, &Mup);
15586
- ggml_vec_sum_f32(Mup, &sum, S);
15587
- #else
15588
- uint16_t scvt[GGML_SOFT_MAX_UNROLL];
15589
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15590
-
15591
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15592
- float * SS = S + i;
15593
-
15594
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15595
- if (SS[j] == -INFINITY) {
15596
- SS[j] = 0.0f;
15597
- } else {
15598
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15599
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15600
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15601
- sump[j] += (ggml_float)val;
15602
- SS[j] = val;
15603
- }
15604
- }
15605
- }
15606
-
15607
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15608
- sum += sump[i];
15609
- }
15610
- #endif
15611
- }
15612
-
15613
- assert(sum > 0.0);
15614
-
15615
- sum = 1.0/sum;
15616
- ggml_vec_scale_f32(M, S, sum);
15617
-
15618
- #ifndef NDEBUG
15619
- for (int i = 0; i < M; ++i) {
15620
- assert(!isnan(S[i]));
15621
- assert(!isinf(S[i]));
15622
- }
15623
- #endif
15624
- }
15625
-
15626
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15602
+ const int64_t nr = ggml_nrows(src0);
15603
+
15604
+ enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
15605
+
15606
+ for (int64_t i = ith; i < nr; i += nth) {
15607
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
15608
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
15627
15609
 
15628
- for (int64_t i = 0; i < M; i++) {
15629
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15610
+ for (int64_t j = 0; j < ne0; j++) {
15611
+ dst_data[j] = j;
15630
15612
  }
15631
15613
 
15632
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15633
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15634
- for (int64_t ic = 0; ic < nev1; ++ic) {
15635
- // dst indices
15636
- const int i1 = iq1;
15637
- const int i2 = iq2;
15638
- const int i3 = iq3;
15639
-
15640
- // v indices
15641
- const int iv2 = iq2 % nev2;
15642
- const int iv3 = iq3;
15643
-
15644
- ggml_vec_dot_f16(nev0,
15645
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15646
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15647
- S16, 0, 1);
15648
- }
15649
- } else {
15650
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15651
- // dst indices
15652
- const int i1 = iq1;
15653
- const int i2 = iq2;
15654
- const int i3 = iq3;
15655
-
15656
- // v indices
15657
- const int iv2 = iq2 % nev2;
15658
- const int iv3 = iq3;
15659
-
15660
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15661
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15662
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15663
- S16);
15614
+ // C doesn't have a functional sort, so we do a bubble sort instead
15615
+ for (int64_t j = 0; j < ne0; j++) {
15616
+ for (int64_t k = j + 1; k < ne0; k++) {
15617
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
15618
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
15619
+ int32_t tmp = dst_data[j];
15620
+ dst_data[j] = dst_data[k];
15621
+ dst_data[k] = tmp;
15622
+ }
15664
15623
  }
15665
15624
  }
15666
15625
  }
15667
15626
  }
15668
15627
 
15669
- static void ggml_compute_forward_flash_attn(
15670
- const struct ggml_compute_params * params,
15671
- const bool masked,
15672
- struct ggml_tensor * dst) {
15628
+ static void ggml_compute_forward_argsort(
15629
+ const struct ggml_compute_params * params,
15630
+ struct ggml_tensor * dst) {
15673
15631
 
15674
- const struct ggml_tensor * q = dst->src[0];
15632
+ const struct ggml_tensor * src0 = dst->src[0];
15675
15633
 
15676
- switch (q->type) {
15677
- case GGML_TYPE_F16:
15678
- {
15679
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15680
- } break;
15634
+ switch (src0->type) {
15681
15635
  case GGML_TYPE_F32:
15682
15636
  {
15683
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15637
+ ggml_compute_forward_argsort_f32(params, dst);
15684
15638
  } break;
15685
15639
  default:
15686
15640
  {
@@ -15719,9 +15673,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15719
15673
  GGML_ASSERT(ne0 == D);
15720
15674
  GGML_ASSERT(ne2 == N);
15721
15675
 
15722
- GGML_ASSERT(nbq0 == sizeof(float));
15723
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15724
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15676
+ // input tensor rows must be contiguous
15677
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15678
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15679
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15725
15680
 
15726
15681
  GGML_ASSERT(neq0 == D);
15727
15682
  GGML_ASSERT(nek0 == D);
@@ -15763,8 +15718,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15763
15718
  const int ir0 = dr*ith;
15764
15719
  const int ir1 = MIN(ir0 + dr, nr);
15765
15720
 
15766
- float scale = 1.0f;
15767
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15721
+ float scale = 1.0f;
15722
+ float max_bias = 0.0f;
15723
+
15724
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15725
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15726
+
15727
+ const uint32_t n_head = neq2;
15728
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
15729
+
15730
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15731
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15732
+
15733
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15734
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15735
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15736
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15768
15737
 
15769
15738
  // loop over n_batch and n_head
15770
15739
  for (int ir = ir0; ir < ir1; ++ir) {
@@ -15773,14 +15742,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15773
15742
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15774
15743
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15775
15744
 
15776
- float S = 0.0f;
15777
- float M = -INFINITY;
15745
+ const uint32_t h = iq2; // head index
15746
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15747
+
15748
+ float S = 0.0f; // sum
15749
+ float M = -INFINITY; // maximum KQ value
15778
15750
 
15779
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15780
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15781
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15751
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15752
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15753
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15754
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15782
15755
 
15783
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15756
+ if (v->type == GGML_TYPE_F16) {
15757
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15758
+ } else {
15759
+ memset(VKQ32, 0, D*sizeof(float));
15760
+ }
15784
15761
 
15785
15762
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15786
15763
 
@@ -15792,61 +15769,79 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15792
15769
  const int iv3 = iq3 / rv3;
15793
15770
  const int iv2 = iq2 / rv2;
15794
15771
 
15772
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15773
+ q_to_vec_dot(pq, Q_q, D);
15774
+
15795
15775
  // online softmax / attention
15796
15776
  // loop over n_kv and n_head_kv
15797
15777
  // ref: https://arxiv.org/pdf/2112.05682.pdf
15798
15778
  for (int64_t ic = 0; ic < nek1; ++ic) {
15799
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15779
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15800
15780
  if (mv == -INFINITY) {
15801
15781
  continue;
15802
15782
  }
15803
15783
 
15804
- float s;
15784
+ float s; // KQ value
15805
15785
 
15806
- // convert Q to F16 in V32
15807
- {
15808
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15786
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15787
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15809
15788
 
15810
- for (int64_t d = 0; d < D; ++d) {
15811
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15812
- }
15813
- }
15789
+ s = s*scale + mv; // scale KQ value and apply mask
15814
15790
 
15815
- ggml_vec_dot_f16(D,
15816
- &s, 0,
15817
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15818
- Q16, 0, 1);
15791
+ const float Mold = M;
15819
15792
 
15820
- s = s*scale + mv;
15793
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15794
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15821
15795
 
15822
- const float Mold = M;
15796
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15823
15797
 
15824
- float ms = 1.0f;
15825
- float vs = 1.0f;
15798
+ if (v->type== GGML_TYPE_F16) {
15799
+ if (s > M) {
15800
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15801
+ M = s;
15802
+ ms = expf(Mold - M);
15826
15803
 
15827
- if (s > M) {
15828
- M = s;
15829
- ms = expf(Mold - M);
15804
+ // V = V*expf(Mold - M)
15805
+ ggml_vec_scale_f16(D, VKQ16, ms);
15806
+ } else {
15807
+ // no new maximum, ms == 1.0f, vs != 1.0f
15808
+ vs = expf(s - M);
15809
+ }
15830
15810
 
15831
- // V = V*expf(Mold - M)
15832
- ggml_vec_scale_f16(D, V16, ms);
15811
+ // V += v*expf(s - M)
15812
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
15833
15813
  } else {
15834
- vs = expf(s - M);
15835
- }
15814
+ if (s > M) {
15815
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15816
+ M = s;
15817
+ ms = expf(Mold - M);
15818
+
15819
+ // V = V*expf(Mold - M)
15820
+ ggml_vec_scale_f32(D, VKQ32, ms);
15821
+ } else {
15822
+ // no new maximum, ms == 1.0f, vs != 1.0f
15823
+ vs = expf(s - M);
15824
+ }
15836
15825
 
15837
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15826
+ v_to_float(v_data, V32, D);
15838
15827
 
15839
- // V += v*expf(s - M)
15840
- ggml_vec_mad_f16(D, V16, v16, vs);
15828
+ // V += v*expf(s - M)
15829
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15830
+ }
15841
15831
 
15842
- S = S*ms + vs;
15832
+ S = S*ms + vs; // scale and increment sum with partial sum
15843
15833
  }
15844
15834
 
15845
- // V /= S
15846
- for (int64_t d = 0; d < D; ++d) {
15847
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15835
+ if (v->type == GGML_TYPE_F16) {
15836
+ for (int64_t d = 0; d < D; ++d) {
15837
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15838
+ }
15848
15839
  }
15849
15840
 
15841
+ // V /= S
15842
+ const float S_inv = 1.0f/S;
15843
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15844
+
15850
15845
  // dst indices
15851
15846
  const int i1 = iq1;
15852
15847
  const int i2 = iq2;
@@ -15856,7 +15851,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15856
15851
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
15857
15852
 
15858
15853
  // permute(0, 2, 1, 3)
15859
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15854
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
15860
15855
  }
15861
15856
  }
15862
15857
 
@@ -15867,7 +15862,7 @@ static void ggml_compute_forward_flash_attn_ext(
15867
15862
  const struct ggml_tensor * v,
15868
15863
  const struct ggml_tensor * mask,
15869
15864
  struct ggml_tensor * dst) {
15870
- switch (dst->op_params[1]) {
15865
+ switch (dst->op_params[2]) {
15871
15866
  case GGML_PREC_DEFAULT:
15872
15867
  case GGML_PREC_F32:
15873
15868
  {
@@ -15881,165 +15876,6 @@ static void ggml_compute_forward_flash_attn_ext(
15881
15876
  }
15882
15877
  }
15883
15878
 
15884
- // ggml_compute_forward_flash_ff
15885
-
15886
- static void ggml_compute_forward_flash_ff_f16(
15887
- const struct ggml_compute_params * params,
15888
- struct ggml_tensor * dst) {
15889
-
15890
- const struct ggml_tensor * a = dst->src[0]; // F16
15891
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
15892
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
15893
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
15894
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
15895
-
15896
- int64_t t0 = ggml_perf_time_us();
15897
- UNUSED(t0);
15898
-
15899
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
15900
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
15901
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
15902
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
15903
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
15904
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
15905
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
15906
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
15907
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
15908
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
15909
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15910
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15911
-
15912
- const int ith = params->ith;
15913
- const int nth = params->nth;
15914
-
15915
- const int64_t D = nea0;
15916
- //const int64_t N = nea1;
15917
- const int64_t M = neb01;
15918
-
15919
- GGML_ASSERT(ne0 == nea0);
15920
- GGML_ASSERT(ne1 == nea1);
15921
- GGML_ASSERT(ne2 == nea2);
15922
-
15923
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
15924
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
15925
- GGML_ASSERT(nbb10 == sizeof(float));
15926
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
15927
- GGML_ASSERT(nbc10 == sizeof(float));
15928
-
15929
- GGML_ASSERT(neb00 == D);
15930
- GGML_ASSERT(neb01 == M);
15931
- GGML_ASSERT(neb10 == M);
15932
- GGML_ASSERT(neb11 == 1);
15933
-
15934
- GGML_ASSERT(nec00 == M);
15935
- GGML_ASSERT(nec01 == D);
15936
- GGML_ASSERT(nec10 == D);
15937
- GGML_ASSERT(nec11 == 1);
15938
-
15939
- // dst cannot be transposed or permuted
15940
- GGML_ASSERT(nb0 == sizeof(float));
15941
- GGML_ASSERT(nb0 <= nb1);
15942
- GGML_ASSERT(nb1 <= nb2);
15943
- GGML_ASSERT(nb2 <= nb3);
15944
-
15945
- if (params->type == GGML_TASK_TYPE_INIT) {
15946
- return;
15947
- }
15948
-
15949
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15950
- return;
15951
- }
15952
-
15953
- // parallelize by a rows using ggml_vec_dot_f32
15954
-
15955
- // total rows in a
15956
- const int nr = nea1*nea2*nea3;
15957
-
15958
- // rows per thread
15959
- const int dr = (nr + nth - 1)/nth;
15960
-
15961
- // row range for this thread
15962
- const int ir0 = dr*ith;
15963
- const int ir1 = MIN(ir0 + dr, nr);
15964
-
15965
- for (int ir = ir0; ir < ir1; ++ir) {
15966
- // a indices
15967
- const int ia3 = ir/(nea2*nea1);
15968
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
15969
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
15970
-
15971
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
15972
-
15973
- for (int64_t ic = 0; ic < neb01; ++ic) {
15974
- // b0 indices
15975
- const int ib03 = ia3;
15976
- const int ib02 = ia2;
15977
- const int ib01 = ic;
15978
-
15979
- // S indices
15980
- const int i1 = ib01;
15981
-
15982
- ggml_vec_dot_f16(nea0,
15983
- S + i1, 0,
15984
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
15985
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
15986
- }
15987
-
15988
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
15989
- //ggml_vec_gelu_f32(neb01, S, S);
15990
-
15991
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
15992
-
15993
- for (int64_t i = 0; i < M; i++) {
15994
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15995
- }
15996
-
15997
- ggml_vec_gelu_f16(neb01, S16, S16);
15998
-
15999
- {
16000
- // dst indices
16001
- const int i1 = ia1;
16002
- const int i2 = ia2;
16003
- const int i3 = ia3;
16004
-
16005
- for (int64_t ic = 0; ic < nec01; ++ic) {
16006
-
16007
- ggml_vec_dot_f16(neb01,
16008
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16009
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16010
- S16, 0, 1);
16011
- }
16012
-
16013
- ggml_vec_add_f32(nec01,
16014
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16015
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16016
- (float *) c1->data);
16017
- }
16018
- }
16019
- }
16020
-
16021
- static void ggml_compute_forward_flash_ff(
16022
- const struct ggml_compute_params * params,
16023
- struct ggml_tensor * dst) {
16024
-
16025
- const struct ggml_tensor * b0 = dst->src[1];
16026
-
16027
- switch (b0->type) {
16028
- case GGML_TYPE_F16:
16029
- {
16030
- ggml_compute_forward_flash_ff_f16(params, dst);
16031
- } break;
16032
- case GGML_TYPE_F32:
16033
- {
16034
- GGML_ASSERT(false); // TODO
16035
- } break;
16036
- default:
16037
- {
16038
- GGML_ASSERT(false);
16039
- } break;
16040
- }
16041
- }
16042
-
16043
15879
  // ggml_compute_forward_flash_attn_back
16044
15880
 
16045
15881
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -16221,38 +16057,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
16221
16057
  vvexpf(SM, SM, &Mup);
16222
16058
  ggml_vec_sum_f32(Mup, &sum, SM);
16223
16059
  #else
16224
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
16225
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
16226
-
16227
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
16228
- if (i >= masked_begin) {
16229
- break;
16230
- }
16231
- float * SR = S + i;
16232
- float * SW = SM + i;
16233
-
16234
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
16235
- if (i + j >= masked_begin) {
16236
- break;
16237
- } else if (SR[j] == -INFINITY) {
16238
- SW[j] = 0.0f;
16239
- } else {
16240
- #ifndef GGML_FLASH_ATTN_EXP_FP16
16241
- const float val = expf(SR[j] - max);
16242
- #else
16243
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
16244
- memcpy(&scvt[j], &s, sizeof(uint16_t));
16245
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
16246
- #endif
16247
- sump[j] += (ggml_float)val;
16248
- SW[j] = val;
16249
- }
16250
- }
16251
- }
16252
-
16253
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
16254
- sum += sump[i];
16255
- }
16060
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
16256
16061
  #endif
16257
16062
  }
16258
16063
 
@@ -16834,6 +16639,10 @@ static void ggml_compute_forward_unary(
16834
16639
  {
16835
16640
  ggml_compute_forward_relu(params, dst);
16836
16641
  } break;
16642
+ case GGML_UNARY_OP_SIGMOID:
16643
+ {
16644
+ ggml_compute_forward_sigmoid(params, dst);
16645
+ } break;
16837
16646
  case GGML_UNARY_OP_GELU:
16838
16647
  {
16839
16648
  ggml_compute_forward_gelu(params, dst);
@@ -17274,35 +17083,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
17274
17083
  assert(!isnan(s1[i]));
17275
17084
  }
17276
17085
  #endif
17277
- // soft_max
17278
- ggml_float sum = 0.0;
17279
- {
17280
- float max = -INFINITY;
17281
- ggml_vec_max_f32(nc, &max, s0);
17282
17086
 
17283
- uint16_t scvt; UNUSED(scvt);
17284
- for (int i = 0; i < nc; i++) {
17285
- if (s0[i] == -INFINITY) {
17286
- st[i] = 0.0f;
17287
- } else {
17288
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17289
- const float s = s0[i] - max;
17290
- const float val = expf(s);
17291
- #else
17292
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17293
- memcpy(&scvt, &s, sizeof(scvt));
17294
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17295
- #endif
17296
- sum += (ggml_float)val;
17297
- st[i] = val;
17298
- }
17299
- }
17087
+ // soft_max
17088
+ float max = -INFINITY;
17089
+ ggml_vec_max_f32(nc, &max, s0);
17090
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17091
+ assert(sum > 0.0);
17092
+ sum = (1.0 - eps) / sum;
17300
17093
 
17301
- assert(sum > 0.0);
17302
- // sum = 1.0/sum;
17303
- }
17304
17094
  // avoid log(0) by rescaling from [0..1] to [eps..1]
17305
- sum = (1.0 - eps) / sum;
17306
17095
  ggml_vec_scale_f32(nc, st, sum);
17307
17096
  ggml_vec_add1_f32(nc, st, st, eps);
17308
17097
  ggml_vec_log_f32(nc, st, st);
@@ -17392,32 +17181,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17392
17181
  #endif
17393
17182
 
17394
17183
  // soft_max
17395
- ggml_float sum = 0.0;
17396
- {
17397
- float max = -INFINITY;
17398
- ggml_vec_max_f32(nc, &max, s0);
17399
-
17400
- uint16_t scvt; UNUSED(scvt);
17401
- for (int i = 0; i < nc; i++) {
17402
- if (s0[i] == -INFINITY) {
17403
- ds0[i] = 0.0f;
17404
- } else {
17405
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17406
- const float s = s0[i] - max;
17407
- const float val = expf(s);
17408
- #else
17409
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17410
- memcpy(&scvt, &s, sizeof(scvt));
17411
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17412
- #endif
17413
- sum += (ggml_float)val;
17414
- ds0[i] = val;
17415
- }
17416
- }
17417
-
17418
- assert(sum > 0.0);
17419
- sum = (1.0 - eps)/sum;
17420
- }
17184
+ float max = -INFINITY;
17185
+ ggml_vec_max_f32(nc, &max, s0);
17186
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17187
+ assert(sum > 0.0);
17188
+ sum = (1.0 - eps) / sum;
17421
17189
 
17422
17190
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17423
17191
  ggml_vec_scale_f32(nc, ds0, sum);
@@ -17454,7 +17222,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
17454
17222
 
17455
17223
  /////////////////////////////////
17456
17224
 
17457
- static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
17225
+ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
17458
17226
  GGML_ASSERT(params);
17459
17227
 
17460
17228
  if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -17552,7 +17320,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17552
17320
  } break;
17553
17321
  case GGML_OP_MUL_MAT:
17554
17322
  {
17555
- ggml_compute_forward_mul_mat(params, tensor);
17323
+ ggml_compute_forward_mul_mat(params, tensor, state);
17556
17324
  } break;
17557
17325
  case GGML_OP_MUL_MAT_ID:
17558
17326
  {
@@ -17630,10 +17398,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17630
17398
  {
17631
17399
  ggml_compute_forward_rope_back(params, tensor);
17632
17400
  } break;
17633
- case GGML_OP_ALIBI:
17634
- {
17635
- ggml_compute_forward_alibi(params, tensor);
17636
- } break;
17637
17401
  case GGML_OP_CLAMP:
17638
17402
  {
17639
17403
  ggml_compute_forward_clamp(params, tensor);
@@ -17682,21 +17446,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17682
17446
  {
17683
17447
  ggml_compute_forward_leaky_relu(params, tensor);
17684
17448
  } break;
17685
- case GGML_OP_FLASH_ATTN:
17686
- {
17687
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17688
- GGML_ASSERT(t == 0 || t == 1);
17689
- const bool masked = t != 0;
17690
- ggml_compute_forward_flash_attn(params, masked, tensor);
17691
- } break;
17692
17449
  case GGML_OP_FLASH_ATTN_EXT:
17693
17450
  {
17694
17451
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17695
17452
  } break;
17696
- case GGML_OP_FLASH_FF:
17697
- {
17698
- ggml_compute_forward_flash_ff(params, tensor);
17699
- } break;
17700
17453
  case GGML_OP_FLASH_ATTN_BACK:
17701
17454
  {
17702
17455
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -18066,6 +17819,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18066
17819
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18067
17820
  struct ggml_tensor * src0 = tensor->src[0];
18068
17821
  struct ggml_tensor * src1 = tensor->src[1];
17822
+ struct ggml_tensor * src2 = tensor->src[2];
18069
17823
 
18070
17824
  switch (tensor->op) {
18071
17825
  case GGML_OP_DUP:
@@ -18597,6 +18351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18597
18351
  ggml_rope_back(ctx,
18598
18352
  tensor->grad,
18599
18353
  src1,
18354
+ src2,
18600
18355
  n_dims,
18601
18356
  mode,
18602
18357
  n_ctx,
@@ -18636,6 +18391,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18636
18391
  ggml_rope_impl(ctx,
18637
18392
  tensor->grad,
18638
18393
  src1,
18394
+ src2,
18639
18395
  n_dims,
18640
18396
  mode,
18641
18397
  n_ctx,
@@ -18652,10 +18408,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18652
18408
  zero_table);
18653
18409
  }
18654
18410
  } break;
18655
- case GGML_OP_ALIBI:
18656
- {
18657
- GGML_ASSERT(false); // TODO: not implemented
18658
- } break;
18659
18411
  case GGML_OP_CLAMP:
18660
18412
  {
18661
18413
  GGML_ASSERT(false); // TODO: not implemented
@@ -18704,7 +18456,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18704
18456
  {
18705
18457
  GGML_ASSERT(false); // TODO: not implemented
18706
18458
  } break;
18707
- case GGML_OP_FLASH_ATTN:
18708
18459
  case GGML_OP_FLASH_ATTN_EXT:
18709
18460
  {
18710
18461
  struct ggml_tensor * flash_grad = NULL;
@@ -18721,7 +18472,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18721
18472
  masked);
18722
18473
  }
18723
18474
 
18724
- struct ggml_tensor * src2 = tensor->src[2];
18725
18475
  const int64_t elem_q = ggml_nelements(src0);
18726
18476
  const int64_t elem_k = ggml_nelements(src1);
18727
18477
  const int64_t elem_v = ggml_nelements(src2);
@@ -18759,10 +18509,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18759
18509
  zero_table);
18760
18510
  }
18761
18511
  } break;
18762
- case GGML_OP_FLASH_FF:
18763
- {
18764
- GGML_ASSERT(false); // not supported
18765
- } break;
18766
18512
  case GGML_OP_FLASH_ATTN_BACK:
18767
18513
  {
18768
18514
  GGML_ASSERT(false); // not supported
@@ -18826,6 +18572,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18826
18572
  zero_table);
18827
18573
  }
18828
18574
  } break;
18575
+ case GGML_UNARY_OP_SIGMOID:
18576
+ {
18577
+ GGML_ASSERT(false); // TODO: not implemented
18578
+ } break;
18829
18579
  case GGML_UNARY_OP_GELU:
18830
18580
  {
18831
18581
  GGML_ASSERT(false); // TODO: not implemented
@@ -19172,8 +18922,6 @@ typedef int ggml_lock_t;
19172
18922
 
19173
18923
  #define GGML_LOCK_INITIALIZER 0
19174
18924
 
19175
- typedef pthread_t ggml_thread_t;
19176
-
19177
18925
  #define ggml_thread_create pthread_create
19178
18926
  #define ggml_thread_join pthread_join
19179
18927
 
@@ -19199,8 +18947,6 @@ typedef int ggml_lock_t;
19199
18947
 
19200
18948
  #define GGML_LOCK_INITIALIZER 0
19201
18949
 
19202
- typedef pthread_t ggml_thread_t;
19203
-
19204
18950
  #define ggml_thread_create pthread_create
19205
18951
  #define ggml_thread_join pthread_join
19206
18952
 
@@ -19280,31 +19026,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
19280
19026
  static void clear_numa_thread_affinity(void) {}
19281
19027
  #endif
19282
19028
 
19283
- struct ggml_compute_state_shared {
19284
- const struct ggml_cgraph * cgraph;
19285
- const struct ggml_cplan * cplan;
19286
-
19287
- int64_t perf_node_start_cycles;
19288
- int64_t perf_node_start_time_us;
19289
-
19290
- const int n_threads;
19291
-
19292
- // synchronization primitives
19293
- atomic_int n_active; // num active threads
19294
- atomic_int node_n; // active graph node
19295
- atomic_int node_task; // active graph node task phase
19296
-
19297
- ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
19298
- void * abort_callback_data;
19299
- };
19300
-
19301
- struct ggml_compute_state {
19302
- ggml_thread_t thrd;
19303
- int ith;
19304
- struct ggml_compute_state_shared * shared;
19305
- enum ggml_status ec;
19306
- };
19307
-
19308
19029
  static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
19309
19030
  int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles;
19310
19031
  int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
@@ -19355,6 +19076,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19355
19076
  case GGML_UNARY_OP_TANH:
19356
19077
  case GGML_UNARY_OP_ELU:
19357
19078
  case GGML_UNARY_OP_RELU:
19079
+ case GGML_UNARY_OP_SIGMOID:
19358
19080
  case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
19359
19081
  case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
19360
19082
  {
@@ -19428,10 +19150,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19428
19150
  {
19429
19151
  n_tasks = n_threads;
19430
19152
  } break;
19431
- case GGML_OP_ALIBI:
19432
- {
19433
- n_tasks = 1; //TODO
19434
- } break;
19435
19153
  case GGML_OP_CLAMP:
19436
19154
  {
19437
19155
  n_tasks = 1; //TODO
@@ -19477,15 +19195,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19477
19195
  {
19478
19196
  n_tasks = n_threads;
19479
19197
  } break;
19480
- case GGML_OP_FLASH_ATTN:
19481
19198
  case GGML_OP_FLASH_ATTN_EXT:
19482
19199
  {
19483
19200
  n_tasks = n_threads;
19484
19201
  } break;
19485
- case GGML_OP_FLASH_FF:
19486
- {
19487
- n_tasks = n_threads;
19488
- } break;
19489
19202
  case GGML_OP_FLASH_ATTN_BACK:
19490
19203
  {
19491
19204
  n_tasks = n_threads;
@@ -19580,6 +19293,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
19580
19293
 
19581
19294
  * node_n = atomic_load(&state->shared->node_n);
19582
19295
  if (* node_n != last_node_n) break;
19296
+ #if defined(__SSE3__)
19297
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19298
+ _mm_pause();
19299
+ #endif
19583
19300
  }
19584
19301
  }
19585
19302
 
@@ -19594,6 +19311,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
19594
19311
 
19595
19312
  * task_phase = atomic_load(&state->shared->node_task);
19596
19313
  if (* task_phase != last_task_phase) break;
19314
+ #if defined(__SSE3__)
19315
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19316
+ _mm_pause();
19317
+ #endif
19597
19318
  }
19598
19319
  }
19599
19320
 
@@ -19633,7 +19354,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19633
19354
  struct ggml_tensor * node = cgraph->nodes[node_n];
19634
19355
  if (GGML_OP_HAS_FINALIZE[node->op]) {
19635
19356
  params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
19636
- ggml_compute_forward(&params, node);
19357
+ ggml_compute_forward(&params, node, state);
19637
19358
  }
19638
19359
  ggml_graph_compute_perf_stats_node(node, state->shared);
19639
19360
  }
@@ -19653,17 +19374,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19653
19374
  /* INIT */
19654
19375
  if (GGML_OP_HAS_INIT[node->op]) {
19655
19376
  params.type = GGML_TASK_TYPE_INIT;
19656
- ggml_compute_forward(&params, node);
19377
+ ggml_compute_forward(&params, node, state);
19657
19378
  }
19658
19379
 
19659
19380
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
19660
19381
  // they do something more efficient than spinning (?)
19661
19382
  params.type = GGML_TASK_TYPE_COMPUTE;
19662
- ggml_compute_forward(&params, node);
19383
+ ggml_compute_forward(&params, node, state);
19663
19384
 
19664
19385
  if (GGML_OP_HAS_FINALIZE[node->op]) {
19665
19386
  params.type = GGML_TASK_TYPE_FINALIZE;
19666
- ggml_compute_forward(&params, node);
19387
+ ggml_compute_forward(&params, node, state);
19667
19388
  }
19668
19389
 
19669
19390
  ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -19702,7 +19423,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19702
19423
 
19703
19424
  if (state->ith < n_tasks) {
19704
19425
  if (GGML_OP_HAS_INIT[node->op]) {
19705
- ggml_compute_forward(&params, node);
19426
+ ggml_compute_forward(&params, node, state);
19706
19427
  }
19707
19428
  }
19708
19429
 
@@ -19723,7 +19444,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19723
19444
 
19724
19445
  if (state->ith < n_tasks) {
19725
19446
  params.type = GGML_TASK_TYPE_COMPUTE;
19726
- ggml_compute_forward(&params, node);
19447
+ ggml_compute_forward(&params, node, state);
19727
19448
  }
19728
19449
 
19729
19450
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -19874,39 +19595,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19874
19595
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19875
19596
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19876
19597
  } break;
19877
- case GGML_OP_FLASH_ATTN:
19878
- {
19879
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19880
-
19881
- if (node->src[1]->type == GGML_TYPE_F32) {
19882
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19883
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19884
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19885
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19886
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19890
- }
19891
- } break;
19892
19598
  case GGML_OP_FLASH_ATTN_EXT:
19893
19599
  {
19894
19600
  const int64_t ne00 = node->src[0]->ne[0]; // D
19895
19601
 
19896
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19897
- } break;
19898
- case GGML_OP_FLASH_FF:
19899
- {
19900
- if (node->src[1]->type == GGML_TYPE_F32) {
19901
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19902
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19903
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19904
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19905
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19909
- }
19602
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19910
19603
  } break;
19911
19604
  case GGML_OP_FLASH_ATTN_BACK:
19912
19605
  {
@@ -19974,6 +19667,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19974
19667
  /*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
19975
19668
  /*.abort_callback =*/ NULL,
19976
19669
  /*.abort_callback_data =*/ NULL,
19670
+ /*.current_chunk; =*/ 0,
19977
19671
  };
19978
19672
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
19979
19673
 
@@ -21747,11 +21441,7 @@ size_t ggml_quantize_chunk(
21747
21441
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21748
21442
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21749
21443
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21750
- #if QK_K == 64
21751
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21752
- #else
21753
21444
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21754
- #endif
21755
21445
  case GGML_TYPE_F16:
21756
21446
  {
21757
21447
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -23028,6 +22718,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23028
22718
  #endif
23029
22719
  }
23030
22720
 
22721
+ int ggml_cpu_has_avx512_bf16(void) {
22722
+ #if defined(__AVX512BF16__)
22723
+ return 1;
22724
+ #else
22725
+ return 0;
22726
+ #endif
22727
+ }
22728
+
23031
22729
  int ggml_cpu_has_fma(void) {
23032
22730
  #if defined(__FMA__)
23033
22731
  return 1;