whisper.rn 0.5.3 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/README.md +1 -1
  2. package/android/src/main/java/com/rnwhisper/WhisperContext.java +5 -0
  3. package/android/src/main/jni.cpp +13 -0
  4. package/cpp/ggml-alloc.c +78 -26
  5. package/cpp/ggml-alloc.h +9 -0
  6. package/cpp/ggml-backend-impl.h +1 -1
  7. package/cpp/ggml-backend-reg.cpp +19 -3
  8. package/cpp/ggml-backend.cpp +72 -20
  9. package/cpp/ggml-backend.h +2 -1
  10. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  11. package/cpp/ggml-cpu/arch/arm/repack.cpp +1004 -0
  12. package/cpp/ggml-cpu/arch/x86/repack.cpp +6 -6
  13. package/cpp/ggml-cpu/arch-fallback.h +50 -2
  14. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  15. package/cpp/ggml-cpu/ggml-cpu.c +139 -58
  16. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  17. package/cpp/ggml-cpu/ops.cpp +170 -18
  18. package/cpp/ggml-cpu/ops.h +1 -0
  19. package/cpp/ggml-cpu/repack.cpp +531 -5
  20. package/cpp/ggml-cpu/repack.h +14 -0
  21. package/cpp/ggml-cpu/simd-mappings.h +16 -18
  22. package/cpp/ggml-cpu/vec.cpp +41 -1
  23. package/cpp/ggml-cpu/vec.h +241 -138
  24. package/cpp/ggml-cpu.h +1 -0
  25. package/cpp/ggml-impl.h +0 -4
  26. package/cpp/ggml-metal/ggml-metal-context.m +26 -16
  27. package/cpp/ggml-metal/ggml-metal-device.cpp +452 -371
  28. package/cpp/ggml-metal/ggml-metal-device.h +87 -65
  29. package/cpp/ggml-metal/ggml-metal-device.m +263 -104
  30. package/cpp/ggml-metal/ggml-metal-impl.h +58 -4
  31. package/cpp/ggml-metal/ggml-metal-ops.cpp +415 -98
  32. package/cpp/ggml-metal/ggml-metal-ops.h +4 -0
  33. package/cpp/ggml-metal/ggml-metal.cpp +6 -5
  34. package/cpp/ggml-metal/ggml-metal.metal +404 -34
  35. package/cpp/ggml.c +110 -31
  36. package/cpp/ggml.h +51 -12
  37. package/cpp/jsi/RNWhisperJSI.cpp +1 -0
  38. package/cpp/whisper.cpp +17 -4
  39. package/ios/CMakeLists.txt +21 -1
  40. package/ios/RNWhisperContext.mm +5 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  44. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  45. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  55. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  56. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  62. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  63. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  64. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  65. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +51 -12
  66. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +404 -34
  68. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +9 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +1 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -0
  73. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +0 -4
  74. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +51 -12
  75. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +404 -34
  78. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  79. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  80. package/lib/commonjs/jest-mock.js +2 -0
  81. package/lib/commonjs/jest-mock.js.map +1 -1
  82. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
  83. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  84. package/lib/commonjs/version.json +1 -1
  85. package/lib/module/NativeRNWhisper.js.map +1 -1
  86. package/lib/module/jest-mock.js +2 -0
  87. package/lib/module/jest-mock.js.map +1 -1
  88. package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
  89. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  90. package/lib/module/version.json +1 -1
  91. package/lib/typescript/NativeRNWhisper.d.ts +1 -0
  92. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  93. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
  94. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  95. package/lib/typescript/realtime-transcription/types.d.ts +7 -0
  96. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  97. package/package.json +1 -1
  98. package/src/NativeRNWhisper.ts +1 -0
  99. package/src/jest-mock.ts +2 -0
  100. package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
  101. package/src/realtime-transcription/types.ts +9 -0
  102. package/src/version.json +1 -1
