whisper.rn 0.5.0-rc.9 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +265 -141
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +30 -13
  6. package/cpp/ggml-backend.cpp +221 -38
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-common.h +17 -0
  9. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  10. package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
  12. package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
  13. package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  14. package/cpp/ggml-cpu/arch-fallback.h +32 -2
  15. package/cpp/ggml-cpu/common.h +14 -0
  16. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  17. package/cpp/ggml-cpu/ggml-cpu.c +70 -42
  18. package/cpp/ggml-cpu/ggml-cpu.cpp +35 -28
  19. package/cpp/ggml-cpu/ops.cpp +1587 -1177
  20. package/cpp/ggml-cpu/ops.h +5 -8
  21. package/cpp/ggml-cpu/quants.c +35 -0
  22. package/cpp/ggml-cpu/quants.h +8 -0
  23. package/cpp/ggml-cpu/repack.cpp +458 -47
  24. package/cpp/ggml-cpu/repack.h +22 -0
  25. package/cpp/ggml-cpu/simd-mappings.h +89 -60
  26. package/cpp/ggml-cpu/traits.cpp +2 -2
  27. package/cpp/ggml-cpu/traits.h +1 -1
  28. package/cpp/ggml-cpu/vec.cpp +170 -26
  29. package/cpp/ggml-cpu/vec.h +506 -63
  30. package/cpp/ggml-cpu.h +1 -1
  31. package/cpp/ggml-impl.h +119 -9
  32. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  33. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  34. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  35. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  36. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  37. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  38. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  39. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  40. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  41. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  42. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  43. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  44. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  45. package/cpp/ggml-metal-impl.h +90 -51
  46. package/cpp/ggml-metal.h +1 -6
  47. package/cpp/ggml-opt.cpp +97 -41
  48. package/cpp/ggml-opt.h +25 -6
  49. package/cpp/ggml-quants.c +111 -16
  50. package/cpp/ggml-quants.h +6 -0
  51. package/cpp/ggml.c +486 -98
  52. package/cpp/ggml.h +221 -16
  53. package/cpp/gguf.cpp +8 -1
  54. package/cpp/jsi/RNWhisperJSI.cpp +25 -6
  55. package/cpp/jsi/ThreadPool.h +3 -3
  56. package/cpp/whisper.cpp +100 -76
  57. package/cpp/whisper.h +1 -0
  58. package/ios/CMakeLists.txt +6 -1
  59. package/ios/RNWhisper.mm +6 -6
  60. package/ios/RNWhisperContext.mm +2 -0
  61. package/ios/RNWhisperVadContext.mm +16 -13
  62. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  63. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  64. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  66. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  67. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  68. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  69. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  70. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  71. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  72. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  73. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  74. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  77. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  78. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  79. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  80. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  81. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  82. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  84. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  86. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  89. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  92. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  93. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
  94. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  95. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  96. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  97. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  98. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  99. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  100. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +221 -16
  101. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  102. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  103. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  104. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  105. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  106. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  107. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
  108. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  109. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +119 -9
  110. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +90 -51
  111. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  112. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
  113. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
  114. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +221 -16
  115. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  116. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  117. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  118. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  119. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  120. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
  121. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  122. package/lib/commonjs/version.json +1 -1
  123. package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
  124. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  125. package/lib/module/version.json +1 -1
  126. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  127. package/lib/typescript/realtime-transcription/types.d.ts +6 -0
  128. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  129. package/package.json +1 -1
  130. package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
  131. package/src/realtime-transcription/types.ts +6 -0
  132. package/src/version.json +1 -1
  133. package/whisper-rn.podspec +8 -9
  134. package/cpp/ggml-metal.m +0 -6284
  135. package/cpp/ggml-whisper-sim.metallib +0 -0
  136. package/cpp/ggml-whisper.metallib +0 -0
@@ -55,7 +55,22 @@ inline static void wsp_ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t
55
55
 
56
56
  inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
57
57
  inline static void wsp_ggml_vec_set_bf16(const int n, wsp_ggml_bf16_t * x, const wsp_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
58
- inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
58
+
59
+ inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
60
+ int i = 0;
61
+ #if defined(__AVX2__)
62
+ for (; i + 7 < n; i += 8) {
63
+ __m256 vx = _mm256_loadu_ps(x + i);
64
+ __m256 vy = _mm256_loadu_ps(y + i);
65
+ __m256 vz = _mm256_add_ps(vx, vy);
66
+ _mm256_storeu_ps(z + i, vz);
67
+ }
68
+ #endif
69
+ for (; i < n; ++i) {
70
+ z[i] = x[i] + y[i];
71
+ }
72
+ }
73
+
59
74
  inline static void wsp_ggml_vec_add_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
