@fugood/llama.node 0.2.0 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. package/CMakeLists.txt +9 -0
  2. package/README.md +1 -1
  3. package/bin/darwin/arm64/default.metallib +0 -0
  4. package/bin/darwin/arm64/llama-node.node +0 -0
  5. package/bin/darwin/x64/default.metallib +0 -0
  6. package/bin/darwin/x64/llama-node.node +0 -0
  7. package/bin/linux/arm64/llama-node.node +0 -0
  8. package/bin/linux/x64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  10. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  11. package/bin/win32/arm64/llama-node.node +0 -0
  12. package/bin/win32/arm64/node.lib +0 -0
  13. package/bin/win32/x64/llama-node.node +0 -0
  14. package/bin/win32/x64/node.lib +0 -0
  15. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/arm64/node.lib +0 -0
  17. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  18. package/bin/win32-vulkan/x64/node.lib +0 -0
  19. package/lib/binding.ts +1 -1
  20. package/package.json +2 -1
  21. package/patches/llama.patch +22 -0
  22. package/src/LlamaContext.cpp +2 -2
  23. package/src/TokenizeWorker.cpp +1 -1
  24. package/src/llama.cpp/CMakeLists.txt +82 -54
  25. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +16 -0
  26. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +6 -0
  27. package/src/llama.cpp/common/common.cpp +748 -754
  28. package/src/llama.cpp/common/common.h +49 -41
  29. package/src/llama.cpp/common/grammar-parser.cpp +10 -1
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +6 -6
  31. package/src/llama.cpp/common/log.h +5 -5
  32. package/src/llama.cpp/common/sampling.cpp +92 -10
  33. package/src/llama.cpp/common/sampling.h +6 -1
  34. package/src/llama.cpp/common/train.cpp +2 -2
  35. package/src/llama.cpp/examples/CMakeLists.txt +3 -0
  36. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  37. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  38. package/src/llama.cpp/examples/embedding/embedding.cpp +13 -4
  39. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +2 -2
  40. package/src/llama.cpp/examples/finetune/finetune.cpp +4 -3
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -2
  42. package/src/llama.cpp/examples/infill/infill.cpp +8 -8
  43. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +57 -8
  44. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +55 -0
  45. package/src/llama.cpp/examples/llama.android/{app → llama}/src/main/cpp/CMakeLists.txt +7 -8
  46. package/src/llama.cpp/examples/llama.android/{app → llama}/src/main/cpp/llama-android.cpp +14 -14
  47. package/src/llama.cpp/examples/llava/clip.h +1 -1
  48. package/src/llama.cpp/examples/llava/llava-cli.cpp +27 -7
  49. package/src/llama.cpp/examples/llava/llava.cpp +0 -15
  50. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  51. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  52. package/src/llama.cpp/examples/main/main.cpp +29 -17
  53. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  54. package/src/llama.cpp/examples/perplexity/perplexity.cpp +9 -9
  55. package/src/llama.cpp/examples/quantize/quantize.cpp +2 -2
  56. package/src/llama.cpp/examples/retrieval/retrieval.cpp +2 -2
  57. package/src/llama.cpp/examples/rpc/CMakeLists.txt +2 -0
  58. package/src/llama.cpp/examples/rpc/rpc-server.cpp +134 -0
  59. package/src/llama.cpp/examples/server/server.cpp +33 -25
  60. package/src/llama.cpp/examples/server/utils.hpp +1 -1
  61. package/src/llama.cpp/examples/tokenize/tokenize.cpp +359 -9
  62. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +4 -3
  63. package/src/llama.cpp/ggml-backend.c +2 -3
  64. package/src/llama.cpp/ggml-common.h +0 -54
  65. package/src/llama.cpp/ggml-cuda.h +1 -0
  66. package/src/llama.cpp/ggml-impl.h +51 -0
  67. package/src/llama.cpp/ggml-kompute.cpp +13 -3
  68. package/src/llama.cpp/ggml-opencl.cpp +4 -1
  69. package/src/llama.cpp/ggml-quants.c +3715 -2050
  70. package/src/llama.cpp/ggml-rpc.cpp +1155 -0
  71. package/src/llama.cpp/ggml-rpc.h +24 -0
  72. package/src/llama.cpp/ggml-sycl.cpp +119 -673
  73. package/src/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
  74. package/src/llama.cpp/ggml-vulkan.cpp +203 -224
  75. package/src/llama.cpp/ggml.c +1208 -1483
  76. package/src/llama.cpp/ggml.h +71 -46
  77. package/src/llama.cpp/llama.cpp +1374 -938
  78. package/src/llama.cpp/llama.h +22 -6
  79. package/src/llama.cpp/requirements.txt +0 -2
  80. package/src/llama.cpp/tests/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/tests/test-backend-ops.cpp +120 -57
  82. package/src/llama.cpp/tests/test-chat-template.cpp +16 -4
  83. package/src/llama.cpp/tests/test-grad0.cpp +43 -83
  84. package/src/llama.cpp/tests/test-grammar-integration.cpp +46 -0
  85. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +27 -3
  86. package/src/llama.cpp/unicode-data.cpp +6969 -2169
  87. package/src/llama.cpp/unicode-data.h +15 -12
  88. package/src/llama.cpp/unicode.cpp +89 -111
  89. package/src/llama.cpp/unicode.h +44 -12
  90. package/src/llama.cpp/build.zig +0 -172
  91. package/src/llama.cpp/ggml-mpi.c +0 -216
  92. package/src/llama.cpp/ggml-mpi.h +0 -39
  93. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +0 -2
  94. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +0 -2
@@ -4,7 +4,6 @@
4
4
  #include "ggml-impl.h"
5
5
  #include "ggml-quants.h"
6
6
  #include "ggml.h"
7
- #include "sgemm.h"
8
7
 
9
8
  #if defined(_MSC_VER) || defined(__MINGW32__)
10
9
  #include <malloc.h> // using malloc.h with MSC/MINGW
@@ -37,6 +36,10 @@
37
36
  #undef GGML_USE_LLAMAFILE
38
37
  #endif
39
38
 
39
+ #ifdef GGML_USE_LLAMAFILE
40
+ #include "sgemm.h"
41
+ #endif
42
+
40
43
  #if defined(_MSC_VER)
41
44
  // disable "possible loss of data" to avoid hundreds of casts
42
45
  // we should just be careful :)
@@ -109,6 +112,8 @@ typedef void * thread_ret_t;
109
112
 
110
113
  #endif
111
114
 
115
+ typedef pthread_t ggml_thread_t;
116
+
112
117
  #ifdef GGML_USE_CPU_HBM
113
118
  #include <hbwmalloc.h>
114
119
  #endif
@@ -160,9 +165,6 @@ void ggml_print_backtrace(void) {
160
165
  #define GGML_DEBUG 0
161
166
  #define GGML_GELU_FP16
162
167
  #define GGML_GELU_QUICK_FP16
163
- #define GGML_SILU_FP16
164
- // #define GGML_CROSS_ENTROPY_EXP_FP16
165
- // #define GGML_FLASH_ATTN_EXP_FP16
166
168
 
167
169
  #define GGML_SOFT_MAX_UNROLL 4
168
170
  #define GGML_VEC_DOT_UNROLL 2
@@ -313,12 +315,6 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
313
315
  // precomputed quick gelu table for f16 (128 KB)
314
316
  static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
315
317
 
316
- // precomputed silu table for f16 (128 KB)
317
- static ggml_fp16_t ggml_table_silu_f16[1 << 16];
318
-
319
- // precomputed exp table for f16 (128 KB)
320
- static ggml_fp16_t ggml_table_exp_f16[1 << 16];
321
-
322
318
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
323
319
  float ggml_table_f32_f16[1 << 16];
324
320
 
@@ -410,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
410
406
  int i = 0;
411
407
  #if defined(__AVX512BF16__)
412
408
  for (; i + 32 <= n; i += 32) {
413
- _mm512_storeu_ps(
414
- (__m512 *)(y + i),
415
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
416
- _mm512_loadu_ps(x + i)));
409
+ _mm512_storeu_si512(
410
+ (__m512i *)(y + i),
411
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412
+ _mm512_loadu_ps(x + i))));
417
413
  }
418
414
  #endif
419
415
  for (; i < n; i++) {
@@ -875,22 +871,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
875
871
  },
876
872
  [GGML_TYPE_IQ4_XS] = {
877
873
  .type_name = "iq4_xs",
878
- #if QK_K == 64
879
- .blck_size = QK4_NL,
880
- #else
881
874
  .blck_size = QK_K,
882
- #endif
883
875
  .type_size = sizeof(block_iq4_xs),
884
876
  .is_quantized = true,
885
877
  .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
886
878
  .from_float = quantize_row_iq4_xs,
887
879
  .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
888
880
  .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
889
- #if QK_K == 64
890
- .vec_dot_type = GGML_TYPE_Q8_0,
891
- #else
892
881
  .vec_dot_type = GGML_TYPE_Q8_K,
893
- #endif
894
882
  .nrows = 1,
895
883
  },
896
884
  [GGML_TYPE_Q8_K] = {
@@ -1303,6 +1291,8 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1303
1291
  #define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
1304
1292
  #define GGML_F16_VEC_SET1 GGML_F32x4_SET1
1305
1293
  #define GGML_F16_VEC_FMA GGML_F32x4_FMA
1294
+ #define GGML_F16_VEC_ADD GGML_F32x4_ADD
1295
+ #define GGML_F16_VEC_MUL GGML_F32x4_MUL
1306
1296
  #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
1307
1297
  // Use vec_xl, not vec_ld, in case the load address is not aligned.
1308
1298
  #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \
@@ -1525,6 +1515,195 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1525
1515
  #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1526
1516
  #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1527
1517
 
1518
+ #elif defined(__loongarch_asx)
1519
+
1520
+ #define GGML_SIMD
1521
+
1522
+ // F32 LASX
1523
+ #define GGML_F32_STEP 32
1524
+ #define GGML_F32_EPR 8
1525
+
1526
+ #define GGML_F32x8 __m256
1527
+ #define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0)
1528
+ #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
1529
+ #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
1530
+ #define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0)
1531
+ #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
1532
+ #define GGML_F32x8_ADD __lasx_xvfadd_s
1533
+ #define GGML_F32x8_MUL __lasx_xvfmul_s
1534
+ #define GGML_F32x8_REDUCE(res, x) \
1535
+ do { \
1536
+ int offset = GGML_F32_ARR >> 1; \
1537
+ for (int i = 0; i < offset; ++i) { \
1538
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1539
+ } \
1540
+ offset >>= 1; \
1541
+ for (int i = 0; i < offset; ++i) { \
1542
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1543
+ } \
1544
+ offset >>= 1; \
1545
+ for (int i = 0; i < offset; ++i) { \
1546
+ x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \
1547
+ } \
1548
+ float *tmp_p = (float *)&x[0]; \
1549
+ res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \
1550
+ } while (0)
1551
+ // TODO: is this optimal ?
1552
+
1553
+ #define GGML_F32_VEC GGML_F32x8
1554
+ #define GGML_F32_VEC_ZERO GGML_F32x8_ZERO
1555
+ #define GGML_F32_VEC_SET1 GGML_F32x8_SET1
1556
+ #define GGML_F32_VEC_LOAD GGML_F32x8_LOAD
1557
+ #define GGML_F32_VEC_STORE GGML_F32x8_STORE
1558
+ #define GGML_F32_VEC_FMA GGML_F32x8_FMA
1559
+ #define GGML_F32_VEC_ADD GGML_F32x8_ADD
1560
+ #define GGML_F32_VEC_MUL GGML_F32x8_MUL
1561
+ #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
1562
+
1563
+ // F16 LASX
1564
+
1565
+ #define GGML_F16_STEP 32
1566
+ #define GGML_F16_EPR 8
1567
+
1568
+ // F16 arithmetic is not supported by AVX, so we use F32 instead
1569
+
1570
+ #define GGML_F32Cx8 __m256
1571
+ #define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1572
+ #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1573
+
1574
+ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
1575
+ float tmp[8];
1576
+
1577
+ for (int i = 0; i < 8; i++) {
1578
+ tmp[i] = GGML_FP16_TO_FP32(x[i]);
1579
+ }
1580
+
1581
+ return (__m256)__lasx_xvld(tmp, 0);
1582
+ }
1583
+ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
1584
+ float arr[8];
1585
+
1586
+ __lasx_xvst(y, arr, 0);
1587
+
1588
+ for (int i = 0; i < 8; i++)
1589
+ x[i] = GGML_FP32_TO_FP16(arr[i]);
1590
+ }
1591
+ #define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1592
+ #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
1593
+
1594
+ #define GGML_F32Cx8_FMA GGML_F32x8_FMA
1595
+ #define GGML_F32Cx8_ADD __lasx_xvfadd_s
1596
+ #define GGML_F32Cx8_MUL __lasx_xvfmul_s
1597
+ #define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE
1598
+
1599
+ #define GGML_F16_VEC GGML_F32Cx8
1600
+ #define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO
1601
+ #define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1
1602
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p)
1603
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
1604
+ #define GGML_F16_VEC_FMA GGML_F32Cx8_FMA
1605
+ #define GGML_F16_VEC_ADD GGML_F32Cx8_ADD
1606
+ #define GGML_F16_VEC_MUL GGML_F32Cx8_MUL
1607
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE
1608
+
1609
+ #elif defined(__loongarch_sx)
1610
+
1611
+ #define GGML_SIMD
1612
+
1613
+ // F32 LSX
1614
+
1615
+ #define GGML_F32_STEP 32
1616
+ #define GGML_F32_EPR 4
1617
+
1618
+ #define GGML_F32x4 __m128
1619
+ #define GGML_F32x4_ZERO __lsx_vldi(0)
1620
+ #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1621
+ #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
1622
+ #define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0)
1623
+ #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1624
+ #define GGML_F32x4_ADD __lsx_vfadd_s
1625
+ #define GGML_F32x4_MUL __lsx_vfmul_s
1626
+ #define GGML_F32x4_REDUCE(res, x) \
1627
+ { \
1628
+ int offset = GGML_F32_ARR >> 1; \
1629
+ for (int i = 0; i < offset; ++i) { \
1630
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1631
+ } \
1632
+ offset >>= 1; \
1633
+ for (int i = 0; i < offset; ++i) { \
1634
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1635
+ } \
1636
+ offset >>= 1; \
1637
+ for (int i = 0; i < offset; ++i) { \
1638
+ x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1639
+ } \
1640
+ __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1641
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1642
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1643
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1644
+ tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1645
+ tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1646
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1647
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1648
+ }
1649
+
1650
+ #define GGML_F32_VEC GGML_F32x4
1651
+ #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
1652
+ #define GGML_F32_VEC_SET1 GGML_F32x4_SET1
1653
+ #define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
1654
+ #define GGML_F32_VEC_STORE GGML_F32x4_STORE
1655
+ #define GGML_F32_VEC_FMA GGML_F32x4_FMA
1656
+ #define GGML_F32_VEC_ADD GGML_F32x4_ADD
1657
+ #define GGML_F32_VEC_MUL GGML_F32x4_MUL
1658
+ #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
1659
+
1660
+ // F16 LSX
1661
+
1662
+ #define GGML_F16_STEP 32
1663
+ #define GGML_F16_EPR 4
1664
+
1665
+ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
1666
+ float tmp[4];
1667
+
1668
+ tmp[0] = GGML_FP16_TO_FP32(x[0]);
1669
+ tmp[1] = GGML_FP16_TO_FP32(x[1]);
1670
+ tmp[2] = GGML_FP16_TO_FP32(x[2]);
1671
+ tmp[3] = GGML_FP16_TO_FP32(x[3]);
1672
+
1673
+ return __lsx_vld(tmp, 0);
1674
+ }
1675
+
1676
+ static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
1677
+ float arr[4];
1678
+
1679
+ __lsx_vst(y, arr, 0);
1680
+
1681
+ x[0] = GGML_FP32_TO_FP16(arr[0]);
1682
+ x[1] = GGML_FP32_TO_FP16(arr[1]);
1683
+ x[2] = GGML_FP32_TO_FP16(arr[2]);
1684
+ x[3] = GGML_FP32_TO_FP16(arr[3]);
1685
+ }
1686
+
1687
+ #define GGML_F32Cx4 __m128
1688
+ #define GGML_F32Cx4_ZERO __lsx_vldi(0)
1689
+ #define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
1690
+ #define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x)
1691
+ #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
1692
+ #define GGML_F32Cx4_FMA GGML_F32x4_FMA
1693
+ #define GGML_F32Cx4_ADD __lsx_vfadd_s
1694
+ #define GGML_F32Cx4_MUL __lsx_vfmul_s
1695
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
1696
+
1697
+ #define GGML_F16_VEC GGML_F32Cx4
1698
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
1699
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
1700
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
1701
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i])
1702
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
1703
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
1704
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
1705
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
1706
+
1528
1707
  #endif
1529
1708
 
1530
1709
  // GGML_F32_ARR / GGML_F16_ARR
@@ -1534,6 +1713,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
1534
1713
  #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
1535
1714
  #endif
1536
1715
 