@@ -80,10 +80,12 @@ extern "C" {
80
80
 
81
81
  void wsp_ggml_wsp_quantize_mat_q8_0_4x4(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
82
82
  void wsp_ggml_wsp_quantize_mat_q8_0_4x8(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
83
+ void wsp_ggml_wsp_quantize_mat_q8_K_4x4(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
83
84
  void wsp_ggml_wsp_quantize_mat_q8_K_4x8(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
84
85
  void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
85
86
  void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
86
87
  void wsp_ggml_gemv_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
88
+ void wsp_ggml_gemv_q4_K_8x4_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
87
89
  void wsp_ggml_gemv_q4_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
88
90
  void wsp_ggml_gemv_q2_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
89
91
  void wsp_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
@@ -91,18 +93,25 @@ void wsp_ggml_gemv_iq4_nl_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs
91
93
  void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
92
94
  void wsp_ggml_gemm_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
93
95
  void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
96
+ void wsp_ggml_gemm_q4_K_8x4_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
94
97
  void wsp_ggml_gemm_q4_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
95
98
  void wsp_ggml_gemm_q2_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
96
99
  void wsp_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
97
100
  void wsp_ggml_gemm_iq4_nl_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
101
+ void wsp_ggml_gemv_q8_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
102
+ void wsp_ggml_gemv_q8_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
103
+ void wsp_ggml_gemm_q8_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
104
+ void wsp_ggml_gemm_q8_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
98
105
 
99
106
  // Native implementations
100
107
  void wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
101
108
  void wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
109
+ void wsp_ggml_wsp_quantize_mat_q8_K_4x4_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
102
110
  void wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
103
111
  void wsp_ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
104
112
  void wsp_ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
105
113
  void wsp_ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
114
+ void wsp_ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
106
115
  void wsp_ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
107
116
  void wsp_ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
108
117
  void wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
@@ -110,10 +119,15 @@ void wsp_ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, s
110
119
  void wsp_ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
111
120
  void wsp_ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
112
121
  void wsp_ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
122
+ void wsp_ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
113
123
  void wsp_ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
114
124
  void wsp_ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
115
125
  void wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
116
126
  void wsp_ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
127
+ void wsp_ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
128
+ void wsp_ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
129
+ void wsp_ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
130
+ void wsp_ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
117
131
 
118
132
  #if defined(__cplusplus)
119
133
  } // extern "C"
@@ -14,10 +14,6 @@
14
14
  #include <arm_neon.h>
15
15
  #endif
16
16
 
17
- #if defined(__F16C__)
18
- #include <immintrin.h>
19
- #endif
20
-
21
17
  #if defined(__riscv_v_intrinsic)
22
18
  #include <riscv_vector.h>
23
19
  #endif
@@ -160,18 +156,18 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
160
156
  #define WSP_GGML_F32xt svfloat32_t
161
157
  #define WSP_GGML_F32xt_ZERO svdup_n_f32(0.0f)
162
158
  #define WSP_GGML_F32xt_SET1(x) svdup_n_f32(x)
163
- #define WSP_GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a)
164
- #define WSP_GGML_F32xt_LOAD(...) WSP_GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
165
- #define WSP_GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
166
- #define WSP_GGML_F32xt_STORE(...) WSP_GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
159
+ #define WSP_GGML_F32xt_LOAD_IMPL(pg, a) svld1_f32(pg, a)
160
+ #define WSP_GGML_F32xt_LOAD(a) WSP_GGML_F32xt_LOAD_IMPL(DEFAULT_PG, a)
161
+ #define WSP_GGML_F32xt_STORE_IMPL(pg, a, b) svst1_f32(pg, a, b)
162
+ #define WSP_GGML_F32xt_STORE(a, b) WSP_GGML_F32xt_STORE_IMPL(DEFAULT_PG, a, b)
167
163
  #define WSP_GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
168
- #define WSP_GGML_F32xt_FMA(...) WSP_GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
164
+ #define WSP_GGML_F32xt_FMA(a, b, c) WSP_GGML_F32xt_FMA_IMPL(DEFAULT_PG, a, b, c)
169
165
  #define WSP_GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
170
- #define WSP_GGML_F32xt_ADD(...) WSP_GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
166
+ #define WSP_GGML_F32xt_ADD(a, b) WSP_GGML_F32xt_ADD_IMPL(DEFAULT_PG, a, b)
171
167
  #define WSP_GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b)
172
- #define WSP_GGML_F32xt_MUL(...) WSP_GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__)
168
+ #define WSP_GGML_F32xt_MUL(a, b) WSP_GGML_F32xt_MUL_IMPL(DEFAULT_PG, a, b)
173
169
  #define WSP_GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a)
174
- #define WSP_GGML_F32xt_REDUCE_ONE(...) WSP_GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__)
170
+ #define WSP_GGML_F32xt_REDUCE_ONE(a) WSP_GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, a)
175
171
  #define WSP_GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
176
172
  { \
177
173
  sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \
@@ -183,7 +179,8 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
183
179
  sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \
184
180
  (res) = (wsp_ggml_float) WSP_GGML_F32xt_REDUCE_ONE(sum1); \
185
181
  }
186
- #define WSP_GGML_F32xt_REDUCE(...) WSP_GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__)
182
+ #define WSP_GGML_F32xt_REDUCE(res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \
183
+ WSP_GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8)
187
184
 
188
185
  #define WSP_GGML_F32_VEC WSP_GGML_F32xt
189
186
  #define WSP_GGML_F32_VEC_ZERO WSP_GGML_F32xt_ZERO
@@ -206,11 +203,11 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
206
203
  #define WSP_GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
207
204
 
208
205
  #define WSP_GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
209
- #define WSP_GGML_F32Cxt_FMA(...) WSP_GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
206
+ #define WSP_GGML_F32Cxt_FMA(a, b, c) WSP_GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, a, b, c)
210
207
  #define WSP_GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
211
- #define WSP_GGML_F32Cxt_ADD(...) WSP_GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
208
+ #define WSP_GGML_F32Cxt_ADD(a, b) WSP_GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, a, b)
212
209
  #define WSP_GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
213
- #define WSP_GGML_F32Cxt_MUL(...) WSP_GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
210
+ #define WSP_GGML_F32Cxt_MUL(a, b) WSP_GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, a, b)
214
211
  #define WSP_GGML_F32Cxt_REDUCE WSP_GGML_F16xt_REDUCE_MIXED
215
212
 
216
213
  #define WSP_GGML_F16x_VEC WSP_GGML_F32Cxt
@@ -224,7 +221,7 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
224
221
  #define WSP_GGML_F16x_VEC_REDUCE WSP_GGML_F32Cxt_REDUCE
225
222
 
226
223
  #define WSP_GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
227
- #define WSP_GGML_F16xt_REDUCE_ONE(...) WSP_GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
224
+ #define WSP_GGML_F16xt_REDUCE_ONE(a) WSP_GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, a)
228
225
 
229
226
  #define WSP_GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
230
227
  { \
@@ -234,7 +231,8 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
234
231
  __fp16 sum_f16 = svaddv_f16(pg16, sum1); \
235
232
  (res) = (wsp_ggml_float) sum_f16; \
236
233
  }
237
- #define WSP_GGML_F16xt_REDUCE_MIXED(...) WSP_GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
234
+ #define WSP_GGML_F16xt_REDUCE_MIXED(res, sum1, sum2, sum3, sum4) \
235
+ WSP_GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, res, sum1, sum2, sum3, sum4)
238
236
 
239
237
  // F16 NEON
240
238
 
@@ -195,8 +195,48 @@ void wsp_ggml_vec_dot_bf16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_gg
195
195
  sumf += (wsp_ggml_float)_mm_cvtss_f32(g);
196
196
 
197
197
  #undef LOAD
198
- #endif
198
+ #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)
199
+ size_t vl = __riscv_vsetvlmax_e32m4();
200
+
201
+ // initialize accumulators to all zeroes
202
+ vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
203
+ vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
204
+
205
+ // calculate step size
206
+ const size_t epr = __riscv_vsetvlmax_e16m2();
207
+ const size_t step = epr * 2;
208
+ const int np = (n & ~(step - 1));
209
+
210
+ // unroll by 2
211
+ for (; i < np; i += step) {
212
+ vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);
213
+ vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);
214
+ vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);
215
+ __asm__ __volatile__ ("" ::: "memory");
216
+
217
+ vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);
218
+ vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);
219
+ vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);
220
+ __asm__ __volatile__ ("" ::: "memory");
221
+ }
199
222
 