60
75
  for (int i = 0; i < n; ++i) {
61
76
  z[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(x[i]) + WSP_GGML_CPU_FP16_TO_FP32(y[i]));
@@ -104,36 +119,149 @@ inline static void wsp_ggml_vec_dot_f16_unroll(const int n, const int xs, float
104
119
  }
105
120
 
106
121
  #if defined(WSP_GGML_SIMD)
107
- const int np = (n & ~(WSP_GGML_F16_STEP - 1));
122
+ #if defined(__ARM_FEATURE_SVE)
123
+
124
+ const int sve_register_length = svcntb() * 8;
125
+ const int wsp_ggml_f16_epr = sve_register_length / 16; // running when 16
126
+ const int wsp_ggml_f16_step = 8 * wsp_ggml_f16_epr; // choose 8 SVE registers
127
+
128
+ const int np = (n & ~(wsp_ggml_f16_step - 1));
129
+
130
+ svfloat16_t sum_00 = svdup_n_f16(0.0f);
131
+ svfloat16_t sum_01 = svdup_n_f16(0.0f);
132
+ svfloat16_t sum_02 = svdup_n_f16(0.0f);
133
+ svfloat16_t sum_03 = svdup_n_f16(0.0f);
134
+
135
+ svfloat16_t sum_10 = svdup_n_f16(0.0f);
136
+ svfloat16_t sum_11 = svdup_n_f16(0.0f);
137
+ svfloat16_t sum_12 = svdup_n_f16(0.0f);
138
+ svfloat16_t sum_13 = svdup_n_f16(0.0f);
139
+
140
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
141
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
142
+
143
+ for (int i = 0; i < np; i += wsp_ggml_f16_step) {
144
+ ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0 * wsp_ggml_f16_epr, 0); // 8 elements
145
+
146
+ ax1 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 0*wsp_ggml_f16_epr, 0); // 8 elemnst
147
+ sum_00 = WSP_GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
148
+ ax1 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 0*wsp_ggml_f16_epr, 0); // 8 elements
149
+ sum_10 = WSP_GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
150
+
151
+ ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1 * wsp_ggml_f16_epr, 1); // next 8 elements
152
+
153
+ ax2 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 1*wsp_ggml_f16_epr, 1); // next 8 ekements
154
+ sum_01 = WSP_GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
155
+ ax2 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 1*wsp_ggml_f16_epr, 1);
156
+ sum_11 = WSP_GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
157
+
158
+ ay3 = WSP_GGML_F16x_VEC_LOAD(y + i + 2 * wsp_ggml_f16_epr, 2);
159
+
160
+ ax3 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 2*wsp_ggml_f16_epr, 2);
161
+ sum_02 = WSP_GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
162
+ ax1 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 2*wsp_ggml_f16_epr, 2);
163
+ sum_12 = WSP_GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
164
+
165
+ ay4 = WSP_GGML_F16x_VEC_LOAD(y + i + 3 * wsp_ggml_f16_epr, 3);
166
+
167
+ ax4 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 3*wsp_ggml_f16_epr, 3);
168
+ sum_03 = WSP_GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
169
+ ax4 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 3*wsp_ggml_f16_epr, 3);
170
+ sum_13 = WSP_GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
171
+
172
+ ay5 = WSP_GGML_F16x_VEC_LOAD(y + i + 4 * wsp_ggml_f16_epr, 4);
173
+
174
+ ax5 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 4*wsp_ggml_f16_epr, 4);
175
+
176
+ sum_00 = WSP_GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
177
+ ax5 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 4*wsp_ggml_f16_epr, 4);
178
+ sum_10 = WSP_GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
179
+
180
+ ay6 = WSP_GGML_F16x_VEC_LOAD(y + i + 5 * wsp_ggml_f16_epr, 5);
108
181
 
109
- WSP_GGML_F16_VEC sum[WSP_GGML_VEC_DOT_UNROLL][WSP_GGML_F16_ARR] = { { WSP_GGML_F16_VEC_ZERO } };
182
+ ax6 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 5*wsp_ggml_f16_epr, 5);
110
183
 
111
- WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
112
- WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
184
+ sum_01 = WSP_GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
185
+ ax6 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 5*wsp_ggml_f16_epr, 5);
186
+ sum_11 = WSP_GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
113
187
 
114
- for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
115
- for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
116
- ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
188
+ ay7 = WSP_GGML_F16x_VEC_LOAD(y + i + 6 * wsp_ggml_f16_epr, 6);
117
189
 
118
- for (int k = 0; k < WSP_GGML_VEC_DOT_UNROLL; ++k) {
119
- ax[j] = WSP_GGML_F16_VEC_LOAD(x[k] + i + j*WSP_GGML_F16_EPR, j);
190
+ ax7 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 6*wsp_ggml_f16_epr, 6);
120
191
 
