@fugood/llama.node 0.2.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/bin/darwin/arm64/default.metallib +0 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/default.metallib +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/LlamaContext.cpp +2 -2
  19. package/src/llama.cpp/CMakeLists.txt +72 -46
  20. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +16 -0
  21. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +6 -0
  22. package/src/llama.cpp/common/common.cpp +732 -752
  23. package/src/llama.cpp/common/common.h +47 -41
  24. package/src/llama.cpp/common/grammar-parser.cpp +1 -1
  25. package/src/llama.cpp/common/json-schema-to-grammar.cpp +6 -6
  26. package/src/llama.cpp/common/log.h +5 -5
  27. package/src/llama.cpp/common/sampling.cpp +89 -7
  28. package/src/llama.cpp/common/sampling.h +5 -0
  29. package/src/llama.cpp/common/train.cpp +2 -2
  30. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  31. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  32. package/src/llama.cpp/examples/embedding/embedding.cpp +3 -2
  33. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +2 -2
  34. package/src/llama.cpp/examples/finetune/finetune.cpp +4 -3
  35. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -2
  36. package/src/llama.cpp/examples/infill/infill.cpp +8 -8
  37. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  38. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +13 -8
  39. package/src/llama.cpp/examples/llava/clip.h +1 -1
  40. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  41. package/src/llama.cpp/examples/llava/llava.cpp +0 -15
  42. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  43. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  44. package/src/llama.cpp/examples/main/main.cpp +24 -16
  45. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  46. package/src/llama.cpp/examples/perplexity/perplexity.cpp +9 -9
  47. package/src/llama.cpp/examples/quantize/quantize.cpp +2 -2
  48. package/src/llama.cpp/examples/retrieval/retrieval.cpp +2 -2
  49. package/src/llama.cpp/examples/rpc/rpc-server.cpp +78 -14
  50. package/src/llama.cpp/examples/server/server.cpp +21 -9
  51. package/src/llama.cpp/examples/tokenize/tokenize.cpp +359 -9
  52. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +4 -3
  53. package/src/llama.cpp/ggml-backend.c +0 -1
  54. package/src/llama.cpp/ggml-common.h +0 -54
  55. package/src/llama.cpp/ggml-cuda.h +1 -0
  56. package/src/llama.cpp/ggml-impl.h +51 -0
  57. package/src/llama.cpp/ggml-kompute.cpp +4 -0
  58. package/src/llama.cpp/ggml-opencl.cpp +4 -1
  59. package/src/llama.cpp/ggml-quants.c +3700 -2041
  60. package/src/llama.cpp/ggml-rpc.cpp +188 -56
  61. package/src/llama.cpp/ggml-sycl.cpp +99 -530
  62. package/src/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
  63. package/src/llama.cpp/ggml-vulkan.cpp +202 -225
  64. package/src/llama.cpp/ggml.c +1034 -1154
  65. package/src/llama.cpp/ggml.h +59 -31
  66. package/src/llama.cpp/llama.cpp +859 -609
  67. package/src/llama.cpp/llama.h +19 -6
  68. package/src/llama.cpp/requirements.txt +0 -1
  69. package/src/llama.cpp/tests/test-backend-ops.cpp +113 -47
  70. package/src/llama.cpp/tests/test-chat-template.cpp +16 -4
  71. package/src/llama.cpp/tests/test-grad0.cpp +43 -83
  72. package/src/llama.cpp/unicode-data.cpp +6969 -2169
  73. package/src/llama.cpp/unicode-data.h +15 -12
  74. package/src/llama.cpp/unicode.cpp +89 -111
  75. package/src/llama.cpp/unicode.h +44 -12
  76. package/src/llama.cpp/build.zig +0 -172
  77. package/src/llama.cpp/ggml-mpi.c +0 -216
  78. package/src/llama.cpp/ggml-mpi.h +0 -39
  79. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +0 -2
@@ -112,6 +112,8 @@ typedef void * thread_ret_t;
112
112
 
113
113
  #endif
114
114
 
115
+ typedef pthread_t ggml_thread_t;
116
+
115
117
  #ifdef GGML_USE_CPU_HBM
116
118
  #include <hbwmalloc.h>
117
119
  #endif
@@ -163,9 +165,6 @@ void ggml_print_backtrace(void) {
163
165
  #define GGML_DEBUG 0
164
166
  #define GGML_GELU_FP16
165
167
  #define GGML_GELU_QUICK_FP16
166
- #define GGML_SILU_FP16
167
- // #define GGML_CROSS_ENTROPY_EXP_FP16
168
- // #define GGML_FLASH_ATTN_EXP_FP16
169
168
 
170
169
  #define GGML_SOFT_MAX_UNROLL 4
171
170
  #define GGML_VEC_DOT_UNROLL 2
@@ -316,12 +315,6 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
316
315
  // precomputed quick gelu table for f16 (128 KB)
317
316
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
318
317
 
319
- // precomputed silu table for f16 (128 KB)
320
- static ggml_fp16_t ggml_table_silu_f16[1 << 16];
321
-
322
- // precomputed exp table for f16 (128 KB)
323
- static ggml_fp16_t ggml_table_exp_f16[1 << 16];
324
-
325
318
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
326
319
  float ggml_table_f32_f16[1 << 16];
327
320
 
@@ -413,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
413
406
  int i = 0;
414
407
  #if defined(__AVX512BF16__)
415
408
  for (; i + 32 <= n; i += 32) {
416
- _mm512_storeu_ps(
417
- (__m512 *)(y + i),
418
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
419
- _mm512_loadu_ps(x + i)));
409
+ _mm512_storeu_si512(
410
+ (__m512i *)(y + i),
411
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i))));
420
413
  }
421
414
  #endif
422
415
  for (; i < n; i++) {
@@ -878,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
878
871
  },
879
872
  [GGML_TYPE_IQ4_XS] = {
880
873
  .type_name = "iq4_xs",
881
- #if QK_K == 64
882
- .blck_size = QK4_NL,
883
- #else
884
874
  .blck_size = QK_K,
885
- #endif
886
875
  .type_size = sizeof(block_iq4_xs),
887
876
  .is_quantized = true,
888
877
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
889
878
  .from_float = quantize_row_iq4_xs,
890
879
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
891
880
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
892
- #if QK_K == 64
893
- .vec_dot_type = GGML_TYPE_Q8_0,
894
- #else
895
881
  .vec_dot_type = GGML_TYPE_Q8_K,
896
- #endif
897
882
  .nrows = 1,
898
883
  },
899
884
  [GGML_TYPE_Q8_K] = {
@@ -1306,6 +1291,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1306
1291
  #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
1307
1292
  #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
1308
1293
  #define GGML_F16_VEC_FMA GGML_F32x4_FMA
1294
+ #define GGML_F16_VEC_ADD GGML_F32x4_ADD
1295
+ #define GGML_F16_VEC_MUL GGML_F32x4_MUL
1309
1296
  #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
1310
1297
  // Use vec_xl, not vec_ld, in case the load address is not aligned.
1311
1298
  #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
@@ -1528,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1528
1515
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1529
1516
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1530
1517
 
1518
+ #elif defined(__loongarch_asx)
1519
+
1520
+ #define GGML_SIMD
1521
+
1522
+ // F32 LASX
1523
+ #define GGML_F32_STEP 32
1524
+ #define GGML_F32_EPR 8
1525
+
1526
+ #define GGML_F32x8 __m256
1527
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1528
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1529
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1530
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1531
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1532
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1533
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1534
+ #define GGML_F32x8_REDUCE(res, x) \
1535
+ do { \
1536
+ int offset = GGML_F32_ARR >> 1; \
1537
+ for (int i = 0; i < offset; ++i) { \
1538
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1539
+ } \
1540
+ offset >>= 1; \
1541
+ for (int i = 0; i < offset; ++i) { \
1542
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1543
+ } \
1544
+ offset >>= 1; \
1545
+ for (int i = 0; i < offset; ++i) { \
1546
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1547
+ } \
1548
+ float *tmp_p = (float *)&x[0]; \
1549
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1550
+ } while (0)
1551
+ // TODO: is this optimal ?
1552
+
1553
+ #define GGML_F32_VEC GGML_F32x8
1554
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1555
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1556
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1557
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1558
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1559
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1560
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1561
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1562
+
1563
+ // F16 LASX
1564
+
1565
+ #define GGML_F16_STEP 32
1566
+ #define GGML_F16_EPR 8
1567
+
1568
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
+
1570
+ #define GGML_F32Cx8 __m256
1571
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
+
1574
+ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1575
+ float tmp[8];
1576
+
1577
+ for (int i = 0; i < 8; i++) {
1578
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1579
+ }
1580
+
1581
+ return (__m256)__lasx_xvld(tmp, 0);
1582
+ }
1583
+ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1584
+ float arr[8];
1585
+
1586
+ __lasx_xvst(y, arr, 0);
1587
+
1588
+ for (int i = 0; i < 8; i++)
1589
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1590
+ }
1591
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1593
+
1594
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1595
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1596
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1597
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1598
+
1599
+ #define GGML_F16_VEC GGML_F32Cx8
1600
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1601
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1602
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1603
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1604
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1605
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1606
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1607
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1608
+
1609
+ #elif defined(__loongarch_sx)
1610
+
1611
+ #define GGML_SIMD
1612
+
1613
+ // F32 LSX
1614
+
1615
+ #define GGML_F32_STEP 32
1616
+ #define GGML_F32_EPR 4
1617
+
1618
+ #define GGML_F32x4 __m128
1619
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1620
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1621
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1622
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1623
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1624
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1625
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1626
+ #define GGML_F32x4_REDUCE(res, x) \
1627
+ { \
1628
+ int offset = GGML_F32_ARR >> 1; \
1629
+ for (int i = 0; i < offset; ++i) { \
1630
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1631
+ } \
1632
+ offset >>= 1; \
1633
+ for (int i = 0; i < offset; ++i) { \
1634
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1635
+ } \
1636
+ offset >>= 1; \
1637
+ for (int i = 0; i < offset; ++i) { \
1638
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1639
+ } \
1640
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1641
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1642
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1643
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1644
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1645
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1646
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1647
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1648
+ }
1649
+
1650
+ #define GGML_F32_VEC GGML_F32x4
1651
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1652
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1653
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1654
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1655
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1656
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1657
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1658
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1659
+
1660
+ // F16 LSX
1661
+
1662
+ #define GGML_F16_STEP 32
1663
+ #define GGML_F16_EPR 4
1664
+
1665
+ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1666
+ float tmp[4];
1667
+
1668
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1669
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1670
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1671
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1672
+
1673
+ return __lsx_vld(tmp, 0);
1674
+ }
1675
+
1676
+ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1677
+ float arr[4];
1678
+
1679
+ __lsx_vst(y, arr, 0);
1680
+
1681
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1682
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1683
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1684
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1685
+ }
1686
+
1687
+ #define GGML_F32Cx4 __m128
1688
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1689
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1690
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1691
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1692
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1693
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1694
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1695
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1696
+
1697
+ #define GGML_F16_VEC GGML_F32Cx4
1698
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1699
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1700
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1701
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1702
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1703
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1704
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1705
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1706
+
1531
1707
  #endif
1532
1708
 
1533
1709
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1537,6 +1713,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1537
1713
  #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1538
1714
  #endif
1539
1715
 
1716
+ //
1717
+ // ggml context
1718
+ //
1719
+
1720
+ struct ggml_context {
1721
+ size_t mem_size;
1722
+ void* mem_buffer;
1723
+ bool mem_buffer_owned;
1724
+ bool no_alloc;
1725
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
1726
+
1727
+ int n_objects;
1728
+
1729
+ struct ggml_object* objects_begin;
1730
+ struct ggml_object* objects_end;
1731
+
1732
+ struct ggml_scratch scratch;
1733
+ struct ggml_scratch scratch_save;
1734
+ };
1735
+
1736
+ struct ggml_context_container {
1737
+ bool used;
1738
+
1739
+ struct ggml_context context;
1740
+ };
1741
+
1742
+ struct ggml_compute_state_shared {
1743
+ const struct ggml_cgraph* cgraph;
1744
+ const struct ggml_cplan* cplan;
1745
+
1746
+ int64_t perf_node_start_cycles;
1747
+ int64_t perf_node_start_time_us;
1748
+
1749
+ const int n_threads;
1750
+
1751
+ // synchronization primitives
1752
+ atomic_int n_active; // num active threads
1753
+ atomic_int node_n; // active graph node
1754
+ atomic_int node_task; // active graph node task phase
1755
+
1756
+ ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
1757
+ void* abort_callback_data;
1758
+
1759
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1760
+ };
1761
+
1762
+ struct ggml_compute_state {
1763
+ ggml_thread_t thrd;
1764
+ int ith;
1765
+ struct ggml_compute_state_shared* shared;
1766
+ enum ggml_status ec;
1767
+ };
1768
+
1540
1769
  //