223
+ // accumulate in 1 register
224
+ vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);
225
+
226
+ // leftovers
227
+ for (i = np; i < n; i += vl) {
228
+ vl = __riscv_vsetvl_e16m2(n - i);
229
+ vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);
230
+ vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);
231
+ vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);
232
+ }
233
+
234
+ // reduce
235
+ vl = __riscv_vsetvlmax_e32m4();
236
+ vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
237
+ sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
238
+
239
+ #endif
200
240
  for (; i < n; ++i) {
201
241
  sumf += (wsp_ggml_float)(WSP_GGML_BF16_TO_FP32(x[i]) *
202
242
  WSP_GGML_BF16_TO_FP32(y[i]));
@@ -224,13 +224,71 @@ inline static void wsp_ggml_vec_dot_f16_unroll(const int n, const int xs, float
224
224
  }
225
225
  WSP_GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
226
226
  WSP_GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
227
- #elif defined(__riscv_v_intrinsic)
228
- // todo: RVV impl
229
- for (int i = 0; i < n; ++i) {
230
- for (int j = 0; j < WSP_GGML_VEC_DOT_UNROLL; ++j) {
231
- sumf[j] += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[j][i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
232
- }
233
- }
227
+
228
+ #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
229
+ size_t vl = __riscv_vsetvlmax_e32m4();
230
+
231
+ // initialize accumulators to all zeroes
232
+ vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
233
+ vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
234
+ vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
235
+ vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
236
+
237
+ // calculate step size
238
+ const size_t epr = __riscv_vsetvlmax_e16m2();
239
+ const size_t step = epr * 2;
240
+ const int np = (n & ~(step - 1));
241
+
242
+ // unroll by 2 along the row dimension
243
+ for (int i = 0; i < np; i += step) {
244
+ vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
245
+ vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
246
+ vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
247
+ vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
248
+ vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
249
+
250
+ vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
251
+ vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
252
+ vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
253
+ vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
254
+ vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
255
+ }
256
+
257
+ vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
258
+ vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
259
+
260
+ // leftovers
261
+ for (int i = np; i < n; i += vl) {
262
+ vl = __riscv_vsetvl_e16m2(n - i);
263
+ vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
264
+ vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
265
+ vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
266
+
267
+ vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
268
+ vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
269
+ }
270
+
271
+ // reduce
272
+ vl = __riscv_vsetvlmax_e32m2();
273
+ vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
274
+ __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
275
+ vl = __riscv_vsetvlmax_e32m1();
276
+ vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
277
+ __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
278
+ vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
279
+ acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
280
+
281
+ vl = __riscv_vsetvlmax_e32m2();
282
+ vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
283
+ __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
284
+ vl = __riscv_vsetvlmax_e32m1();
285
+ vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
286
+ __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
287
+ vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
288
+ acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
289
+ sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
290
+ sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
291
+
234
292
  #else
235
293
  const int np = (n & ~(WSP_GGML_F16_STEP - 1));
236
294
 
@@ -397,119 +455,142 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
397
455
  }
398
456
 
399
457
  inline static void wsp_ggml_vec_mad_f16(const int n, wsp_ggml_fp16_t * WSP_GGML_RESTRICT y, const wsp_ggml_fp16_t * WSP_GGML_RESTRICT x, const float v) {
400
- #if defined(WSP_GGML_SIMD)
401
- #if defined(__ARM_FEATURE_SVE)
402
- const int sve_register_length = svcntb() * 8;
403
- const int wsp_ggml_f16_epr = sve_register_length / 16;
404
- const int wsp_ggml_f16_step = 8 * wsp_ggml_f16_epr;
458
+ #if defined(WSP_GGML_SIMD) && defined(__ARM_FEATURE_SVE)
459
+ const int sve_register_length = svcntb() * 8;
460
+ const int wsp_ggml_f16_epr = sve_register_length / 16;
461
+ const int wsp_ggml_f16_step = 8 * wsp_ggml_f16_epr;
405
462
 
406
- WSP_GGML_F16x_VEC vx = WSP_GGML_F16x_VEC_SET1(v);
463
+ WSP_GGML_F16x_VEC vx = WSP_GGML_F16x_VEC_SET1(v);
407
464
 
408
- const int np= (n & ~(wsp_ggml_f16_step - 1));
465
+ int np = (n & ~(wsp_ggml_f16_step - 1));
409
466
 
410
- svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
411
- svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
412
- for (int i = 0; i < np; i += wsp_ggml_f16_step) {
413
- ax1 = WSP_GGML_F16x_VEC_LOAD(x + i + 0 * wsp_ggml_f16_epr, 0);
414
- ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0 * wsp_ggml_f16_epr, 0);
415
- ay1 = WSP_GGML_F16x_VEC_FMA(ay1, ax1, vx);
467
+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
468
+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
469
+ for (int i = 0; i < np; i += wsp_ggml_f16_step) {
470
+ ax1 = WSP_GGML_F16x_VEC_LOAD(x + i + 0 * wsp_ggml_f16_epr, 0);
471
+ ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0 * wsp_ggml_f16_epr, 0);
472
+ ay1 = WSP_GGML_F16x_VEC_FMA(ay1, ax1, vx);
416
473
 
417
- WSP_GGML_F16x_VEC_STORE(y + i + 0 * wsp_ggml_f16_epr, ay1, 0);
474
+ WSP_GGML_F16x_VEC_STORE(y + i + 0 * wsp_ggml_f16_epr, ay1, 0);
418
475
 
419
- ax2 = WSP_GGML_F16x_VEC_LOAD(x + i + 1 * wsp_ggml_f16_epr, 1);
420
- ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1 * wsp_ggml_f16_epr, 1);
421
- ay2 = WSP_GGML_F16x_VEC_FMA(ay2, ax2, vx);
476
+ ax2 = WSP_GGML_F16x_VEC_LOAD(x + i + 1 * wsp_ggml_f16_epr, 1);
477
+ ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1 * wsp_ggml_f16_epr, 1);
478
+ ay2 = WSP_GGML_F16x_VEC_FMA(ay2, ax2, vx);
422
479
 