121
- sum[k][j] = WSP_GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
192
+ sum_02 = WSP_GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
193
+ ax7 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 6*wsp_ggml_f16_epr, 6);
194
+ sum_12 = WSP_GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
195
+
196
+ ay8 = WSP_GGML_F16x_VEC_LOAD(y + i + 7 * wsp_ggml_f16_epr, 7);
197
+
198
+ ax8 = WSP_GGML_F16x_VEC_LOAD(x[0] + i + 7*wsp_ggml_f16_epr, 7);
199
+
200
+ sum_03 = WSP_GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
201
+ ax8 = WSP_GGML_F16x_VEC_LOAD(x[1] + i + 7*wsp_ggml_f16_epr, 7);
202
+ sum_13 = WSP_GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
203
+ }
204
+
205
+ const int np2 = (n & ~(wsp_ggml_f16_epr - 1));
206
+ for (int k = np; k < np2; k += wsp_ggml_f16_epr) {
207
+ svfloat16_t ry = WSP_GGML_F16x_VEC_LOAD(y + k, 0);
208
+
209
+ svfloat16_t rx = WSP_GGML_F16x_VEC_LOAD(x[0] + k, 0);
210
+ sum_00 = WSP_GGML_F16x_VEC_FMA(sum_00, rx, ry);
211
+ rx = WSP_GGML_F16x_VEC_LOAD(x[1] + k, 0);
212
+ sum_10 = WSP_GGML_F16x_VEC_FMA(sum_10, rx, ry);
213
+ }
214
+
215
+ if (np2 < n) {
216
+ svbool_t pg = svwhilelt_b16(np2, n);
217
+ svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
218
+ svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
219
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
220
+
221
+ sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
222
+ sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
223
+ }
224
+ WSP_GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
225
+ WSP_GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
226
+ #elif defined(__riscv_v_intrinsic)
227
+ // todo: RVV impl
228
+ for (int i = 0; i < n; ++i) {
229
+ for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) {
230
+ sumf[j] += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[j][i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
231
+ }
232
+ }
233
+ #else
234
+ const int np = (n & ~(WSP_GGML_F16_STEP - 1));
235
+
236
+ WSP_GGML_F16_VEC sum[WSP_GGML_VEC_DOT_UNROLL][WSP_GGML_F16_ARR] = { { WSP_GGML_F16_VEC_ZERO } };
237
+
238
+ WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
239
+ WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
240
+
241
+ for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
242
+ for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
243
+ ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
244
+
245
+ for (int k = 0; k < WSP_GGML_VEC_DOT_UNROLL; ++k) {
246
+ ax[j] = WSP_GGML_F16_VEC_LOAD(x[k] + i + j*WSP_GGML_F16_EPR, j);
247
+
248
+ sum[k][j] = WSP_GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
249
+ }
122
250
  }
123
251
  }
124
- }
125
252
 
126
- // reduce sum0..sum3 to sum0
127
- for (int k = 0; k < WSP_GGML_VEC_DOT_UNROLL; ++k) {
128
- WSP_GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
129
- }
253
+ // reduce sum0..sum3 to sum0
254
+ for (int k = 0; k < WSP_GGML_VEC_DOT_UNROLL; ++k) {
255
+ WSP_GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
256
+ }
130
257
 
