@fugood/llama.node 1.0.0-beta.4 → 1.0.0-beta.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/CMakeLists.txt +7 -4
  2. package/lib/binding.ts +1 -1
  3. package/package.json +14 -14
  4. package/scripts/llama.cpp.patch +27 -26
  5. package/src/LlamaCompletionWorker.cpp +21 -4
  6. package/src/LlamaCompletionWorker.h +2 -0
  7. package/src/LlamaContext.cpp +3 -12
  8. package/src/common.hpp +6 -5
  9. package/src/llama.cpp/CMakeLists.txt +15 -4
  10. package/src/llama.cpp/common/CMakeLists.txt +15 -24
  11. package/src/llama.cpp/common/arg.cpp +172 -110
  12. package/src/llama.cpp/common/chat-parser.cpp +385 -0
  13. package/src/llama.cpp/common/chat-parser.h +120 -0
  14. package/src/llama.cpp/common/chat.cpp +726 -596
  15. package/src/llama.cpp/common/chat.h +74 -8
  16. package/src/llama.cpp/common/common.cpp +56 -38
  17. package/src/llama.cpp/common/common.h +9 -3
  18. package/src/llama.cpp/common/json-partial.cpp +256 -0
  19. package/src/llama.cpp/common/json-partial.h +38 -0
  20. package/src/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  21. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -4
  22. package/src/llama.cpp/common/sampling.cpp +7 -8
  23. package/src/llama.cpp/common/speculative.cpp +6 -4
  24. package/src/llama.cpp/ggml/CMakeLists.txt +48 -3
  25. package/src/llama.cpp/ggml/include/ggml.h +22 -3
  26. package/src/llama.cpp/ggml/src/CMakeLists.txt +81 -22
  27. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +131 -49
  28. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  29. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  30. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  31. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  32. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2162 -0
  33. package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  34. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  36. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  37. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  38. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  39. package/src/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  40. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  41. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  42. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  43. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +12 -13
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +64 -88
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  47. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  48. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  49. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  50. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +282 -100
  52. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  53. package/src/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  54. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  55. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1570 -0
  56. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  57. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +119 -5
  58. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  60. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +204 -49
  61. package/src/llama.cpp/include/llama.h +145 -40
  62. package/src/llama.cpp/src/CMakeLists.txt +5 -1
  63. package/src/llama.cpp/src/llama-arch.cpp +99 -3
  64. package/src/llama.cpp/src/llama-arch.h +10 -1
  65. package/src/llama.cpp/src/llama-batch.cpp +728 -272
  66. package/src/llama.cpp/src/llama-batch.h +112 -54
  67. package/src/llama.cpp/src/llama-chat.cpp +19 -2
  68. package/src/llama.cpp/src/llama-chat.h +1 -0
  69. package/src/llama.cpp/src/llama-context.cpp +525 -339
  70. package/src/llama.cpp/src/llama-context.h +38 -17
  71. package/src/llama.cpp/src/llama-cparams.cpp +4 -0
  72. package/src/llama.cpp/src/llama-cparams.h +2 -0
  73. package/src/llama.cpp/src/llama-grammar.cpp +12 -2
  74. package/src/llama.cpp/src/llama-graph.cpp +413 -353
  75. package/src/llama.cpp/src/llama-graph.h +112 -56
  76. package/src/llama.cpp/src/llama-hparams.cpp +10 -2
  77. package/src/llama.cpp/src/llama-hparams.h +13 -2
  78. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +279 -0
  79. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +128 -0
  80. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +1815 -0
  81. package/src/llama.cpp/src/llama-kv-cache-unified.h +303 -0
  82. package/src/llama.cpp/src/llama-kv-cells.h +415 -0
  83. package/src/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  84. package/src/llama.cpp/src/llama-memory-hybrid.h +138 -0
  85. package/src/llama.cpp/src/llama-memory-recurrent.cpp +1112 -0
  86. package/src/llama.cpp/src/llama-memory-recurrent.h +183 -0
  87. package/src/llama.cpp/src/llama-memory.cpp +41 -0
  88. package/src/llama.cpp/src/llama-memory.h +86 -5
  89. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  90. package/src/llama.cpp/src/llama-model-loader.cpp +42 -17
  91. package/src/llama.cpp/src/llama-model-saver.cpp +1 -0
  92. package/src/llama.cpp/src/llama-model.cpp +1137 -528
  93. package/src/llama.cpp/src/llama-model.h +4 -0
  94. package/src/llama.cpp/src/llama-quant.cpp +2 -1
  95. package/src/llama.cpp/src/llama-sampling.cpp +2 -2
  96. package/src/llama.cpp/src/llama-vocab.cpp +69 -32
  97. package/src/llama.cpp/src/llama-vocab.h +1 -0
  98. package/src/llama.cpp/src/llama.cpp +11 -7
  99. package/src/llama.cpp/src/unicode.cpp +5 -0
  100. package/src/tts_utils.h +1 -1
  101. package/src/llama.cpp/common/json.hpp +0 -24766
  102. package/src/llama.cpp/common/minja/chat-template.hpp +0 -541
  103. package/src/llama.cpp/common/minja/minja.hpp +0 -2974
  104. package/src/llama.cpp/common/stb_image.h +0 -7988
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  106. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13326
  107. package/src/llama.cpp/src/llama-kv-cache.cpp +0 -2827
  108. package/src/llama.cpp/src/llama-kv-cache.h +0 -515
  109. /package/src/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  110. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  111. /package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -0,0 +1,98 @@