423
- WSP_GGML_F16x_VEC_STORE(y + i + 1 * wsp_ggml_f16_epr, ay2, 1);
480
+ WSP_GGML_F16x_VEC_STORE(y + i + 1 * wsp_ggml_f16_epr, ay2, 1);
424
481
 
425
- ax3 = WSP_GGML_F16x_VEC_LOAD(x + i + 2 * wsp_ggml_f16_epr, 2);
426
- ay3 = WSP_GGML_F16x_VEC_LOAD(y + i + 2 * wsp_ggml_f16_epr, 2);
427
- ay3 = WSP_GGML_F16x_VEC_FMA(ay3, ax3, vx);
482
+ ax3 = WSP_GGML_F16x_VEC_LOAD(x + i + 2 * wsp_ggml_f16_epr, 2);
483
+ ay3 = WSP_GGML_F16x_VEC_LOAD(y + i + 2 * wsp_ggml_f16_epr, 2);
484
+ ay3 = WSP_GGML_F16x_VEC_FMA(ay3, ax3, vx);
428
485
 
429
- WSP_GGML_F16x_VEC_STORE(y + i + 2 * wsp_ggml_f16_epr, ay3, 2);
486
+ WSP_GGML_F16x_VEC_STORE(y + i + 2 * wsp_ggml_f16_epr, ay3, 2);
430
487
 