1541
1770
  // fundamental operations
1542
1771
  //
@@ -1618,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1618
1847
  __m512 c1 = _mm512_setzero_ps();
1619
1848
  __m512 c2 = _mm512_setzero_ps();
1620
1849
  for (; i + 64 <= n; i += 64) {
1621
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1622
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1623
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1624
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1850
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1851
+ m512bh(_mm512_loadu_si512((y + i))));
1852
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1853
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1625
1854
  }
1626
1855
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1627
1856
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -2028,52 +2257,291 @@ inline static float ggml_silu_f32(float x) {
2028
2257
  return x/(1.0f + expf(-x));
2029
2258
  }
2030
2259
 
2031
- //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
2032
- // const uint16_t * i16 = (const uint16_t *) x;
2033
- // for (int i = 0; i < n; ++i) {
2034
- // y[i] = ggml_table_silu_f16[i16[i]];
2035
- // }
2036
- //}
2260
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2261
+
2262
+ // adapted from arm limited optimized routine
2263
+ // the maximum error is 1.45358 plus 0.5 ulps
2264
+ // numbers above 88.38 will flush to infinity
2265
+ // numbers beneath -103.97 will flush to zero
2266
+ inline static float32x4_t ggml_v_expf(float32x4_t x) {
2267
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
2268
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
2269
+ const float32x4_t n = vsubq_f32(z, r);
2270
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
2271
+ vdupq_n_f32(0x1.7f7d1cp-20f));
2272
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
2273
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
2274
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
2275
+ const float32x4_t u = vmulq_f32(b, b);
2276
+ const float32x4_t j = vfmaq_f32(
2277
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
2278
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
2279
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
2280
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
2281
+ return vfmaq_f32(k, j, k);
2282
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
2283
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
2284
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
2285
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
2286
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
2287
+ }
2288
+
2289
+ // computes silu x/(1+exp(-x)) in single precision vector
2290
+ inline static float32x4_t ggml_v_silu(float32x4_t x) {
2291
+ const float32x4_t one = vdupq_n_f32(1.0f);
2292
+ const float32x4_t zero = vdupq_n_f32(0.0f);
2293
+ const float32x4_t neg_x = vsubq_f32(zero, x);
2294
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
2295
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
2296
+ return vdivq_f32(x, one_plus_exp_neg_x);
2297
+ }
2298
+
2299
+ #elif defined(__AVX512F__) && defined(__AVX512DQ__)
2300
+
2301
+ // adapted from arm limited optimized routine
2302
+ // the maximum error is 1.45358 plus 0.5 ulps
2303
+ // numbers above 88.38 will flush to infinity
2304
+ // numbers beneath -103.97 will flush to zero
2305
+ inline static __m512 ggml_v_expf(__m512 x) {
2306
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
2307
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2308
+ const __m512 n = _mm512_sub_ps(z, r);
2309
+ const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2310
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2311
+ const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2312
+ const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2313
+ const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2314
+ const __m512 u = _mm512_mul_ps(b, b);
2315
+ const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2316
+ _mm512_set1_ps(0x1.573e2ep-5f)), u,
2317
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2318
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2319
+ u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2320
+ if (_mm512_kortestz(c, c))
2321
+ return _mm512_fmadd_ps(j, k, k);
2322
+ const __m512i g = _mm512_and_si512(
2323
+ _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2324
+ _mm512_set1_epi32(0x82000000u));
2325
+ const __m512 s1 =
2326
+ _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2327
+ const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2328
+ const __mmask16 d =
2329
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2330
+ return _mm512_mask_blend_ps(
2331
+ d, _mm512_mask_blend_ps(
2332
+ c, _mm512_fmadd_ps(k, j, k),
2333
+ _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2334
+ _mm512_mul_ps(s1, s1));
2335
+ }
2336
+
2337
+ // computes silu x/(1+exp(-x)) in single precision vector
2338
+ inline static __m512 ggml_v_silu(__m512 x) {
2339
+ const __m512 one = _mm512_set1_ps(1);
2340
+ const __m512 zero = _mm512_setzero_ps();
2341
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
2342
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
2343
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
2344
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
2345
+ }
2346
+
2347
+ #elif defined(__AVX2__) && defined(__FMA__)
2348
+
2349
+ // adapted from arm limited optimized routine
2350
+ // the maximum error is 1.45358 plus 0.5 ulps
2351
+ // numbers above 88.38 will flush to infinity
2352
+ // numbers beneath -103.97 will flush to zero
2353
+ inline static __m256 ggml_v_expf(__m256 x) {
2354
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
2355
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
2356
+ const __m256 n = _mm256_sub_ps(z, r);
2357
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
2358
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
2359
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
2360
+ const __m256 k = _mm256_castsi256_ps(
2361
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
2362
+ const __m256i c = _mm256_castps_si256(
2363
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2364
+ _mm256_set1_ps(126), _CMP_GT_OQ));
2365
+ const __m256 u = _mm256_mul_ps(b, b);
2366
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
2367
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
2368
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
2369
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
2370
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
2371
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
2372
+ return _mm256_fmadd_ps(j, k, k);
2373
+ const __m256i g = _mm256_and_si256(
2374
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
2375
+ _mm256_set1_epi32(0x82000000u));
2376
+ const __m256 s1 =
2377
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
2378
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
2379
+ const __m256i d = _mm256_castps_si256(
2380
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2381
+ _mm256_set1_ps(192), _CMP_GT_OQ));
2382
+ return _mm256_or_ps(
2383
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
2384
+ _mm256_andnot_ps(
2385
+ _mm256_castsi256_ps(d),
2386
+ _mm256_or_ps(
2387
+ _mm256_and_ps(_mm256_castsi256_ps(c),
2388
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
2389
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
2390
+ }
2391
+
2392
+ // computes silu x/(1+exp(-x)) in single precision vector
2393
+ inline static __m256 ggml_v_silu(__m256 x) {
2394
+ const __m256 one = _mm256_set1_ps(1);
2395
+ const __m256 zero = _mm256_setzero_ps();
2396
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
2397
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
2398
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
2399
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
2400
+ }
2401
+
2402
+ #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
2037
2403
 
2038
- #ifdef GGML_SILU_FP16
2039
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2040
- uint16_t t;
2041
- for (int i = 0; i < n; ++i) {
2042
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2043
- memcpy(&t, &fp16, sizeof(uint16_t));
2044
- y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
2045
- }
2046
- }
2404
+ #if defined(__FMA__)
2405
+ #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
2406
+ #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
2047
2407
  #else
2048
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2049
- for (int i = 0; i < n; ++i) {
2408
+ #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
2409
+ #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
2410
+ #endif
2411
+
2412
+ // adapted from arm limited optimized routine
2413
+ // the maximum error is 1.45358 plus 0.5 ulps
2414
+ // numbers above 88.38 will flush to infinity
2415
+ // numbers beneath -103.97 will flush to zero
2416
+ inline static __m128 ggml_v_expf(__m128 x) {
2417
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
2418
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
2419
+ const __m128 n = _mm_sub_ps(z, r);
2420
+ const __m128 b =
2421
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
2422
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
2423
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
2424
+ const __m128i c =
2425
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
2426
+ const __m128 u = _mm_mul_ps(b, b);
2427
+ const __m128 j =
2428
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
2429
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
2430
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
2431
+ if (!_mm_movemask_epi8(c))
2432
+ return MADD128(j, k, k);
2433
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
2434
+ _mm_set1_epi32(0x82000000u));
2435
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
2436
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
2437
+ const __m128i d =
2438
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
2439
+ return _mm_or_ps(
2440
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
2441
+ _mm_andnot_ps(_mm_castsi128_ps(d),
2442
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
2443
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
2444
+ }
2445
+
2446
+ // computes silu x/(1+exp(-x)) in single precision vector
2447
+ inline static __m128 ggml_v_silu(__m128 x) {
2448
+ const __m128 one = _mm_set1_ps(1);
2449
+ const __m128 zero = _mm_setzero_ps();
2450
+ const __m128 neg_x = _mm_sub_ps(zero, x);
2451
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
2452
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
2453
+ return _mm_div_ps(x, one_plus_exp_neg_x);
2454
+ }
2455
+
2456
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__
2457
+
2458
+ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2459
+ int i = 0;
2460
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2461
+ for (; i + 15 < n; i += 16) {
2462
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
2463
+ }
2464
+ #elif defined(__AVX2__) && defined(__FMA__)
2465
+ for (; i + 7 < n; i += 8) {
2466
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
2467
+ }
2468
+ #elif defined(__SSE2__)
2469
+ for (; i + 3 < n; i += 4) {
2470
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2471
+ }
2472
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2473
+ for (; i + 3 < n; i += 4) {
2474
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2475
+ }
2476
+ #endif
2477
+ for (; i < n; ++i) {
2050
2478
  y[i] = ggml_silu_f32(x[i]);
2051
2479
  }
2052
2480
  }
2481
+
2482
+ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
2483
+ int i = 0;
2484
+ ggml_float sum = 0;
2485
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2486
+ for (; i + 15 < n; i += 16) {
2487
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
2488
+ _mm512_set1_ps(max)));
2489
+ _mm512_storeu_ps(y + i, val);
2490
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
2491
+ }
2492
+ #elif defined(__AVX2__) && defined(__FMA__)
2493
+ for (; i + 7 < n; i += 8) {
2494
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
2495
+ _mm256_set1_ps(max)));
2496
+ _mm256_storeu_ps(y + i, val);
2497
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
2498
+ _mm256_castps256_ps128(val));
2499
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
2500
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
2501
+ sum += (ggml_float)_mm_cvtss_f32(val2);
2502
+ }
2503
+ #elif defined(__SSE2__)
2504
+ for (; i + 3 < n; i += 4) {
2505
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
2506
+ _mm_set1_ps(max)));
2507
+ _mm_storeu_ps(y + i, val);
2508
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
2509
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
2510
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
2511
+ #else
2512
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
2513
+ val = _mm_add_ps(val, tmp);
2514
+ tmp = _mm_movehl_ps(tmp, val);
2515
+ val = _mm_add_ss(val, tmp);
2516
+ #endif
2517
+ sum += (ggml_float)_mm_cvtss_f32(val);
2518
+ }
2519
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2520
+ for (; i + 3 < n; i += 4) {
2521
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2522
+ vdupq_n_f32(max)));
2523
+ vst1q_f32(y + i, val);
2524
+ sum += (ggml_float)vaddvq_f32(val);
2525
+ }
2053
2526
  #endif
2527
+ for (; i < n; ++i) {
2528
+ float val = expf(x[i] - max);
2529
+ sum += (ggml_float)val;
2530
+ y[i] = val;
2531
+ }
2532
+ return sum;
2533
+ }
2054
2534
 
2055
2535
  inline static float ggml_silu_backward_f32(float x, float dy) {
2056
2536
  const float s = 1.0f/(1.0f + expf(-x));
2057
2537
  return dy*s*(1.0f + x*(1.0f - s));
2058
2538
  }
2059
2539
 
2060
- #ifdef GGML_SILU_FP16
2061
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2062
- for (int i = 0; i < n; ++i) {
2063
- // we did not use x[i] to compute forward silu but its f16 equivalent
2064
- // take derivative at f16 of x[i]:
2065
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2066
- float usedx = GGML_FP16_TO_FP32(fp16);
2067
- dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
2068
- }
2069
- }
2070
- #else
2071
2540
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2072
2541
  for (int i = 0; i < n; ++i) {
2073
2542
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
2074
2543
  }
2075
2544
  }
2076
- #endif
2077
2545
 
2078
2546
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
2079
2547
  #ifndef GGML_USE_ACCELERATE
@@ -2202,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2202
2670
  "ARGSORT",
2203
2671
  "LEAKY_RELU",
2204
2672
 
2205
- "FLASH_ATTN",
2206
2673
  "FLASH_ATTN_EXT",
2207
- "FLASH_FF",
2208
2674
  "FLASH_ATTN_BACK",
2209
2675
  "SSM_CONV",
2210
2676
  "SSM_SCAN",
@@ -2230,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2230
2696
  "CROSS_ENTROPY_LOSS_BACK",
2231
2697
  };
2232
2698
 
2233
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2699
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2234
2700
 
2235
2701
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2236
2702
  "none",
@@ -2292,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2292
2758
  "argsort(x)",
2293
2759
  "leaky_relu(x)",
2294
2760
 
2295
- "flash_attn(x)",
2296
2761
  "flash_attn_ext(x)",
2297
- "flash_ff(x)",
2298
2762
  "flash_attn_back(x)",
2299
2763
  "ssm_conv(x)",
2300
2764
  "ssm_scan(x)",
@@ -2320,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2320
2784
  "cross_entropy_loss_back(x,y)",
2321
2785
  };
2322
2786
 