1716
+ //
1717
+ // ggml context
1718
+ //
1719
+
1720
+ struct ggml_context {
1721
+ size_t mem_size;
1722
+ void* mem_buffer;
1723
+ bool mem_buffer_owned;
1724
+ bool no_alloc;
1725
+ bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
1726
+
1727
+ int n_objects;
1728
+
1729
+ struct ggml_object* objects_begin;
1730
+ struct ggml_object* objects_end;
1731
+
1732
+ struct ggml_scratch scratch;
1733
+ struct ggml_scratch scratch_save;
1734
+ };
1735
+
1736
+ struct ggml_context_container {
1737
+ bool used;
1738
+
1739
+ struct ggml_context context;
1740
+ };
1741
+
1742
+ struct ggml_compute_state_shared {
1743
+ const struct ggml_cgraph* cgraph;
1744
+ const struct ggml_cplan* cplan;
1745
+
1746
+ int64_t perf_node_start_cycles;
1747
+ int64_t perf_node_start_time_us;
1748
+
1749
+ const int n_threads;
1750
+
1751
+ // synchronization primitives
1752
+ atomic_int n_active; // num active threads
1753
+ atomic_int node_n; // active graph node
1754
+ atomic_int node_task; // active graph node task phase
1755
+
1756
+ ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
1757
+ void* abort_callback_data;
1758
+
1759
+ atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1760
+ };
1761
+
1762
+ struct ggml_compute_state {
1763
+ ggml_thread_t thrd;
1764
+ int ith;
1765
+ struct ggml_compute_state_shared* shared;
1766
+ enum ggml_status ec;
1767
+ };
1768
+
1537
1769
  //
1538
1770
  // fundamental operations
1539
1771
  //
@@ -1615,10 +1847,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
1615
1847
  __m512 c1 = _mm512_setzero_ps();
1616
1848
  __m512 c2 = _mm512_setzero_ps();
1617
1849
  for (; i + 64 <= n; i += 64) {
1618
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
1619
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
1620
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
1621
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
1850
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
1851
+ m512bh(_mm512_loadu_si512((y + i))));
1852
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
1853
+ m512bh(_mm512_loadu_si512((y + i + 32))));
1622
1854
  }
1623
1855
  sumf += (ggml_float)_mm512_reduce_add_ps(c1);
1624
1856
  sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -1949,6 +2181,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
1949
2181
  inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1950
2182
  inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1951
2183
  inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
2184
+ inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
1952
2185
  // TODO: optimize performance
1953
2186
  inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
1954
2187
  inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -2024,52 +2257,291 @@ inline static float ggml_silu_f32(float x) {
2024
2257
  return x/(1.0f + expf(-x));
2025
2258
  }
2026
2259
 
2027
- //inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
2028
- // const uint16_t * i16 = (const uint16_t *) x;
2029
- // for (int i = 0; i < n; ++i) {
2030
- // y[i] = ggml_table_silu_f16[i16[i]];
2031
- // }
2032
- //}
2260
+ #if defined(__ARM_NEON) && defined(__aarch64__)
2261
+
2262
+ // adapted from arm limited optimized routine
2263
+ // the maximum error is 1.45358 plus 0.5 ulps
2264
+ // numbers above 88.38 will flush to infinity
2265
+ // numbers beneath -103.97 will flush to zero
2266
+ inline static float32x4_t ggml_v_expf(float32x4_t x) {
2267
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
2268
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
2269
+ const float32x4_t n = vsubq_f32(z, r);
2270
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
2271
+ vdupq_n_f32(0x1.7f7d1cp-20f));
2272
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
2273
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
2274
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
2275
+ const float32x4_t u = vmulq_f32(b, b);
2276
+ const float32x4_t j = vfmaq_f32(
2277
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
2278
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
2279
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
2280
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
2281
+ return vfmaq_f32(k, j, k);
2282
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
2283
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
2284
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
2285
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
2286
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
2287
+ }
2288
+
2289
+ // computes silu x/(1+exp(-x)) in single precision vector
2290
+ inline static float32x4_t ggml_v_silu(float32x4_t x) {
2291
+ const float32x4_t one = vdupq_n_f32(1.0f);
2292
+ const float32x4_t zero = vdupq_n_f32(0.0f);
2293
+ const float32x4_t neg_x = vsubq_f32(zero, x);
2294
+ const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
2295
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
2296
+ return vdivq_f32(x, one_plus_exp_neg_x);
2297
+ }
2298
+
2299
+ #elif defined(__AVX512F__) && defined(__AVX512DQ__)
2300
+
2301
+ // adapted from arm limited optimized routine
2302
+ // the maximum error is 1.45358 plus 0.5 ulps
2303
+ // numbers above 88.38 will flush to infinity
2304
+ // numbers beneath -103.97 will flush to zero
2305
+ inline static __m512 ggml_v_expf(__m512 x) {
2306
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
2307
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
2308
+ const __m512 n = _mm512_sub_ps(z, r);
2309
+ const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2310
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2311
+ const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2312
+ const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2313
+ const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2314
+ const __m512 u = _mm512_mul_ps(b, b);
2315
+ const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2316
+ _mm512_set1_ps(0x1.573e2ep-5f)), u,
2317
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2318
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
2319
+ u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2320
+ if (_mm512_kortestz(c, c))
2321
+ return _mm512_fmadd_ps(j, k, k);
2322
+ const __m512i g = _mm512_and_si512(
2323
+ _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2324
+ _mm512_set1_epi32(0x82000000u));
2325
+ const __m512 s1 =
2326
+ _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2327
+ const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2328
+ const __mmask16 d =
2329
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2330
+ return _mm512_mask_blend_ps(
2331
+ d, _mm512_mask_blend_ps(
2332
+ c, _mm512_fmadd_ps(k, j, k),
2333
+ _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2334
+ _mm512_mul_ps(s1, s1));
2335
+ }
2336
+
2337
+ // computes silu x/(1+exp(-x)) in single precision vector
2338
+ inline static __m512 ggml_v_silu(__m512 x) {
2339
+ const __m512 one = _mm512_set1_ps(1);
2340
+ const __m512 zero = _mm512_setzero_ps();
2341
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
2342
+ const __m512 exp_neg_x = ggml_v_expf(neg_x);
2343
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
2344
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
2345
+ }
2346
+
2347
+ #elif defined(__AVX2__) && defined(__FMA__)
2348
+
2349
+ // adapted from arm limited optimized routine
2350
+ // the maximum error is 1.45358 plus 0.5 ulps
2351
+ // numbers above 88.38 will flush to infinity
2352
+ // numbers beneath -103.97 will flush to zero
2353
+ inline static __m256 ggml_v_expf(__m256 x) {
2354
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
2355
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
2356
+ const __m256 n = _mm256_sub_ps(z, r);
2357
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
2358
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
2359
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
2360
+ const __m256 k = _mm256_castsi256_ps(
2361
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
2362
+ const __m256i c = _mm256_castps_si256(
2363
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2364
+ _mm256_set1_ps(126), _CMP_GT_OQ));
2365
+ const __m256 u = _mm256_mul_ps(b, b);
2366
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
2367
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
2368
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
2369
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
2370
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
2371
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
2372
+ return _mm256_fmadd_ps(j, k, k);
2373
+ const __m256i g = _mm256_and_si256(
2374
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
2375
+ _mm256_set1_epi32(0x82000000u));
2376
+ const __m256 s1 =
2377
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
2378
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
2379
+ const __m256i d = _mm256_castps_si256(
2380
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
2381
+ _mm256_set1_ps(192), _CMP_GT_OQ));
2382
+ return _mm256_or_ps(
2383
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
2384
+ _mm256_andnot_ps(
2385
+ _mm256_castsi256_ps(d),
2386
+ _mm256_or_ps(
2387
+ _mm256_and_ps(_mm256_castsi256_ps(c),
2388
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
2389
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
2390
+ }
2391
+
2392
+ // computes silu x/(1+exp(-x)) in single precision vector
2393
+ inline static __m256 ggml_v_silu(__m256 x) {
2394
+ const __m256 one = _mm256_set1_ps(1);
2395
+ const __m256 zero = _mm256_setzero_ps();
2396
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
2397
+ const __m256 exp_neg_x = ggml_v_expf(neg_x);
2398
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
2399
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
2400
+ }
2401
+
2402
+ #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
2033
2403
 
2034
- #ifdef GGML_SILU_FP16
2035
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2036
- uint16_t t;
2037
- for (int i = 0; i < n; ++i) {
2038
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2039
- memcpy(&t, &fp16, sizeof(uint16_t));
2040
- y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
2041
- }
2042
- }
2404
+ #if defined(__FMA__)
2405
+ #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
2406
+ #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
2043
2407
  #else
2044
- inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2045
- for (int i = 0; i < n; ++i) {
2408
+ #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
2409
+ #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
2410
+ #endif
2411
+
2412
+ // adapted from arm limited optimized routine
2413
+ // the maximum error is 1.45358 plus 0.5 ulps
2414
+ // numbers above 88.38 will flush to infinity
2415
+ // numbers beneath -103.97 will flush to zero
2416
+ inline static __m128 ggml_v_expf(__m128 x) {
2417
+ const __m128 r = _mm_set1_ps(0x1.8p23f);
2418
+ const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
2419
+ const __m128 n = _mm_sub_ps(z, r);
2420
+ const __m128 b =
2421
+ NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
2422
+ const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
2423
+ const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
2424
+ const __m128i c =
2425
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
2426
+ const __m128 u = _mm_mul_ps(b, b);
2427
+ const __m128 j =
2428
+ MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
2429
+ MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
2430
+ u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
2431
+ if (!_mm_movemask_epi8(c))
2432
+ return MADD128(j, k, k);
2433
+ const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
2434
+ _mm_set1_epi32(0x82000000u));
2435
+ const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
2436
+ const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
2437
+ const __m128i d =
2438
+ _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
2439
+ return _mm_or_ps(
2440
+ _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
2441
+ _mm_andnot_ps(_mm_castsi128_ps(d),
2442
+ _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
2443
+ _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
2444
+ }
2445
+
2446
+ // computes silu x/(1+exp(-x)) in single precision vector
2447
+ inline static __m128 ggml_v_silu(__m128 x) {
2448
+ const __m128 one = _mm_set1_ps(1);
2449
+ const __m128 zero = _mm_setzero_ps();
2450
+ const __m128 neg_x = _mm_sub_ps(zero, x);
2451
+ const __m128 exp_neg_x = ggml_v_expf(neg_x);
2452
+ const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
2453
+ return _mm_div_ps(x, one_plus_exp_neg_x);
2454
+ }
2455
+
2456
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__
2457
+
2458
+ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
2459
+ int i = 0;
2460
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2461
+ for (; i + 15 < n; i += 16) {
2462
+ _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
2463
+ }
2464
+ #elif defined(__AVX2__) && defined(__FMA__)
2465
+ for (; i + 7 < n; i += 8) {
2466
+ _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
2467
+ }
2468
+ #elif defined(__SSE2__)
2469
+ for (; i + 3 < n; i += 4) {
2470
+ _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
2471
+ }
2472
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2473
+ for (; i + 3 < n; i += 4) {
2474
+ vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
2475
+ }
2476
+ #endif
2477
+ for (; i < n; ++i) {
2046
2478
  y[i] = ggml_silu_f32(x[i]);
2047
2479
  }
2048
2480
  }
2481
+
2482
+ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
2483
+ int i = 0;
2484
+ ggml_float sum = 0;
2485
+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
2486
+ for (; i + 15 < n; i += 16) {
2487
+ __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
2488
+ _mm512_set1_ps(max)));
2489
+ _mm512_storeu_ps(y + i, val);
2490
+ sum += (ggml_float)_mm512_reduce_add_ps(val);
2491
+ }
2492
+ #elif defined(__AVX2__) && defined(__FMA__)
2493
+ for (; i + 7 < n; i += 8) {
2494
+ __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
2495
+ _mm256_set1_ps(max)));
2496
+ _mm256_storeu_ps(y + i, val);
2497
+ __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
2498
+ _mm256_castps256_ps128(val));
2499
+ val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
2500
+ val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
2501
+ sum += (ggml_float)_mm_cvtss_f32(val2);
2502
+ }
2503
+ #elif defined(__SSE2__)
2504
+ for (; i + 3 < n; i += 4) {
2505
+ __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
2506
+ _mm_set1_ps(max)));
2507
+ _mm_storeu_ps(y + i, val);
2508
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
2509
+ val = _mm_add_ps(val, _mm_movehl_ps(val, val));
2510
+ val = _mm_add_ss(val, _mm_movehdup_ps(val));
2511
+ #else
2512
+ __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
2513
+ val = _mm_add_ps(val, tmp);
2514
+ tmp = _mm_movehl_ps(tmp, val);
2515
+ val = _mm_add_ss(val, tmp);
2516
+ #endif
2517
+ sum += (ggml_float)_mm_cvtss_f32(val);
2518
+ }
2519
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
2520
+ for (; i + 3 < n; i += 4) {
2521
+ float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
2522
+ vdupq_n_f32(max)));
2523
+ vst1q_f32(y + i, val);
2524
+ sum += (ggml_float)vaddvq_f32(val);
2525
+ }
2049
2526
  #endif
2527
+ for (; i < n; ++i) {
2528
+ float val = expf(x[i] - max);
2529
+ sum += (ggml_float)val;
2530
+ y[i] = val;
2531
+ }
2532
+ return sum;
2533
+ }
2050
2534
 
2051
2535
  inline static float ggml_silu_backward_f32(float x, float dy) {
2052
2536
  const float s = 1.0f/(1.0f + expf(-x));
2053
2537
  return dy*s*(1.0f + x*(1.0f - s));
2054
2538
  }
2055
2539
 
2056
- #ifdef GGML_SILU_FP16
2057
- inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2058
- for (int i = 0; i < n; ++i) {
2059
- // we did not use x[i] to compute forward silu but its f16 equivalent
2060
- // take derivative at f16 of x[i]:
2061
- ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
2062
- float usedx = GGML_FP16_TO_FP32(fp16);
2063
- dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
2064
- }
2065
- }
2066
- #else
2067
2540
  inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
2068
2541
  for (int i = 0; i < n; ++i) {
2069
2542
  dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
2070
2543
  }
2071
2544
  }
2072
- #endif
2073
2545
 
2074
2546
  inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
2075
2547
  #ifndef GGML_USE_ACCELERATE
@@ -2185,7 +2657,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2185
2657
  "SOFT_MAX_BACK",
2186
2658
  "ROPE",
2187
2659
  "ROPE_BACK",
2188
- "ALIBI",
2189
2660
  "CLAMP",
2190
2661
  "CONV_TRANSPOSE_1D",
2191
2662
  "IM2COL",
@@ -2199,9 +2670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2199
2670
  "ARGSORT",
2200
2671
  "LEAKY_RELU",
2201
2672
 
2202
- "FLASH_ATTN",
2203
2673
  "FLASH_ATTN_EXT",
2204
- "FLASH_FF",
2205
2674
  "FLASH_ATTN_BACK",
2206
2675
  "SSM_CONV",
2207
2676
  "SSM_SCAN",
@@ -2227,7 +2696,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2227
2696
  "CROSS_ENTROPY_LOSS_BACK",
2228
2697
  };
2229
2698
 
2230
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2699
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2231
2700
 
2232
2701
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2233
2702
  "none",
@@ -2276,7 +2745,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2276
2745
  "soft_max_back(x)",
2277
2746
  "rope(x)",
2278
2747
  "rope_back(x)",
2279
- "alibi(x)",
2280
2748
  "clamp(x)",
2281
2749
  "conv_transpose_1d(x)",
2282
2750
  "im2col(x)",
@@ -2290,9 +2758,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2290
2758
  "argsort(x)",
2291
2759
  "leaky_relu(x)",
2292
2760
 
2293
- "flash_attn(x)",
2294
2761
  "flash_attn_ext(x)",
2295
- "flash_ff(x)",
2296
2762
  "flash_attn_back(x)",
2297
2763
  "ssm_conv(x)",
2298
2764
  "ssm_scan(x)",
@@ -2318,7 +2784,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2318
2784
  "cross_entropy_loss_back(x,y)",
2319
2785
  };
2320
2786
 
2321
- static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
2787
+ static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
2322
2788
 
2323
2789
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2324
2790
 
@@ -2331,6 +2797,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2331
2797
  "TANH",
2332
2798
  "ELU",
2333
2799
  "RELU",
2800
+ "SIGMOID",
2334
2801
  "GELU",
2335
2802
  "GELU_QUICK",
2336
2803
  "SILU",
@@ -2338,7 +2805,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
2338
2805
  "HARDSIGMOID",
2339
2806
  };
2340
2807
 
2341
- static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
2808
+ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
2342
2809
 
2343
2810
 
2344
2811
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2380,32 +2847,6 @@ static void ggml_setup_op_has_task_pass(void) {
2380
2847
  }
2381
2848
  }
2382
2849
 
2383
- //
2384
- // ggml context
2385
- //
2386
-
2387
- struct ggml_context {
2388
- size_t mem_size;
2389
- void * mem_buffer;
2390
- bool mem_buffer_owned;
2391
- bool no_alloc;
2392
- bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
2393
-
2394
- int n_objects;
2395
-
2396
- struct ggml_object * objects_begin;
2397
- struct ggml_object * objects_end;
2398
-
2399
- struct ggml_scratch scratch;
2400
- struct ggml_scratch scratch_save;
2401
- };
2402
-
2403
- struct ggml_context_container {
2404
- bool used;
2405
-
2406
- struct ggml_context context;
2407
- };
2408
-
2409
2850
  //
2410
2851
  // NUMA support
2411
2852
  //
@@ -2819,8 +3260,18 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
2819
3260
  (t0->ne[3] == t1->ne[3] );
2820
3261
  }
2821
3262
 