131
- // leftovers
132
- for (int i = np; i < n; ++i) {
133
- for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) {
134
- sumf[j] += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[j][i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
258
+ // leftovers
259
+ for (int i = np; i < n; ++i) {
260
+ for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) {
261
+ sumf[j] += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[j][i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
262
+ }
135
263
  }
136
- }
264
+ #endif
137
265
  #else
138
266
  for (int i = 0; i < n; ++i) {
139
267
  for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) {
@@ -163,49 +291,49 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
163
291
 
164
292
  ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
165
293
  ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
166
- ay1 = WSP_GGML_F32_VEC_FMA(ax1, vx, ay1);
294
+ ay1 = WSP_GGML_F32_VEC_FMA(ay1, ax1, vx);
167
295
 
168
296
  WSP_GGML_F32_VEC_STORE(y + i, ay1);
169
297
 
170
298
  ax2 = WSP_GGML_F32_VEC_LOAD(x + i + 1*wsp_ggml_f32_epr);
171
299
  ay2 = WSP_GGML_F32_VEC_LOAD(y + i + 1*wsp_ggml_f32_epr);
172
- ay2 = WSP_GGML_F32_VEC_FMA(ax2, vx, ay2);
300
+ ay2 = WSP_GGML_F32_VEC_FMA(ay2, ax2, vx);
173
301
 
174
302
  WSP_GGML_F32_VEC_STORE(y + i + 1*wsp_ggml_f32_epr, ay2);
175
303
 
176
304
  ax3 = WSP_GGML_F32_VEC_LOAD(x + i + 2*wsp_ggml_f32_epr);
177
305
  ay3 = WSP_GGML_F32_VEC_LOAD(y + i + 2*wsp_ggml_f32_epr);
178
- ay3 = WSP_GGML_F32_VEC_FMA(ax3, vx, ay3);
306
+ ay3 = WSP_GGML_F32_VEC_FMA(ay3, ax3, vx);
179
307
 
180
308
  WSP_GGML_F32_VEC_STORE(y + i + 2*wsp_ggml_f32_epr, ay3);
181
309
 
182
310
  ax4 = WSP_GGML_F32_VEC_LOAD(x + i + 3*wsp_ggml_f32_epr);
183
311
  ay4 = WSP_GGML_F32_VEC_LOAD(y + i + 3*wsp_ggml_f32_epr);
184
- ay4 = WSP_GGML_F32_VEC_FMA(ax4, vx, ay4);
312
+ ay4 = WSP_GGML_F32_VEC_FMA(ay4, ax4, vx);
185
313
 
186
314
  WSP_GGML_F32_VEC_STORE(y + i + 3*wsp_ggml_f32_epr, ay4);
187
315
 
188
316
  ax5 = WSP_GGML_F32_VEC_LOAD(x + i + 4*wsp_ggml_f32_epr);
189
317
  ay5 = WSP_GGML_F32_VEC_LOAD(y + i + 4*wsp_ggml_f32_epr);
190
- ay5 = WSP_GGML_F32_VEC_FMA(ax5, vx, ay5);
318
+ ay5 = WSP_GGML_F32_VEC_FMA(ay5, ax5, vx);
191
319
 
192
320
  WSP_GGML_F32_VEC_STORE(y + i + 4*wsp_ggml_f32_epr, ay5);
193
321
 
194
322
  ax6 = WSP_GGML_F32_VEC_LOAD(x + i + 5*wsp_ggml_f32_epr);
195
323
  ay6 = WSP_GGML_F32_VEC_LOAD(y + i + 5*wsp_ggml_f32_epr);
196
- ay6 = WSP_GGML_F32_VEC_FMA(ax6, vx, ay6);
324
+ ay6 = WSP_GGML_F32_VEC_FMA(ay6, ax6, vx);
197
325
 
198
326
  WSP_GGML_F32_VEC_STORE(y + i + 5*wsp_ggml_f32_epr, ay6);
199
327
 
200
328
  ax7 = WSP_GGML_F32_VEC_LOAD(x + i + 6*wsp_ggml_f32_epr);
201
329
  ay7 = WSP_GGML_F32_VEC_LOAD(y + i + 6*wsp_ggml_f32_epr);
202
- ay7 = WSP_GGML_F32_VEC_FMA(ax7, vx, ay7);
330
+ ay7 = WSP_GGML_F32_VEC_FMA(ay7, ax7, vx);
203
331
 
204
332
  WSP_GGML_F32_VEC_STORE(y + i + 6*wsp_ggml_f32_epr, ay7);
205
333
 
206
334
  ax8 = WSP_GGML_F32_VEC_LOAD(x + i + 7*wsp_ggml_f32_epr);
207
335
  ay8 = WSP_GGML_F32_VEC_LOAD(y + i + 7*wsp_ggml_f32_epr);
208
- ay8 = WSP_GGML_F32_VEC_FMA(ax8, vx, ay8);
336
+ ay8 = WSP_GGML_F32_VEC_FMA(ay8, ax8, vx);
209
337
 
210
338
  WSP_GGML_F32_VEC_STORE(y + i + 7*wsp_ggml_f32_epr, ay8);
211
339
  }
@@ -215,7 +343,7 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
215
343
  for (int i = np; i < np2; i += wsp_ggml_f32_epr) {
216
344
  ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
217
345
  ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
218
- ay1 = WSP_GGML_F32_VEC_FMA(ax1, vx, ay1);
346
+ ay1 = WSP_GGML_F32_VEC_FMA(ay1, ax1, vx);
219
347
 
220
348
  WSP_GGML_F32_VEC_STORE(y + i, ay1);
221
349
  }
@@ -228,6 +356,14 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
228
356
 
229
357
  svst1_f32(pg, y + np2, ay1);
230
358
  }
359
+ #elif defined(__riscv_v_intrinsic)
360
+ for (int i = 0, avl; i < n; i += avl) {
361
+ avl = __riscv_vsetvl_e32m8(n - i);
362
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
363
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
364
+ vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, v, ay, avl);
365
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
366
+ }
231
367
  #else
232
368
  const int np = (n & ~(WSP_GGML_F32_STEP - 1));
233
369
 
@@ -261,27 +397,112 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
261
397
 
262
398
  inline static void wsp_ggml_vec_mad_f16(const int n, wsp_ggml_fp16_t * WSP_GGML_RESTRICT y, const wsp_ggml_fp16_t * WSP_GGML_RESTRICT x, const float v) {
263
399
  #if defined(WSP_GGML_SIMD)
264
- const int np = (n & ~(WSP_GGML_F16_STEP - 1));
400
+ #if defined(__ARM_FEATURE_SVE)
401
+ const int sve_register_length = svcntb() * 8;
402
+ const int wsp_ggml_f16_epr = sve_register_length / 16;
403
+ const int wsp_ggml_f16_step = 8 * wsp_ggml_f16_epr;
404
+
405
+ WSP_GGML_F16x_VEC vx = WSP_GGML_F16x_VEC_SET1(v);
406
+
407
+ const int np= (n & ~(wsp_ggml_f16_step - 1));
408
+
409
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
410
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
411
+ for (int i = 0; i < np; i += wsp_ggml_f16_step) {
412
+ ax1 = WSP_GGML_F16x_VEC_LOAD(x + i + 0 * wsp_ggml_f16_epr, 0);
413
+ ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0 * wsp_ggml_f16_epr, 0);
414
+ ay1 = WSP_GGML_F16x_VEC_FMA(ay1, ax1, vx);
415
+
416
+ WSP_GGML_F16x_VEC_STORE(y + i + 0 * wsp_ggml_f16_epr, ay1, 0);
417
+
418
+ ax2 = WSP_GGML_F16x_VEC_LOAD(x + i + 1 * wsp_ggml_f16_epr, 1);
419
+ ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1 * wsp_ggml_f16_epr, 1);
420
+ ay2 = WSP_GGML_F16x_VEC_FMA(ay2, ax2, vx);
421
+
422
+ WSP_GGML_F16x_VEC_STORE(y + i + 1 * wsp_ggml_f16_epr, ay2, 1);
265
423
 