1
+ #pragma once
2
+
3
+ #define GGML_COMMON_DECL_CPP
4
+ #include "ggml-common.h"
5
+
6
+ #include "traits.h"
7
+ #include "ggml.h"
8
+
9
+ // GGML internal header
10
+
11
+ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
12
+
13
+ template <int K> constexpr int QK_0() {
14
+ if constexpr (K == 4) {
15
+ return QK4_0;
16
+ }
17
+ if constexpr (K == 8) {
18
+ return QK8_0;
19
+ }
20
+ return -1;
21
+ }
22
+
23
+ template <int K, int N> struct block {
24
+ ggml_half d[N]; // deltas for N qK_0 blocks
25
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
26
+ };
27
+
28
+ // control size
29
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
30
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
31
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
32
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
33
+
34
+ using block_q4_0x4 = block<4, 4>;
35
+ using block_q4_0x8 = block<4, 8>;
36
+ using block_q8_0x4 = block<8, 4>;
37
+ using block_q8_0x8 = block<8, 8>;
38
+
39
+ struct block_q4_Kx8 {
40
+ ggml_half d[8]; // super-block scale for quantized scales
41
+ ggml_half dmin[8]; // super-block scale for quantized mins
42
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
43
+ uint8_t qs[1024]; // 4--bit quants
44
+ };
45
+
46
+ static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
47
+
48
+ struct block_q8_Kx4 {
49
+ float d[4]; // delta
50
+ int8_t qs[QK_K * 4]; // quants
51
+ int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
52
+ };
53
+
54
+ static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
55
+
56
+ struct block_iq4_nlx4 {
57
+ ggml_half d[4]; // deltas for 4 iq4_nl blocks
58
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
59
+ };
60
+
61
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
62
+
63
+ #if defined(__cplusplus)
64
+ extern "C" {
65
+ #endif
66
+
67
+ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
68
+ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
69
+ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
70
+ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
71
+ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
72
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
73
+ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
74
+ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
75
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
76
+ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
77
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
78
+ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
79
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
80
+
81
+ // Native implementations
82
+ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
83
+ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
84
+ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
85
+ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
86
+ void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
87
+ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
88
+ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
89
+ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
90
+ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
91
+ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
92
+ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
93
+ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
94
+ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
95
+
96
+ #if defined(__cplusplus)
97
+ } // extern "C"
98
+ #endif
@@ -17,7 +17,123 @@
17
17
  // number of elements to fit in a single register