431
- ax4 = WSP_GGML_F16x_VEC_LOAD(x + i + 3 * wsp_ggml_f16_epr, 3);
432
- ay4 = WSP_GGML_F16x_VEC_LOAD(y + i + 3 * wsp_ggml_f16_epr, 3);
433
- ay4 = WSP_GGML_F16x_VEC_FMA(ay4, ax4, vx);
488
+ ax4 = WSP_GGML_F16x_VEC_LOAD(x + i + 3 * wsp_ggml_f16_epr, 3);
489
+ ay4 = WSP_GGML_F16x_VEC_LOAD(y + i + 3 * wsp_ggml_f16_epr, 3);
490
+ ay4 = WSP_GGML_F16x_VEC_FMA(ay4, ax4, vx);
434
491
 
435
- WSP_GGML_F16x_VEC_STORE(y + i + 3 * wsp_ggml_f16_epr, ay4, 3);
492
+ WSP_GGML_F16x_VEC_STORE(y + i + 3 * wsp_ggml_f16_epr, ay4, 3);
436
493
 
437
- ax5 = WSP_GGML_F16x_VEC_LOAD(x + i + 4 * wsp_ggml_f16_epr, 4);
438
- ay5 = WSP_GGML_F16x_VEC_LOAD(y + i + 4 * wsp_ggml_f16_epr, 4);
439
- ay5 = WSP_GGML_F16x_VEC_FMA(ay5, ax5, vx);
494
+ ax5 = WSP_GGML_F16x_VEC_LOAD(x + i + 4 * wsp_ggml_f16_epr, 4);
495
+ ay5 = WSP_GGML_F16x_VEC_LOAD(y + i + 4 * wsp_ggml_f16_epr, 4);
496
+ ay5 = WSP_GGML_F16x_VEC_FMA(ay5, ax5, vx);
440
497
 
441
- WSP_GGML_F16x_VEC_STORE(y + i + 4 * wsp_ggml_f16_epr, ay5, 4);
498
+ WSP_GGML_F16x_VEC_STORE(y + i + 4 * wsp_ggml_f16_epr, ay5, 4);
442
499
 
443
- ax6 = WSP_GGML_F16x_VEC_LOAD(x + i + 5 * wsp_ggml_f16_epr, 5);
444
- ay6 = WSP_GGML_F16x_VEC_LOAD(y + i + 5 * wsp_ggml_f16_epr, 5);
445
- ay6 = WSP_GGML_F16x_VEC_FMA(ay6, ax6, vx);
500
+ ax6 = WSP_GGML_F16x_VEC_LOAD(x + i + 5 * wsp_ggml_f16_epr, 5);
501
+ ay6 = WSP_GGML_F16x_VEC_LOAD(y + i + 5 * wsp_ggml_f16_epr, 5);
502
+ ay6 = WSP_GGML_F16x_VEC_FMA(ay6, ax6, vx);
446
503
 
447
- WSP_GGML_F16x_VEC_STORE(y + i + 5 * wsp_ggml_f16_epr, ay6, 5);
504
+ WSP_GGML_F16x_VEC_STORE(y + i + 5 * wsp_ggml_f16_epr, ay6, 5);
448
505
 