266
- WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
424
+ ax3 = WSP_GGML_F16x_VEC_LOAD(x + i + 2 * wsp_ggml_f16_epr, 2);
425
+ ay3 = WSP_GGML_F16x_VEC_LOAD(y + i + 2 * wsp_ggml_f16_epr, 2);
426
+ ay3 = WSP_GGML_F16x_VEC_FMA(ay3, ax3, vx);
267
427
 
268
- WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
269
- WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
428
+ WSP_GGML_F16x_VEC_STORE(y + i + 2 * wsp_ggml_f16_epr, ay3, 2);
270
429
 
271
- for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
272
- for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
273
- ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j);
274
- ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
275
- ay[j] = WSP_GGML_F16_VEC_FMA(ay[j], ax[j], vx);
430
+ ax4 = WSP_GGML_F16x_VEC_LOAD(x + i + 3 * wsp_ggml_f16_epr, 3);
431
+ ay4 = WSP_GGML_F16x_VEC_LOAD(y + i + 3 * wsp_ggml_f16_epr, 3);
432
+ ay4 = WSP_GGML_F16x_VEC_FMA(ay4, ax4, vx);
276
433
 
277
- WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
434
+ WSP_GGML_F16x_VEC_STORE(y + i + 3 * wsp_ggml_f16_epr, ay4, 3);
435
+
436
+ ax5 = WSP_GGML_F16x_VEC_LOAD(x + i + 4 * wsp_ggml_f16_epr, 4);
437
+ ay5 = WSP_GGML_F16x_VEC_LOAD(y + i + 4 * wsp_ggml_f16_epr, 4);
438
+ ay5 = WSP_GGML_F16x_VEC_FMA(ay5, ax5, vx);
439
+
440
+ WSP_GGML_F16x_VEC_STORE(y + i + 4 * wsp_ggml_f16_epr, ay5, 4);
441
+
442
+ ax6 = WSP_GGML_F16x_VEC_LOAD(x + i + 5 * wsp_ggml_f16_epr, 5);
443
+ ay6 = WSP_GGML_F16x_VEC_LOAD(y + i + 5 * wsp_ggml_f16_epr, 5);
444
+ ay6 = WSP_GGML_F16x_VEC_FMA(ay6, ax6, vx);
445
+
446
+ WSP_GGML_F16x_VEC_STORE(y + i + 5 * wsp_ggml_f16_epr, ay6, 5);
447
+
448
+ ax7 = WSP_GGML_F16x_VEC_LOAD(x + i + 6 * wsp_ggml_f16_epr, 6);
449
+ ay7 = WSP_GGML_F16x_VEC_LOAD(y + i + 6 * wsp_ggml_f16_epr, 6);
450
+ ay7 = WSP_GGML_F16x_VEC_FMA(ay7, ax7, vx);
451
+
452
+ WSP_GGML_F16x_VEC_STORE(y + i + 6 * wsp_ggml_f16_epr, ay7, 6);
453
+
454
+ ax8 = WSP_GGML_F16x_VEC_LOAD(x + i + 7 * wsp_ggml_f16_epr, 7);
455
+ ay8 = WSP_GGML_F16x_VEC_LOAD(y + i + 7 * wsp_ggml_f16_epr, 7);
456
+ ay8 = WSP_GGML_F16x_VEC_FMA(ay8, ax8, vx);
457
+
458
+ WSP_GGML_F16x_VEC_STORE(y + i + 7 * wsp_ggml_f16_epr, ay8, 7);
278
459
  }