18
18
  //
19
19
 
20
- #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
20
+ #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA)
21
+
22
+ #define GGML_SIMD
23
+
24
+ // F32 SVE
25
+ #define GGML_F32_EPR 8
26
+ #define DEFAULT_PG svptrue_b32()
27
+
28
+ #define GGML_F32xt svfloat32_t
29
+ #define GGML_F32xt_ZERO svdup_n_f32(0.0f)
30
+ #define GGML_F32xt_SET1(x) svdup_n_f32(x)
31
+ #define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
32
+ #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
33
+ #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
34
+ #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
35
+ #define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
36
+ #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
37
+ #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
38
+ #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
39
+ #define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
40
+ #define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
41
+ #define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
42
+ #define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
43
+ #define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
44
+ { \
45
+ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
46
+ sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4); \
47
+ sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6); \
48
+ sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8); \
49
+ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3); \
50
+ sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7); \
51
+ sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
52
+ (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \
53
+ }
54
+ #define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
55
+
56
+ #define GGML_F32_VEC GGML_F32xt
57
+ #define GGML_F32_VEC_ZERO GGML_F32xt_ZERO
58
+ #define GGML_F32_VEC_SET1 GGML_F32xt_SET1
59
+ #define GGML_F32_VEC_LOAD GGML_F32xt_LOAD
60
+ #define GGML_F32_VEC_STORE GGML_F32xt_STORE
61
+ #define GGML_F32_VEC_FMA GGML_F32xt_FMA
62
+ #define GGML_F32_VEC_ADD GGML_F32xt_ADD
63
+ #define GGML_F32_VEC_MUL GGML_F32xt_MUL
64
+ #define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
65
+
66
+ // F16 NEON
67
+
68
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
69
+ #define GGML_F16_STEP 32
70
+ #define GGML_F16_EPR 8
71
+
72
+ #define GGML_F16x8 float16x8_t
73
+ #define GGML_F16x8_ZERO vdupq_n_f16(0.0f)
74
+ #define GGML_F16x8_SET1(x) vdupq_n_f16(x)
75
+ #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x))
76
+ #define GGML_F16x8_STORE vst1q_f16
77
+ #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
78
+ #define GGML_F16x8_ADD vaddq_f16
79
+ #define GGML_F16x8_MUL vmulq_f16
80
+ #define GGML_F16x8_REDUCE(res, x) \
81
+ do { \
82
+ int offset = GGML_F16_ARR >> 1; \
83
+ for (int i = 0; i < offset; ++i) { \
84
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
85
+ } \
86
+ offset >>= 1; \
87
+ for (int i = 0; i < offset; ++i) { \
88
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
89
+ } \
90
+ offset >>= 1; \
91
+ for (int i = 0; i < offset; ++i) { \
92
+ (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \
93
+ } \
94
+ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
95
+ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
96
+ (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \
97
+ } while (0)
98
+
99
+ #define GGML_F16_VEC GGML_F16x8
100
+ #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO
101
+ #define GGML_F16_VEC_SET1 GGML_F16x8_SET1
102
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p)
103
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i])
104
+ #define GGML_F16_VEC_FMA GGML_F16x8_FMA
105
+ #define GGML_F16_VEC_ADD GGML_F16x8_ADD
106
+ #define GGML_F16_VEC_MUL GGML_F16x8_MUL
107
+ #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE
108
+ #else
109
+ // if FP16 vector arithmetic is not supported, we use FP32 instead
110
+ // and take advantage of the vcvt_ functions to convert to/from FP16
111
+
112
+ #define GGML_F16_STEP 16
113
+ #define GGML_F16_EPR 4
114
+
115
+ #define GGML_F32Cx4 float32x4_t
116
+ #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f)
117
+ #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x)
118
+ #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x)))
119
+ #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y))
120
+ #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
121
+ #define GGML_F32Cx4_ADD vaddq_f32
122
+ #define GGML_F32Cx4_MUL vmulq_f32
123
+ #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE
124
+
125
+ #define GGML_F16_VEC GGML_F32Cx4
126
+ #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO
127
+ #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1
128
+ #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p)
129
+ #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i])
130
+ #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA
131
+ #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD
132
+ #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
133
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
134
+ #endif
135
+
136
+ #elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
21
137
 