449
- ax7 = WSP_GGML_F16x_VEC_LOAD(x + i + 6 * wsp_ggml_f16_epr, 6);
450
- ay7 = WSP_GGML_F16x_VEC_LOAD(y + i + 6 * wsp_ggml_f16_epr, 6);
451
- ay7 = WSP_GGML_F16x_VEC_FMA(ay7, ax7, vx);
506
+ ax7 = WSP_GGML_F16x_VEC_LOAD(x + i + 6 * wsp_ggml_f16_epr, 6);
507
+ ay7 = WSP_GGML_F16x_VEC_LOAD(y + i + 6 * wsp_ggml_f16_epr, 6);
508
+ ay7 = WSP_GGML_F16x_VEC_FMA(ay7, ax7, vx);
452
509
 
453
- WSP_GGML_F16x_VEC_STORE(y + i + 6 * wsp_ggml_f16_epr, ay7, 6);
510
+ WSP_GGML_F16x_VEC_STORE(y + i + 6 * wsp_ggml_f16_epr, ay7, 6);
454
511
 
455
- ax8 = WSP_GGML_F16x_VEC_LOAD(x + i + 7 * wsp_ggml_f16_epr, 7);
456
- ay8 = WSP_GGML_F16x_VEC_LOAD(y + i + 7 * wsp_ggml_f16_epr, 7);
457
- ay8 = WSP_GGML_F16x_VEC_FMA(ay8, ax8, vx);
512
+ ax8 = WSP_GGML_F16x_VEC_LOAD(x + i + 7 * wsp_ggml_f16_epr, 7);
513
+ ay8 = WSP_GGML_F16x_VEC_LOAD(y + i + 7 * wsp_ggml_f16_epr, 7);
514
+ ay8 = WSP_GGML_F16x_VEC_FMA(ay8, ax8, vx);
458
515
 
459
- WSP_GGML_F16x_VEC_STORE(y + i + 7 * wsp_ggml_f16_epr, ay8, 7);
460
- }
461
- const int np2 = (n & ~(wsp_ggml_f16_epr - 1));
462
- for (int k = np; k < np2; k += wsp_ggml_f16_epr) {
463
- svfloat16_t rx = WSP_GGML_F16x_VEC_LOAD(x + k, 0);
464
- svfloat16_t ry = WSP_GGML_F16x_VEC_LOAD(y + k, 0);
465
- ry = WSP_GGML_F16x_VEC_FMA(ry, rx, vx);
466
-
467
- WSP_GGML_F16x_VEC_STORE(y + k, ry, 0);
468
- }
516
+ WSP_GGML_F16x_VEC_STORE(y + i + 7 * wsp_ggml_f16_epr, ay8, 7);
517
+ }
518
+ const int np2 = (n & ~(wsp_ggml_f16_epr - 1));
519
+ for (int k = np; k < np2; k += wsp_ggml_f16_epr) {
520
+ svfloat16_t rx = WSP_GGML_F16x_VEC_LOAD(x + k, 0);
521
+ svfloat16_t ry = WSP_GGML_F16x_VEC_LOAD(y + k, 0);
522
+ ry = WSP_GGML_F16x_VEC_FMA(ry, rx, vx);
469
523
 
470
- if (np2 < n) {
471
- svbool_t pg = svwhilelt_b16(np2, n);
472
- svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
473
- svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
474
- hy = svmad_f16_x(pg, hx, vx, hy);
475
- svst1_f16(pg, (__fp16 *)(y + np2), hy);
476
- }
524
+ WSP_GGML_F16x_VEC_STORE(y + k, ry, 0);
525
+ }
477
526
 
478
- #elif defined(__riscv_v_intrinsic)
479
- // todo: RVV impl
480
- // scalar
481
- for (int i = 0; i < n; ++i) {
482
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i]) + WSP_GGML_CPU_FP16_TO_FP32(x[i])*v);
483
- }
484
- #else
485
- const int np = (n & ~(WSP_GGML_F16_STEP - 1));
527
+ if (np2 < n) {
528
+ svbool_t pg = svwhilelt_b16(np2, n);
529
+ svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
530
+ svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
531
+ hy = svmad_f16_x(pg, hx, vx, hy);
532
+ svst1_f16(pg, (__fp16 *)(y + np2), hy);
533
+ }
534
+ np = n;
535
+ #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
536
+ const wsp_ggml_fp16_t s = WSP_GGML_CPU_FP32_TO_FP16(v);
537
+ const _Float16 scale = *(const _Float16*)(&s);
538
+
539
+ // calculate step size
540
+ const int epr = __riscv_vsetvlmax_e16m4();
541
+ const int step = epr * 2;
542
+ int np = (n & ~(step - 1));
543
+
544
+ // unroll by 2
545
+ for (int i = 0; i < np; i += step) {
546
+ vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
547
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
548
+ ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
549
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
550
+ __asm__ __volatile__ ("" ::: "memory");
551
+
552
+ vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
553
+ vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
554
+ ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
555
+ __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
556
+ __asm__ __volatile__ ("" ::: "memory");
557
+ }
486
558
 