279
- }
460
+ const int np2 = (n & ~(wsp_ggml_f16_epr - 1));
461
+ for (int k = np; k < np2; k += wsp_ggml_f16_epr) {
462
+ svfloat16_t rx = WSP_GGML_F16x_VEC_LOAD(x + k, 0);
463
+ svfloat16_t ry = WSP_GGML_F16x_VEC_LOAD(y + k, 0);
464
+ ry = WSP_GGML_F16x_VEC_FMA(ry, rx, vx);
280
465
 
281
- // leftovers
282
- for (int i = np; i < n; ++i) {
283
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i]) + WSP_GGML_CPU_FP16_TO_FP32(x[i])*v);
284
- }
466
+ WSP_GGML_F16x_VEC_STORE(y + k, ry, 0);
467
+ }
468
+
469
+ if (np2 < n) {
470
+ svbool_t pg = svwhilelt_b16(np2, n);
471
+ svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
472
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
473
+ hy = svmad_f16_x(pg, hx, vx, hy);
474
+ svst1_f16(pg, (__fp16 *)(y + np2), hy);
475
+ }
476
+
477
+ #elif defined(__riscv_v_intrinsic)
478
+ // todo: RVV impl
479
+ // scalar
480
+ for (int i = 0; i < n; ++i) {
481
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i]) + WSP_GGML_CPU_FP16_TO_FP32(x[i])*v);
482
+ }
483
+ #else
484
+ const int np = (n & ~(WSP_GGML_F16_STEP - 1));
485
+
486
+ WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
487
+
488
+ WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
489
+ WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
490
+
491
+ for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
492
+ for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
493
+ ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j);
494
+ ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
495
+ ay[j] = WSP_GGML_F16_VEC_FMA(ay[j], ax[j], vx);
496
+
497
+ WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
498
+ }
499
+ }
500
+
501
+ // leftovers
502
+ for (int i = np; i < n; ++i) {
503
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i]) + WSP_GGML_CPU_FP16_TO_FP32(x[i])*v);
504
+ }
505
+ #endif
285
506
  #else
286
507
  // scalar
287
508
  for (int i = 0; i < n; ++i) {
@@ -309,6 +530,16 @@ inline static void wsp_ggml_vec_mad_f32_unroll(const int n, const int xs, const
309
530
  y[i] += x[k][i]*v[k][0];
310
531
  }
311
532
  }
533
+ #elif defined(__riscv_v_intrinsic)
534
+ for (int i = 0, avl; i < n; i += avl) {
535
+ avl = __riscv_vsetvl_e32m8(n - i);
536
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
537
+ for (int k = 0; k < WSP_GGML_VEC_MAD_UNROLL; k++) {
538
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[k][i], avl);
539
+ ay = __riscv_vfmadd_vf_f32m8(ax, v[k][0], ay, avl);
540
+ }
541
+ __riscv_vse32_v_f32m8(&y[i], ay, avl);
542
+ }
312
543
  #else
313
544
  const int np = (n & ~(WSP_GGML_F32_STEP - 1));
314
545
 
@@ -351,6 +582,53 @@ inline static void wsp_ggml_vec_mad_f32_unroll(const int n, const int xs, const
351
582
  #endif
352
583
  }
353
584
 
585
+ inline static void wsp_ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
586
+ #if defined(WSP_GGML_USE_ACCELERATE)
587
+ vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
588
+ #elif defined(WSP_GGML_SIMD)
589
+ #if defined(__ARM_FEATURE_SVE)
590
+ // scalar ; TODO: Write SVE code
591
+ for (int i = 0; i < n; ++i) {
592
+ y[i] = x[i]*s + b;
593
+ }
594
+ #elif defined(__riscv_v_intrinsic)
595
+ for (int i = 0, avl; i < n; i += avl) {
596
+ avl = __riscv_vsetvl_e32m8(n - i);
597
+ vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
598
+ vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl);
599
+ vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl);
600
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
601
+ }
602
+ #else
603
+ const int np = (n & ~(WSP_GGML_F32_STEP - 1));
604
+
605
+ WSP_GGML_F32_VEC vs = WSP_GGML_F32_VEC_SET1(s);
606
+ WSP_GGML_F32_VEC vb = WSP_GGML_F32_VEC_SET1(b);
607
+
608
+ WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR];
609
+
610
+ for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
611
+ for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
612
+ ay[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR);
613
+ ay[j] = WSP_GGML_F32_VEC_FMA(vb, ay[j], vs);
614
+
615
+ WSP_GGML_F32_VEC_STORE(y + i + j*WSP_GGML_F32_EPR, ay[j]);
616
+ }
617
+ }
618
+
619
+ // leftovers
620
+ for (int i = np; i < n; ++i) {
621
+ y[i] = x[i]*s + b;
622
+ }
623
+ #endif
624
+ #else
625
+ // scalar
626
+ for (int i = 0; i < n; ++i) {
627
+ y[i] = x[i]*s + b;
628
+ }
629
+ #endif
630
+ }
631
+
354
632
  //inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
355
633
  inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) {
356
634
  #if defined(WSP_GGML_USE_ACCELERATE)
@@ -382,6 +660,13 @@ inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float
382
660
  ay1 = svmul_f32_m(pg, ay1, vx);
383
661
  svst1_f32(pg, y + np, ay1);
384
662
  }
663
+ #elif defined(__riscv_v_intrinsic)
664
+ for (int i = 0, avl; i < n; i += avl) {
665
+ avl = __riscv_vsetvl_e32m8(n - i);
666
+ vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
667
+ vfloat32m8_t ny = __riscv_vfmul_vf_f32m8(ay, v, avl);
668
+ __riscv_vse32_v_f32m8(&y[i], ny, avl);
669
+ }
385
670
  #else