22
138
  #define GGML_SIMD
23
139
 
@@ -828,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
828
944
  for (int i = 0; i < offset; ++i) { \
829
945
  x[i] = vec_add(x[i], x[offset + i]); \
830
946
  } \
831
- res = vec_extract(x[0], 0) + \
832
- vec_extract(x[0], 1) + \
833
- vec_extract(x[0], 2) + \
834
- vec_extract(x[0], 3); \
947
+ float32x4_t tmp = x[0] + vec_reve(x[0]); \
948
+ res = tmp[0] + tmp[1]; \
835
949
  }
836
950
 
837
951
  #define GGML_F32_VEC GGML_F32x4
@@ -1,4 +1,4 @@
1
- #include "ggml-cpu-traits.h"
1
+ #include "traits.h"
2
2
 
3
3
  #include "ggml-backend-impl.h"
4
4
  #include "ggml-backend.h"
@@ -17,29 +17,98 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
17
17
 
18
18
  #if defined(GGML_SIMD)
19
19
  float sumf = 0.0f;
20
- const int np = (n & ~(GGML_F32_STEP - 1));
21
20
 
22
- GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
21
+ #if defined(__ARM_FEATURE_SVE)
22
+ const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
23
+ const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
24
+ const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
25
+
26
+ const int np = (n & ~(ggml_f32_step - 1));
27
+ svfloat32_t sum1 = svdup_n_f32(0.0f);
28
+ svfloat32_t sum2 = svdup_n_f32(0.0f);
29
+ svfloat32_t sum3 = svdup_n_f32(0.0f);
30
+ svfloat32_t sum4 = svdup_n_f32(0.0f);
31
+ svfloat32_t sum5 = svdup_n_f32(0.0f);
32
+ svfloat32_t sum6 = svdup_n_f32(0.0f);
33
+ svfloat32_t sum7 = svdup_n_f32(0.0f);
34
+ svfloat32_t sum8 = svdup_n_f32(0.0f);
35
+ svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8;
36
+ svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8;
37
+ for (int i = 0; i < np; i += ggml_f32_step) {
38
+ ax1 = GGML_F32_VEC_LOAD(x + i);
39
+ ay1 = GGML_F32_VEC_LOAD(y + i);
40
+ sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
41
+
42
+ ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
43
+ ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
44
+ sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
45
+
46
+ ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
47
+ ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
48
+ sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
49
+
50
+ ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
51
+ ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
52
+ sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
53
+
54
+ ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
55
+ ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
56
+ sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
57
+
58
+ ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
59
+ ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
60
+ sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
61
+
62
+ ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
63
+ ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
64
+ sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
65
+
66
+ ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
67
+ ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
68
+ sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
69
+ }
70
+ // leftovers
71
+ // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
72
+ const int np2 = (n & ~(ggml_f32_epr - 1));
73
+ for (int i = np; i < np2; i += ggml_f32_epr) {
74
+ ax1 = GGML_F32_VEC_LOAD(x + i);
75
+ ay1 = GGML_F32_VEC_LOAD(y + i);
76
+ sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
77
+ }
78
+ // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
79
+ if (np2 < n) {
80
+ svbool_t pg = svwhilelt_b32(np2, n);
81
+ ax1 = svld1_f32(pg, x + np2);
82
+ ay1 = svld1_f32(pg, y + np2);
83
+ sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
84
+ }
85
+ // reduce sum1,sum2 to sum1
86
+ GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
87
+ #else
88
+ const int np = (n & ~(GGML_F32_STEP - 1));
23
89
 