2323
- static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2787
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2324
2788
 
2325
2789
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2326
2790
 
@@ -2383,32 +2847,6 @@ static void ggml_setup_op_has_task_pass(void) {
2383
2847
  }
2384
2848
  }
2385
2849
 
2386
- //
2387
- // ggml context
2388
- //
2389
-
2390
- struct ggml_context {
2391
- size_t mem_size;
2392
- void * mem_buffer;
2393
- bool mem_buffer_owned;
2394
- bool no_alloc;
2395
- bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
2396
-
2397
- int n_objects;
2398
-
2399
- struct ggml_object * objects_begin;
2400
- struct ggml_object * objects_end;
2401
-
2402
- struct ggml_scratch scratch;
2403
- struct ggml_scratch scratch_save;
2404
- };
2405
-
2406
- struct ggml_context_container {
2407
- bool used;
2408
-
2409
- struct ggml_context context;
2410
- };
2411
-
2412
2850
  //
2413
2851
  // NUMA support
2414
2852
  //
@@ -2822,6 +3260,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
2822
3260
  (t0->ne[3] == t1->ne[3] );
2823
3261
  }
2824
3262
 
3263
+ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3264
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3265
+
3266
+ return
3267
+ (t0->nb[0] == t1->nb[0] ) &&
3268
+ (t0->nb[1] == t1->nb[1] ) &&
3269
+ (t0->nb[2] == t1->nb[2] ) &&
3270
+ (t0->nb[3] == t1->nb[3] );
3271
+ }
3272
+
2825
3273
  // check if t1 can be represented as a repeatition of t0
2826
3274
  static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2827
3275
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -2881,8 +3329,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2881
3329
  float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2882
3330
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2883
3331
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2884
- ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2885
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2886
3332
  }
2887
3333
 
2888
3334
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3166,7 +3612,13 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3166
3612
 
3167
3613
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
3168
3614
 