487
- WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
559
+ // leftovers
560
+ int vl;
561
+ for (int i = np; i < n; i += vl) {
562
+ vl = __riscv_vsetvl_e16m4(n - i);
563
+ vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
564
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
565
+ ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
566
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
567
+ }
568
+ np = n;
569
+ #elif defined(WSP_GGML_SIMD)
570
+ const int np = (n & ~(WSP_GGML_F16_STEP - 1));
488
571
 
489
- WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
490
- WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
572
+ WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
491
573
 
492
- for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
493
- for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
494
- ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j);
495
- ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
496
- ay[j] = WSP_GGML_F16_VEC_FMA(ay[j], ax[j], vx);
574
+ WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR];
575
+ WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
497
576
 
498
- WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
499
- }
500
- }
577
+ for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
578
+ for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
579
+ ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j);
580
+ ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
581
+ ay[j] = WSP_GGML_F16_VEC_FMA(ay[j], ax[j], vx);
501
582
 
502
- // leftovers
503
- for (int i = np; i < n; ++i) {
504
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i]) + WSP_GGML_CPU_FP16_TO_FP32(x[i])*v);
583
+ WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
505
584
  }
506
- #endif
585
+ }
507
586
  #else
508
- // scalar
509
- for (int i = 0; i < n; ++i) {
587
+ const int np = 0;
588
+ #endif
589
+
590
+ // leftovers
591
+ for (int i = np; i < n; ++i) {
510
592
  y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i]) + WSP_GGML_CPU_FP16_TO_FP32(x[i])*v);
511
593
  }
512
- #endif
513
594
  }
514
595
 
515
596
  // xs and vs are byte strides of x and v
@@ -698,60 +779,82 @@ inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float
698
779
  }
699
780
 
700
781
  inline static void wsp_ggml_vec_scale_f16(const int n, wsp_ggml_fp16_t * y, const float v) {
701
- #if defined(WSP_GGML_SIMD)
702
- #if defined(__ARM_FEATURE_SVE)
703
- const int sve_register_length = svcntb() * 8;
704
- const int wsp_ggml_f16_epr = sve_register_length / 16;
705
- const int wsp_ggml_f16_step = 2 * wsp_ggml_f16_epr;
706
-
707
- WSP_GGML_F16x_VEC vx = WSP_GGML_F16x_VEC_SET1(v);
708
- const int np = (n & ~(wsp_ggml_f16_step - 1));
709
- svfloat16_t ay1, ay2;
710
-
711
- for (int i = 0; i < np; i += wsp_ggml_f16_step) {
712
- ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0*wsp_ggml_f16_epr, 0);
713
- ay1 = WSP_GGML_F16x_VEC_MUL(ay1, vx);
714
- WSP_GGML_F16x_VEC_STORE(y + i + 0*wsp_ggml_f16_epr, ay1, 0);
782
+ #if defined(WSP_GGML_SIMD) && defined(__ARM_FEATURE_SVE)
783
+ const int sve_register_length = svcntb() * 8;
784
+ const int wsp_ggml_f16_epr = sve_register_length / 16;
785
+ const int wsp_ggml_f16_step = 2 * wsp_ggml_f16_epr;
786
+
787
+ WSP_GGML_F16x_VEC vx = WSP_GGML_F16x_VEC_SET1(v);
788
+ const int np = (n & ~(wsp_ggml_f16_step - 1));
789
+ svfloat16_t ay1, ay2;
790
+
791
+ for (int i = 0; i < np; i += wsp_ggml_f16_step) {
792
+ ay1 = WSP_GGML_F16x_VEC_LOAD(y + i + 0*wsp_ggml_f16_epr, 0);
793
+ ay1 = WSP_GGML_F16x_VEC_MUL(ay1, vx);
794
+ WSP_GGML_F16x_VEC_STORE(y + i + 0*wsp_ggml_f16_epr, ay1, 0);
795
+
796
+ ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1*wsp_ggml_f16_epr, 1);
797
+ ay2 = WSP_GGML_F16x_VEC_MUL(ay2, vx);
798
+ WSP_GGML_F16x_VEC_STORE(y + i + 1*wsp_ggml_f16_epr, ay2, 1);
799
+ }
800
+ // leftovers
801
+ // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
802
+ if (np < n) {
803
+ svbool_t pg = svwhilelt_b16(np, n);
804
+ svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
805
+ svfloat16_t out = svmul_f16_m(pg, hy, vx);
806
+ svst1_f16(pg, (__fp16 *)(y + np), out);
807
+ }
808
+ #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
809
+ const wsp_ggml_fp16_t s = WSP_GGML_CPU_FP32_TO_FP16(v);
810
+ const _Float16 scale = *(const _Float16*)(&s);
811
+
812
+ // calculate step size
813
+ const int epr = __riscv_vsetvlmax_e16m4();
814
+ const int step = epr * 2;
815
+ const int np = (n & ~(step - 1));
816
+
817
+ // unroll by 2
818
+ for (int i = 0; i < np; i += step) {
819
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
820
+ ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
821
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
822
+ __asm__ __volatile__ ("" ::: "memory");
823
+
824
+ vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
825
+ ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
826
+ __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
827
+ __asm__ __volatile__ ("" ::: "memory");
828
+ }
715
829
 