24
- GGML_F32_VEC ax[GGML_F32_ARR];
25
- GGML_F32_VEC ay[GGML_F32_ARR];
90
+ GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
26
91
 
27
- for (int i = 0; i < np; i += GGML_F32_STEP) {
28
- for (int j = 0; j < GGML_F32_ARR; j++) {
29
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
30
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
92
+ GGML_F32_VEC ax[GGML_F32_ARR];
93
+ GGML_F32_VEC ay[GGML_F32_ARR];
31
94
 
32
- sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
95
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
96
+ for (int j = 0; j < GGML_F32_ARR; j++) {
97
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
98
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
99
+
100
+ sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
101
+ }
33
102
  }
34
- }
35
103
 
36
- // reduce sum0..sum3 to sum0
37
- GGML_F32_VEC_REDUCE(sumf, sum);
104
+ // reduce sum0..sum3 to sum0
105
+ GGML_F32_VEC_REDUCE(sumf, sum);
38
106
 
39
- // leftovers
40
- for (int i = np; i < n; ++i) {
41
- sumf += x[i]*y[i];
42
- }
107
+ // leftovers
108
+ for (int i = np; i < n; ++i) {
109
+ sumf += x[i]*y[i];
110
+ }
111
+ #endif
43
112
  #else
44
113
  // scalar
45
114
  ggml_float sumf = 0.0;
@@ -5,6 +5,7 @@
5
5
  #include "ggml-impl.h"
6
6
  #include "simd-mappings.h"
7
7
  #include "ggml.h"
8
+ #include "ggml-cpu.h"
8
9
 
9
10
  #if defined(GGML_USE_ACCELERATE)
10
11
  #include <Accelerate/Accelerate.h>
@@ -148,27 +149,108 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
148
149
 