386
671
  const int np = (n & ~(WSP_GGML_F32_STEP - 1));
387
672
 
@@ -413,25 +698,59 @@ inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float
413
698
 
414
699
  inline static void wsp_ggml_vec_scale_f16(const int n, wsp_ggml_fp16_t * y, const float v) {
415
700
  #if defined(WSP_GGML_SIMD)
416
- const int np = (n & ~(WSP_GGML_F16_STEP - 1));
701
+ #if defined(__ARM_FEATURE_SVE)
702
+ const int sve_register_length = svcntb() * 8;
703
+ const int wsp_ggml_f16_epr = sve_register_length / 16;
704
+ const int wsp_ggml_f16_step = 2 * wsp_ggml_f16_epr;
705
+
706
+ WSP_GGML_F16x_VEC vx = WSP_GGML_F16x_VEC_SET1(v);
707
+ const int np = (n & ~(wsp_ggml_f16_step - 1));
708
+ svfloat16_t ay1, ay2;
709
+
710
+ for (int i = 0; i < np; i += wsp_ggml_f16_step) {
711
+ ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0*wsp_ggml_f16_epr, 0);
712
+ ay1 = WSP_GGML_F16x_VEC_MUL(ay1, vx);
713
+ WSP_GGML_F16x_VEC_STORE(y + i + 0*wsp_ggml_f16_epr, ay1, 0);
714
+
715
+ ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1*wsp_ggml_f16_epr, 1);
716
+ ay2 = WSP_GGML_F16x_VEC_MUL(ay2, vx);
717
+ WSP_GGML_F16x_VEC_STORE(y + i + 1*wsp_ggml_f16_epr, ay2, 1);
718
+ }
719
+ // leftovers
720
+ // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
721
+ if (np < n) {
722
+ svbool_t pg = svwhilelt_b16(np, n);
723
+ svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
724
+ svfloat16_t out = svmul_f16_m(pg, hy, vx);
725
+ svst1_f16(pg, (__fp16 *)(y + np), out);
726
+ }
727
+ #elif defined(__riscv_v_intrinsic)
728
+ // todo: RVV impl
729
+ // scalar
730
+ for (int i = 0; i < n; ++i) {
731
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i])*v);
732
+ }
733
+ #else
734
+ const int np = (n & ~(WSP_GGML_F16_STEP - 1));
417
735
 
418
- WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
736
+ WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
419
737
 
420
- WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
738
+ WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
421
739
 
422
- for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
423
- for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
424
- ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
425
- ay[j] = WSP_GGML_F16_VEC_MUL(ay[j], vx);
740
+ for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
741
+ for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
742
+ ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
743
+ ay[j] = WSP_GGML_F16_VEC_MUL(ay[j], vx);
426
744
 
427
- WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
745
+ WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
746
+ }
428
747
  }
429
- }
430
748
 
431
- // leftovers
432
- for (int i = np; i < n; ++i) {
433
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i])*v);
434
- }
749
+ // leftovers
750
+ for (int i = np; i < n; ++i) {
751
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i])*v);
752
+ }
753
+ #endif
435
754
  #else
436
755
  // scalar
437
756
  for (int i = 0; i < n; ++i) {
@@ -683,7 +1002,39 @@ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/sr
683
1002
  }
684
1003
  #endif
685
1004
 
686
- #if defined(__ARM_NEON) && defined(__aarch64__)
1005
+ #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
1006
+
1007
+ inline static svfloat32_t wsp_ggml_v_expf(svbool_t pg, svfloat32_t x) {
1008
+ const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);
1009
+ const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);
1010
+ const svfloat32_t n = svsub_f32_x(pg, z, r);
1011
+ const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);
1012
+ const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);
1013
+ const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));
1014
+ const svbool_t c = svacgt_n_f32(pg, n, 126);
1015
+ const svfloat32_t u = svmul_f32_x(pg, b, b);
1016
+ const svfloat32_t j = svmla_f32_x(pg,
1017
+ svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),
1018
+ svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),
1019
+ svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);
1020
+ const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);
1021
+ const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));
1022
+ const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));
1023
+ return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),
1024
+ svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
1025
+ }
1026
+
1027
+ // computes silu x/(1+exp(-x)) in single precision vector
1028
+ inline static svfloat32_t wsp_ggml_v_silu(svbool_t pg, svfloat32_t x) {
1029
+ const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
1030
+ const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
1031
+ const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
1032
+ const svfloat32_t exp_neg_x = wsp_ggml_v_expf(pg, neg_x);
1033
+ const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
1034
+ return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
1035
+ }
1036
+
1037
+ #elif defined(__ARM_NEON) && defined(__aarch64__)
687
1038
 
688
1039
  // adapted from arm limited optimized routine
689
1040
  // the maximum error is 1.45358 plus 0.5 ulps