3169
- *result = (struct ggml_tensor) {
3615
+ #ifdef __clang__
3616
+ // temporary until ggml_tensor::backend is removed
3617
+ #pragma clang diagnostic push
3618
+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
3619
+ #endif
3620
+
3621
+ *result = (struct ggml_tensor) {
3170
3622
  /*.type =*/ type,
3171
3623
  /*.backend =*/ GGML_BACKEND_TYPE_CPU,
3172
3624
  /*.buffer =*/ NULL,
@@ -3188,6 +3640,10 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3188
3640
  /*.padding =*/ { 0 },
3189
3641
  };
3190
3642
 
3643
+ #ifdef __clang__
3644
+ #pragma clang diagnostic pop
3645
+ #endif
3646
+
3191
3647
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3192
3648
  //ggml_assert_aligned(result->data);
3193
3649
 
@@ -4426,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back(
4426
4882
  // ggml_concat
4427
4883
 
4428
4884
  struct ggml_tensor * ggml_concat(
4429
- struct ggml_context* ctx,
4430
- struct ggml_tensor* a,
4431
- struct ggml_tensor* b) {
4432
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4885
+ struct ggml_context * ctx,
4886
+ struct ggml_tensor * a,
4887
+ struct ggml_tensor * b,
4888
+ int dim) {
4889
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4890
+
4891
+ int64_t ne[GGML_MAX_DIMS];
4892
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4893
+ if (d == dim) {
4894
+ ne[d] = a->ne[d] + b->ne[d];
4895
+ continue;
4896
+ }
4897
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4898
+ ne[d] = a->ne[d];
4899
+ }
4433
4900
 
4434
4901
  bool is_node = false;
4435
4902
 
@@ -4437,7 +4904,9 @@ struct ggml_tensor * ggml_concat(
4437
4904
  is_node = true;
4438
4905
  }
4439
4906
 
4440
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4907
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4908
+
4909
+ ggml_set_op_params_i32(result, 0, dim);
4441
4910
 
4442
4911
  result->op = GGML_OP_CONCAT;
4443
4912
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4557,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
4557
5026
  }
4558
5027
 
4559
5028
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5029
+
4560
5030
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
4561
5031
 
4562
5032
  result->op = GGML_OP_LEAKY_RELU;
@@ -5763,6 +6233,7 @@ static struct ggml_tensor * ggml_rope_impl(
5763
6233
  struct ggml_context * ctx,
5764
6234
  struct ggml_tensor * a,
5765
6235
  struct ggml_tensor * b,
6236
+ struct ggml_tensor * c,
5766
6237
  int n_dims,
5767
6238
  int mode,
5768
6239
  int n_ctx,
@@ -5776,10 +6247,17 @@ static struct ggml_tensor * ggml_rope_impl(
5776
6247
  float xpos_base,
5777
6248
  bool xpos_down,
5778
6249
  bool inplace) {
6250
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6251
+
5779
6252
  GGML_ASSERT(ggml_is_vector(b));
5780
6253
  GGML_ASSERT(b->type == GGML_TYPE_I32);
5781
6254
  GGML_ASSERT(a->ne[2] == b->ne[0]);
5782
6255
 
6256
+ if (c) {
6257
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6258
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6259
+ }
6260
+
5783
6261
  bool is_node = false;
5784
6262
 
5785
6263
  if (a->grad) {
@@ -5803,6 +6281,7 @@ static struct ggml_tensor * ggml_rope_impl(
5803
6281
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5804
6282
  result->src[0] = a;
5805
6283
  result->src[1] = b;
6284
+ result->src[2] = c;
5806
6285
 
5807
6286
  return result;
5808
6287
  }
@@ -5815,7 +6294,7 @@ struct ggml_tensor * ggml_rope(
5815
6294
  int mode,
5816
6295
  int n_ctx) {
5817
6296
  return ggml_rope_impl(
5818
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6297
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
5819
6298
  );
5820
6299
  }
5821
6300
 
@@ -5827,14 +6306,15 @@ struct ggml_tensor * ggml_rope_inplace(
5827
6306
  int mode,
5828
6307
  int n_ctx) {
5829
6308
  return ggml_rope_impl(
5830
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6309
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
5831
6310
  );
5832
6311
  }
5833
6312
 
5834
- struct ggml_tensor * ggml_rope_custom(
6313
+ struct ggml_tensor * ggml_rope_ext(
5835
6314
  struct ggml_context * ctx,
5836
6315
  struct ggml_tensor * a,
5837
6316
  struct ggml_tensor * b,
6317
+ struct ggml_tensor * c,
5838
6318
  int n_dims,
5839
6319
  int mode,
5840
6320
  int n_ctx,
@@ -5846,15 +6326,16 @@ struct ggml_tensor * ggml_rope_custom(
5846
6326
  float beta_fast,
5847
6327
  float beta_slow) {
5848
6328
  return ggml_rope_impl(
5849
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6329
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
5850
6330
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
5851
6331
  );
5852
6332
  }
5853
6333
 
5854
- struct ggml_tensor * ggml_rope_custom_inplace(
6334
+ struct ggml_tensor * ggml_rope_ext_inplace(
5855
6335
  struct ggml_context * ctx,
5856
6336
  struct ggml_tensor * a,
5857
6337
  struct ggml_tensor * b,
6338
+ struct ggml_tensor * c,
5858
6339
  int n_dims,
5859
6340
  int mode,
5860
6341
  int n_ctx,
@@ -5866,19 +6347,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
5866
6347
  float beta_fast,
5867
6348
  float beta_slow) {
5868
6349
  return ggml_rope_impl(
5869
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6350
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
5870
6351
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
5871
6352
  );
5872
6353
  }
5873
6354
 
5874
- struct ggml_tensor * ggml_rope_xpos_inplace(
6355
+ struct ggml_tensor * ggml_rope_custom(
6356
+ struct ggml_context * ctx,
6357
+ struct ggml_tensor * a,
6358
+ struct ggml_tensor * b,
6359
+ int n_dims,
6360
+ int mode,
6361
+ int n_ctx,
6362
+ int n_orig_ctx,
6363
+ float freq_base,
6364
+ float freq_scale,
6365
+ float ext_factor,
6366
+ float attn_factor,
6367
+ float beta_fast,
6368
+ float beta_slow) {
6369
+ return ggml_rope_impl(
6370
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6371
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6372
+ );
6373
+ }
6374
+
6375
+ struct ggml_tensor * ggml_rope_custom_inplace(
5875
6376
  struct ggml_context * ctx,
5876
6377
  struct ggml_tensor * a,
5877
6378
  struct ggml_tensor * b,
5878
6379
  int n_dims,
5879
- float base,
5880
- bool down) {
5881
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6380
+ int mode,
6381
+ int n_ctx,
6382
+ int n_orig_ctx,
6383
+ float freq_base,
6384
+ float freq_scale,
6385
+ float ext_factor,
6386
+ float attn_factor,
6387
+ float beta_fast,
6388
+ float beta_slow) {
6389
+ return ggml_rope_impl(
6390
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6391
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6392
+ );
5882
6393
  }
5883
6394
 
5884
6395
  // ggml_rope_back
@@ -5887,6 +6398,7 @@ struct ggml_tensor * ggml_rope_back(
5887
6398
  struct ggml_context * ctx,
5888
6399
  struct ggml_tensor * a,
5889
6400
  struct ggml_tensor * b,
6401
+ struct ggml_tensor * c,
5890
6402
  int n_dims,
5891
6403
  int mode,
5892
6404
  int n_ctx,
@@ -5902,6 +6414,7 @@ struct ggml_tensor * ggml_rope_back(
5902
6414
  GGML_ASSERT(ggml_is_vector(b));
5903
6415
  GGML_ASSERT(b->type == GGML_TYPE_I32);
5904
6416
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6417
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
5905
6418
 
5906
6419
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
5907
6420
 
@@ -6281,7 +6794,10 @@ struct ggml_tensor * ggml_pool_2d(
6281
6794
  static struct ggml_tensor * ggml_upscale_impl(
6282
6795
  struct ggml_context * ctx,
6283
6796
  struct ggml_tensor * a,
6284
- int scale_factor) {
6797
+ int ne0,
6798
+ int ne1,
6799
+ int ne2,
6800
+ int ne3) {
6285
6801
  bool is_node = false;
6286
6802
 
6287
6803
  if (a->grad) {
@@ -6289,19 +6805,45 @@ static struct ggml_tensor * ggml_upscale_impl(
6289
6805
  is_node = true;
6290
6806
  }
6291
6807
 
6808
+ GGML_ASSERT(a->ne[0] <= ne0);
6809
+ GGML_ASSERT(a->ne[1] <= ne1);
6810
+ GGML_ASSERT(a->ne[2] <= ne2);
6811
+ GGML_ASSERT(a->ne[3] <= ne3);
6812
+
6292
6813
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
6293
- a->ne[0] * scale_factor,
6294
- a->ne[1] * scale_factor,
6295
- a->ne[2], a->ne[3]);
6814
+ ne0,
6815
+ ne1,
6816
+ ne2,
6817
+ ne3
6818
+ );
6296
6819
 
6297
6820
  result->op = GGML_OP_UPSCALE;
6298
- result->op_params[0] = scale_factor;
6821
+
6299
6822
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6300
6823
  result->src[0] = a;
6301
6824
 
6302
6825
  return result;
6303
6826
  }
6304
6827
 
6828
+ struct ggml_tensor * ggml_upscale(
6829
+ struct ggml_context * ctx,
6830
+ struct ggml_tensor * a,
6831
+ int scale_factor) {
6832
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
6833
+ }
6834
+
6835
+ struct ggml_tensor * ggml_upscale_ext(
6836
+ struct ggml_context * ctx,
6837
+ struct ggml_tensor * a,
6838
+ int ne0,
6839
+ int ne1,
6840
+ int ne2,
6841
+ int ne3) {
6842
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
6843
+ }
6844
+
6845
+ // ggml_pad
6846
+
6305
6847
  struct ggml_tensor * ggml_pad(
6306
6848
  struct ggml_context * ctx,
6307
6849
  struct ggml_tensor * a,
@@ -6326,12 +6868,7 @@ struct ggml_tensor * ggml_pad(
6326
6868
  return result;
6327
6869
  }
6328
6870
 
6329
- struct ggml_tensor * ggml_upscale(
6330
- struct ggml_context * ctx,
6331
- struct ggml_tensor * a,
6332
- int scale_factor) {
6333
- return ggml_upscale_impl(ctx, a, scale_factor);
6334
- }
6871
+ // ggml_arange
6335
6872
 
6336
6873
  struct ggml_tensor * ggml_arange(
6337
6874
  struct ggml_context * ctx,
@@ -6353,6 +6890,8 @@ struct ggml_tensor * ggml_arange(
6353
6890
  return result;
6354
6891
  }
6355
6892
 
6893
+ // ggml_timestep_embedding
6894
+
6356
6895
  struct ggml_tensor * ggml_timestep_embedding(
6357
6896
  struct ggml_context * ctx,
6358
6897
  struct ggml_tensor * timesteps,
@@ -6419,38 +6958,6 @@ struct ggml_tensor * ggml_top_k(
6419
6958
  return result;
6420
6959
  }
6421
6960
 
6422
- // ggml_flash_attn
6423
-
6424
- struct ggml_tensor * ggml_flash_attn(
6425
- struct ggml_context * ctx,
6426
- struct ggml_tensor * q,
6427
- struct ggml_tensor * k,
6428
- struct ggml_tensor * v,
6429
- bool masked) {
6430
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6431
- // TODO: check if vT can be multiplied by (k*qT)
6432
-
6433
- bool is_node = false;
6434
-
6435
- if (q->grad || k->grad || v->grad) {
6436
- is_node = true;
6437
- }
6438
-
6439
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6440
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6441
-
6442
- int32_t t = masked ? 1 : 0;
6443
- ggml_set_op_params(result, &t, sizeof(t));
6444
-
6445
- result->op = GGML_OP_FLASH_ATTN;
6446
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6447
- result->src[0] = q;
6448
- result->src[1] = k;
6449
- result->src[2] = v;
6450
-
6451
- return result;
6452
- }
6453
-
6454
6961
  // ggml_flash_attn_ext
6455
6962
 
6456
6963
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6510,38 +7017,6 @@ void ggml_flash_attn_ext_set_prec(
6510
7017
  ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6511
7018
  }
6512
7019
 
6513
- // ggml_flash_ff
6514
-
6515
- struct ggml_tensor * ggml_flash_ff(
6516
- struct ggml_context * ctx,
6517
- struct ggml_tensor * a,
6518
- struct ggml_tensor * b0,
6519
- struct ggml_tensor * b1,
6520
- struct ggml_tensor * c0,
6521
- struct ggml_tensor * c1) {
6522
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6523
- // TODO: more checks
6524
-
6525
- bool is_node = false;
6526
-
6527
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6528
- is_node = true;
6529
- }
6530
-
6531
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6532
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6533
-
6534
- result->op = GGML_OP_FLASH_FF;
6535
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6536
- result->src[0] = a;
6537
- result->src[1] = b0;
6538
- result->src[2] = b1;
6539
- result->src[3] = c0;
6540
- result->src[4] = c1;
6541
-
6542
- return result;
6543
- }
6544
-
6545
7020
  // ggml_flash_attn_back
6546
7021
 
6547
7022
  struct ggml_tensor * ggml_flash_attn_back(
@@ -6551,6 +7026,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6551
7026
  struct ggml_tensor * v,
6552
7027
  struct ggml_tensor * d,
6553
7028
  bool masked) {
7029
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7030
+
6554
7031
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6555
7032
  // TODO: check if vT can be multiplied by (k*qT)
6556
7033
 
@@ -10504,26 +10981,29 @@ static void ggml_compute_forward_concat_f32(
10504
10981
  GGML_ASSERT(nb00 == sizeof(float));
10505
10982
  GGML_ASSERT(nb10 == sizeof(float));
10506
10983
 
10984
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
10985
+
10986
+ GGML_ASSERT(dim >= 0 && dim < 4);
10987
+
10988
+ int64_t o[4] = {0, 0, 0, 0};
10989
+ o[dim] = src0->ne[dim];
10990
+
10991
+ const float * x;
10992
+
10993
+ // TODO: smarter multi-theading
10507
10994
  for (int i3 = 0; i3 < ne3; i3++) {
10508
10995
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10509
- if (i2 < ne02) { // src0
10510
- for (int i1 = 0; i1 < ne1; i1++) {
10511
- for (int i0 = 0; i0 < ne0; i0++) {
10512
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10513
-
10514
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10515
- *y = *x;
10516
- }
10517
- }
10518
- } // src1
10519
- else {
10520
- for (int i1 = 0; i1 < ne1; i1++) {
10521
- for (int i0 = 0; i0 < ne0; i0++) {
10522
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10523
-
10524
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10525
- *y = *x;
10996
+ for (int i1 = 0; i1 < ne1; i1++) {
10997
+ for (int i0 = 0; i0 < ne0; i0++) {
10998
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
10999
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
11000
+ } else {
11001
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
10526
11002
  }
11003
+
11004
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
11005
+
11006
+ *y = *x;
10527
11007
  }
10528
11008
  }
10529
11009
  }
@@ -10531,7 +11011,7 @@ static void ggml_compute_forward_concat_f32(
10531
11011
  }
10532
11012
 
10533
11013
  static void ggml_compute_forward_concat(
10534
- const struct ggml_compute_params* params,
11014
+ const struct ggml_compute_params * params,
10535
11015
  struct ggml_tensor* dst) {
10536
11016
 
10537
11017
  const struct ggml_tensor * src0 = dst->src[0];
@@ -11767,9 +12247,101 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
11767
12247
  }
11768
12248
  #endif
11769
12249
 
12250
+ static void ggml_compute_forward_mul_mat_one_chunk(
12251
+ const struct ggml_compute_params * params,
12252
+ struct ggml_tensor * dst,
12253
+ const int64_t num_rows_per_vec_dot,
12254
+ const int64_t ir0_start,
12255
+ const int64_t ir0_end,
12256
+ const int64_t ir1_start,
12257
+ const int64_t ir1_end) {
12258
+
12259
+ const struct ggml_tensor * src0 = dst->src[0];
12260
+ const struct ggml_tensor * src1 = dst->src[1];
12261
+
12262
+ GGML_TENSOR_BINARY_OP_LOCALS
12263
+
12264
+ const enum ggml_type type = src0->type;
12265
+
12266
+ const bool src1_cont = ggml_is_contiguous(src1);
12267
+
12268
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
12269
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
12270
+
12271
+ // broadcast factors
12272
+ const int64_t r2 = ne12 / ne02;
12273
+ const int64_t r3 = ne13 / ne03;
12274
+
12275
+ //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
12276
+
12277
+ // threads with no work simply yield (not sure if it helps)
12278
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
12279
+ return;
12280
+ }
12281
+
12282
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12283
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12284
+
12285
+ assert(ne12 % ne02 == 0);
12286
+ assert(ne13 % ne03 == 0);
12287
+
12288
+ // block-tiling attempt
12289
+ const int64_t blck_0 = 16;
12290
+ const int64_t blck_1 = 16;
12291
+
12292
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12293
+
12294
+ // attempt to reduce false-sharing (does not seem to make a difference)
12295
+ // 16 * 2, accounting for mmla kernels
12296
+ float tmp[32];
12297
+
12298
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
12299
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
12300
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
12301
+ const int64_t i13 = (ir1 / (ne12 * ne1));
12302
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
12303
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
12304
+
12305
+ // broadcast src0 into src1
12306
+ const int64_t i03 = i13 / r3;
12307
+ const int64_t i02 = i12 / r2;
12308
+
12309
+ const int64_t i1 = i11;
12310
+ const int64_t i2 = i12;
12311
+ const int64_t i3 = i13;
12312
+
12313
+ const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
12314
+
12315
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12316
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12317
+ // the original src1 data pointer, so we should index using the indices directly
12318
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
12319
+ const char * src1_col = (const char*)wdata +
12320
+ (src1_cont || src1->type != vec_dot_type
12321
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
12322
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
12323
+ float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
12324
+
12325
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
12326
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12327
+ //}
12328
+
12329
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
12330
+ vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
12331
+ }
12332
+
12333
+ for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
12334
+ memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
12335
+ }
12336
+ }
12337
+ }
12338
+ }
12339
+ }
12340
+
11770
12341
  static void ggml_compute_forward_mul_mat(
11771
12342
  const struct ggml_compute_params * params,
11772
- struct ggml_tensor * dst) {
12343
+ struct ggml_tensor * dst,
12344
+ struct ggml_compute_state * state) {
11773
12345
 
11774
12346
  const struct ggml_tensor * src0 = dst->src[0];
11775
12347
  const struct ggml_tensor * src1 = dst->src[1];
@@ -11784,9 +12356,6 @@ static void ggml_compute_forward_mul_mat(
11784
12356
 
11785
12357
  const enum ggml_type type = src0->type;
11786
12358
 
11787
- const bool src1_cont = ggml_is_contiguous(src1);
11788
-
11789
- ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
11790
12359
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11791
12360
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
11792
12361
  int64_t const vec_dot_num_rows = type_traits[type].nrows;
@@ -11807,8 +12376,10 @@ static void ggml_compute_forward_mul_mat(
11807
12376
  GGML_ASSERT(nb2 <= nb3);
11808
12377
 
11809
12378
  // broadcast factors
11810
- const int64_t r2 = ne12/ne02;
11811
- const int64_t r3 = ne13/ne03;
12379
+ const int64_t r2 = ne12 / ne02;
12380
+ const int64_t r3 = ne13 / ne03;
12381
+ UNUSED(r2);
12382
+ UNUSED(r3);
11812
12383
 
11813
12384
  // nb01 >= nb00 - src0 is not transposed
11814
12385
  // compute by src0 rows
@@ -11890,6 +12461,8 @@ static void ggml_compute_forward_mul_mat(
11890
12461
  #endif
11891
12462
 
11892
12463
  #if GGML_USE_LLAMAFILE
12464
+ const bool src1_cont = ggml_is_contiguous(src1);
12465
+
11893
12466
  if (src1_cont) {
11894
12467
  for (int64_t i13 = 0; i13 < ne13; i13++)
11895
12468
  for (int64_t i12 = 0; i12 < ne12; i12++)
@@ -11915,6 +12488,8 @@ UseGgmlGemm1:;
11915
12488
  if (ith != 0) {
11916
12489
  return;
11917
12490
  }
12491
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12492
+ atomic_store(&state->shared->current_chunk, nth);
11918
12493
  if (src1->type != vec_dot_type) {
11919
12494
  char * wdata = params->wdata;
11920
12495
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11939,11 +12514,11 @@ UseGgmlGemm1:;
11939
12514
  return;
11940
12515
  }
11941
12516
 
11942
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11943
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11944
-
11945
12517
  #if GGML_USE_LLAMAFILE
11946
12518
  if (src1->type != vec_dot_type) {
12519
+ const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12520
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12521
+
11947
12522
  for (int64_t i13 = 0; i13 < ne13; i13++)
11948
12523
  for (int64_t i12 = 0; i12 < ne12; i12++)
11949
12524
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -11964,98 +12539,87 @@ UseGgmlGemm1:;
11964
12539
  UseGgmlGemm2:;
11965
12540
  #endif
11966
12541
 
11967
- const int64_t nr0 = ne01; // src0 rows
11968
- const int64_t nr1 = ne1*ne12*ne13; // src1 rows
11969
-
11970
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
11971
-
11972
- // distribute the thread work across the inner or outer loop based on which one is larger
11973
-
11974
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
11975
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
11976
-
11977
- const int64_t ith0 = ith % nth0;
11978
- const int64_t ith1 = ith / nth0;
11979
-
11980
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
11981
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
11982
-
11983
- const int64_t ir010 = dr0*ith0;
11984
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
11985
-
11986
- const int64_t ir110 = dr1*ith1;
11987
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
11988
-
11989
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
11990
-
11991
- // threads with no work simply yield (not sure if it helps)
11992
- if (ir010 >= ir011 || ir110 >= ir111) {
11993
- sched_yield();
11994
- return;
11995
- }
12542
+ #ifdef GGML_PERF
12543
+ int chunks_executed = 0;
12544
+ UNUSED(chunks_executed);
12545
+ #endif
11996
12546
 
11997
- assert(ne12 % ne02 == 0);
11998
- assert(ne13 % ne03 == 0);
12547
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
12548
+ const int64_t nr0 = ne0;
11999
12549
 
12000
- // block-tiling attempt
12001
- const int64_t blck_0 = 16;
12002
- const int64_t blck_1 = 16;
12550
+ // This is the size of the rest of the dimensions of the result
12551
+ const int64_t nr1 = ne1 * ne2 * ne3;
12003
12552
 
12004
12553
  // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
12005
- int64_t nrc = vec_dot_num_rows;
12554
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
12006
12555
  // TODO: currently the mmla kernels support only even numbered rows/cols.
12007
12556
  // this check can be removed once they are extended to support odd numbered rows/cols too
12008
12557
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
12009
- nrc = 1;
12558
+ num_rows_per_vec_dot = 1;
12010
12559
  }
12011
12560
 
12012
- const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12561
+ // Now select a reasonable chunk size.
12562
+ int chunk_size = 16;
12013
12563
 
12014
- // attempt to reduce false-sharing (does not seem to make a difference)
12015
- // 16 * 2, accounting for mmla kernels
12016
- float tmp[32];
12564
+ // We need to step up the size if it's small
12565
+ if (nr0 == 1 || nr1 == 1) {
12566
+ chunk_size = 64;
12567
+ }
12017
12568
 
12018
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
12019
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
12020
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
12021
- const int64_t i13 = (ir1/(ne12*ne1));
12022
- const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
12023
- const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
12569
+ // distribute the work across the inner or outer loop based on which one is larger
12570
+ // The number of chunks in the 0/1 dim.
12571
+ // CEIL(nr0/chunk_size)
12572
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
12573
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
12024
12574
 
12025
- // broadcast src0 into src1
12026
- const int64_t i03 = i13/r3;
12027
- const int64_t i02 = i12/r2;
12575
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
12576
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
12577
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
12578
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
12579
+ // distribute the thread work across the inner or outer loop based on which one is larger
12580
+ nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
12581
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
12582
+ }
12028
12583
 
12029
- const int64_t i1 = i11;
12030
- const int64_t i2 = i12;
12031
- const int64_t i3 = i13;
12584
+ // The number of elements in each chunk
12585
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
12586
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
12032
12587
 
12033
- const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
12588
+ //if (ith == 0)
12589
+ // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
12034
12590
 
12035
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12036
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12037
- // the original src1 data pointer, so we should index using the indices directly
12038
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
12039
- const char * src1_col = (const char *) wdata +
12040
- (src1_cont || src1->type != vec_dot_type
12041
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
12042
- : (i11*nb11 + i12*nb12 + i13*nb13));
12043
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
12591
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
12592
+ int current_chunk = ith;
12044
12593
 
12045
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
12046
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12047
- //}
12594
+ while (current_chunk < nchunk0 * nchunk1) {
12595
+ const int64_t ith0 = current_chunk % nchunk0;
12596
+ const int64_t ith1 = current_chunk / nchunk0;
12048
12597
 
12049
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
12050
- vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
12051
- }
12598
+ const int64_t ir0_start = dr0 * ith0;
12599
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
12052
12600
 
12053
- for (int cn = 0; cn < nrc; ++cn) {
12054
- memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
12055
- }
12056
- }
12601
+ const int64_t ir1_start = dr1 * ith1;
12602
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
12603
+
12604
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
12605
+
12606
+ #ifdef GGML_PERF
12607
+ chunks_executed++;
12608
+ #endif
12609
+
12610
+ if (nth >= nchunk0 * nchunk1) {
12611
+ break;
12057
12612
  }
12613
+
12614
+ current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
12058
12615
  }
12616
+
12617
+ #ifdef GGML_PERF
12618
+ // These numbers are useful when trying to measure how well the threading scheduling works.
12619
+ //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
12620
+ //float time = (ggml_perf_time_us() - t0);
12621
+ //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
12622
+ #endif
12059
12623
  }
12060
12624
 
12061
12625
  // ggml_compute_forward_mul_mat_id
@@ -13439,23 +14003,8 @@ static void ggml_compute_forward_soft_max_f32(
13439
14003
  float max = -INFINITY;
13440
14004
  ggml_vec_max_f32(nc, &max, wp);
13441
14005
 
13442
- ggml_float sum = 0.0;
13443
-
13444
- uint16_t scvt;
13445
- for (int i = 0; i < nc; i++) {
13446
- if (wp[i] == -INFINITY) {
13447
- dp[i] = 0.0f;
13448
- } else {
13449
- // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
13450
- ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
13451
- memcpy(&scvt, &s, sizeof(scvt));
13452
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
13453
- sum += (ggml_float)val;
13454
- dp[i] = val;
13455
- }
13456
- }
13457
-
13458
- assert(sum > 0.0);
14006
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
14007
+ assert(sum > 0.0);
13459
14008
 
13460
14009
  sum = 1.0/sum;
13461
14010
  ggml_vec_scale_f32(nc, dp, sum);
@@ -13741,6 +14290,7 @@ static void ggml_compute_forward_rope_f32(
13741
14290
 
13742
14291
  const struct ggml_tensor * src0 = dst->src[0];
13743
14292
  const struct ggml_tensor * src1 = dst->src[1];
14293
+ const struct ggml_tensor * src2 = dst->src[2];
13744
14294
 
13745
14295
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13746
14296
  return;
@@ -13800,6 +14350,17 @@ static void ggml_compute_forward_rope_f32(
13800
14350
  const bool is_neox = mode & 2;
13801
14351
  const bool is_glm = mode & 4;
13802
14352
 
14353
+ const float * freq_factors = NULL;
14354
+ if (is_neox) {
14355
+ if (src2 != NULL) {
14356
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14357
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14358
+ freq_factors = (const float *) src2->data;
14359
+ }
14360
+ } else {
14361
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14362
+ }
14363
+
13803
14364
  // backward process uses inverse rotation by cos and sin.
13804
14365
  // cos and sin build a rotation matrix, where the inverse is the transpose.
13805
14366
  // this essentially just switches the sign of sin.
@@ -13876,10 +14437,11 @@ static void ggml_compute_forward_rope_f32(
13876
14437
 
13877
14438
  // simplified from `(ib * n_dims + ic) * inv_ndims`
13878
14439
  float cur_rot = inv_ndims * ic - ib;
14440
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
13879
14441
 
13880
14442
  float cos_theta, sin_theta;
13881
14443
  rope_yarn(
13882
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14444
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
13883
14445
  &cos_theta, &sin_theta
13884
14446
  );
13885
14447
  sin_theta *= sin_sign;
@@ -13912,6 +14474,7 @@ static void ggml_compute_forward_rope_f32(
13912
14474
  }
13913
14475
  }
13914
14476
 
14477
+ // TODO: deduplicate f16/f32 code
13915
14478
  static void ggml_compute_forward_rope_f16(
13916
14479
  const struct ggml_compute_params * params,
13917
14480
  struct ggml_tensor * dst,
@@ -13919,6 +14482,7 @@ static void ggml_compute_forward_rope_f16(
13919
14482
 
13920
14483
  const struct ggml_tensor * src0 = dst->src[0];
13921
14484
  const struct ggml_tensor * src1 = dst->src[1];
14485
+ const struct ggml_tensor * src2 = dst->src[2];
13922
14486
 
13923
14487
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13924
14488
  return;
@@ -13971,6 +14535,17 @@ static void ggml_compute_forward_rope_f16(
13971
14535
  const bool is_neox = mode & 2;
13972
14536
  const bool is_glm = mode & 4;
13973
14537
 
14538
+ const float * freq_factors = NULL;
14539
+ if (is_neox) {
14540
+ if (src2 != NULL) {
14541
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14542
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14543
+ freq_factors = (const float *) src2->data;
14544
+ }
14545
+ } else {
14546
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14547
+ }
14548
+
13974
14549
  // backward process uses inverse rotation by cos and sin.
13975
14550
  // cos and sin build a rotation matrix, where the inverse is the transpose.
13976
14551
  // this essentially just switches the sign of sin.
@@ -14043,10 +14618,11 @@ static void ggml_compute_forward_rope_f16(
14043
14618
 
14044
14619
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14045
14620
  float cur_rot = inv_ndims * ic - ib;
14621
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14046
14622
 
14047
14623
  float cos_theta, sin_theta;
14048
14624
  rope_yarn(
14049
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14625
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14050
14626
  &cos_theta, &sin_theta
14051
14627
  );
14052
14628
  sin_theta *= sin_sign;
@@ -14808,25 +15384,28 @@ static void ggml_compute_forward_upscale_f32(
14808
15384
  return;
14809
15385
  }
14810
15386
 
14811
- GGML_ASSERT(src0->nb[0] == sizeof(float));
15387
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
14812
15388
 
14813
15389
  const int ith = params->ith;
14814
15390
  const int nth = params->nth;
14815
15391
 
14816
15392
  GGML_TENSOR_UNARY_OP_LOCALS
14817
15393
 
14818
- const int scale_factor = dst->op_params[0];
15394
+ const float sf0 = (float)ne0/src0->ne[0];
15395
+ const float sf1 = (float)ne1/src0->ne[1];
15396
+ const float sf2 = (float)ne2/src0->ne[2];
15397
+ const float sf3 = (float)ne3/src0->ne[3];
14819
15398
 
14820
15399
  // TODO: optimize
14821
15400
 
14822
15401
  for (int64_t i3 = 0; i3 < ne3; i3++) {
14823
- const int64_t i03 = i3;
15402
+ const int64_t i03 = i3 / sf3;
14824
15403
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
14825
- const int64_t i02 = i2;
15404
+ const int64_t i02 = i2 / sf2;
14826
15405
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14827
- const int64_t i01 = i1 / scale_factor;
15406
+ const int64_t i01 = i1 / sf1;
14828
15407
  for (int64_t i0 = 0; i0 < ne0; i0++) {
14829
- const int64_t i00 = i0 / scale_factor;
15408
+ const int64_t i00 = i0 / sf0;
14830
15409
 
14831
15410
  const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
14832
15411
  float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
@@ -14856,6 +15435,7 @@ static void ggml_compute_forward_upscale(
14856
15435
  }
14857
15436
  }
14858
15437
 
15438
+
14859
15439
  // ggml_compute_forward_pad
14860
15440
 
14861
15441
  static void ggml_compute_forward_pad_f32(
@@ -15023,500 +15603,55 @@ static void ggml_compute_forward_argsort_f32(
15023
15603
  const struct ggml_compute_params * params,
15024
15604
  struct ggml_tensor * dst) {
15025
15605
 
15026
- const struct ggml_tensor * src0 = dst->src[0];
15027
-
15028
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
15029
- return;
15030
- }
15031
-
15032
- GGML_TENSOR_UNARY_OP_LOCALS
15033
-
15034
- GGML_ASSERT(nb0 == sizeof(float));
15035
-
15036
- const int ith = params->ith;
15037
- const int nth = params->nth;
15038
-
15039
- const int64_t nr = ggml_nrows(src0);
15040
-
15041
- enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
15042
-
15043
- for (int64_t i = ith; i < nr; i += nth) {
15044
- int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
15045
- const float * src_data = (float *)((char *) src0->data + i*nb01);
15046
-
15047
- for (int64_t j = 0; j < ne0; j++) {
15048
- dst_data[j] = j;
15049
- }
15050
-
15051
- // C doesn't have a functional sort, so we do a bubble sort instead
15052
- for (int64_t j = 0; j < ne0; j++) {
15053
- for (int64_t k = j + 1; k < ne0; k++) {
15054
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
15055
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
15056
- int32_t tmp = dst_data[j];
15057
- dst_data[j] = dst_data[k];
15058
- dst_data[k] = tmp;
15059
- }
15060
- }
15061
- }
15062
- }
15063
- }
15064
-
15065
- static void ggml_compute_forward_argsort(
15066
- const struct ggml_compute_params * params,
15067
- struct ggml_tensor * dst) {
15068
-
15069
- const struct ggml_tensor * src0 = dst->src[0];
15070
-
15071
- switch (src0->type) {
15072
- case GGML_TYPE_F32:
15073
- {
15074
- ggml_compute_forward_argsort_f32(params, dst);
15075
- } break;
15076
- default:
15077
- {
15078
- GGML_ASSERT(false);
15079
- } break;
15080
- }
15081
- }
15082
-
15083
- // ggml_compute_forward_flash_attn
15084
-
15085
- static void ggml_compute_forward_flash_attn_f32(
15086
- const struct ggml_compute_params * params,
15087
- const bool masked,
15088
- struct ggml_tensor * dst) {
15089
-
15090
- const struct ggml_tensor * q = dst->src[0];
15091
- const struct ggml_tensor * k = dst->src[1];
15092
- const struct ggml_tensor * v = dst->src[2];
15093
-
15094
- int64_t t0 = ggml_perf_time_us();
15095
- UNUSED(t0);
15096
-
15097
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15098
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15099
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15100
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15101
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15102
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15103
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15104
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15105
-
15106
- const int ith = params->ith;
15107
- const int nth = params->nth;
15108
-
15109
- const int64_t D = neq0;
15110
- const int64_t N = neq1;
15111
- const int64_t P = nek1 - N;
15112
- const int64_t M = P + N;
15113
-
15114
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15115
-
15116
- GGML_ASSERT(ne0 == D);
15117
- GGML_ASSERT(ne1 == N);
15118
- GGML_ASSERT(P >= 0);
15119
-
15120
- GGML_ASSERT(nbq0 == sizeof(float));
15121
- GGML_ASSERT(nbk0 == sizeof(float));
15122
- GGML_ASSERT(nbv0 == sizeof(float));
15123
-
15124
- GGML_ASSERT(neq0 == D);
15125
- GGML_ASSERT(nek0 == D);
15126
- GGML_ASSERT(nev1 == D);
15127
-
15128
- GGML_ASSERT(neq1 == N);
15129
- GGML_ASSERT(nek1 == N + P);
15130
- GGML_ASSERT(nev1 == D);
15131
-
15132
- // dst cannot be transposed or permuted
15133
- GGML_ASSERT(nb0 == sizeof(float));
15134
- GGML_ASSERT(nb0 <= nb1);
15135
- GGML_ASSERT(nb1 <= nb2);
15136
- GGML_ASSERT(nb2 <= nb3);
15137
-
15138
- if (params->type == GGML_TASK_TYPE_INIT) {
15139
- return;
15140
- }
15141
-
15142
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15143
- return;
15144
- }
15145
-
15146
- // parallelize by q rows using ggml_vec_dot_f32
15147
-
15148
- // total rows in q
15149
- const int nr = neq1*neq2*neq3;
15150
-
15151
- // rows per thread
15152
- const int dr = (nr + nth - 1)/nth;
15153
-
15154
- // row range for this thread
15155
- const int ir0 = dr*ith;
15156
- const int ir1 = MIN(ir0 + dr, nr);
15157
-
15158
- const float scale = 1.0f/sqrtf(D);
15159
-
15160
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15161
-
15162
- for (int ir = ir0; ir < ir1; ++ir) {
15163
- // q indices
15164
- const int iq3 = ir/(neq2*neq1);
15165
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15166
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15167
-
15168
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15169
-
15170
- for (int i = M; i < Mup; ++i) {
15171
- S[i] = -INFINITY;
15172
- }
15173
-
15174
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15175
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15176
- // k indices
15177
- const int ik3 = iq3;
15178
- const int ik2 = iq2 % nek2;
15179
- const int ik1 = ic;
15180
-
15181
- // S indices
15182
- const int i1 = ik1;
15183
-
15184
- ggml_vec_dot_f32(neq0,
15185
- S + i1, 0,
15186
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15187
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15188
- }
15189
-
15190
- // scale
15191
- ggml_vec_scale_f32(masked_begin, S, scale);
15192
-
15193
- for (int64_t i = masked_begin; i < M; i++) {
15194
- S[i] = -INFINITY;
15195
- }
15196
-
15197
- // softmax
15198
- // exclude known -INF S[..] values from max and loop
15199
- // dont forget to set their SW values to zero
15200
- {
15201
- float max = -INFINITY;
15202
- ggml_vec_max_f32(masked_begin, &max, S);
15203
-
15204
- ggml_float sum = 0.0;
15205
- {
15206
- #ifdef GGML_SOFT_MAX_ACCELERATE
15207
- max = -max;
15208
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15209
- vvexpf(S, S, &Mup);
15210
- ggml_vec_sum_f32(Mup, &sum, S);
15211
- #else
15212
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15213
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15214
-
15215
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15216
- if (i >= masked_begin) {
15217
- break;
15218
- }
15219
- float * SS = S + i;
15220
-
15221
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15222
- if (i + j >= masked_begin) {
15223
- break;
15224
- } else if (SS[j] == -INFINITY) {
15225
- SS[j] = 0.0f;
15226
- } else {
15227
- #ifndef GGML_FLASH_ATTN_EXP_FP16
15228
- const float val = expf(SS[j] - max);
15229
- #else
15230
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15231
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15232
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15233
- #endif
15234
- sump[j] += (ggml_float)val;
15235
- SS[j] = val;
15236
- }
15237
- }
15238
- }
15239
-
15240
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15241
- sum += sump[i];
15242
- }
15243
- #endif
15244
- }
15245
-
15246
- assert(sum > 0.0);
15247
-
15248
- sum = 1.0/sum;
15249
- ggml_vec_scale_f32(masked_begin, S, sum);
15250
-
15251
- #ifndef NDEBUG
15252
- for (int i = 0; i < masked_begin; ++i) {
15253
- assert(!isnan(S[i]));
15254
- assert(!isinf(S[i]));
15255
- }
15256
- #endif
15257
- }
15258
-
15259
- for (int64_t ic = 0; ic < nev1; ++ic) {
15260
- // dst indices
15261
- const int i1 = iq1;
15262
- const int i2 = iq2;
15263
- const int i3 = iq3;
15264
-
15265
- // v indices
15266
- const int iv2 = iq2 % nev2;
15267
- const int iv3 = iq3;
15268
-
15269
- ggml_vec_dot_f32(masked_begin,
15270
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15271
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15272
- S, 0, 1);
15273
- }
15274
- }
15275
- }
15276
-
15277
- static void ggml_compute_forward_flash_attn_f16(
15278
- const struct ggml_compute_params * params,
15279
- const bool masked,
15280
- struct ggml_tensor * dst) {
15281
-
15282
- const struct ggml_tensor * q = dst->src[0];
15283
- const struct ggml_tensor * k = dst->src[1];
15284
- const struct ggml_tensor * v = dst->src[2];
15285
-
15286
- int64_t t0 = ggml_perf_time_us();
15287
- UNUSED(t0);
15288
-
15289
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15290
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15291
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15292
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15293
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15294
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15295
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15296
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15297
-
15298
- const int ith = params->ith;
15299
- const int nth = params->nth;
15300
-
15301
- const int64_t D = neq0;
15302
- const int64_t N = neq1;
15303
- const int64_t P = nek1 - N;
15304
- const int64_t M = P + N;
15305
-
15306
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15307
-
15308
- GGML_ASSERT(ne0 == D);
15309
- GGML_ASSERT(ne1 == N);
15310
- GGML_ASSERT(P >= 0);
15311
-
15312
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15313
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15314
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15315
-
15316
- GGML_ASSERT(neq0 == D);
15317
- GGML_ASSERT(nek0 == D);
15318
- GGML_ASSERT(nev1 == D);
15319
-
15320
- GGML_ASSERT(neq1 == N);
15321
- GGML_ASSERT(nek1 == N + P);
15322
- GGML_ASSERT(nev1 == D);
15323
-
15324
- // dst cannot be transposed or permuted
15325
- GGML_ASSERT(nb0 == sizeof(float));
15326
- GGML_ASSERT(nb0 <= nb1);
15327
- GGML_ASSERT(nb1 <= nb2);
15328
- GGML_ASSERT(nb2 <= nb3);
15329
-
15330
- if (params->type == GGML_TASK_TYPE_INIT) {
15331
- return;
15332
- }
15333
-
15334
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15335
- return;
15336
- }
15337
-
15338
- // parallelize by q rows using ggml_vec_dot_f32
15339
-
15340
- // total rows in q
15341
- const int nr = neq1*neq2*neq3;
15342
-
15343
- // rows per thread
15344
- const int dr = (nr + nth - 1)/nth;
15345
-
15346
- // row range for this thread
15347
- const int ir0 = dr*ith;
15348
- const int ir1 = MIN(ir0 + dr, nr);
15349
-
15350
- const float scale = 1.0f/sqrtf(D);
15351
-
15352
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15353
-
15354
- for (int ir = ir0; ir < ir1; ++ir) {
15355
- // q indices
15356
- const int iq3 = ir/(neq2*neq1);
15357
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15358
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15359
-
15360
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15361
-
15362
- for (int i = M; i < Mup; ++i) {
15363
- S[i] = -INFINITY;
15364
- }
15365
-
15366
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15367
- for (int64_t ic = 0; ic < nek1; ++ic) {
15368
- // k indices
15369
- const int ik3 = iq3;
15370
- const int ik2 = iq2 % nek2;
15371
- const int ik1 = ic;
15372
-
15373
- // S indices
15374
- const int i1 = ik1;
15375
-
15376
- ggml_vec_dot_f16(neq0,
15377
- S + i1, 0,
15378
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15379
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15380
- }
15381
- } else {
15382
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15383
- // k indices
15384
- const int ik3 = iq3;
15385
- const int ik2 = iq2 % nek2;
15386
- const int ik1 = ic;
15387
-
15388
- // S indices
15389
- const int i1 = ik1;
15390
-
15391
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15392
- S + i1,
15393
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15394
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15395
- }
15396
- }
15397
-
15398
- // scale
15399
- ggml_vec_scale_f32(nek1, S, scale);
15400
-
15401
- if (masked) {
15402
- for (int64_t i = P; i < M; i++) {
15403
- if (i > P + iq1) {
15404
- S[i] = -INFINITY;
15405
- }
15406
- }
15407
- }
15408
-
15409
- // softmax
15410
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15411
- // dont forget to set their S values to zero
15412
- {
15413
- float max = -INFINITY;
15414
- ggml_vec_max_f32(M, &max, S);
15415
-
15416
- ggml_float sum = 0.0;
15417
- {
15418
- #ifdef GGML_SOFT_MAX_ACCELERATE
15419
- max = -max;
15420
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15421
- vvexpf(S, S, &Mup);
15422
- ggml_vec_sum_f32(Mup, &sum, S);
15423
- #else
15424
- uint16_t scvt[GGML_SOFT_MAX_UNROLL];
15425
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15606
+ const struct ggml_tensor * src0 = dst->src[0];
15426
15607
 
15427
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15428
- float * SS = S + i;
15608
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
15609
+ return;
15610
+ }
15429
15611
 
15430
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15431
- if (SS[j] == -INFINITY) {
15432
- SS[j] = 0.0f;
15433
- } else {
15434
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15435
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15436
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15437
- sump[j] += (ggml_float)val;
15438
- SS[j] = val;
15439
- }
15440
- }
15441
- }
15612
+ GGML_TENSOR_UNARY_OP_LOCALS
15442
15613
 
15443
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15444
- sum += sump[i];
15445
- }
15446
- #endif
15447
- }
15614
+ GGML_ASSERT(nb0 == sizeof(float));
15448
15615
 
15449
- assert(sum > 0.0);
15616
+ const int ith = params->ith;
15617
+ const int nth = params->nth;
15450
15618
 
15451
- sum = 1.0/sum;
15452
- ggml_vec_scale_f32(M, S, sum);
15619
+ const int64_t nr = ggml_nrows(src0);
15453
15620
 
15454
- #ifndef NDEBUG
15455
- for (int i = 0; i < M; ++i) {
15456
- assert(!isnan(S[i]));
15457
- assert(!isinf(S[i]));
15458
- }
15459
- #endif
15460
- }
15621
+ enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
15461
15622
 
15462
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15623
+ for (int64_t i = ith; i < nr; i += nth) {
15624
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
15625
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
15463
15626
 
15464
- for (int64_t i = 0; i < M; i++) {
15465
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15627
+ for (int64_t j = 0; j < ne0; j++) {
15628
+ dst_data[j] = j;
15466
15629
  }
15467
15630
 
15468
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15469
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15470
- for (int64_t ic = 0; ic < nev1; ++ic) {
15471
- // dst indices
15472
- const int i1 = iq1;
15473
- const int i2 = iq2;
15474
- const int i3 = iq3;
15475
-
15476
- // v indices
15477
- const int iv2 = iq2 % nev2;
15478
- const int iv3 = iq3;
15479
-
15480
- ggml_vec_dot_f16(nev0,
15481
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15482
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15483
- S16, 0, 1);
15484
- }
15485
- } else {
15486
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15487
- // dst indices
15488
- const int i1 = iq1;
15489
- const int i2 = iq2;
15490
- const int i3 = iq3;
15491
-
15492
- // v indices
15493
- const int iv2 = iq2 % nev2;
15494
- const int iv3 = iq3;
15495
-
15496
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15497
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15498
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15499
- S16);
15631
+ // C doesn't have a functional sort, so we do a bubble sort instead
15632
+ for (int64_t j = 0; j < ne0; j++) {
15633
+ for (int64_t k = j + 1; k < ne0; k++) {
15634
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
15635
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
15636
+ int32_t tmp = dst_data[j];
15637
+ dst_data[j] = dst_data[k];
15638
+ dst_data[k] = tmp;
15639
+ }
15500
15640
  }
15501
15641
  }
15502
15642
  }
15503
15643
  }
15504
15644
 
15505
- static void ggml_compute_forward_flash_attn(
15506
- const struct ggml_compute_params * params,
15507
- const bool masked,
15508
- struct ggml_tensor * dst) {
15645
+ static void ggml_compute_forward_argsort(
15646
+ const struct ggml_compute_params * params,
15647
+ struct ggml_tensor * dst) {
15509
15648
 
15510
- const struct ggml_tensor * q = dst->src[0];
15649
+ const struct ggml_tensor * src0 = dst->src[0];
15511
15650
 
15512
- switch (q->type) {
15513
- case GGML_TYPE_F16:
15514
- {
15515
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15516
- } break;
15651
+ switch (src0->type) {
15517
15652
  case GGML_TYPE_F32:
15518
15653
  {
15519
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15654
+ ggml_compute_forward_argsort_f32(params, dst);
15520
15655
  } break;
15521
15656
  default:
15522
15657
  {
@@ -15555,9 +15690,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15555
15690
  GGML_ASSERT(ne0 == D);
15556
15691
  GGML_ASSERT(ne2 == N);
15557
15692
 
15558
- GGML_ASSERT(nbq0 == sizeof(float));
15559
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15560
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15693
+ // input tensor rows must be contiguous
15694
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15695
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15696
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15561
15697
 
15562
15698
  GGML_ASSERT(neq0 == D);
15563
15699
  GGML_ASSERT(nek0 == D);
@@ -15611,6 +15747,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15611
15747
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15612
15748
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15613
15749
 
15750
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15751
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15752
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15753
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15754
+
15614
15755
  // loop over n_batch and n_head
15615
15756
  for (int ir = ir0; ir < ir1; ++ir) {
15616
15757
  // q indices
@@ -15618,17 +15759,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15618
15759
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15619
15760
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15620
15761
 
15621
- const uint32_t h = iq2; // head
15762
+ const uint32_t h = iq2; // head index
15622
15763
  const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15623
15764
 
15624
- float S = 0.0f;
15625
- float M = -INFINITY;
15765
+ float S = 0.0f; // sum
15766
+ float M = -INFINITY; // maximum KQ value
15626
15767
 
15627
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15628
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15629
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15768
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15769
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15770
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15771
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15630
15772
 
15631
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15773
+ if (v->type == GGML_TYPE_F16) {
15774
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15775
+ } else {
15776
+ memset(VKQ32, 0, D*sizeof(float));
15777
+ }
15632
15778
 
15633
15779
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15634
15780
 
@@ -15640,6 +15786,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15640
15786
  const int iv3 = iq3 / rv3;
15641
15787
  const int iv2 = iq2 / rv2;
15642
15788
 
15789
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15790
+ q_to_vec_dot(pq, Q_q, D);
15791
+
15643
15792
  // online softmax / attention
15644
15793
  // loop over n_kv and n_head_kv
15645
15794
  // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15649,52 +15798,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15649
15798
  continue;
15650
15799
  }
15651
15800
 
15652
- float s;
15801
+ float s; // KQ value
15653
15802
 
15654
- // convert Q to F16 in V32
15655
- {
15656
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15803
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15804
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15657
15805
 
15658
- for (int64_t d = 0; d < D; ++d) {
15659
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15660
- }
15661
- }
15806
+ s = s*scale + mv; // scale KQ value and apply mask
15662
15807
 
15663
- ggml_vec_dot_f16(D,
15664
- &s, 0,
15665
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15666
- Q16, 0, 1);
15808
+ const float Mold = M;
15667
15809
 
15668
- s = s*scale + mv;
15810
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15811
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15669
15812
 
15670
- const float Mold = M;
15813
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15671
15814
 
15672
- float ms = 1.0f;
15673
- float vs = 1.0f;
15815
+ if (v->type== GGML_TYPE_F16) {
15816
+ if (s > M) {
15817
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15818
+ M = s;
15819
+ ms = expf(Mold - M);
15674
15820
 
15675
- if (s > M) {
15676
- M = s;
15677
- ms = expf(Mold - M);
15821
+ // V = V*expf(Mold - M)
15822
+ ggml_vec_scale_f16(D, VKQ16, ms);
15823
+ } else {
15824
+ // no new maximum, ms == 1.0f, vs != 1.0f
15825
+ vs = expf(s - M);
15826
+ }
15678
15827
 
15679
- // V = V*expf(Mold - M)
15680
- ggml_vec_scale_f16(D, V16, ms);
15828
+ // V += v*expf(s - M)
15829
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
15681
15830
  } else {
15682
- vs = expf(s - M);
15683
- }
15831
+ if (s > M) {
15832
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15833
+ M = s;
15834
+ ms = expf(Mold - M);
15835
+
15836
+ // V = V*expf(Mold - M)
15837
+ ggml_vec_scale_f32(D, VKQ32, ms);
15838
+ } else {
15839
+ // no new maximum, ms == 1.0f, vs != 1.0f
15840
+ vs = expf(s - M);
15841
+ }
15684
15842
 
15685
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15843
+ v_to_float(v_data, V32, D);
15686
15844
 
15687
- // V += v*expf(s - M)
15688
- ggml_vec_mad_f16(D, V16, v16, vs);
15845
+ // V += v*expf(s - M)
15846
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15847
+ }
15689
15848
 
15690
- S = S*ms + vs;
15849
+ S = S*ms + vs; // scale and increment sum with partial sum
15691
15850
  }
15692
15851
 
15693
- // V /= S
15694
- for (int64_t d = 0; d < D; ++d) {
15695
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15852
+ if (v->type == GGML_TYPE_F16) {
15853
+ for (int64_t d = 0; d < D; ++d) {
15854
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15855
+ }
15696
15856
  }
15697
15857
 
15858
+ // V /= S
15859
+ const float S_inv = 1.0f/S;
15860
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15861
+
15698
15862
  // dst indices
15699
15863
  const int i1 = iq1;
15700
15864
  const int i2 = iq2;
@@ -15704,7 +15868,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15704
15868
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
15705
15869
 
15706
15870
  // permute(0, 2, 1, 3)
15707
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15871
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
15708
15872
  }
15709
15873
  }
15710
15874
 
@@ -15729,165 +15893,6 @@ static void ggml_compute_forward_flash_attn_ext(
15729
15893
  }
15730
15894
  }
15731
15895
 
15732
- // ggml_compute_forward_flash_ff
15733
-
15734
- static void ggml_compute_forward_flash_ff_f16(
15735
- const struct ggml_compute_params * params,
15736
- struct ggml_tensor * dst) {
15737
-
15738
- const struct ggml_tensor * a = dst->src[0]; // F16
15739
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
15740
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
15741
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
15742
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
15743
-
15744
- int64_t t0 = ggml_perf_time_us();
15745
- UNUSED(t0);
15746
-
15747
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
15748
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
15749
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
15750
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
15751
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
15752
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
15753
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
15754
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
15755
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
15756
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
15757
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15758
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15759
-
15760
- const int ith = params->ith;
15761
- const int nth = params->nth;
15762
-
15763
- const int64_t D = nea0;
15764
- //const int64_t N = nea1;
15765
- const int64_t M = neb01;
15766
-
15767
- GGML_ASSERT(ne0 == nea0);
15768
- GGML_ASSERT(ne1 == nea1);
15769
- GGML_ASSERT(ne2 == nea2);
15770
-
15771
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
15772
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
15773
- GGML_ASSERT(nbb10 == sizeof(float));
15774
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
15775
- GGML_ASSERT(nbc10 == sizeof(float));
15776
-
15777
- GGML_ASSERT(neb00 == D);
15778
- GGML_ASSERT(neb01 == M);
15779
- GGML_ASSERT(neb10 == M);
15780
- GGML_ASSERT(neb11 == 1);
15781
-
15782
- GGML_ASSERT(nec00 == M);
15783
- GGML_ASSERT(nec01 == D);
15784
- GGML_ASSERT(nec10 == D);
15785
- GGML_ASSERT(nec11 == 1);
15786
-
15787
- // dst cannot be transposed or permuted
15788
- GGML_ASSERT(nb0 == sizeof(float));
15789
- GGML_ASSERT(nb0 <= nb1);
15790
- GGML_ASSERT(nb1 <= nb2);
15791
- GGML_ASSERT(nb2 <= nb3);
15792
-
15793
- if (params->type == GGML_TASK_TYPE_INIT) {
15794
- return;
15795
- }
15796
-
15797
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15798
- return;
15799
- }
15800
-
15801
- // parallelize by a rows using ggml_vec_dot_f32
15802
-
15803
- // total rows in a
15804
- const int nr = nea1*nea2*nea3;
15805
-
15806
- // rows per thread
15807
- const int dr = (nr + nth - 1)/nth;
15808
-
15809
- // row range for this thread
15810
- const int ir0 = dr*ith;
15811
- const int ir1 = MIN(ir0 + dr, nr);
15812
-
15813
- for (int ir = ir0; ir < ir1; ++ir) {
15814
- // a indices
15815
- const int ia3 = ir/(nea2*nea1);
15816
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
15817
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
15818
-
15819
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
15820
-
15821
- for (int64_t ic = 0; ic < neb01; ++ic) {
15822
- // b0 indices
15823
- const int ib03 = ia3;
15824
- const int ib02 = ia2;
15825
- const int ib01 = ic;
15826
-
15827
- // S indices
15828
- const int i1 = ib01;
15829
-
15830
- ggml_vec_dot_f16(nea0,
15831
- S + i1, 0,
15832
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
15833
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
15834
- }
15835
-
15836
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
15837
- //ggml_vec_gelu_f32(neb01, S, S);
15838
-
15839
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
15840
-
15841
- for (int64_t i = 0; i < M; i++) {
15842
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15843
- }
15844
-
15845
- ggml_vec_gelu_f16(neb01, S16, S16);
15846
-
15847
- {
15848
- // dst indices
15849
- const int i1 = ia1;
15850
- const int i2 = ia2;
15851
- const int i3 = ia3;
15852
-
15853
- for (int64_t ic = 0; ic < nec01; ++ic) {
15854
-
15855
- ggml_vec_dot_f16(neb01,
15856
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15857
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
15858
- S16, 0, 1);
15859
- }
15860
-
15861
- ggml_vec_add_f32(nec01,
15862
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
15863
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
15864
- (float *) c1->data);
15865
- }
15866
- }
15867
- }
15868
-
15869
- static void ggml_compute_forward_flash_ff(
15870
- const struct ggml_compute_params * params,
15871
- struct ggml_tensor * dst) {
15872
-
15873
- const struct ggml_tensor * b0 = dst->src[1];
15874
-
15875
- switch (b0->type) {
15876
- case GGML_TYPE_F16:
15877
- {
15878
- ggml_compute_forward_flash_ff_f16(params, dst);
15879
- } break;
15880
- case GGML_TYPE_F32:
15881
- {
15882
- GGML_ASSERT(false); // TODO
15883
- } break;
15884
- default:
15885
- {
15886
- GGML_ASSERT(false);
15887
- } break;
15888
- }
15889
- }
15890
-
15891
15896
  // ggml_compute_forward_flash_attn_back
15892
15897
 
15893
15898
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -16069,38 +16074,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
16069
16074
  vvexpf(SM, SM, &Mup);
16070
16075
  ggml_vec_sum_f32(Mup, &sum, SM);
16071
16076
  #else
16072
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
16073
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
16074
-
16075
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
16076
- if (i >= masked_begin) {
16077
- break;
16078
- }
16079
- float * SR = S + i;
16080
- float * SW = SM + i;
16081
-
16082
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
16083
- if (i + j >= masked_begin) {
16084
- break;
16085
- } else if (SR[j] == -INFINITY) {
16086
- SW[j] = 0.0f;
16087
- } else {
16088
- #ifndef GGML_FLASH_ATTN_EXP_FP16
16089
- const float val = expf(SR[j] - max);
16090
- #else
16091
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
16092
- memcpy(&scvt[j], &s, sizeof(uint16_t));
16093
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
16094
- #endif
16095
- sump[j] += (ggml_float)val;
16096
- SW[j] = val;
16097
- }
16098
- }
16099
- }
16100
-
16101
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
16102
- sum += sump[i];
16103
- }
16077
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
16104
16078
  #endif
16105
16079
  }
16106
16080
 
@@ -17126,35 +17100,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
17126
17100
  assert(!isnan(s1[i]));
17127
17101
  }
17128
17102
  #endif
17129
- // soft_max
17130
- ggml_float sum = 0.0;
17131
- {
17132
- float max = -INFINITY;
17133
- ggml_vec_max_f32(nc, &max, s0);
17134
17103
 
17135
- uint16_t scvt; UNUSED(scvt);
17136
- for (int i = 0; i < nc; i++) {
17137
- if (s0[i] == -INFINITY) {
17138
- st[i] = 0.0f;
17139
- } else {
17140
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17141
- const float s = s0[i] - max;
17142
- const float val = expf(s);
17143
- #else
17144
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17145
- memcpy(&scvt, &s, sizeof(scvt));
17146
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17147
- #endif
17148
- sum += (ggml_float)val;
17149
- st[i] = val;
17150
- }
17151
- }
17104
+ // soft_max
17105
+ float max = -INFINITY;
17106
+ ggml_vec_max_f32(nc, &max, s0);
17107
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17108
+ assert(sum > 0.0);
17109
+ sum = (1.0 - eps) / sum;
17152
17110
 
17153
- assert(sum > 0.0);
17154
- // sum = 1.0/sum;
17155
- }
17156
17111
  // avoid log(0) by rescaling from [0..1] to [eps..1]
17157
- sum = (1.0 - eps) / sum;
17158
17112
  ggml_vec_scale_f32(nc, st, sum);
17159
17113
  ggml_vec_add1_f32(nc, st, st, eps);
17160
17114
  ggml_vec_log_f32(nc, st, st);
@@ -17244,32 +17198,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17244
17198
  #endif
17245
17199
 
17246
17200
  // soft_max
17247
- ggml_float sum = 0.0;
17248
- {
17249
- float max = -INFINITY;
17250
- ggml_vec_max_f32(nc, &max, s0);
17251
-
17252
- uint16_t scvt; UNUSED(scvt);
17253
- for (int i = 0; i < nc; i++) {
17254
- if (s0[i] == -INFINITY) {
17255
- ds0[i] = 0.0f;
17256
- } else {
17257
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17258
- const float s = s0[i] - max;
17259
- const float val = expf(s);
17260
- #else
17261
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17262
- memcpy(&scvt, &s, sizeof(scvt));
17263
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17264
- #endif
17265
- sum += (ggml_float)val;
17266
- ds0[i] = val;
17267
- }
17268
- }
17269
-
17270
- assert(sum > 0.0);
17271
- sum = (1.0 - eps)/sum;
17272
- }
17201
+ float max = -INFINITY;
17202
+ ggml_vec_max_f32(nc, &max, s0);
17203
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17204
+ assert(sum > 0.0);
17205
+ sum = (1.0 - eps) / sum;
17273
17206
 
17274
17207
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17275
17208
  ggml_vec_scale_f32(nc, ds0, sum);
@@ -17306,7 +17239,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
17306
17239
 
17307
17240
  /////////////////////////////////
17308
17241
 
17309
- static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
17242
+ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
17310
17243
  GGML_ASSERT(params);
17311
17244
 
17312
17245
  if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -17404,7 +17337,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17404
17337
  } break;
17405
17338
  case GGML_OP_MUL_MAT:
17406
17339
  {
17407
- ggml_compute_forward_mul_mat(params, tensor);
17340
+ ggml_compute_forward_mul_mat(params, tensor, state);
17408
17341
  } break;
17409
17342
  case GGML_OP_MUL_MAT_ID:
17410
17343
  {
@@ -17530,21 +17463,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17530
17463
  {
17531
17464
  ggml_compute_forward_leaky_relu(params, tensor);
17532
17465
  } break;
17533
- case GGML_OP_FLASH_ATTN:
17534
- {
17535
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17536
- GGML_ASSERT(t == 0 || t == 1);
17537
- const bool masked = t != 0;
17538
- ggml_compute_forward_flash_attn(params, masked, tensor);
17539
- } break;
17540
17466
  case GGML_OP_FLASH_ATTN_EXT:
17541
17467
  {
17542
17468
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17543
17469
  } break;
17544
- case GGML_OP_FLASH_FF:
17545
- {
17546
- ggml_compute_forward_flash_ff(params, tensor);
17547
- } break;
17548
17470
  case GGML_OP_FLASH_ATTN_BACK:
17549
17471
  {
17550
17472
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -17914,6 +17836,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
17914
17836
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
17915
17837
  struct ggml_tensor * src0 = tensor->src[0];
17916
17838
  struct ggml_tensor * src1 = tensor->src[1];
17839
+ struct ggml_tensor * src2 = tensor->src[2];
17917
17840
 
17918
17841
  switch (tensor->op) {
17919
17842
  case GGML_OP_DUP:
@@ -18445,6 +18368,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18445
18368
  ggml_rope_back(ctx,
18446
18369
  tensor->grad,
18447
18370
  src1,
18371
+ src2,
18448
18372
  n_dims,
18449
18373
  mode,
18450
18374
  n_ctx,
@@ -18484,6 +18408,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18484
18408
  ggml_rope_impl(ctx,
18485
18409
  tensor->grad,
18486
18410
  src1,
18411
+ src2,
18487
18412
  n_dims,
18488
18413
  mode,
18489
18414
  n_ctx,
@@ -18548,7 +18473,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18548
18473
  {
18549
18474
  GGML_ASSERT(false); // TODO: not implemented
18550
18475
  } break;
18551
- case GGML_OP_FLASH_ATTN:
18552
18476
  case GGML_OP_FLASH_ATTN_EXT:
18553
18477
  {
18554
18478
  struct ggml_tensor * flash_grad = NULL;
@@ -18565,7 +18489,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18565
18489
  masked);
18566
18490
  }
18567
18491
 
18568
- struct ggml_tensor * src2 = tensor->src[2];
18569
18492
  const int64_t elem_q = ggml_nelements(src0);
18570
18493
  const int64_t elem_k = ggml_nelements(src1);
18571
18494
  const int64_t elem_v = ggml_nelements(src2);
@@ -18603,10 +18526,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18603
18526
  zero_table);
18604
18527
  }
18605
18528
  } break;
18606
- case GGML_OP_FLASH_FF:
18607
- {
18608
- GGML_ASSERT(false); // not supported
18609
- } break;
18610
18529
  case GGML_OP_FLASH_ATTN_BACK:
18611
18530
  {
18612
18531
  GGML_ASSERT(false); // not supported
@@ -19020,8 +18939,6 @@ typedef int ggml_lock_t;
19020
18939
 
19021
18940
  #define GGML_LOCK_INITIALIZER 0
19022
18941
 
19023
- typedef pthread_t ggml_thread_t;
19024
-
19025
18942
  #define ggml_thread_create pthread_create
19026
18943
  #define ggml_thread_join pthread_join
19027
18944
 
@@ -19047,8 +18964,6 @@ typedef int ggml_lock_t;
19047
18964
 
19048
18965
  #define GGML_LOCK_INITIALIZER 0
19049
18966
 
19050
- typedef pthread_t ggml_thread_t;
19051
-
19052
18967
  #define ggml_thread_create pthread_create
19053
18968
  #define ggml_thread_join pthread_join
19054
18969
 
@@ -19128,31 +19043,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
19128
19043
  static void clear_numa_thread_affinity(void) {}
19129
19044
  #endif
19130
19045
 
19131
- struct ggml_compute_state_shared {
19132
- const struct ggml_cgraph * cgraph;
19133
- const struct ggml_cplan * cplan;
19134
-
19135
- int64_t perf_node_start_cycles;
19136
- int64_t perf_node_start_time_us;
19137
-
19138
- const int n_threads;
19139
-
19140
- // synchronization primitives
19141
- atomic_int n_active; // num active threads
19142
- atomic_int node_n; // active graph node
19143
- atomic_int node_task; // active graph node task phase
19144
-
19145
- ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
19146
- void * abort_callback_data;
19147
- };
19148
-
19149
- struct ggml_compute_state {
19150
- ggml_thread_t thrd;
19151
- int ith;
19152
- struct ggml_compute_state_shared * shared;
19153
- enum ggml_status ec;
19154
- };
19155
-
19156
19046
  static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
19157
19047
  int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles;
19158
19048
  int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
@@ -19322,15 +19212,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19322
19212
  {
19323
19213
  n_tasks = n_threads;
19324
19214
  } break;
19325
- case GGML_OP_FLASH_ATTN:
19326
19215
  case GGML_OP_FLASH_ATTN_EXT:
19327
19216
  {
19328
19217
  n_tasks = n_threads;
19329
19218
  } break;
19330
- case GGML_OP_FLASH_FF:
19331
- {
19332
- n_tasks = n_threads;
19333
- } break;
19334
19219
  case GGML_OP_FLASH_ATTN_BACK:
19335
19220
  {
19336
19221
  n_tasks = n_threads;
@@ -19425,6 +19310,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
19425
19310
 
19426
19311
  * node_n = atomic_load(&state->shared->node_n);
19427
19312
  if (* node_n != last_node_n) break;
19313
+ #if defined(__SSE3__)
19314
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19315
+ _mm_pause();
19316
+ #endif
19428
19317
  }
19429
19318
  }
19430
19319
 
@@ -19439,6 +19328,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
19439
19328
 
19440
19329
  * task_phase = atomic_load(&state->shared->node_task);
19441
19330
  if (* task_phase != last_task_phase) break;
19331
+ #if defined(__SSE3__)
19332
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19333
+ _mm_pause();
19334
+ #endif
19442
19335
  }
19443
19336
  }
19444
19337
 
@@ -19478,7 +19371,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19478
19371
  struct ggml_tensor * node = cgraph->nodes[node_n];
19479
19372
  if (GGML_OP_HAS_FINALIZE[node->op]) {
19480
19373
  params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
19481
- ggml_compute_forward(&params, node);
19374
+ ggml_compute_forward(&params, node, state);
19482
19375
  }
19483
19376
  ggml_graph_compute_perf_stats_node(node, state->shared);
19484
19377
  }
@@ -19498,17 +19391,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19498
19391
  /* INIT */
19499
19392
  if (GGML_OP_HAS_INIT[node->op]) {
19500
19393
  params.type = GGML_TASK_TYPE_INIT;
19501
- ggml_compute_forward(&params, node);
19394
+ ggml_compute_forward(&params, node, state);
19502
19395
  }
19503
19396
 
19504
19397
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
19505
19398
  // they do something more efficient than spinning (?)
19506
19399
  params.type = GGML_TASK_TYPE_COMPUTE;
19507
- ggml_compute_forward(&params, node);
19400
+ ggml_compute_forward(&params, node, state);
19508
19401
 
19509
19402
  if (GGML_OP_HAS_FINALIZE[node->op]) {
19510
19403
  params.type = GGML_TASK_TYPE_FINALIZE;
19511
- ggml_compute_forward(&params, node);
19404
+ ggml_compute_forward(&params, node, state);
19512
19405
  }
19513
19406
 
19514
19407
  ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -19547,7 +19440,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19547
19440
 
19548
19441
  if (state->ith < n_tasks) {
19549
19442
  if (GGML_OP_HAS_INIT[node->op]) {
19550
- ggml_compute_forward(&params, node);
19443
+ ggml_compute_forward(&params, node, state);
19551
19444
  }
19552
19445
  }
19553
19446
 
@@ -19568,7 +19461,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19568
19461
 
19569
19462
  if (state->ith < n_tasks) {
19570
19463
  params.type = GGML_TASK_TYPE_COMPUTE;
19571
- ggml_compute_forward(&params, node);
19464
+ ggml_compute_forward(&params, node, state);
19572
19465
  }
19573
19466
 
19574
19467
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -19719,39 +19612,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19719
19612
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19720
19613
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19721
19614
  } break;
19722
- case GGML_OP_FLASH_ATTN:
19723
- {
19724
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19725
-
19726
- if (node->src[1]->type == GGML_TYPE_F32) {
19727
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19728
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19729
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19730
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19731
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19732
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19733
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19734
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19735
- }
19736
- } break;
19737
19615
  case GGML_OP_FLASH_ATTN_EXT:
19738
19616
  {
19739
19617
  const int64_t ne00 = node->src[0]->ne[0]; // D
19740
19618
 
19741
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19742
- } break;
19743
- case GGML_OP_FLASH_FF:
19744
- {
19745
- if (node->src[1]->type == GGML_TYPE_F32) {
19746
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19747
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19748
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19749
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19750
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19751
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19752
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19753
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19754
- }
19619
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19755
19620
  } break;
19756
19621
  case GGML_OP_FLASH_ATTN_BACK:
19757
19622
  {
@@ -19819,6 +19684,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19819
19684
  /*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
19820
19685
  /*.abort_callback =*/ NULL,
19821
19686
  /*.abort_callback_data =*/ NULL,
19687
+ /*.current_chunk; =*/ 0,
19822
19688
  };
19823
19689
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
19824
19690
 
@@ -21592,11 +21458,7 @@ size_t ggml_quantize_chunk(
21592
21458
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21593
21459
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21594
21460
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21595
- #if QK_K == 64
21596
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21597
- #else
21598
21461
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21599
- #endif
21600
21462
  case GGML_TYPE_F16:
21601
21463
  {
21602
21464
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -22873,6 +22735,14 @@ int ggml_cpu_has_avx512_vnni(void) {
22873
22735
  #endif
22874
22736
  }
22875
22737
 
22738
+ int ggml_cpu_has_avx512_bf16(void) {
22739
+ #if defined(__AVX512BF16__)
22740
+ return 1;
22741
+ #else
22742
+ return 0;
22743
+ #endif
22744
+ }
22745
+
22876
22746
  int ggml_cpu_has_fma(void) {
22877
22747
  #if defined(__FMA__)
22878
22748
  return 1;
@@ -22889,6 +22759,16 @@ int ggml_cpu_has_neon(void) {
22889
22759
  #endif
22890
22760
  }
22891
22761
 
22762
+ int ggml_cpu_has_sve(void) {
22763
+ #if defined(__ARM_FEATURE_SVE)
22764
+ // TODO: Currently, SVE 256 bit is only supported.
22765
+ GGML_ASSERT(svcntb() == QK8_0);
22766
+ return 1;
22767
+ #else
22768
+ return 0;
22769
+ #endif
22770
+ }
22771
+
22892
22772
  int ggml_cpu_has_arm_fma(void) {
22893
22773
  #if defined(__ARM_FEATURE_FMA)
22894
22774
  return 1;