149
150
  inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) {
150
151
  #if defined(GGML_SIMD)
151
- const int np = (n & ~(GGML_F32_STEP - 1));
152
+ #if defined(__ARM_FEATURE_SVE)
152
153
 
153
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
154
+ const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
155
+ const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
156
+ const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers
157
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
154
158
 
155
- GGML_F32_VEC ax[GGML_F32_ARR];
156
- GGML_F32_VEC ay[GGML_F32_ARR];
159
+ const int np = (n & ~(ggml_f32_step - 1));
160
+ svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
161
+ svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
162
+ for (int i = 0; i < np; i += ggml_f32_step) {
157
163
 
158
- for (int i = 0; i < np; i += GGML_F32_STEP) {
159
- for (int j = 0; j < GGML_F32_ARR; j++) {
160
- ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
161
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
162
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
164
+ ax1 = GGML_F32_VEC_LOAD(x + i);
165
+ ay1 = GGML_F32_VEC_LOAD(y + i);
166
+ ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
163
167
 
164
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
168
+ GGML_F32_VEC_STORE(y + i, ay1);
169
+
170
+ ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
171
+ ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
172
+ ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
173
+
174
+ GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
175
+
176
+ ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
177
+ ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
178
+ ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
179
+
180
+ GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
181
+
182
+ ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
183
+ ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
184
+ ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
185
+
186
+ GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
187
+
188
+ ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
189
+ ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
190
+ ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
191
+
192
+ GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
193
+
194
+ ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
195
+ ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
196
+ ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
197
+
198
+ GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
199
+
200
+ ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
201
+ ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
202
+ ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
203
+
204
+ GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
205
+
206
+ ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
207
+ ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
208
+ ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
209
+
210
+ GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
165
211
  }
166
- }
212
+ // leftovers
213
+ // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
214
+ const int np2 = (n & ~(ggml_f32_epr - 1));
215
+ for (int i = np; i < np2; i += ggml_f32_epr) {
216
+ ax1 = GGML_F32_VEC_LOAD(x + i);
217
+ ay1 = GGML_F32_VEC_LOAD(y + i);
218
+ ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
219
+
220
+ GGML_F32_VEC_STORE(y + i, ay1);
221
+ }
222
+ // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
223
+ if (np2 < n) {
224
+ svbool_t pg =svwhilelt_b32(np2, n);
225
+ ax1 = svld1_f32(pg, x + np2);
226
+ ay1 = svld1_f32(pg, y + np2);
227
+ ay1 = svmad_f32_m(pg, ax1, vx, ay1);
228
+
229
+ svst1_f32(pg, y + np2, ay1);
230
+ }
231
+ #else
232
+ const int np = (n & ~(GGML_F32_STEP - 1));
167
233
 
168
- // leftovers
169
- for (int i = np; i < n; ++i) {
170
- y[i] += x[i]*v;
171
- }
234
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
235
+
236
+ GGML_F32_VEC ax[GGML_F32_ARR];
237
+ GGML_F32_VEC ay[GGML_F32_ARR];
238
+
239
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
240
+ for (int j = 0; j < GGML_F32_ARR; j++) {
241
+ ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
242
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
243
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
244
+
245
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
246
+ }
247
+ }
248
+
249
+ // leftovers
250
+ for (int i = np; i < n; ++i) {
251
+ y[i] += x[i]*v;
252
+ }
253
+ #endif
172
254
  #else
173
255
  // scalar
174
256
  for (int i = 0; i < n; ++i) {
@@ -220,36 +302,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
220
302
  }
221
303
 
222
304
  #if defined(GGML_SIMD)
223
- const int np = (n & ~(GGML_F32_STEP - 1));
305
+ #if defined(__ARM_FEATURE_SVE)
306
+ // scalar Route to scalar implementation //TODO: Write SVE code
307
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
308
+ for (int i = 0; i < n; ++i) {
309
+ y[i] += x[k][i]*v[k][0];
310
+ }
311
+ }
312
+ #else
313
+ const int np = (n & ~(GGML_F32_STEP - 1));
224
314
 
225
- GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
315
+ GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
226
316
 
227
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
228
- vx[k] = GGML_F32_VEC_SET1(v[k][0]);
229
- }
317
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
318
+ vx[k] = GGML_F32_VEC_SET1(v[k][0]);
319
+ }
230
320
 
231
- GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
232
- GGML_F32_VEC ay[GGML_F32_ARR];
321
+ GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
322
+ GGML_F32_VEC ay[GGML_F32_ARR];
233
323
 
234
- for (int i = 0; i < np; i += GGML_F32_STEP) {
235
- for (int j = 0; j < GGML_F32_ARR; j++) {
236
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
324
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
325
+ for (int j = 0; j < GGML_F32_ARR; j++) {
326
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
237
327
 
238
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
239
- ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
240
- ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
241
- }
328
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
329
+ ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
330
+ ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
331
+ }
242
332
 
243
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
333
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
334
+ }
244
335
  }
245
- }
246
336
 
247
- // leftovers
248
- for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
249
- for (int i = np; i < n; ++i) {
250
- y[i] += x[k][i]*v[k][0];
337
+ // leftovers
338
+ for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
339
+ for (int i = np; i < n; ++i) {
340
+ y[i] += x[k][i]*v[k][0];
341
+ }
251
342
  }
252
- }
343
+ #endif
253
344
  #else
254
345
  // scalar
255
346
  for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
@@ -265,25 +356,53 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
265
356
  #if defined(GGML_USE_ACCELERATE)
266
357
  vDSP_vsmul(y, 1, &v, y, 1, n);
267
358
  #elif defined(GGML_SIMD)