@@ -874,7 +1225,59 @@ inline static __m128 wsp_ggml_v_silu(__m128 x) {
874
1225
  return _mm_div_ps(x, one_plus_exp_neg_x);
875
1226
  }
876
1227
 
877
- #endif // __ARM_NEON / __AVX2__ / __SSE2__
1228
+ #elif defined(__riscv_v_intrinsic)
1229
+
1230
+ // adapted from arm limited optimized routine
1231
+ // the maximum error is 1.45358 plus 0.5 ulps
1232
+ // numbers above 88.38 will flush to infinity
1233
+ // numbers beneath -103.97 will flush to zero
1234
+ inline static vfloat32m2_t wsp_ggml_v_expf_m2(vfloat32m2_t x, int vl) {
1235
+ const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl);
1236
+ #ifdef __riscv_xtheadvector
1237
+ // workaround for compiler bug (gcc 14.3.0: Error: unrecognized opcode `th.vmv1r.v v2,v4')
1238
+ vfloat32m2_t z = __riscv_vfadd_vf_f32m2(r, 0.0f, vl);
1239
+ z = __riscv_vfmacc_vf_f32m2(z, 0x1.715476p+0f, x, vl);
1240
+ #else
1241
+ const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl);
1242
+ #endif
1243
+ const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl);
1244
+ const vfloat32m2_t b = __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl),
1245
+ 0x1.7f7d1cp-20f, n, vl);
1246
+ const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl);
1247
+ const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); // 1.0f
1248
+ const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl);
1249
+ const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl);
1250
+ const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2(
1251
+ __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl),
1252
+ __riscv_vfmacc_vv_f32m2(
1253
+ __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl),
1254
+ __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl),
1255
+ u, vl), u, vl);
1256
+ if (!__riscv_vcpop_m_b16(c, vl))
1257
+ return __riscv_vfmacc_vv_f32m2(k, j, k, vl);
1258
+ const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl);
1259
+ const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl);
1260
+ const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl));
1261
+ const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl));
1262
+ const vfloat32m2_t r1 = __riscv_vmerge_vvm_f32m2(
1263
+ __riscv_vfmacc_vv_f32m2(k, k, j, vl),
1264
+ __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl),
1265
+ c, vl);
1266
+ return __riscv_vmerge_vvm_f32m2(
1267
+ r1, __riscv_vfmul_vv_f32m2(s1, s1, vl),
1268
+ __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl),
1269
+ vl);
1270
+ }
1271
+
1272
+ // computes silu x/(1+exp(-x)) in single precision vector
1273
+ inline static vfloat32m2_t wsp_ggml_v_silu_m2(vfloat32m2_t x, int vl) {
1274
+ const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
1275
+ const vfloat32m2_t exp_neg_x = wsp_ggml_v_expf_m2(neg_x, vl);
1276
+ const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
1277
+ return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
1278
+ }
1279
+
1280
+ #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
878
1281
 
879
1282
  inline static void wsp_ggml_vec_silu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) {
880
1283
  for (int i = 0; i < n; ++i) {
@@ -953,9 +1356,49 @@ void wsp_ggml_vec_swiglu_f32(const int n, float * y, const float * x, const floa
953
1356
 
954
1357
  inline static void wsp_ggml_vec_swiglu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
955
1358
  for (int i = 0; i < n; ++i) {
956
- float v = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
957
- float w = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
958
- y[i] = WSP_GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
1359
+ float xi = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
1360
+ float gi = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
1361
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
1362
+ }
1363
+ }
1364
+
1365
+ inline static void wsp_ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
1366
+ for (int i = 0; i < n; ++i) {
1367
+ float xi = x[i];
1368
+ y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
1369
+ }
1370
+ }
1371
+
1372
+ inline static void wsp_ggml_vec_geglu_erf_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
1373
+ for (int i = 0; i < n; ++i) {
1374
+ float xi = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
1375
+ float gi = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
1376
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
1377
+ }
1378
+ }
1379
+
1380
+ #ifdef WSP_GGML_GELU_QUICK_FP16
1381
+ inline static void wsp_ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1382
+ uint16_t t;
1383
+ for (int i = 0; i < n; ++i) {
1384
+ wsp_ggml_fp16_t fp16 = WSP_GGML_CPU_FP32_TO_FP16(x[i]);
1385
+ memcpy(&t, &fp16, sizeof(uint16_t));
1386
+ y[i] = WSP_GGML_CPU_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[t]) * g[i];
1387
+ }
1388
+ }
1389
+ #else
1390
+ inline static void wsp_ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
1391
+ for (int i = 0; i < n; ++i) {
1392
+ y[i] = wsp_ggml_gelu_quick_f32(x[i]) * g[i];
1393
+ }
1394
+ }
1395
+ #endif
1396
+
1397
+ inline static void wsp_ggml_vec_geglu_quick_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
1398
+ const uint16_t * i16 = (const uint16_t *) x;
1399
+ for (int i = 0; i < n; ++i) {
1400
+ float v = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
1401
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[i16[i]]) * v);
959
1402
  }
960
1403
  }
961
1404