2822
- // check if t1 can be represented as a repeatition of t0
2823
- static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3263
+ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3264
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3265
+
3266
+ return
3267
+ (t0->nb[0] == t1->nb[0] ) &&
3268
+ (t0->nb[1] == t1->nb[1] ) &&
3269
+ (t0->nb[2] == t1->nb[2] ) &&
3270
+ (t0->nb[3] == t1->nb[3] );
3271
+ }
3272
+
3273
+ // check if t1 can be represented as a repeatition of t0
3274
+ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
2824
3275
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
2825
3276
 
2826
3277
  return ggml_is_empty(t0) ? ggml_is_empty(t1) :
@@ -2878,8 +3329,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
2878
3329
  float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
2879
3330
  ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
2880
3331
  ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
2881
- ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
2882
- ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
2883
3332
  }
2884
3333
 
2885
3334
  const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -3163,6 +3612,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3163
3612
 
3164
3613
  struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
3165
3614
 
3615
+ #ifdef __clang__
3616
+ // temporary until ggml_tensor::backend is removed
3617
+ #pragma clang diagnostic push
3618
+ #pragma clang diagnostic ignored "-Wdeprecated-declarations"
3619
+ #endif
3620
+
3166
3621
  *result = (struct ggml_tensor) {
3167
3622
  /*.type =*/ type,
3168
3623
  /*.backend =*/ GGML_BACKEND_TYPE_CPU,
@@ -3185,6 +3640,10 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3185
3640
  /*.padding =*/ { 0 },
3186
3641
  };
3187
3642
 
3643
+ #ifdef __clang__
3644
+ #pragma clang diagnostic pop
3645
+ #endif
3646
+
3188
3647
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
3189
3648
  //ggml_assert_aligned(result->data);
3190
3649
 
@@ -4423,10 +4882,21 @@ struct ggml_tensor * ggml_repeat_back(
4423
4882
  // ggml_concat
4424
4883
 
4425
4884
  struct ggml_tensor * ggml_concat(
4426
- struct ggml_context* ctx,
4427
- struct ggml_tensor* a,
4428
- struct ggml_tensor* b) {
4429
- GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]);
4885
+ struct ggml_context * ctx,
4886
+ struct ggml_tensor * a,
4887
+ struct ggml_tensor * b,
4888
+ int dim) {
4889
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
4890
+
4891
+ int64_t ne[GGML_MAX_DIMS];
4892
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
4893
+ if (d == dim) {
4894
+ ne[d] = a->ne[d] + b->ne[d];
4895
+ continue;
4896
+ }
4897
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
4898
+ ne[d] = a->ne[d];
4899
+ }
4430
4900
 
4431
4901
  bool is_node = false;
4432
4902
 
@@ -4434,7 +4904,9 @@ struct ggml_tensor * ggml_concat(
4434
4904
  is_node = true;
4435
4905
  }
4436
4906
 
4437
- struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]);
4907
+ struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
4908
+
4909
+ ggml_set_op_params_i32(result, 0, dim);
4438
4910
 
4439
4911
  result->op = GGML_OP_CONCAT;
4440
4912
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -4554,6 +5026,7 @@ struct ggml_tensor * ggml_leaky_relu(
4554
5026
  }
4555
5027
 
4556
5028
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5029
+
4557
5030
  ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
4558
5031
 
4559
5032
  result->op = GGML_OP_LEAKY_RELU;
@@ -4563,6 +5036,20 @@ struct ggml_tensor * ggml_leaky_relu(
4563
5036
  return result;
4564
5037
  }
4565
5038
 
5039
+ // ggml_sigmoid
5040
+
5041
+ struct ggml_tensor * ggml_sigmoid(
5042
+ struct ggml_context * ctx,
5043
+ struct ggml_tensor * a) {
5044
+ return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
5045
+ }
5046
+
5047
+ struct ggml_tensor * ggml_sigmoid_inplace(
5048
+ struct ggml_context * ctx,
5049
+ struct ggml_tensor * a) {
5050
+ return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
5051
+ }
5052
+
4566
5053
  // ggml_gelu
4567
5054
 
4568
5055
  struct ggml_tensor * ggml_gelu(
@@ -5646,7 +6133,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5646
6133
  struct ggml_context * ctx,
5647
6134
  struct ggml_tensor * a,
5648
6135
  struct ggml_tensor * mask,
5649
- struct ggml_tensor * pos,
5650
6136
  float scale,
5651
6137
  float max_bias,
5652
6138
  bool inplace) {
@@ -5660,18 +6146,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
5660
6146
  GGML_ASSERT(mask->ne[1] >= a->ne[1]);
5661
6147
  }
5662
6148
 
5663
- if (pos) {
5664
- GGML_ASSERT(ggml_is_vector(pos));
5665
- GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
5666
- GGML_ASSERT(pos->ne[0] == a->ne[0]);
5667
- }
5668
-
5669
- if (pos && mask) {
5670
- GGML_ASSERT(pos->type == mask->type);
5671
- }
5672
-
5673
6149
  if (max_bias > 0.0f) {
5674
- GGML_ASSERT(pos);
6150
+ GGML_ASSERT(mask);
5675
6151
  }
5676
6152
 
5677
6153
  bool is_node = false;
@@ -5689,7 +6165,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
5689
6165
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5690
6166
  result->src[0] = a;
5691
6167
  result->src[1] = mask;
5692
- result->src[2] = pos;
5693
6168
 
5694
6169
  return result;
5695
6170
  }
@@ -5697,23 +6172,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
5697
6172
  struct ggml_tensor * ggml_soft_max(
5698
6173
  struct ggml_context * ctx,
5699
6174
  struct ggml_tensor * a) {
5700
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
6175
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
5701
6176
  }
5702
6177
 
5703
6178
  struct ggml_tensor * ggml_soft_max_inplace(
5704
6179
  struct ggml_context * ctx,
5705
6180
  struct ggml_tensor * a) {
5706
- return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
6181
+ return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
5707
6182
  }
5708
6183
 
5709
6184
  struct ggml_tensor * ggml_soft_max_ext(
5710
6185
  struct ggml_context * ctx,
5711
6186
  struct ggml_tensor * a,
5712
6187
  struct ggml_tensor * mask,
5713
- struct ggml_tensor * pos,
5714
6188
  float scale,
5715
6189
  float max_bias) {
5716
- return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
6190
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
5717
6191
  }
5718
6192
 
5719
6193
  // ggml_soft_max_back