268
- const int np = (n & ~(GGML_F32_STEP - 1));
359
+ #if defined(__ARM_FEATURE_SVE)
360
+ const int sve_register_length = ggml_cpu_get_sve_cnt() * 8;
361
+ const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16
362
+ const int ggml_f32_step = 2 * ggml_f32_epr;
363
+
364
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
365
+ const int np = (n & ~(ggml_f32_step - 1));
366
+ svfloat32_t ay1;
367
+ svfloat32_t ay2;
368
+ for (int i = 0; i < np; i += ggml_f32_step) {
369
+ ay1 = GGML_F32_VEC_LOAD(y + i);
370
+ ay1 = GGML_F32_VEC_MUL(ay1, vx);
371
+ GGML_F32_VEC_STORE(y + i, ay1);
372
+
373
+ ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
374
+ ay2 = GGML_F32_VEC_MUL(ay2, vx);
375
+ GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
376
+ }
377
+ // leftovers
378
+ // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
379
+ if (np < n) {
380
+ svbool_t pg = svwhilelt_b32(np, n);
381
+ ay1 = svld1_f32(pg, y + np);
382
+ ay1 = svmul_f32_m(pg, ay1, vx);
383
+ svst1_f32(pg, y + np, ay1);
384
+ }
385
+ #else
386
+ const int np = (n & ~(GGML_F32_STEP - 1));
269
387
 
270
- GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
388
+ GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
271
389
 
272
- GGML_F32_VEC ay[GGML_F32_ARR];
390
+ GGML_F32_VEC ay[GGML_F32_ARR];
273
391
 
274
- for (int i = 0; i < np; i += GGML_F32_STEP) {
275
- for (int j = 0; j < GGML_F32_ARR; j++) {
276
- ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
277
- ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
392
+ for (int i = 0; i < np; i += GGML_F32_STEP) {
393
+ for (int j = 0; j < GGML_F32_ARR; j++) {
394
+ ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
395
+ ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
278
396
 
279
- GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
397
+ GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
398
+ }
280
399
  }
281
- }
282
400
 
283
- // leftovers
284
- for (int i = np; i < n; ++i) {
285
- y[i] *= v;
286
- }
401
+ // leftovers
402
+ for (int i = np; i < n; ++i) {
403
+ y[i] *= v;
404
+ }
405
+ #endif
287
406
  #else
288
407
  // scalar
289
408
  for (int i = 0; i < n; ++i) {
@@ -528,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) {
528
647
  #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
529
648
  #endif
530
649
 
650
+ /* Below function was borrowed from the GitHub repository:
651
+ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */
652
+ #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
653
+ inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) {
654
+ // Constants
655
+ const svfloat32_t log2_e = svdup_n_f32(1.4426950409f);
656
+ const svfloat32_t ln2 = svdup_n_f32(0.6931473921f);
657
+ const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f);
658
+ const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1));
659
+ const svfloat32_t one = svdup_n_f32(1.0f);
660
+ const svfloat32_t inactive1 = svdup_n_f32(0.0f);
661
+ const svint32_t inactive2 = svdup_n_s32(0);
662
+
663
+ // Algorithm starts here
664
+ svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e)
665
+ svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float)
666
+ svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n
667
+
668
+ t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y)
669
+ t1 = svadd_f32_m(pg, t1, one); // b = a + 1
670
+
671
+ svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32)
672
+ svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v)
673
+ t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n)
674
+
675
+ // and_(t2.d, t1.d, not_mask17.d)
676
+ svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17));
677
+ t5 = svsub_f32_m(pg, t1, t5); // z
678
+ t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z
679
+ t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z)
680
+ t0 = svmul_f32_m(pg, t0, t4); // Final result
681
+
682
+ return t0;
683
+ }
684
+ #endif
685
+
531
686
  #if defined(__ARM_NEON) && defined(__aarch64__)
532
687
 
533
688
  // adapted from arm limited optimized routine