716
- ay2 = WSP_GGML_F16x_VEC_LOAD(y + i + 1*wsp_ggml_f16_epr, 1);
717
- ay2 = WSP_GGML_F16x_VEC_MUL(ay2, vx);
718
- WSP_GGML_F16x_VEC_STORE(y + i + 1*wsp_ggml_f16_epr, ay2, 1);
719
- }
720
- // leftovers
721
- // maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
722
- if (np < n) {
723
- svbool_t pg = svwhilelt_b16(np, n);
724
- svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
725
- svfloat16_t out = svmul_f16_m(pg, hy, vx);
726
- svst1_f16(pg, (__fp16 *)(y + np), out);
727
- }
728
- #elif defined(__riscv_v_intrinsic)
729
- // todo: RVV impl
730
- // scalar
731
- for (int i = 0; i < n; ++i) {
732
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i])*v);
733
- }
734
- #else
735
- const int np = (n & ~(WSP_GGML_F16_STEP - 1));
830
+ // leftovers
831
+ int vl;
832
+ for (int i = np; i < n; i += vl) {
833
+ vl = __riscv_vsetvl_e16m4(n - i);
834
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
835
+ ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
836
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
837
+ }
838
+ #elif defined(WSP_GGML_SIMD)
839
+ const int np = (n & ~(WSP_GGML_F16_STEP - 1));
736
840
 
737
- WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
841
+ WSP_GGML_F16_VEC vx = WSP_GGML_F16_VEC_SET1(v);
738
842
 
739
- WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
843
+ WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR];
740
844
 
741
- for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
742
- for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
743
- ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
744
- ay[j] = WSP_GGML_F16_VEC_MUL(ay[j], vx);
845
+ for (int i = 0; i < np; i += WSP_GGML_F16_STEP) {
846
+ for (int j = 0; j < WSP_GGML_F16_ARR; j++) {
847
+ ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j);
848
+ ay[j] = WSP_GGML_F16_VEC_MUL(ay[j], vx);
745
849
 
746
- WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
747
- }
850
+ WSP_GGML_F16_VEC_STORE(y + i + j*WSP_GGML_F16_EPR, ay, j);
748
851
  }
852
+ }
749
853
 
750
- // leftovers
751
- for (int i = np; i < n; ++i) {
752
- y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i])*v);
753
- }
754
- #endif
854
+ // leftovers
855
+ for (int i = np; i < n; ++i) {
856
+ y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(y[i])*v);
857
+ }
755
858
  #else
756
859
  // scalar
757
860
  for (int i = 0; i < n; ++i) {
package/cpp/ggml-cpu.h CHANGED
@@ -99,6 +99,7 @@ extern "C" {
99
99
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_sme (void);
100
100
  // other
101
101
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_riscv_v (void);
102
+ WSP_GGML_BACKEND_API int wsp_ggml_cpu_get_rvv_vlen (void); // risc-v vector length in bytes
102
103
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vsx (void);
103
104
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_vxe (void);
104
105
  WSP_GGML_BACKEND_API int wsp_ggml_cpu_has_wasm_simd (void);
package/cpp/ggml-impl.h CHANGED
@@ -24,10 +24,6 @@
24
24
  #include <arm_neon.h>
25
25
  #endif
26
26
 
27
- #if defined(__F16C__)
28
- #include <immintrin.h>
29
- #endif
30
-
31
27
  #ifdef __cplusplus
32
28
  extern "C" {
33
29
  #endif