@@ -5759,6 +6233,7 @@ static struct ggml_tensor * ggml_rope_impl(
5759
6233
  struct ggml_context * ctx,
5760
6234
  struct ggml_tensor * a,
5761
6235
  struct ggml_tensor * b,
6236
+ struct ggml_tensor * c,
5762
6237
  int n_dims,
5763
6238
  int mode,
5764
6239
  int n_ctx,
@@ -5772,10 +6247,17 @@ static struct ggml_tensor * ggml_rope_impl(
5772
6247
  float xpos_base,
5773
6248
  bool xpos_down,
5774
6249
  bool inplace) {
6250
+ GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6251
+
5775
6252
  GGML_ASSERT(ggml_is_vector(b));
5776
6253
  GGML_ASSERT(b->type == GGML_TYPE_I32);
5777
6254
  GGML_ASSERT(a->ne[2] == b->ne[0]);
5778
6255
 
6256
+ if (c) {
6257
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
6258
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
6259
+ }
6260
+
5779
6261
  bool is_node = false;
5780
6262
 
5781
6263
  if (a->grad) {
@@ -5799,6 +6281,7 @@ static struct ggml_tensor * ggml_rope_impl(
5799
6281
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5800
6282
  result->src[0] = a;
5801
6283
  result->src[1] = b;
6284
+ result->src[2] = c;
5802
6285
 
5803
6286
  return result;
5804
6287
  }
@@ -5811,7 +6294,7 @@ struct ggml_tensor * ggml_rope(
5811
6294
  int mode,
5812
6295
  int n_ctx) {
5813
6296
  return ggml_rope_impl(
5814
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
6297
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
5815
6298
  );
5816
6299
  }
5817
6300
 
@@ -5823,14 +6306,15 @@ struct ggml_tensor * ggml_rope_inplace(
5823
6306
  int mode,
5824
6307
  int n_ctx) {
5825
6308
  return ggml_rope_impl(
5826
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
6309
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
5827
6310
  );
5828
6311
  }
5829
6312
 
5830
- struct ggml_tensor * ggml_rope_custom(
6313
+ struct ggml_tensor * ggml_rope_ext(
5831
6314
  struct ggml_context * ctx,
5832
6315
  struct ggml_tensor * a,
5833
6316
  struct ggml_tensor * b,
6317
+ struct ggml_tensor * c,
5834
6318
  int n_dims,
5835
6319
  int mode,
5836
6320
  int n_ctx,
@@ -5842,15 +6326,16 @@ struct ggml_tensor * ggml_rope_custom(
5842
6326
  float beta_fast,
5843
6327
  float beta_slow) {
5844
6328
  return ggml_rope_impl(
5845
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6329
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
5846
6330
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
5847
6331
  );
5848
6332
  }
5849
6333
 
5850
- struct ggml_tensor * ggml_rope_custom_inplace(
6334
+ struct ggml_tensor * ggml_rope_ext_inplace(
5851
6335
  struct ggml_context * ctx,
5852
6336
  struct ggml_tensor * a,
5853
6337
  struct ggml_tensor * b,
6338
+ struct ggml_tensor * c,
5854
6339
  int n_dims,
5855
6340
  int mode,
5856
6341
  int n_ctx,
@@ -5862,19 +6347,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
5862
6347
  float beta_fast,
5863
6348
  float beta_slow) {
5864
6349
  return ggml_rope_impl(
5865
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6350
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
5866
6351
  ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
5867
6352
  );
5868
6353
  }
5869
6354
 
5870
- struct ggml_tensor * ggml_rope_xpos_inplace(
6355
+ struct ggml_tensor * ggml_rope_custom(
6356
+ struct ggml_context * ctx,
6357
+ struct ggml_tensor * a,
6358
+ struct ggml_tensor * b,
6359
+ int n_dims,
6360
+ int mode,
6361
+ int n_ctx,
6362
+ int n_orig_ctx,
6363
+ float freq_base,
6364
+ float freq_scale,
6365
+ float ext_factor,
6366
+ float attn_factor,
6367
+ float beta_fast,
6368
+ float beta_slow) {
6369
+ return ggml_rope_impl(
6370
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6371
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
6372
+ );
6373
+ }
6374
+
6375
+ struct ggml_tensor * ggml_rope_custom_inplace(
5871
6376
  struct ggml_context * ctx,
5872
6377
  struct ggml_tensor * a,
5873
6378
  struct ggml_tensor * b,
5874
6379
  int n_dims,
5875
- float base,
5876
- bool down) {
5877
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
6380
+ int mode,
6381
+ int n_ctx,
6382
+ int n_orig_ctx,
6383
+ float freq_base,
6384
+ float freq_scale,
6385
+ float ext_factor,
6386
+ float attn_factor,
6387
+ float beta_fast,
6388
+ float beta_slow) {
6389
+ return ggml_rope_impl(
6390
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
6391
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
6392
+ );
5878
6393
  }
5879
6394
 
5880
6395
  // ggml_rope_back
@@ -5883,6 +6398,7 @@ struct ggml_tensor * ggml_rope_back(
5883
6398
  struct ggml_context * ctx,
5884
6399
  struct ggml_tensor * a,
5885
6400
  struct ggml_tensor * b,
6401
+ struct ggml_tensor * c,
5886
6402
  int n_dims,
5887
6403
  int mode,
5888
6404
  int n_ctx,
@@ -5898,6 +6414,7 @@ struct ggml_tensor * ggml_rope_back(
5898
6414
  GGML_ASSERT(ggml_is_vector(b));
5899
6415
  GGML_ASSERT(b->type == GGML_TYPE_I32);
5900
6416
  GGML_ASSERT(a->ne[2] == b->ne[0]);
6417
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
5901
6418
 
5902
6419
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
5903
6420
 
@@ -5928,37 +6445,6 @@ struct ggml_tensor * ggml_rope_back(
5928
6445
  return result;
5929
6446
  }
5930
6447
 
5931
- // ggml_alibi
5932
-
5933
- struct ggml_tensor * ggml_alibi(
5934
- struct ggml_context * ctx,
5935
- struct ggml_tensor * a,
5936
- int n_past,
5937
- int n_head,
5938
- float bias_max) {
5939
- GGML_ASSERT(n_past >= 0);
5940
- bool is_node = false;
5941
-
5942
- if (a->grad) {
5943
- GGML_ASSERT(false); // TODO: implement backward
5944
- is_node = true;
5945
- }
5946
-
5947
- // TODO: when implement backward, fix this:
5948
- //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5949
- struct ggml_tensor * result = ggml_view_tensor(ctx, a);
5950
-
5951
- int32_t op_params[3] = { n_past, n_head };
5952
- memcpy(op_params + 2, &bias_max, sizeof(float));
5953
- ggml_set_op_params(result, op_params, sizeof(op_params));
5954
-
5955
- result->op = GGML_OP_ALIBI;
5956
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5957
- result->src[0] = a;
5958
-
5959
- return result;
5960
- }
5961
-
5962
6448
  // ggml_clamp
5963
6449
 
5964
6450
  struct ggml_tensor * ggml_clamp(
@@ -6308,7 +6794,10 @@ struct ggml_tensor * ggml_pool_2d(
6308
6794
  static struct ggml_tensor * ggml_upscale_impl(
6309
6795
  struct ggml_context * ctx,
6310
6796
  struct ggml_tensor * a,
6311
- int scale_factor) {
6797
+ int ne0,
6798
+ int ne1,
6799
+ int ne2,
6800
+ int ne3) {
6312
6801
  bool is_node = false;
6313
6802
 
6314
6803
  if (a->grad) {
@@ -6316,19 +6805,45 @@ static struct ggml_tensor * ggml_upscale_impl(
6316
6805
  is_node = true;
6317
6806
  }
6318
6807
 
6808
+ GGML_ASSERT(a->ne[0] <= ne0);
6809
+ GGML_ASSERT(a->ne[1] <= ne1);
6810
+ GGML_ASSERT(a->ne[2] <= ne2);
6811
+ GGML_ASSERT(a->ne[3] <= ne3);
6812
+
6319
6813
  struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
6320
- a->ne[0] * scale_factor,
6321
- a->ne[1] * scale_factor,
6322
- a->ne[2], a->ne[3]);
6814
+ ne0,
6815
+ ne1,
6816
+ ne2,
6817
+ ne3
6818
+ );
6323
6819
 
6324
6820
  result->op = GGML_OP_UPSCALE;
6325
- result->op_params[0] = scale_factor;
6821
+
6326
6822
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6327
6823
  result->src[0] = a;
6328
6824
 
6329
6825
  return result;
6330
6826
  }
6331
6827
 
6828
+ struct ggml_tensor * ggml_upscale(
6829
+ struct ggml_context * ctx,
6830
+ struct ggml_tensor * a,
6831
+ int scale_factor) {
6832
+ return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
6833
+ }
6834
+
6835
+ struct ggml_tensor * ggml_upscale_ext(
6836
+ struct ggml_context * ctx,
6837
+ struct ggml_tensor * a,
6838
+ int ne0,
6839
+ int ne1,
6840
+ int ne2,
6841
+ int ne3) {
6842
+ return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
6843
+ }
6844
+
6845
+ // ggml_pad
6846
+
6332
6847
  struct ggml_tensor * ggml_pad(
6333
6848
  struct ggml_context * ctx,
6334
6849
  struct ggml_tensor * a,
@@ -6353,12 +6868,7 @@ struct ggml_tensor * ggml_pad(
6353
6868
  return result;
6354
6869
  }
6355
6870
 
6356
- struct ggml_tensor * ggml_upscale(
6357
- struct ggml_context * ctx,
6358
- struct ggml_tensor * a,
6359
- int scale_factor) {
6360
- return ggml_upscale_impl(ctx, a, scale_factor);
6361
- }
6871
+ // ggml_arange
6362
6872
 
6363
6873
  struct ggml_tensor * ggml_arange(
6364
6874
  struct ggml_context * ctx,
@@ -6380,6 +6890,8 @@ struct ggml_tensor * ggml_arange(
6380
6890
  return result;
6381
6891
  }
6382
6892
 
6893
+ // ggml_timestep_embedding
6894
+
6383
6895
  struct ggml_tensor * ggml_timestep_embedding(
6384
6896
  struct ggml_context * ctx,
6385
6897
  struct ggml_tensor * timesteps,
@@ -6446,38 +6958,6 @@ struct ggml_tensor * ggml_top_k(
6446
6958
  return result;
6447
6959
  }
6448
6960
 
6449
- // ggml_flash_attn
6450
-
6451
- struct ggml_tensor * ggml_flash_attn(
6452
- struct ggml_context * ctx,
6453
- struct ggml_tensor * q,
6454
- struct ggml_tensor * k,
6455
- struct ggml_tensor * v,
6456
- bool masked) {
6457
- GGML_ASSERT(ggml_can_mul_mat(k, q));
6458
- // TODO: check if vT can be multiplied by (k*qT)
6459
-
6460
- bool is_node = false;
6461
-
6462
- if (q->grad || k->grad || v->grad) {
6463
- is_node = true;
6464
- }
6465
-
6466
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
6467
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne);
6468
-
6469
- int32_t t = masked ? 1 : 0;
6470
- ggml_set_op_params(result, &t, sizeof(t));
6471
-
6472
- result->op = GGML_OP_FLASH_ATTN;
6473
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6474
- result->src[0] = q;
6475
- result->src[1] = k;
6476
- result->src[2] = v;
6477
-
6478
- return result;
6479
- }
6480
-
6481
6961
  // ggml_flash_attn_ext
6482
6962
 
6483
6963
  struct ggml_tensor * ggml_flash_attn_ext(
@@ -6486,9 +6966,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
6486
6966
  struct ggml_tensor * k,
6487
6967
  struct ggml_tensor * v,
6488
6968
  struct ggml_tensor * mask,
6489
- float scale) {
6969
+ float scale,
6970
+ float max_bias) {
6490
6971
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6491
6972
  // TODO: check if vT can be multiplied by (k*qT)
6973
+
6492
6974
  if (mask) {
6493
6975
  GGML_ASSERT(ggml_is_contiguous(mask));
6494
6976
  GGML_ASSERT(mask->ne[2] == 1);
@@ -6498,6 +6980,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
6498
6980
  //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
6499
6981
  }
6500
6982
 
6983
+ if (max_bias > 0.0f) {
6984
+ GGML_ASSERT(mask);
6985
+ }
6986
+
6501
6987
  bool is_node = false;
6502
6988
 
6503
6989
  if (q->grad || k->grad || v->grad) {
@@ -6508,7 +6994,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
6508
6994
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
6509
6995
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
6510
6996
 
6511
- float params[] = { scale };
6997
+ float params[] = { scale, max_bias };
6512
6998
  ggml_set_op_params(result, params, sizeof(params));
6513
6999
 
6514
7000
  result->op = GGML_OP_FLASH_ATTN_EXT;
@@ -6528,39 +7014,7 @@ void ggml_flash_attn_ext_set_prec(
6528
7014
 
6529
7015
  const int32_t prec_i32 = (int32_t) prec;
6530
7016
 
6531
- ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
6532
- }
6533
-
6534
- // ggml_flash_ff
6535
-
6536
- struct ggml_tensor * ggml_flash_ff(
6537
- struct ggml_context * ctx,
6538
- struct ggml_tensor * a,
6539
- struct ggml_tensor * b0,
6540
- struct ggml_tensor * b1,
6541
- struct ggml_tensor * c0,
6542
- struct ggml_tensor * c1) {
6543
- GGML_ASSERT(ggml_can_mul_mat(b0, a));
6544
- // TODO: more checks
6545
-
6546
- bool is_node = false;
6547
-
6548
- if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
6549
- is_node = true;
6550
- }
6551
-
6552
- //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
6553
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne);
6554
-
6555
- result->op = GGML_OP_FLASH_FF;
6556
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6557
- result->src[0] = a;
6558
- result->src[1] = b0;
6559
- result->src[2] = b1;
6560
- result->src[3] = c0;
6561
- result->src[4] = c1;
6562
-
6563
- return result;
7017
+ ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
6564
7018
  }
6565
7019
 
6566
7020
  // ggml_flash_attn_back
@@ -6572,6 +7026,8 @@ struct ggml_tensor * ggml_flash_attn_back(
6572
7026
  struct ggml_tensor * v,
6573
7027
  struct ggml_tensor * d,
6574
7028
  bool masked) {
7029
+ GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes");
7030
+
6575
7031
  GGML_ASSERT(ggml_can_mul_mat(k, q));
6576
7032
  // TODO: check if vT can be multiplied by (k*qT)
6577
7033
 
@@ -10525,26 +10981,29 @@ static void ggml_compute_forward_concat_f32(
10525
10981
  GGML_ASSERT(nb00 == sizeof(float));
10526
10982
  GGML_ASSERT(nb10 == sizeof(float));
10527
10983
 
10984
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
10985
+
10986
+ GGML_ASSERT(dim >= 0 && dim < 4);
10987
+
10988
+ int64_t o[4] = {0, 0, 0, 0};
10989
+ o[dim] = src0->ne[dim];
10990
+
10991
+ const float * x;
10992
+
10993
+ // TODO: smarter multi-theading
10528
10994
  for (int i3 = 0; i3 < ne3; i3++) {
10529
10995
  for (int i2 = ith; i2 < ne2; i2 += nth) {
10530
- if (i2 < ne02) { // src0
10531
- for (int i1 = 0; i1 < ne1; i1++) {
10532
- for (int i0 = 0; i0 < ne0; i0++) {
10533
- const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03);
10534
-
10535
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10536
- *y = *x;
10537
- }
10538
- }
10539
- } // src1
10540
- else {
10541
- for (int i1 = 0; i1 < ne1; i1++) {
10542
- for (int i0 = 0; i0 < ne0; i0++) {
10543
- const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13);
10544
-
10545
- float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3);
10546
- *y = *x;
10996
+ for (int i1 = 0; i1 < ne1; i1++) {
10997
+ for (int i0 = 0; i0 < ne0; i0++) {
10998
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
10999
+ x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
11000
+ } else {
11001
+ x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
10547
11002
  }
11003
+
11004
+ float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
11005
+
11006
+ *y = *x;
10548
11007
  }
10549
11008
  }
10550
11009
  }
@@ -10552,7 +11011,7 @@ static void ggml_compute_forward_concat_f32(
10552
11011
  }
10553
11012
 
10554
11013
  static void ggml_compute_forward_concat(
10555
- const struct ggml_compute_params* params,
11014
+ const struct ggml_compute_params * params,
10556
11015
  struct ggml_tensor* dst) {
10557
11016
 
10558
11017
  const struct ggml_tensor * src0 = dst->src[0];
@@ -10892,6 +11351,52 @@ static void ggml_compute_forward_relu(
10892
11351
  }
10893
11352
  }
10894
11353
 
11354
+ // ggml_compute_forward_sigmoid
11355
+
11356
+ static void ggml_compute_forward_sigmoid_f32(
11357
+ const struct ggml_compute_params * params,
11358
+ struct ggml_tensor * dst) {
11359
+
11360
+ const struct ggml_tensor * src0 = dst->src[0];
11361
+
11362
+ assert(params->ith == 0);
11363
+ assert(ggml_are_same_shape(src0, dst));
11364
+
11365
+ if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
11366
+ return;
11367
+ }
11368
+
11369
+ const int n = ggml_nrows(src0);
11370
+ const int nc = src0->ne[0];
11371
+
11372
+ assert(dst->nb[0] == sizeof(float));
11373
+ assert(src0->nb[0] == sizeof(float));
11374
+
11375
+ for (int i = 0; i < n; i++) {
11376
+ ggml_vec_sigmoid_f32(nc,
11377
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
11378
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
11379
+ }
11380
+ }
11381
+
11382
+ static void ggml_compute_forward_sigmoid(
11383
+ const struct ggml_compute_params * params,
11384
+ struct ggml_tensor * dst) {
11385
+
11386
+ const struct ggml_tensor * src0 = dst->src[0];
11387
+
11388
+ switch (src0->type) {
11389
+ case GGML_TYPE_F32:
11390
+ {
11391
+ ggml_compute_forward_sigmoid_f32(params, dst);
11392
+ } break;
11393
+ default:
11394
+ {
11395
+ GGML_ASSERT(false);
11396
+ } break;
11397
+ }
11398
+ }
11399
+
10895
11400
  // ggml_compute_forward_gelu
10896
11401
 
10897
11402
  static void ggml_compute_forward_gelu_f32(
@@ -11742,50 +12247,141 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
11742
12247
  }
11743
12248
  #endif
11744
12249
 
11745
- static void ggml_compute_forward_mul_mat(
11746
- const struct ggml_compute_params * params,
11747
- struct ggml_tensor * dst) {
12250
+ static void ggml_compute_forward_mul_mat_one_chunk(
12251
+ const struct ggml_compute_params * params,
12252
+ struct ggml_tensor * dst,
12253
+ const int64_t num_rows_per_vec_dot,
12254
+ const int64_t ir0_start,
12255
+ const int64_t ir0_end,
12256
+ const int64_t ir1_start,
12257
+ const int64_t ir1_end) {
11748
12258
 
11749
12259
  const struct ggml_tensor * src0 = dst->src[0];
11750
12260
  const struct ggml_tensor * src1 = dst->src[1];
11751
12261
 
11752
- int64_t t0 = ggml_perf_time_us();
11753
- UNUSED(t0);
11754
-
11755
12262
  GGML_TENSOR_BINARY_OP_LOCALS
11756
12263
 
11757
- const int ith = params->ith;
11758
- const int nth = params->nth;
11759
-
11760
12264
  const enum ggml_type type = src0->type;
11761
12265
 
11762
12266
  const bool src1_cont = ggml_is_contiguous(src1);
11763
12267
 
11764
- ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
11765
- enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11766
- ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
11767
- int64_t const vec_dot_num_rows = type_traits[type].nrows;
12268
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
12269
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11768
12270
 
11769
- GGML_ASSERT(ne0 == ne01);
11770
- GGML_ASSERT(ne1 == ne11);
11771
- GGML_ASSERT(ne2 == ne12);
11772
- GGML_ASSERT(ne3 == ne13);
12271
+ // broadcast factors
12272
+ const int64_t r2 = ne12 / ne02;
12273
+ const int64_t r3 = ne13 / ne03;
11773
12274
 
11774
- // we don't support permuted src0 or src1
11775
- GGML_ASSERT(nb00 == ggml_type_size(type));
11776
- GGML_ASSERT(nb10 == ggml_type_size(src1->type));
12275
+ //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
11777
12276
 
11778
- // dst cannot be transposed or permuted
11779
- GGML_ASSERT(nb0 == sizeof(float));
11780
- GGML_ASSERT(nb0 <= nb1);
11781
- GGML_ASSERT(nb1 <= nb2);
11782
- GGML_ASSERT(nb2 <= nb3);
12277
+ // threads with no work simply yield (not sure if it helps)
12278
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
12279
+ return;
12280
+ }
11783
12281
 
11784
- // broadcast factors
11785
- const int64_t r2 = ne12/ne02;
11786
- const int64_t r3 = ne13/ne03;
12282
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12283
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11787
12284
 
11788
- // nb01 >= nb00 - src0 is not transposed
12285
+ assert(ne12 % ne02 == 0);
12286
+ assert(ne13 % ne03 == 0);
12287
+
12288
+ // block-tiling attempt
12289
+ const int64_t blck_0 = 16;
12290
+ const int64_t blck_1 = 16;
12291
+
12292
+ const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12293
+
12294
+ // attempt to reduce false-sharing (does not seem to make a difference)
12295
+ // 16 * 2, accounting for mmla kernels
12296
+ float tmp[32];
12297
+
12298
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
12299
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
12300
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
12301
+ const int64_t i13 = (ir1 / (ne12 * ne1));
12302
+ const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
12303
+ const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
12304
+
12305
+ // broadcast src0 into src1
12306
+ const int64_t i03 = i13 / r3;
12307
+ const int64_t i02 = i12 / r2;
12308
+
12309
+ const int64_t i1 = i11;
12310
+ const int64_t i2 = i12;
12311
+ const int64_t i3 = i13;
12312
+
12313
+ const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
12314
+
12315
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12316
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12317
+ // the original src1 data pointer, so we should index using the indices directly
12318
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
12319
+ const char * src1_col = (const char*)wdata +
12320
+ (src1_cont || src1->type != vec_dot_type
12321
+ ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
12322
+ : (i11 * nb11 + i12 * nb12 + i13 * nb13));
12323
+ float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
12324
+
12325
+ //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
12326
+ // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12327
+ //}
12328
+
12329
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
12330
+ vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
12331
+ }
12332
+
12333
+ for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
12334
+ memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
12335
+ }
12336
+ }
12337
+ }
12338
+ }
12339
+ }
12340
+
12341
+ static void ggml_compute_forward_mul_mat(
12342
+ const struct ggml_compute_params * params,
12343
+ struct ggml_tensor * dst,
12344
+ struct ggml_compute_state * state) {
12345
+
12346
+ const struct ggml_tensor * src0 = dst->src[0];
12347
+ const struct ggml_tensor * src1 = dst->src[1];
12348
+
12349
+ int64_t t0 = ggml_perf_time_us();
12350
+ UNUSED(t0);
12351
+
12352
+ GGML_TENSOR_BINARY_OP_LOCALS
12353
+
12354
+ const int ith = params->ith;
12355
+ const int nth = params->nth;
12356
+
12357
+ const enum ggml_type type = src0->type;
12358
+
12359
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
12360
+ ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
12361
+ int64_t const vec_dot_num_rows = type_traits[type].nrows;
12362
+
12363
+ GGML_ASSERT(ne0 == ne01);
12364
+ GGML_ASSERT(ne1 == ne11);
12365
+ GGML_ASSERT(ne2 == ne12);
12366
+ GGML_ASSERT(ne3 == ne13);
12367
+
12368
+ // we don't support permuted src0 or src1
12369
+ GGML_ASSERT(nb00 == ggml_type_size(type));
12370
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
12371
+
12372
+ // dst cannot be transposed or permuted
12373
+ GGML_ASSERT(nb0 == sizeof(float));
12374
+ GGML_ASSERT(nb0 <= nb1);
12375
+ GGML_ASSERT(nb1 <= nb2);
12376
+ GGML_ASSERT(nb2 <= nb3);
12377
+
12378
+ // broadcast factors
12379
+ const int64_t r2 = ne12 / ne02;
12380
+ const int64_t r3 = ne13 / ne03;
12381
+ UNUSED(r2);
12382
+ UNUSED(r3);
12383
+
12384
+ // nb01 >= nb00 - src0 is not transposed
11789
12385
  // compute by src0 rows
11790
12386
 
11791
12387
  #if defined(GGML_USE_CLBLAST)
@@ -11865,6 +12461,8 @@ static void ggml_compute_forward_mul_mat(
11865
12461
  #endif
11866
12462
 
11867
12463
  #if GGML_USE_LLAMAFILE
12464
+ const bool src1_cont = ggml_is_contiguous(src1);
12465
+
11868
12466
  if (src1_cont) {
11869
12467
  for (int64_t i13 = 0; i13 < ne13; i13++)
11870
12468
  for (int64_t i12 = 0; i12 < ne12; i12++)
@@ -11890,6 +12488,8 @@ UseGgmlGemm1:;
11890
12488
  if (ith != 0) {
11891
12489
  return;
11892
12490
  }
12491
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
12492
+ atomic_store(&state->shared->current_chunk, nth);
11893
12493
  if (src1->type != vec_dot_type) {
11894
12494
  char * wdata = params->wdata;
11895
12495
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -11914,11 +12514,11 @@ UseGgmlGemm1:;
11914
12514
  return;
11915
12515
  }
11916
12516
 
11917
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11918
- const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11919
-
11920
12517
  #if GGML_USE_LLAMAFILE
11921
12518
  if (src1->type != vec_dot_type) {
12519
+ const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
12520
+ const size_t row_size = ggml_row_size(vec_dot_type, ne10);
12521
+
11922
12522
  for (int64_t i13 = 0; i13 < ne13; i13++)
11923
12523
  for (int64_t i12 = 0; i12 < ne12; i12++)
11924
12524
  if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
@@ -11939,98 +12539,87 @@ UseGgmlGemm1:;
11939
12539
  UseGgmlGemm2:;
11940
12540
  #endif
11941
12541
 
11942
- const int64_t nr0 = ne01; // src0 rows
11943
- const int64_t nr1 = ne1*ne12*ne13; // src1 rows
11944
-
11945
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
11946
-
11947
- // distribute the thread work across the inner or outer loop based on which one is larger
11948
-
11949
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
11950
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
11951
-
11952
- const int64_t ith0 = ith % nth0;
11953
- const int64_t ith1 = ith / nth0;
11954
-
11955
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
11956
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
11957
-
11958
- const int64_t ir010 = dr0*ith0;
11959
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
11960
-
11961
- const int64_t ir110 = dr1*ith1;
11962
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
11963
-
11964
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
11965
-
11966
- // threads with no work simply yield (not sure if it helps)
11967
- if (ir010 >= ir011 || ir110 >= ir111) {
11968
- sched_yield();
11969
- return;
11970
- }
12542
+ #ifdef GGML_PERF
12543
+ int chunks_executed = 0;
12544
+ UNUSED(chunks_executed);
12545
+ #endif
11971
12546
 
11972
- assert(ne12 % ne02 == 0);
11973
- assert(ne13 % ne03 == 0);
12547
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
12548
+ const int64_t nr0 = ne0;
11974
12549
 
11975
- // block-tiling attempt
11976
- const int64_t blck_0 = 16;
11977
- const int64_t blck_1 = 16;
12550
+ // This is the size of the rest of the dimensions of the result
12551
+ const int64_t nr1 = ne1 * ne2 * ne3;
11978
12552
 
11979
12553
  // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
11980
- int64_t nrc = vec_dot_num_rows;
12554
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
11981
12555
  // TODO: currently the mmla kernels support only even numbered rows/cols.
11982
12556
  // this check can be removed once they are extended to support odd numbered rows/cols too
11983
12557
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
11984
- nrc = 1;
12558
+ num_rows_per_vec_dot = 1;
11985
12559
  }
11986
12560
 
11987
- const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
12561
+ // Now select a reasonable chunk size.
12562
+ int chunk_size = 16;
11988
12563
 
11989
- // attempt to reduce false-sharing (does not seem to make a difference)
11990
- // 16 * 2, accounting for mmla kernels
11991
- float tmp[32];
12564
+ // We need to step up the size if it's small
12565
+ if (nr0 == 1 || nr1 == 1) {
12566
+ chunk_size = 64;
12567
+ }
11992
12568
 
11993
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
11994
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
11995
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
11996
- const int64_t i13 = (ir1/(ne12*ne1));
11997
- const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
11998
- const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
12569
+ // distribute the work across the inner or outer loop based on which one is larger
12570
+ // The number of chunks in the 0/1 dim.
12571
+ // CEIL(nr0/chunk_size)
12572
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
12573
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
11999
12574
 
12000
- // broadcast src0 into src1
12001
- const int64_t i03 = i13/r3;
12002
- const int64_t i02 = i12/r2;
12575
+ // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
12576
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
12577
+ // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
12578
+ if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
12579
+ // distribute the thread work across the inner or outer loop based on which one is larger
12580
+ nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
12581
+ nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
12582
+ }
12003
12583
 
12004
- const int64_t i1 = i11;
12005
- const int64_t i2 = i12;
12006
- const int64_t i3 = i13;
12584
+ // The number of elements in each chunk
12585
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
12586
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
12007
12587
 
12008
- const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
12588
+ //if (ith == 0)
12589
+ // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
12009
12590
 
12010
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
12011
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
12012
- // the original src1 data pointer, so we should index using the indices directly
12013
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
12014
- const char * src1_col = (const char *) wdata +
12015
- (src1_cont || src1->type != vec_dot_type
12016
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
12017
- : (i11*nb11 + i12*nb12 + i13*nb13));
12018
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
12591
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
12592
+ int current_chunk = ith;
12019
12593
 
12020
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
12021
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
12022
- //}
12594
+ while (current_chunk < nchunk0 * nchunk1) {
12595
+ const int64_t ith0 = current_chunk % nchunk0;
12596
+ const int64_t ith1 = current_chunk / nchunk0;
12023
12597
 
12024
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
12025
- vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
12026
- }
12598
+ const int64_t ir0_start = dr0 * ith0;
12599
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
12027
12600
 
12028
- for (int cn = 0; cn < nrc; ++cn) {
12029
- memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
12030
- }
12031
- }
12601
+ const int64_t ir1_start = dr1 * ith1;
12602
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
12603
+
12604
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
12605
+
12606
+ #ifdef GGML_PERF
12607
+ chunks_executed++;
12608
+ #endif
12609
+
12610
+ if (nth >= nchunk0 * nchunk1) {
12611
+ break;
12032
12612
  }
12613
+
12614
+ current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
12033
12615
  }
12616
+
12617
+ #ifdef GGML_PERF
12618
+ // These numbers are useful when trying to measure how well the threading scheduling works.
12619
+ //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
12620
+ //float time = (ggml_perf_time_us() - t0);
12621
+ //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
12622
+ #endif
12034
12623
  }
12035
12624
 
12036
12625
  // ggml_compute_forward_mul_mat_id
@@ -13333,7 +13922,6 @@ static void ggml_compute_forward_soft_max_f32(
13333
13922
 
13334
13923
  const struct ggml_tensor * src0 = dst->src[0];
13335
13924
  const struct ggml_tensor * src1 = dst->src[1];
13336
- const struct ggml_tensor * src2 = dst->src[2];
13337
13925
 
13338
13926
  assert(ggml_is_contiguous(dst));
13339
13927
  assert(ggml_are_same_shape(src0, dst));
@@ -13359,8 +13947,8 @@ static void ggml_compute_forward_soft_max_f32(
13359
13947
 
13360
13948
  // TODO: is this supposed to be ceil instead of floor?
13361
13949
  // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
13362
- const uint32_t n_head_kv = ne02;
13363
- const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
13950
+ const uint32_t n_head = ne02;
13951
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
13364
13952
 
13365
13953
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
13366
13954
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -13377,13 +13965,13 @@ static void ggml_compute_forward_soft_max_f32(
13377
13965
 
13378
13966
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
13379
13967
 
13380
- // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
13381
- ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
13382
- float * pos_f32 = src2 ? (float *) src2->data : src0->data;
13383
-
13384
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
13968
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
13385
13969
 
13386
13970
  for (int i1 = ir0; i1 < ir1; i1++) {
13971
+ // ALiBi
13972
+ const uint32_t h = (i1/ne01)%ne02; // head
13973
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
13974
+
13387
13975
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
13388
13976
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
13389
13977
 
@@ -13396,27 +13984,11 @@ static void ggml_compute_forward_soft_max_f32(
13396
13984
  if (mp_f32) {
13397
13985
  if (use_f16) {
13398
13986
  for (int i = 0; i < nc; ++i) {
13399
- wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
13400
- }
13401
- } else {
13402
- for (int i = 0; i < nc; ++i) {
13403
- wp[i] += mp_f32[i];
13404
- }
13405
- }
13406
- }
13407
-
13408
- // ALiBi bias
13409
- if (max_bias > 0.0f) {
13410
- const uint32_t h = (i1/ne01)%ne02; // head
13411
- const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
13412
-
13413
- if (use_f16) {
13414
- for (int i = 0; i < nc; ++i) {
13415
- wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
13987
+ wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
13416
13988
  }
13417
13989
  } else {
13418
13990
  for (int i = 0; i < nc; ++i) {
13419
- wp[i] += slope*pos_f32[i];
13991
+ wp[i] += slope*mp_f32[i];
13420
13992
  }
13421
13993
  }
13422
13994
  }
@@ -13431,22 +14003,7 @@ static void ggml_compute_forward_soft_max_f32(
13431
14003
  float max = -INFINITY;
13432
14004
  ggml_vec_max_f32(nc, &max, wp);
13433
14005
 
13434
- ggml_float sum = 0.0;
13435
-
13436
- uint16_t scvt;
13437
- for (int i = 0; i < nc; i++) {
13438
- if (wp[i] == -INFINITY) {
13439
- dp[i] = 0.0f;
13440
- } else {
13441
- // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
13442
- ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
13443
- memcpy(&scvt, &s, sizeof(scvt));
13444
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
13445
- sum += (ggml_float)val;
13446
- dp[i] = val;
13447
- }
13448
- }
13449
-
14006
+ ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
13450
14007
  assert(sum > 0.0);
13451
14008
 
13452
14009
  sum = 1.0/sum;
@@ -13578,68 +14135,9 @@ static void ggml_compute_forward_soft_max_back(
13578
14135
  }
13579
14136
  }
13580
14137
 
13581
- // ggml_compute_forward_alibi
13582
-
13583
- static void ggml_compute_forward_alibi_f32(
13584
- const struct ggml_compute_params * params,
13585
- struct ggml_tensor * dst) {
13586
-
13587
- const struct ggml_tensor * src0 = dst->src[0];
13588
-
13589
- assert(params->ith == 0);
13590
-
13591
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13592
- return;
13593
- }
13594
-
13595
- //const int n_past = ((int32_t *) dst->op_params)[0];
13596
- const int n_head = ((int32_t *) dst->op_params)[1];
13597
- float max_bias;
13598
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
13599
-
13600
- const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13601
- const int64_t ne1 = src0->ne[1]; // seq_len_without_past
13602
- const int64_t ne2 = src0->ne[2]; // n_head -> this is k
13603
- //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
13604
-
13605
- const int64_t n = ggml_nrows(src0);
13606
- const int64_t ne2_ne3 = n/ne1; // ne2*ne3
13607
-
13608
- const size_t nb0 = src0->nb[0];
13609
- const size_t nb1 = src0->nb[1];
13610
- const size_t nb2 = src0->nb[2];
13611
- //const int nb3 = src0->nb[3];
13612
-
13613
- GGML_ASSERT(nb0 == sizeof(float));
13614
- GGML_ASSERT(n_head == ne2);
13615
-
13616
- // add alibi to src0 (KQ_scaled)
13617
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
13618
-
13619
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
13620
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
13621
-
13622
- for (int64_t k = 0; k < ne2_ne3; k++) {
13623
- // TODO: k*nb2 or k*nb3
13624
- float m_k;
13625
-
13626
- if (k < n_heads_log2_floor) {
13627
- m_k = powf(m0, k + 1);
13628
- } else {
13629
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
13630
- }
13631
-
13632
- for (int64_t i = 0; i < ne0; i++) {
13633
- for (int64_t j = 0; j < ne1; j++) {
13634
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
13635
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
13636
- pdst[0] = i * m_k + src[0];
13637
- }
13638
- }
13639
- }
13640
- }
14138
+ // ggml_compute_forward_clamp
13641
14139
 
13642
- static void ggml_compute_forward_alibi_f16(
14140
+ static void ggml_compute_forward_clamp_f32(
13643
14141
  const struct ggml_compute_params * params,
13644
14142
  struct ggml_tensor * dst) {
13645
14143
 
@@ -13651,71 +14149,48 @@ static void ggml_compute_forward_alibi_f16(
13651
14149
  return;
13652
14150
  }
13653
14151
 
13654
- //const int n_past = ((int32_t *) dst->op_params)[0];
13655
- const int n_head = ((int32_t *) dst->op_params)[1];
13656
- float max_bias;
13657
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
14152
+ float min;
14153
+ float max;
14154
+ memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
14155
+ memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
13658
14156
 
13659
- const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
13660
- const int ne1 = src0->ne[1]; // seq_len_without_past
13661
- const int ne2 = src0->ne[2]; // n_head -> this is k
13662
- //const int ne3 = src0->ne[3]; // 1 -> bsz
14157
+ const int ith = params->ith;
14158
+ const int nth = params->nth;
13663
14159
 
13664
14160
  const int n = ggml_nrows(src0);
13665
- const int ne2_ne3 = n/ne1; // ne2*ne3
13666
-
13667
- const int nb0 = src0->nb[0];
13668
- const int nb1 = src0->nb[1];
13669
- const int nb2 = src0->nb[2];
13670
- //const int nb3 = src0->nb[3];
13671
-
13672
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
13673
- //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
13674
- GGML_ASSERT(n_head == ne2);
13675
-
13676
- // add alibi to src0 (KQ_scaled)
13677
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
14161
+ const int nc = src0->ne[0];
13678
14162
 
13679
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
13680
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
14163
+ const size_t nb00 = src0->nb[0];
14164
+ const size_t nb01 = src0->nb[1];
13681
14165
 
13682
- for (int k = 0; k < ne2_ne3; k++) {
13683
- // TODO: k*nb2 or k*nb3
13684
- float m_k;
14166
+ const size_t nb0 = dst->nb[0];
14167
+ const size_t nb1 = dst->nb[1];
13685
14168
 
13686
- if (k < n_heads_log2_floor) {
13687
- m_k = powf(m0, k + 1);
13688
- } else {
13689
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
13690
- }
14169
+ GGML_ASSERT( nb0 == sizeof(float));
14170
+ GGML_ASSERT(nb00 == sizeof(float));
13691
14171
 
13692
- for (int i = 0; i < ne0; i++) {
13693
- for (int j = 0; j < ne1; j++) {
13694
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
13695
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
14172
+ for (int j = ith; j < n; j += nth) {
14173
+ float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
14174
+ float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
13696
14175
 
13697
- // we return F32
13698
- pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
13699
- }
14176
+ for (int i = 0; i < nc; i++) {
14177
+ dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
13700
14178
  }
13701
14179
  }
13702
14180
  }
13703
14181
 
13704
- static void ggml_compute_forward_alibi(
14182
+ static void ggml_compute_forward_clamp(
13705
14183
  const struct ggml_compute_params * params,
13706
14184
  struct ggml_tensor * dst) {
13707
14185
 
13708
14186
  const struct ggml_tensor * src0 = dst->src[0];
13709
14187
 
13710
14188
  switch (src0->type) {
13711
- case GGML_TYPE_F16:
13712
- {
13713
- ggml_compute_forward_alibi_f16(params, dst);
13714
- } break;
13715
14189
  case GGML_TYPE_F32:
13716
14190
  {
13717
- ggml_compute_forward_alibi_f32(params, dst);
14191
+ ggml_compute_forward_clamp_f32(params, dst);
13718
14192
  } break;
14193
+ case GGML_TYPE_F16:
13719
14194
  case GGML_TYPE_BF16:
13720
14195
  case GGML_TYPE_Q4_0:
13721
14196
  case GGML_TYPE_Q4_1:
@@ -13750,128 +14225,38 @@ static void ggml_compute_forward_alibi(
13750
14225
  }
13751
14226
  }
13752
14227
 
13753
- // ggml_compute_forward_clamp
13754
-
13755
- static void ggml_compute_forward_clamp_f32(
13756
- const struct ggml_compute_params * params,
13757
- struct ggml_tensor * dst) {
14228
+ // ggml_compute_forward_rope
13758
14229
 
13759
- const struct ggml_tensor * src0 = dst->src[0];
14230
+ static float rope_yarn_ramp(const float low, const float high, const int i0) {
14231
+ const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
14232
+ return 1 - MIN(1, MAX(0, y));
14233
+ }
13760
14234
 
13761
- assert(params->ith == 0);
14235
+ // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
14236
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
14237
+ static void rope_yarn(
14238
+ float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
14239
+ float * cos_theta, float * sin_theta
14240
+ ) {
14241
+ // Get n-d rotational scaling corrected for extrapolation
14242
+ float theta_interp = freq_scale * theta_extrap;
14243
+ float theta = theta_interp;
14244
+ if (ext_factor != 0.0f) {
14245
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
14246
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
13762
14247
 
13763
- if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13764
- return;
14248
+ // Get n-d magnitude scaling corrected for interpolation
14249
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
13765
14250
  }
14251
+ *cos_theta = cosf(theta) * mscale;
14252
+ *sin_theta = sinf(theta) * mscale;
14253
+ }
13766
14254
 
13767
- float min;
13768
- float max;
13769
- memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
13770
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
13771
-
13772
- const int ith = params->ith;
13773
- const int nth = params->nth;
13774
-
13775
- const int n = ggml_nrows(src0);
13776
- const int nc = src0->ne[0];
13777
-
13778
- const size_t nb00 = src0->nb[0];
13779
- const size_t nb01 = src0->nb[1];
13780
-
13781
- const size_t nb0 = dst->nb[0];
13782
- const size_t nb1 = dst->nb[1];
13783
-
13784
- GGML_ASSERT( nb0 == sizeof(float));
13785
- GGML_ASSERT(nb00 == sizeof(float));
13786
-
13787
- for (int j = ith; j < n; j += nth) {
13788
- float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
13789
- float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
13790
-
13791
- for (int i = 0; i < nc; i++) {
13792
- dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
13793
- }
13794
- }
13795
- }
13796
-
13797
- static void ggml_compute_forward_clamp(
13798
- const struct ggml_compute_params * params,
13799
- struct ggml_tensor * dst) {
13800
-
13801
- const struct ggml_tensor * src0 = dst->src[0];
13802
-
13803
- switch (src0->type) {
13804
- case GGML_TYPE_F32:
13805
- {
13806
- ggml_compute_forward_clamp_f32(params, dst);
13807
- } break;
13808
- case GGML_TYPE_F16:
13809
- case GGML_TYPE_BF16:
13810
- case GGML_TYPE_Q4_0:
13811
- case GGML_TYPE_Q4_1:
13812
- case GGML_TYPE_Q5_0:
13813
- case GGML_TYPE_Q5_1:
13814
- case GGML_TYPE_Q8_0:
13815
- case GGML_TYPE_Q8_1:
13816
- case GGML_TYPE_Q2_K:
13817
- case GGML_TYPE_Q3_K:
13818
- case GGML_TYPE_Q4_K:
13819
- case GGML_TYPE_Q5_K:
13820
- case GGML_TYPE_Q6_K:
13821
- case GGML_TYPE_IQ2_XXS:
13822
- case GGML_TYPE_IQ2_XS:
13823
- case GGML_TYPE_IQ3_XXS:
13824
- case GGML_TYPE_IQ1_S:
13825
- case GGML_TYPE_IQ1_M:
13826
- case GGML_TYPE_IQ4_NL:
13827
- case GGML_TYPE_IQ4_XS:
13828
- case GGML_TYPE_IQ3_S:
13829
- case GGML_TYPE_IQ2_S:
13830
- case GGML_TYPE_Q8_K:
13831
- case GGML_TYPE_I8:
13832
- case GGML_TYPE_I16:
13833
- case GGML_TYPE_I32:
13834
- case GGML_TYPE_I64:
13835
- case GGML_TYPE_F64:
13836
- case GGML_TYPE_COUNT:
13837
- {
13838
- GGML_ASSERT(false);
13839
- } break;
13840
- }
13841
- }
13842
-
13843
- // ggml_compute_forward_rope
13844
-
13845
- static float rope_yarn_ramp(const float low, const float high, const int i0) {
13846
- const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
13847
- return 1 - MIN(1, MAX(0, y));
13848
- }
13849
-
13850
- // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
13851
- // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
13852
- static void rope_yarn(
13853
- float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
13854
- float * cos_theta, float * sin_theta
13855
- ) {
13856
- // Get n-d rotational scaling corrected for extrapolation
13857
- float theta_interp = freq_scale * theta_extrap;
13858
- float theta = theta_interp;
13859
- if (ext_factor != 0.0f) {
13860
- float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
13861
- theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
13862
-
13863
- // Get n-d magnitude scaling corrected for interpolation
13864
- mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
13865
- }
13866
- *cos_theta = cosf(theta) * mscale;
13867
- *sin_theta = sinf(theta) * mscale;
13868
- }
13869
-
13870
- // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
13871
- // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
13872
- static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
13873
- return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
13874
- }
14255
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
14256
+ // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
14257
+ static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
14258
+ return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
14259
+ }
13875
14260
 
13876
14261
  static void ggml_rope_cache_init(
13877
14262
  float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
@@ -13905,6 +14290,7 @@ static void ggml_compute_forward_rope_f32(
13905
14290
 
13906
14291
  const struct ggml_tensor * src0 = dst->src[0];
13907
14292
  const struct ggml_tensor * src1 = dst->src[1];
14293
+ const struct ggml_tensor * src2 = dst->src[2];
13908
14294
 
13909
14295
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
13910
14296
  return;
@@ -13964,6 +14350,17 @@ static void ggml_compute_forward_rope_f32(
13964
14350
  const bool is_neox = mode & 2;
13965
14351
  const bool is_glm = mode & 4;
13966
14352
 
14353
+ const float * freq_factors = NULL;
14354
+ if (is_neox) {
14355
+ if (src2 != NULL) {
14356
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14357
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14358
+ freq_factors = (const float *) src2->data;
14359
+ }
14360
+ } else {
14361
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14362
+ }
14363
+
13967
14364
  // backward process uses inverse rotation by cos and sin.
13968
14365
  // cos and sin build a rotation matrix, where the inverse is the transpose.
13969
14366
  // this essentially just switches the sign of sin.
@@ -14040,10 +14437,11 @@ static void ggml_compute_forward_rope_f32(
14040
14437
 
14041
14438
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14042
14439
  float cur_rot = inv_ndims * ic - ib;
14440
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14043
14441
 
14044
14442
  float cos_theta, sin_theta;
14045
14443
  rope_yarn(
14046
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14444
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14047
14445
  &cos_theta, &sin_theta
14048
14446
  );
14049
14447
  sin_theta *= sin_sign;
@@ -14076,6 +14474,7 @@ static void ggml_compute_forward_rope_f32(
14076
14474
  }
14077
14475
  }
14078
14476
 
14477
+ // TODO: deduplicate f16/f32 code
14079
14478
  static void ggml_compute_forward_rope_f16(
14080
14479
  const struct ggml_compute_params * params,
14081
14480
  struct ggml_tensor * dst,
@@ -14083,6 +14482,7 @@ static void ggml_compute_forward_rope_f16(
14083
14482
 
14084
14483
  const struct ggml_tensor * src0 = dst->src[0];
14085
14484
  const struct ggml_tensor * src1 = dst->src[1];
14485
+ const struct ggml_tensor * src2 = dst->src[2];
14086
14486
 
14087
14487
  if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
14088
14488
  return;
@@ -14135,6 +14535,17 @@ static void ggml_compute_forward_rope_f16(
14135
14535
  const bool is_neox = mode & 2;
14136
14536
  const bool is_glm = mode & 4;
14137
14537
 
14538
+ const float * freq_factors = NULL;
14539
+ if (is_neox) {
14540
+ if (src2 != NULL) {
14541
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
14542
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14543
+ freq_factors = (const float *) src2->data;
14544
+ }
14545
+ } else {
14546
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14547
+ }
14548
+
14138
14549
  // backward process uses inverse rotation by cos and sin.
14139
14550
  // cos and sin build a rotation matrix, where the inverse is the transpose.
14140
14551
  // this essentially just switches the sign of sin.
@@ -14207,10 +14618,11 @@ static void ggml_compute_forward_rope_f16(
14207
14618
 
14208
14619
  // simplified from `(ib * n_dims + ic) * inv_ndims`
14209
14620
  float cur_rot = inv_ndims * ic - ib;
14621
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14210
14622
 
14211
14623
  float cos_theta, sin_theta;
14212
14624
  rope_yarn(
14213
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14625
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14214
14626
  &cos_theta, &sin_theta
14215
14627
  );
14216
14628
  sin_theta *= sin_sign;
@@ -14972,25 +15384,28 @@ static void ggml_compute_forward_upscale_f32(
14972
15384
  return;
14973
15385
  }
14974
15386
 
14975
- GGML_ASSERT(src0->nb[0] == sizeof(float));
15387
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
14976
15388
 
14977
15389
  const int ith = params->ith;
14978
15390
  const int nth = params->nth;
14979
15391
 
14980
15392
  GGML_TENSOR_UNARY_OP_LOCALS
14981
15393
 
14982
- const int scale_factor = dst->op_params[0];
15394
+ const float sf0 = (float)ne0/src0->ne[0];
15395
+ const float sf1 = (float)ne1/src0->ne[1];
15396
+ const float sf2 = (float)ne2/src0->ne[2];
15397
+ const float sf3 = (float)ne3/src0->ne[3];
14983
15398
 
14984
15399
  // TODO: optimize
14985
15400
 
14986
15401
  for (int64_t i3 = 0; i3 < ne3; i3++) {
14987
- const int64_t i03 = i3;
15402
+ const int64_t i03 = i3 / sf3;
14988
15403
  for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
14989
- const int64_t i02 = i2;
15404
+ const int64_t i02 = i2 / sf2;
14990
15405
  for (int64_t i1 = 0; i1 < ne1; i1++) {
14991
- const int64_t i01 = i1 / scale_factor;
15406
+ const int64_t i01 = i1 / sf1;
14992
15407
  for (int64_t i0 = 0; i0 < ne0; i0++) {
14993
- const int64_t i00 = i0 / scale_factor;
15408
+ const int64_t i00 = i0 / sf0;
14994
15409
 
14995
15410
  const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
14996
15411
  float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
@@ -15020,6 +15435,7 @@ static void ggml_compute_forward_upscale(
15020
15435
  }
15021
15436
  }
15022
15437
 
15438
+
15023
15439
  // ggml_compute_forward_pad
15024
15440
 
15025
15441
  static void ggml_compute_forward_pad_f32(
@@ -15206,481 +15622,36 @@ static void ggml_compute_forward_argsort_f32(
15206
15622
 
15207
15623
  for (int64_t i = ith; i < nr; i += nth) {
15208
15624
  int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
15209
- const float * src_data = (float *)((char *) src0->data + i*nb01);
15210
-
15211
- for (int64_t j = 0; j < ne0; j++) {
15212
- dst_data[j] = j;
15213
- }
15214
-
15215
- // C doesn't have a functional sort, so we do a bubble sort instead
15216
- for (int64_t j = 0; j < ne0; j++) {
15217
- for (int64_t k = j + 1; k < ne0; k++) {
15218
- if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
15219
- (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
15220
- int32_t tmp = dst_data[j];
15221
- dst_data[j] = dst_data[k];
15222
- dst_data[k] = tmp;
15223
- }
15224
- }
15225
- }
15226
- }
15227
- }
15228
-
15229
- static void ggml_compute_forward_argsort(
15230
- const struct ggml_compute_params * params,
15231
- struct ggml_tensor * dst) {
15232
-
15233
- const struct ggml_tensor * src0 = dst->src[0];
15234
-
15235
- switch (src0->type) {
15236
- case GGML_TYPE_F32:
15237
- {
15238
- ggml_compute_forward_argsort_f32(params, dst);
15239
- } break;
15240
- default:
15241
- {
15242
- GGML_ASSERT(false);
15243
- } break;
15244
- }
15245
- }
15246
-
15247
- // ggml_compute_forward_flash_attn
15248
-
15249
- static void ggml_compute_forward_flash_attn_f32(
15250
- const struct ggml_compute_params * params,
15251
- const bool masked,
15252
- struct ggml_tensor * dst) {
15253
-
15254
- const struct ggml_tensor * q = dst->src[0];
15255
- const struct ggml_tensor * k = dst->src[1];
15256
- const struct ggml_tensor * v = dst->src[2];
15257
-
15258
- int64_t t0 = ggml_perf_time_us();
15259
- UNUSED(t0);
15260
-
15261
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15262
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15263
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15264
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15265
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15266
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15267
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15268
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15269
-
15270
- const int ith = params->ith;
15271
- const int nth = params->nth;
15272
-
15273
- const int64_t D = neq0;
15274
- const int64_t N = neq1;
15275
- const int64_t P = nek1 - N;
15276
- const int64_t M = P + N;
15277
-
15278
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15279
-
15280
- GGML_ASSERT(ne0 == D);
15281
- GGML_ASSERT(ne1 == N);
15282
- GGML_ASSERT(P >= 0);
15283
-
15284
- GGML_ASSERT(nbq0 == sizeof(float));
15285
- GGML_ASSERT(nbk0 == sizeof(float));
15286
- GGML_ASSERT(nbv0 == sizeof(float));
15287
-
15288
- GGML_ASSERT(neq0 == D);
15289
- GGML_ASSERT(nek0 == D);
15290
- GGML_ASSERT(nev1 == D);
15291
-
15292
- GGML_ASSERT(neq1 == N);
15293
- GGML_ASSERT(nek1 == N + P);
15294
- GGML_ASSERT(nev1 == D);
15295
-
15296
- // dst cannot be transposed or permuted
15297
- GGML_ASSERT(nb0 == sizeof(float));
15298
- GGML_ASSERT(nb0 <= nb1);
15299
- GGML_ASSERT(nb1 <= nb2);
15300
- GGML_ASSERT(nb2 <= nb3);
15301
-
15302
- if (params->type == GGML_TASK_TYPE_INIT) {
15303
- return;
15304
- }
15305
-
15306
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15307
- return;
15308
- }
15309
-
15310
- // parallelize by q rows using ggml_vec_dot_f32
15311
-
15312
- // total rows in q
15313
- const int nr = neq1*neq2*neq3;
15314
-
15315
- // rows per thread
15316
- const int dr = (nr + nth - 1)/nth;
15317
-
15318
- // row range for this thread
15319
- const int ir0 = dr*ith;
15320
- const int ir1 = MIN(ir0 + dr, nr);
15321
-
15322
- const float scale = 1.0f/sqrtf(D);
15323
-
15324
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15325
-
15326
- for (int ir = ir0; ir < ir1; ++ir) {
15327
- // q indices
15328
- const int iq3 = ir/(neq2*neq1);
15329
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15330
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15331
-
15332
- float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32);
15333
-
15334
- for (int i = M; i < Mup; ++i) {
15335
- S[i] = -INFINITY;
15336
- }
15337
-
15338
- const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
15339
- for (int64_t ic = 0; ic < masked_begin; ++ic) {
15340
- // k indices
15341
- const int ik3 = iq3;
15342
- const int ik2 = iq2 % nek2;
15343
- const int ik1 = ic;
15344
-
15345
- // S indices
15346
- const int i1 = ik1;
15347
-
15348
- ggml_vec_dot_f32(neq0,
15349
- S + i1, 0,
15350
- (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15351
- (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15352
- }
15353
-
15354
- // scale
15355
- ggml_vec_scale_f32(masked_begin, S, scale);
15356
-
15357
- for (int64_t i = masked_begin; i < M; i++) {
15358
- S[i] = -INFINITY;
15359
- }
15360
-
15361
- // softmax
15362
- // exclude known -INF S[..] values from max and loop
15363
- // dont forget to set their SW values to zero
15364
- {
15365
- float max = -INFINITY;
15366
- ggml_vec_max_f32(masked_begin, &max, S);
15367
-
15368
- ggml_float sum = 0.0;
15369
- {
15370
- #ifdef GGML_SOFT_MAX_ACCELERATE
15371
- max = -max;
15372
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15373
- vvexpf(S, S, &Mup);
15374
- ggml_vec_sum_f32(Mup, &sum, S);
15375
- #else
15376
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
15377
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15378
-
15379
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15380
- if (i >= masked_begin) {
15381
- break;
15382
- }
15383
- float * SS = S + i;
15384
-
15385
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15386
- if (i + j >= masked_begin) {
15387
- break;
15388
- } else if (SS[j] == -INFINITY) {
15389
- SS[j] = 0.0f;
15390
- } else {
15391
- #ifndef GGML_FLASH_ATTN_EXP_FP16
15392
- const float val = expf(SS[j] - max);
15393
- #else
15394
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15395
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15396
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15397
- #endif
15398
- sump[j] += (ggml_float)val;
15399
- SS[j] = val;
15400
- }
15401
- }
15402
- }
15403
-
15404
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15405
- sum += sump[i];
15406
- }
15407
- #endif
15408
- }
15409
-
15410
- assert(sum > 0.0);
15411
-
15412
- sum = 1.0/sum;
15413
- ggml_vec_scale_f32(masked_begin, S, sum);
15414
-
15415
- #ifndef NDEBUG
15416
- for (int i = 0; i < masked_begin; ++i) {
15417
- assert(!isnan(S[i]));
15418
- assert(!isinf(S[i]));
15419
- }
15420
- #endif
15421
- }
15422
-
15423
- for (int64_t ic = 0; ic < nev1; ++ic) {
15424
- // dst indices
15425
- const int i1 = iq1;
15426
- const int i2 = iq2;
15427
- const int i3 = iq3;
15428
-
15429
- // v indices
15430
- const int iv2 = iq2 % nev2;
15431
- const int iv3 = iq3;
15432
-
15433
- ggml_vec_dot_f32(masked_begin,
15434
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15435
- (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15436
- S, 0, 1);
15437
- }
15438
- }
15439
- }
15440
-
15441
- static void ggml_compute_forward_flash_attn_f16(
15442
- const struct ggml_compute_params * params,
15443
- const bool masked,
15444
- struct ggml_tensor * dst) {
15445
-
15446
- const struct ggml_tensor * q = dst->src[0];
15447
- const struct ggml_tensor * k = dst->src[1];
15448
- const struct ggml_tensor * v = dst->src[2];
15449
-
15450
- int64_t t0 = ggml_perf_time_us();
15451
- UNUSED(t0);
15452
-
15453
- GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
15454
- GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
15455
- GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
15456
- GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
15457
- GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
15458
- GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
15459
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15460
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15461
-
15462
- const int ith = params->ith;
15463
- const int nth = params->nth;
15464
-
15465
- const int64_t D = neq0;
15466
- const int64_t N = neq1;
15467
- const int64_t P = nek1 - N;
15468
- const int64_t M = P + N;
15469
-
15470
- const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
15471
-
15472
- GGML_ASSERT(ne0 == D);
15473
- GGML_ASSERT(ne1 == N);
15474
- GGML_ASSERT(P >= 0);
15475
-
15476
- GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t));
15477
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15478
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15479
-
15480
- GGML_ASSERT(neq0 == D);
15481
- GGML_ASSERT(nek0 == D);
15482
- GGML_ASSERT(nev1 == D);
15483
-
15484
- GGML_ASSERT(neq1 == N);
15485
- GGML_ASSERT(nek1 == N + P);
15486
- GGML_ASSERT(nev1 == D);
15487
-
15488
- // dst cannot be transposed or permuted
15489
- GGML_ASSERT(nb0 == sizeof(float));
15490
- GGML_ASSERT(nb0 <= nb1);
15491
- GGML_ASSERT(nb1 <= nb2);
15492
- GGML_ASSERT(nb2 <= nb3);
15493
-
15494
- if (params->type == GGML_TASK_TYPE_INIT) {
15495
- return;
15496
- }
15497
-
15498
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15499
- return;
15500
- }
15501
-
15502
- // parallelize by q rows using ggml_vec_dot_f32
15503
-
15504
- // total rows in q
15505
- const int nr = neq1*neq2*neq3;
15506
-
15507
- // rows per thread
15508
- const int dr = (nr + nth - 1)/nth;
15509
-
15510
- // row range for this thread
15511
- const int ir0 = dr*ith;
15512
- const int ir1 = MIN(ir0 + dr, nr);
15513
-
15514
- const float scale = 1.0f/sqrtf(D);
15515
-
15516
- //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
15517
-
15518
- for (int ir = ir0; ir < ir1; ++ir) {
15519
- // q indices
15520
- const int iq3 = ir/(neq2*neq1);
15521
- const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15522
- const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15523
-
15524
- float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32);
15525
-
15526
- for (int i = M; i < Mup; ++i) {
15527
- S[i] = -INFINITY;
15528
- }
15529
-
15530
- if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
15531
- for (int64_t ic = 0; ic < nek1; ++ic) {
15532
- // k indices
15533
- const int ik3 = iq3;
15534
- const int ik2 = iq2 % nek2;
15535
- const int ik1 = ic;
15536
-
15537
- // S indices
15538
- const int i1 = ik1;
15539
-
15540
- ggml_vec_dot_f16(neq0,
15541
- S + i1, 0,
15542
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15543
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
15544
- }
15545
- } else {
15546
- for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
15547
- // k indices
15548
- const int ik3 = iq3;
15549
- const int ik2 = iq2 % nek2;
15550
- const int ik1 = ic;
15551
-
15552
- // S indices
15553
- const int i1 = ik1;
15554
-
15555
- ggml_vec_dot_f16_unroll(neq0, nbk1,
15556
- S + i1,
15557
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
15558
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
15559
- }
15560
- }
15561
-
15562
- // scale
15563
- ggml_vec_scale_f32(nek1, S, scale);
15564
-
15565
- if (masked) {
15566
- for (int64_t i = P; i < M; i++) {
15567
- if (i > P + iq1) {
15568
- S[i] = -INFINITY;
15569
- }
15570
- }
15571
- }
15572
-
15573
- // softmax
15574
- // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
15575
- // dont forget to set their S values to zero
15576
- {
15577
- float max = -INFINITY;
15578
- ggml_vec_max_f32(M, &max, S);
15579
-
15580
- ggml_float sum = 0.0;
15581
- {
15582
- #ifdef GGML_SOFT_MAX_ACCELERATE
15583
- max = -max;
15584
- vDSP_vsadd(S, 1, &max, S, 1, Mup);
15585
- vvexpf(S, S, &Mup);
15586
- ggml_vec_sum_f32(Mup, &sum, S);
15587
- #else
15588
- uint16_t scvt[GGML_SOFT_MAX_UNROLL];
15589
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
15590
-
15591
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
15592
- float * SS = S + i;
15593
-
15594
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
15595
- if (SS[j] == -INFINITY) {
15596
- SS[j] = 0.0f;
15597
- } else {
15598
- ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
15599
- memcpy(&scvt[j], &s, sizeof(uint16_t));
15600
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
15601
- sump[j] += (ggml_float)val;
15602
- SS[j] = val;
15603
- }
15604
- }
15605
- }
15606
-
15607
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
15608
- sum += sump[i];
15609
- }
15610
- #endif
15611
- }
15612
-
15613
- assert(sum > 0.0);
15614
-
15615
- sum = 1.0/sum;
15616
- ggml_vec_scale_f32(M, S, sum);
15617
-
15618
- #ifndef NDEBUG
15619
- for (int i = 0; i < M; ++i) {
15620
- assert(!isnan(S[i]));
15621
- assert(!isinf(S[i]));
15622
- }
15623
- #endif
15624
- }
15625
-
15626
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
15625
+ const float * src_data = (float *)((char *) src0->data + i*nb01);
15627
15626
 
15628
- for (int64_t i = 0; i < M; i++) {
15629
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15627
+ for (int64_t j = 0; j < ne0; j++) {
15628
+ dst_data[j] = j;
15630
15629
  }
15631
15630
 
15632
- // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
15633
- if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
15634
- for (int64_t ic = 0; ic < nev1; ++ic) {
15635
- // dst indices
15636
- const int i1 = iq1;
15637
- const int i2 = iq2;
15638
- const int i3 = iq3;
15639
-
15640
- // v indices
15641
- const int iv2 = iq2 % nev2;
15642
- const int iv3 = iq3;
15643
-
15644
- ggml_vec_dot_f16(nev0,
15645
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
15646
- (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
15647
- S16, 0, 1);
15648
- }
15649
- } else {
15650
- for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
15651
- // dst indices
15652
- const int i1 = iq1;
15653
- const int i2 = iq2;
15654
- const int i3 = iq3;
15655
-
15656
- // v indices
15657
- const int iv2 = iq2 % nev2;
15658
- const int iv3 = iq3;
15659
-
15660
- ggml_vec_dot_f16_unroll(nev0, nbv1,
15661
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
15662
- ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
15663
- S16);
15631
+ // C doesn't have a functional sort, so we do a bubble sort instead
15632
+ for (int64_t j = 0; j < ne0; j++) {
15633
+ for (int64_t k = j + 1; k < ne0; k++) {
15634
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
15635
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
15636
+ int32_t tmp = dst_data[j];
15637
+ dst_data[j] = dst_data[k];
15638
+ dst_data[k] = tmp;
15639
+ }
15664
15640
  }
15665
15641
  }
15666
15642
  }
15667
15643
  }
15668
15644
 
15669
- static void ggml_compute_forward_flash_attn(
15670
- const struct ggml_compute_params * params,
15671
- const bool masked,
15672
- struct ggml_tensor * dst) {
15645
+ static void ggml_compute_forward_argsort(
15646
+ const struct ggml_compute_params * params,
15647
+ struct ggml_tensor * dst) {
15673
15648
 
15674
- const struct ggml_tensor * q = dst->src[0];
15649
+ const struct ggml_tensor * src0 = dst->src[0];
15675
15650
 
15676
- switch (q->type) {
15677
- case GGML_TYPE_F16:
15678
- {
15679
- ggml_compute_forward_flash_attn_f16(params, masked, dst);
15680
- } break;
15651
+ switch (src0->type) {
15681
15652
  case GGML_TYPE_F32:
15682
15653
  {
15683
- ggml_compute_forward_flash_attn_f32(params, masked, dst);
15654
+ ggml_compute_forward_argsort_f32(params, dst);
15684
15655
  } break;
15685
15656
  default:
15686
15657
  {
@@ -15719,9 +15690,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15719
15690
  GGML_ASSERT(ne0 == D);
15720
15691
  GGML_ASSERT(ne2 == N);
15721
15692
 
15722
- GGML_ASSERT(nbq0 == sizeof(float));
15723
- GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
15724
- GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
15693
+ // input tensor rows must be contiguous
15694
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
15695
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
15696
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
15725
15697
 
15726
15698
  GGML_ASSERT(neq0 == D);
15727
15699
  GGML_ASSERT(nek0 == D);
@@ -15763,8 +15735,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15763
15735
  const int ir0 = dr*ith;
15764
15736
  const int ir1 = MIN(ir0 + dr, nr);
15765
15737
 
15766
- float scale = 1.0f;
15767
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15738
+ float scale = 1.0f;
15739
+ float max_bias = 0.0f;
15740
+
15741
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
15742
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
15743
+
15744
+ const uint32_t n_head = neq2;
15745
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
15746
+
15747
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
15748
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
15749
+
15750
+ enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type;
15751
+ ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float;
15752
+ ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
15753
+ ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
15768
15754
 
15769
15755
  // loop over n_batch and n_head
15770
15756
  for (int ir = ir0; ir < ir1; ++ir) {
@@ -15773,14 +15759,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15773
15759
  const int iq2 = (ir - iq3*neq2*neq1)/neq1;
15774
15760
  const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
15775
15761
 
15776
- float S = 0.0f;
15777
- float M = -INFINITY;
15762
+ const uint32_t h = iq2; // head index
15763
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
15764
+
15765
+ float S = 0.0f; // sum
15766
+ float M = -INFINITY; // maximum KQ value
15778
15767
 
15779
- float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
15780
- ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
15781
- ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
15768
+ float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
15769
+ float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
15770
+ ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
15771
+ ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
15782
15772
 
15783
- memset(V16, 0, D*sizeof(ggml_fp16_t));
15773
+ if (v->type == GGML_TYPE_F16) {
15774
+ memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
15775
+ } else {
15776
+ memset(VKQ32, 0, D*sizeof(float));
15777
+ }
15784
15778
 
15785
15779
  const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
15786
15780
 
@@ -15792,61 +15786,79 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15792
15786
  const int iv3 = iq3 / rv3;
15793
15787
  const int iv2 = iq2 / rv2;
15794
15788
 
15789
+ const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15790
+ q_to_vec_dot(pq, Q_q, D);
15791
+
15795
15792
  // online softmax / attention
15796
15793
  // loop over n_kv and n_head_kv
15797
15794
  // ref: https://arxiv.org/pdf/2112.05682.pdf
15798
15795
  for (int64_t ic = 0; ic < nek1; ++ic) {
15799
- const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15796
+ const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
15800
15797
  if (mv == -INFINITY) {
15801
15798
  continue;
15802
15799
  }
15803
15800
 
15804
- float s;
15801
+ float s; // KQ value
15805
15802
 
15806
- // convert Q to F16 in V32
15807
- {
15808
- const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
15803
+ const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
15804
+ kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
15809
15805
 
15810
- for (int64_t d = 0; d < D; ++d) {
15811
- Q16[d] = GGML_FP32_TO_FP16(pq[d]);
15812
- }
15813
- }
15806
+ s = s*scale + mv; // scale KQ value and apply mask
15814
15807
 
15815
- ggml_vec_dot_f16(D,
15816
- &s, 0,
15817
- (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
15818
- Q16, 0, 1);
15808
+ const float Mold = M;
15819
15809
 
15820
- s = s*scale + mv;
15810
+ float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
15811
+ float vs = 1.0f; // post-softmax KQ value, expf(s - M)
15821
15812
 
15822
- const float Mold = M;
15813
+ const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15823
15814
 
15824
- float ms = 1.0f;
15825
- float vs = 1.0f;
15815
+ if (v->type== GGML_TYPE_F16) {
15816
+ if (s > M) {
15817
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15818
+ M = s;
15819
+ ms = expf(Mold - M);
15826
15820
 
15827
- if (s > M) {
15828
- M = s;
15829
- ms = expf(Mold - M);
15821
+ // V = V*expf(Mold - M)
15822
+ ggml_vec_scale_f16(D, VKQ16, ms);
15823
+ } else {
15824
+ // no new maximum, ms == 1.0f, vs != 1.0f
15825
+ vs = expf(s - M);
15826
+ }
15830
15827
 
15831
- // V = V*expf(Mold - M)
15832
- ggml_vec_scale_f16(D, V16, ms);
15828
+ // V += v*expf(s - M)
15829
+ ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
15833
15830
  } else {
15834
- vs = expf(s - M);
15835
- }
15831
+ if (s > M) {
15832
+ // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
15833
+ M = s;
15834
+ ms = expf(Mold - M);
15835
+
15836
+ // V = V*expf(Mold - M)
15837
+ ggml_vec_scale_f32(D, VKQ32, ms);
15838
+ } else {
15839
+ // no new maximum, ms == 1.0f, vs != 1.0f
15840
+ vs = expf(s - M);
15841
+ }
15836
15842
 
15837
- const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
15843
+ v_to_float(v_data, V32, D);
15838
15844
 
15839
- // V += v*expf(s - M)
15840
- ggml_vec_mad_f16(D, V16, v16, vs);
15845
+ // V += v*expf(s - M)
15846
+ ggml_vec_mad_f32(D, VKQ32, V32, vs);
15847
+ }
15841
15848
 
15842
- S = S*ms + vs;
15849
+ S = S*ms + vs; // scale and increment sum with partial sum
15843
15850
  }
15844
15851
 
15845
- // V /= S
15846
- for (int64_t d = 0; d < D; ++d) {
15847
- V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
15852
+ if (v->type == GGML_TYPE_F16) {
15853
+ for (int64_t d = 0; d < D; ++d) {
15854
+ VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
15855
+ }
15848
15856
  }
15849
15857
 
15858
+ // V /= S
15859
+ const float S_inv = 1.0f/S;
15860
+ ggml_vec_scale_f32(D, VKQ32, S_inv);
15861
+
15850
15862
  // dst indices
15851
15863
  const int i1 = iq1;
15852
15864
  const int i2 = iq2;
@@ -15856,7 +15868,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
15856
15868
  //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
15857
15869
 
15858
15870
  // permute(0, 2, 1, 3)
15859
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
15871
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
15860
15872
  }
15861
15873
  }
15862
15874
 
@@ -15867,7 +15879,7 @@ static void ggml_compute_forward_flash_attn_ext(
15867
15879
  const struct ggml_tensor * v,
15868
15880
  const struct ggml_tensor * mask,
15869
15881
  struct ggml_tensor * dst) {
15870
- switch (dst->op_params[1]) {
15882
+ switch (dst->op_params[2]) {
15871
15883
  case GGML_PREC_DEFAULT:
15872
15884
  case GGML_PREC_F32:
15873
15885
  {
@@ -15881,165 +15893,6 @@ static void ggml_compute_forward_flash_attn_ext(
15881
15893
  }
15882
15894
  }
15883
15895
 
15884
- // ggml_compute_forward_flash_ff
15885
-
15886
- static void ggml_compute_forward_flash_ff_f16(
15887
- const struct ggml_compute_params * params,
15888
- struct ggml_tensor * dst) {
15889
-
15890
- const struct ggml_tensor * a = dst->src[0]; // F16
15891
- const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w
15892
- const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b
15893
- const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w
15894
- const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b
15895
-
15896
- int64_t t0 = ggml_perf_time_us();
15897
- UNUSED(t0);
15898
-
15899
- GGML_TENSOR_LOCALS(int64_t, nea, a, ne)
15900
- GGML_TENSOR_LOCALS(size_t, nba, a, nb)
15901
- GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne)
15902
- GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb)
15903
- GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne)
15904
- GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb)
15905
- GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne)
15906
- GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb)
15907
- GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne)
15908
- GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb)
15909
- GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
15910
- GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
15911
-
15912
- const int ith = params->ith;
15913
- const int nth = params->nth;
15914
-
15915
- const int64_t D = nea0;
15916
- //const int64_t N = nea1;
15917
- const int64_t M = neb01;
15918
-
15919
- GGML_ASSERT(ne0 == nea0);
15920
- GGML_ASSERT(ne1 == nea1);
15921
- GGML_ASSERT(ne2 == nea2);
15922
-
15923
- GGML_ASSERT(nba0 == sizeof(ggml_fp16_t));
15924
- GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t));
15925
- GGML_ASSERT(nbb10 == sizeof(float));
15926
- GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t));
15927
- GGML_ASSERT(nbc10 == sizeof(float));
15928
-
15929
- GGML_ASSERT(neb00 == D);
15930
- GGML_ASSERT(neb01 == M);
15931
- GGML_ASSERT(neb10 == M);
15932
- GGML_ASSERT(neb11 == 1);
15933
-
15934
- GGML_ASSERT(nec00 == M);
15935
- GGML_ASSERT(nec01 == D);
15936
- GGML_ASSERT(nec10 == D);
15937
- GGML_ASSERT(nec11 == 1);
15938
-
15939
- // dst cannot be transposed or permuted
15940
- GGML_ASSERT(nb0 == sizeof(float));
15941
- GGML_ASSERT(nb0 <= nb1);
15942
- GGML_ASSERT(nb1 <= nb2);
15943
- GGML_ASSERT(nb2 <= nb3);
15944
-
15945
- if (params->type == GGML_TASK_TYPE_INIT) {
15946
- return;
15947
- }
15948
-
15949
- if (params->type == GGML_TASK_TYPE_FINALIZE) {
15950
- return;
15951
- }
15952
-
15953
- // parallelize by a rows using ggml_vec_dot_f32
15954
-
15955
- // total rows in a
15956
- const int nr = nea1*nea2*nea3;
15957
-
15958
- // rows per thread
15959
- const int dr = (nr + nth - 1)/nth;
15960
-
15961
- // row range for this thread
15962
- const int ir0 = dr*ith;
15963
- const int ir1 = MIN(ir0 + dr, nr);
15964
-
15965
- for (int ir = ir0; ir < ir1; ++ir) {
15966
- // a indices
15967
- const int ia3 = ir/(nea2*nea1);
15968
- const int ia2 = (ir - ia3*nea2*nea1)/nea1;
15969
- const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1);
15970
-
15971
- float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32);
15972
-
15973
- for (int64_t ic = 0; ic < neb01; ++ic) {
15974
- // b0 indices
15975
- const int ib03 = ia3;
15976
- const int ib02 = ia2;
15977
- const int ib01 = ic;
15978
-
15979
- // S indices
15980
- const int i1 = ib01;
15981
-
15982
- ggml_vec_dot_f16(nea0,
15983
- S + i1, 0,
15984
- (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
15985
- (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
15986
- }
15987
-
15988
- ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
15989
- //ggml_vec_gelu_f32(neb01, S, S);
15990
-
15991
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
15992
-
15993
- for (int64_t i = 0; i < M; i++) {
15994
- S16[i] = GGML_FP32_TO_FP16(S[i]);
15995
- }
15996
-
15997
- ggml_vec_gelu_f16(neb01, S16, S16);
15998
-
15999
- {
16000
- // dst indices
16001
- const int i1 = ia1;
16002
- const int i2 = ia2;
16003
- const int i3 = ia3;
16004
-
16005
- for (int64_t ic = 0; ic < nec01; ++ic) {
16006
-
16007
- ggml_vec_dot_f16(neb01,
16008
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
16009
- (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
16010
- S16, 0, 1);
16011
- }
16012
-
16013
- ggml_vec_add_f32(nec01,
16014
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16015
- (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)),
16016
- (float *) c1->data);
16017
- }
16018
- }
16019
- }
16020
-
16021
- static void ggml_compute_forward_flash_ff(
16022
- const struct ggml_compute_params * params,
16023
- struct ggml_tensor * dst) {
16024
-
16025
- const struct ggml_tensor * b0 = dst->src[1];
16026
-
16027
- switch (b0->type) {
16028
- case GGML_TYPE_F16:
16029
- {
16030
- ggml_compute_forward_flash_ff_f16(params, dst);
16031
- } break;
16032
- case GGML_TYPE_F32:
16033
- {
16034
- GGML_ASSERT(false); // TODO
16035
- } break;
16036
- default:
16037
- {
16038
- GGML_ASSERT(false);
16039
- } break;
16040
- }
16041
- }
16042
-
16043
15896
  // ggml_compute_forward_flash_attn_back
16044
15897
 
16045
15898
  static void ggml_compute_forward_flash_attn_back_f32(
@@ -16221,38 +16074,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
16221
16074
  vvexpf(SM, SM, &Mup);
16222
16075
  ggml_vec_sum_f32(Mup, &sum, SM);
16223
16076
  #else
16224
- uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
16225
- ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
16226
-
16227
- for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
16228
- if (i >= masked_begin) {
16229
- break;
16230
- }
16231
- float * SR = S + i;
16232
- float * SW = SM + i;
16233
-
16234
- for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
16235
- if (i + j >= masked_begin) {
16236
- break;
16237
- } else if (SR[j] == -INFINITY) {
16238
- SW[j] = 0.0f;
16239
- } else {
16240
- #ifndef GGML_FLASH_ATTN_EXP_FP16
16241
- const float val = expf(SR[j] - max);
16242
- #else
16243
- ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
16244
- memcpy(&scvt[j], &s, sizeof(uint16_t));
16245
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
16246
- #endif
16247
- sump[j] += (ggml_float)val;
16248
- SW[j] = val;
16249
- }
16250
- }
16251
- }
16252
-
16253
- for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
16254
- sum += sump[i];
16255
- }
16077
+ sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
16256
16078
  #endif
16257
16079
  }
16258
16080
 
@@ -16834,6 +16656,10 @@ static void ggml_compute_forward_unary(
16834
16656
  {
16835
16657
  ggml_compute_forward_relu(params, dst);
16836
16658
  } break;
16659
+ case GGML_UNARY_OP_SIGMOID:
16660
+ {
16661
+ ggml_compute_forward_sigmoid(params, dst);
16662
+ } break;
16837
16663
  case GGML_UNARY_OP_GELU:
16838
16664
  {
16839
16665
  ggml_compute_forward_gelu(params, dst);
@@ -17274,35 +17100,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
17274
17100
  assert(!isnan(s1[i]));
17275
17101
  }
17276
17102
  #endif
17277
- // soft_max
17278
- ggml_float sum = 0.0;
17279
- {
17280
- float max = -INFINITY;
17281
- ggml_vec_max_f32(nc, &max, s0);
17282
17103
 
17283
- uint16_t scvt; UNUSED(scvt);
17284
- for (int i = 0; i < nc; i++) {
17285
- if (s0[i] == -INFINITY) {
17286
- st[i] = 0.0f;
17287
- } else {
17288
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17289
- const float s = s0[i] - max;
17290
- const float val = expf(s);
17291
- #else
17292
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17293
- memcpy(&scvt, &s, sizeof(scvt));
17294
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17295
- #endif
17296
- sum += (ggml_float)val;
17297
- st[i] = val;
17298
- }
17299
- }
17104
+ // soft_max
17105
+ float max = -INFINITY;
17106
+ ggml_vec_max_f32(nc, &max, s0);
17107
+ ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
17108
+ assert(sum > 0.0);
17109
+ sum = (1.0 - eps) / sum;
17300
17110
 
17301
- assert(sum > 0.0);
17302
- // sum = 1.0/sum;
17303
- }
17304
17111
  // avoid log(0) by rescaling from [0..1] to [eps..1]
17305
- sum = (1.0 - eps) / sum;
17306
17112
  ggml_vec_scale_f32(nc, st, sum);
17307
17113
  ggml_vec_add1_f32(nc, st, st, eps);
17308
17114
  ggml_vec_log_f32(nc, st, st);
@@ -17392,32 +17198,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
17392
17198
  #endif
17393
17199
 
17394
17200
  // soft_max
17395
- ggml_float sum = 0.0;
17396
- {
17397
- float max = -INFINITY;
17398
- ggml_vec_max_f32(nc, &max, s0);
17399
-
17400
- uint16_t scvt; UNUSED(scvt);
17401
- for (int i = 0; i < nc; i++) {
17402
- if (s0[i] == -INFINITY) {
17403
- ds0[i] = 0.0f;
17404
- } else {
17405
- #ifndef GGML_CROSS_ENTROPY_EXP_FP16
17406
- const float s = s0[i] - max;
17407
- const float val = expf(s);
17408
- #else
17409
- ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
17410
- memcpy(&scvt, &s, sizeof(scvt));
17411
- const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
17412
- #endif
17413
- sum += (ggml_float)val;
17414
- ds0[i] = val;
17415
- }
17416
- }
17417
-
17418
- assert(sum > 0.0);
17419
- sum = (1.0 - eps)/sum;
17420
- }
17201
+ float max = -INFINITY;
17202
+ ggml_vec_max_f32(nc, &max, s0);
17203
+ ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
17204
+ assert(sum > 0.0);
17205
+ sum = (1.0 - eps) / sum;
17421
17206
 
17422
17207
  // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
17423
17208
  ggml_vec_scale_f32(nc, ds0, sum);
@@ -17454,7 +17239,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
17454
17239
 
17455
17240
  /////////////////////////////////
17456
17241
 
17457
- static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
17242
+ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
17458
17243
  GGML_ASSERT(params);
17459
17244
 
17460
17245
  if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
@@ -17552,7 +17337,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17552
17337
  } break;
17553
17338
  case GGML_OP_MUL_MAT:
17554
17339
  {
17555
- ggml_compute_forward_mul_mat(params, tensor);
17340
+ ggml_compute_forward_mul_mat(params, tensor, state);
17556
17341
  } break;
17557
17342
  case GGML_OP_MUL_MAT_ID:
17558
17343
  {
@@ -17630,10 +17415,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17630
17415
  {
17631
17416
  ggml_compute_forward_rope_back(params, tensor);
17632
17417
  } break;
17633
- case GGML_OP_ALIBI:
17634
- {
17635
- ggml_compute_forward_alibi(params, tensor);
17636
- } break;
17637
17418
  case GGML_OP_CLAMP:
17638
17419
  {
17639
17420
  ggml_compute_forward_clamp(params, tensor);
@@ -17682,21 +17463,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17682
17463
  {
17683
17464
  ggml_compute_forward_leaky_relu(params, tensor);
17684
17465
  } break;
17685
- case GGML_OP_FLASH_ATTN:
17686
- {
17687
- const int32_t t = ggml_get_op_params_i32(tensor, 0);
17688
- GGML_ASSERT(t == 0 || t == 1);
17689
- const bool masked = t != 0;
17690
- ggml_compute_forward_flash_attn(params, masked, tensor);
17691
- } break;
17692
17466
  case GGML_OP_FLASH_ATTN_EXT:
17693
17467
  {
17694
17468
  ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
17695
17469
  } break;
17696
- case GGML_OP_FLASH_FF:
17697
- {
17698
- ggml_compute_forward_flash_ff(params, tensor);
17699
- } break;
17700
17470
  case GGML_OP_FLASH_ATTN_BACK:
17701
17471
  {
17702
17472
  int32_t t = ggml_get_op_params_i32(tensor, 0);
@@ -18066,6 +17836,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
18066
17836
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
18067
17837
  struct ggml_tensor * src0 = tensor->src[0];
18068
17838
  struct ggml_tensor * src1 = tensor->src[1];
17839
+ struct ggml_tensor * src2 = tensor->src[2];
18069
17840
 
18070
17841
  switch (tensor->op) {
18071
17842
  case GGML_OP_DUP:
@@ -18597,6 +18368,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18597
18368
  ggml_rope_back(ctx,
18598
18369
  tensor->grad,
18599
18370
  src1,
18371
+ src2,
18600
18372
  n_dims,
18601
18373
  mode,
18602
18374
  n_ctx,
@@ -18636,6 +18408,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18636
18408
  ggml_rope_impl(ctx,
18637
18409
  tensor->grad,
18638
18410
  src1,
18411
+ src2,
18639
18412
  n_dims,
18640
18413
  mode,
18641
18414
  n_ctx,
@@ -18652,10 +18425,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18652
18425
  zero_table);
18653
18426
  }
18654
18427
  } break;
18655
- case GGML_OP_ALIBI:
18656
- {
18657
- GGML_ASSERT(false); // TODO: not implemented
18658
- } break;
18659
18428
  case GGML_OP_CLAMP:
18660
18429
  {
18661
18430
  GGML_ASSERT(false); // TODO: not implemented
@@ -18704,7 +18473,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18704
18473
  {
18705
18474
  GGML_ASSERT(false); // TODO: not implemented
18706
18475
  } break;
18707
- case GGML_OP_FLASH_ATTN:
18708
18476
  case GGML_OP_FLASH_ATTN_EXT:
18709
18477
  {
18710
18478
  struct ggml_tensor * flash_grad = NULL;
@@ -18721,7 +18489,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18721
18489
  masked);
18722
18490
  }
18723
18491
 
18724
- struct ggml_tensor * src2 = tensor->src[2];
18725
18492
  const int64_t elem_q = ggml_nelements(src0);
18726
18493
  const int64_t elem_k = ggml_nelements(src1);
18727
18494
  const int64_t elem_v = ggml_nelements(src2);
@@ -18759,10 +18526,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18759
18526
  zero_table);
18760
18527
  }
18761
18528
  } break;
18762
- case GGML_OP_FLASH_FF:
18763
- {
18764
- GGML_ASSERT(false); // not supported
18765
- } break;
18766
18529
  case GGML_OP_FLASH_ATTN_BACK:
18767
18530
  {
18768
18531
  GGML_ASSERT(false); // not supported
@@ -18826,6 +18589,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18826
18589
  zero_table);
18827
18590
  }
18828
18591
  } break;
18592
+ case GGML_UNARY_OP_SIGMOID:
18593
+ {
18594
+ GGML_ASSERT(false); // TODO: not implemented
18595
+ } break;
18829
18596
  case GGML_UNARY_OP_GELU:
18830
18597
  {
18831
18598
  GGML_ASSERT(false); // TODO: not implemented
@@ -19172,8 +18939,6 @@ typedef int ggml_lock_t;
19172
18939
 
19173
18940
  #define GGML_LOCK_INITIALIZER 0
19174
18941
 
19175
- typedef pthread_t ggml_thread_t;
19176
-
19177
18942
  #define ggml_thread_create pthread_create
19178
18943
  #define ggml_thread_join pthread_join
19179
18944
 
@@ -19199,8 +18964,6 @@ typedef int ggml_lock_t;
19199
18964
 
19200
18965
  #define GGML_LOCK_INITIALIZER 0
19201
18966
 
19202
- typedef pthread_t ggml_thread_t;
19203
-
19204
18967
  #define ggml_thread_create pthread_create
19205
18968
  #define ggml_thread_join pthread_join
19206
18969
 
@@ -19280,31 +19043,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
19280
19043
  static void clear_numa_thread_affinity(void) {}
19281
19044
  #endif
19282
19045
 
19283
- struct ggml_compute_state_shared {
19284
- const struct ggml_cgraph * cgraph;
19285
- const struct ggml_cplan * cplan;
19286
-
19287
- int64_t perf_node_start_cycles;
19288
- int64_t perf_node_start_time_us;
19289
-
19290
- const int n_threads;
19291
-
19292
- // synchronization primitives
19293
- atomic_int n_active; // num active threads
19294
- atomic_int node_n; // active graph node
19295
- atomic_int node_task; // active graph node task phase
19296
-
19297
- ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
19298
- void * abort_callback_data;
19299
- };
19300
-
19301
- struct ggml_compute_state {
19302
- ggml_thread_t thrd;
19303
- int ith;
19304
- struct ggml_compute_state_shared * shared;
19305
- enum ggml_status ec;
19306
- };
19307
-
19308
19046
  static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
19309
19047
  int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles;
19310
19048
  int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
@@ -19355,6 +19093,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19355
19093
  case GGML_UNARY_OP_TANH:
19356
19094
  case GGML_UNARY_OP_ELU:
19357
19095
  case GGML_UNARY_OP_RELU:
19096
+ case GGML_UNARY_OP_SIGMOID:
19358
19097
  case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
19359
19098
  case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
19360
19099
  {
@@ -19428,10 +19167,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19428
19167
  {
19429
19168
  n_tasks = n_threads;
19430
19169
  } break;
19431
- case GGML_OP_ALIBI:
19432
- {
19433
- n_tasks = 1; //TODO
19434
- } break;
19435
19170
  case GGML_OP_CLAMP:
19436
19171
  {
19437
19172
  n_tasks = 1; //TODO
@@ -19477,15 +19212,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
19477
19212
  {
19478
19213
  n_tasks = n_threads;
19479
19214
  } break;
19480
- case GGML_OP_FLASH_ATTN:
19481
19215
  case GGML_OP_FLASH_ATTN_EXT:
19482
19216
  {
19483
19217
  n_tasks = n_threads;
19484
19218
  } break;
19485
- case GGML_OP_FLASH_FF:
19486
- {
19487
- n_tasks = n_threads;
19488
- } break;
19489
19219
  case GGML_OP_FLASH_ATTN_BACK:
19490
19220
  {
19491
19221
  n_tasks = n_threads;
@@ -19580,6 +19310,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
19580
19310
 
19581
19311
  * node_n = atomic_load(&state->shared->node_n);
19582
19312
  if (* node_n != last_node_n) break;
19313
+ #if defined(__SSE3__)
19314
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19315
+ _mm_pause();
19316
+ #endif
19583
19317
  }
19584
19318
  }
19585
19319
 
@@ -19594,6 +19328,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
19594
19328
 
19595
19329
  * task_phase = atomic_load(&state->shared->node_task);
19596
19330
  if (* task_phase != last_task_phase) break;
19331
+ #if defined(__SSE3__)
19332
+ // Tell the processor we're spinning. It's a processor hint for spinlocks.
19333
+ _mm_pause();
19334
+ #endif
19597
19335
  }
19598
19336
  }
19599
19337
 
@@ -19633,7 +19371,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19633
19371
  struct ggml_tensor * node = cgraph->nodes[node_n];
19634
19372
  if (GGML_OP_HAS_FINALIZE[node->op]) {
19635
19373
  params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
19636
- ggml_compute_forward(&params, node);
19374
+ ggml_compute_forward(&params, node, state);
19637
19375
  }
19638
19376
  ggml_graph_compute_perf_stats_node(node, state->shared);
19639
19377
  }
@@ -19653,17 +19391,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19653
19391
  /* INIT */
19654
19392
  if (GGML_OP_HAS_INIT[node->op]) {
19655
19393
  params.type = GGML_TASK_TYPE_INIT;
19656
- ggml_compute_forward(&params, node);
19394
+ ggml_compute_forward(&params, node, state);
19657
19395
  }
19658
19396
 
19659
19397
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
19660
19398
  // they do something more efficient than spinning (?)
19661
19399
  params.type = GGML_TASK_TYPE_COMPUTE;
19662
- ggml_compute_forward(&params, node);
19400
+ ggml_compute_forward(&params, node, state);
19663
19401
 
19664
19402
  if (GGML_OP_HAS_FINALIZE[node->op]) {
19665
19403
  params.type = GGML_TASK_TYPE_FINALIZE;
19666
- ggml_compute_forward(&params, node);
19404
+ ggml_compute_forward(&params, node, state);
19667
19405
  }
19668
19406
 
19669
19407
  ggml_graph_compute_perf_stats_node(node, state->shared);
@@ -19702,7 +19440,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19702
19440
 
19703
19441
  if (state->ith < n_tasks) {
19704
19442
  if (GGML_OP_HAS_INIT[node->op]) {
19705
- ggml_compute_forward(&params, node);
19443
+ ggml_compute_forward(&params, node, state);
19706
19444
  }
19707
19445
  }
19708
19446
 
@@ -19723,7 +19461,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
19723
19461
 
19724
19462
  if (state->ith < n_tasks) {
19725
19463
  params.type = GGML_TASK_TYPE_COMPUTE;
19726
- ggml_compute_forward(&params, node);
19464
+ ggml_compute_forward(&params, node, state);
19727
19465
  }
19728
19466
 
19729
19467
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
@@ -19874,39 +19612,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19874
19612
  cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
19875
19613
  cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
19876
19614
  } break;
19877
- case GGML_OP_FLASH_ATTN:
19878
- {
19879
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
19880
-
19881
- if (node->src[1]->type == GGML_TYPE_F32) {
19882
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19883
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19884
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19885
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19886
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19887
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19888
- cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
19889
- cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
19890
- }
19891
- } break;
19892
19615
  case GGML_OP_FLASH_ATTN_EXT:
19893
19616
  {
19894
19617
  const int64_t ne00 = node->src[0]->ne[0]; // D
19895
19618
 
19896
- cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
19897
- } break;
19898
- case GGML_OP_FLASH_FF:
19899
- {
19900
- if (node->src[1]->type == GGML_TYPE_F32) {
19901
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19902
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19903
- } else if (node->src[1]->type == GGML_TYPE_F16) {
19904
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19905
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19906
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
19907
- cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
19908
- cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
19909
- }
19619
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
19910
19620
  } break;
19911
19621
  case GGML_OP_FLASH_ATTN_BACK:
19912
19622
  {
@@ -19974,6 +19684,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19974
19684
  /*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
19975
19685
  /*.abort_callback =*/ NULL,
19976
19686
  /*.abort_callback_data =*/ NULL,
19687
+ /*.current_chunk; =*/ 0,
19977
19688
  };
19978
19689
  struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
19979
19690
 
@@ -21747,11 +21458,7 @@ size_t ggml_quantize_chunk(
21747
21458
  case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21748
21459
  case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21749
21460
  case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21750
- #if QK_K == 64
21751
- case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21752
- #else
21753
21461
  case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
21754
- #endif
21755
21462
  case GGML_TYPE_F16:
21756
21463
  {
21757
21464
  size_t elemsize = sizeof(ggml_fp16_t);
@@ -23028,6 +22735,14 @@ int ggml_cpu_has_avx512_vnni(void) {
23028
22735
  #endif
23029
22736
  }
23030
22737
 
22738
+ int ggml_cpu_has_avx512_bf16(void) {
22739
+ #if defined(__AVX512BF16__)
22740
+ return 1;
22741
+ #else
22742
+ return 0;
22743
+ #endif
22744
+ }
22745
+
23031
22746
  int ggml_cpu_has_fma(void) {
23032
22747
  #if defined(__FMA__)
23033
22748
  return 1;
@@ -23044,6 +22759,16 @@ int ggml_cpu_has_neon(void) {
23044
22759
  #endif
23045
22760
  }
23046
22761
 
22762
+ int ggml_cpu_has_sve(void) {
22763
+ #if defined(__ARM_FEATURE_SVE)
22764
+ // TODO: Currently, SVE 256 bit is only supported.
22765
+ GGML_ASSERT(svcntb() == QK8_0);
22766
+ return 1;
22767
+ #else
22768
+ return 0;
22769
+ #endif
22770
+ }
22771
+
23047
22772
  int ggml_cpu_has_arm_fma(void) {
23048
22773
  #if defined(__ARM_FEATURE_FMA)
23049
22774
  return 1;