whisper.rn 0.4.0-rc.9 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +43 -13
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +33 -35
  5. package/android/src/main/jni.cpp +9 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  12. package/cpp/coreml/whisper-compat.h +10 -0
  13. package/cpp/coreml/whisper-compat.m +35 -0
  14. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  15. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  16. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  17. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  18. package/cpp/ggml-alloc.c +39 -37
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +55 -27
  21. package/cpp/ggml-backend-reg.cpp +591 -0
  22. package/cpp/ggml-backend.cpp +336 -955
  23. package/cpp/ggml-backend.h +70 -42
  24. package/cpp/ggml-common.h +57 -49
  25. package/cpp/ggml-cpp.h +39 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  27. package/cpp/ggml-cpu/amx/amx.h +8 -0
  28. package/cpp/ggml-cpu/amx/common.h +91 -0
  29. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  30. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  31. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  32. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  33. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  34. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  35. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  36. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  37. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  38. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  39. package/cpp/ggml-cpu/binary-ops.h +16 -0
  40. package/cpp/ggml-cpu/common.h +72 -0
  41. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  42. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  43. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  44. package/cpp/ggml-cpu/ops.cpp +9085 -0
  45. package/cpp/ggml-cpu/ops.h +111 -0
  46. package/cpp/ggml-cpu/quants.c +1157 -0
  47. package/cpp/ggml-cpu/quants.h +89 -0
  48. package/cpp/ggml-cpu/repack.cpp +1570 -0
  49. package/cpp/ggml-cpu/repack.h +98 -0
  50. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  51. package/cpp/ggml-cpu/traits.cpp +36 -0
  52. package/cpp/ggml-cpu/traits.h +38 -0
  53. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  54. package/cpp/ggml-cpu/unary-ops.h +28 -0
  55. package/cpp/ggml-cpu/vec.cpp +321 -0
  56. package/cpp/ggml-cpu/vec.h +973 -0
  57. package/cpp/ggml-cpu.h +143 -0
  58. package/cpp/ggml-impl.h +417 -23
  59. package/cpp/ggml-metal-impl.h +622 -0
  60. package/cpp/ggml-metal.h +9 -9
  61. package/cpp/ggml-metal.m +3451 -1344
  62. package/cpp/ggml-opt.cpp +1037 -0
  63. package/cpp/ggml-opt.h +237 -0
  64. package/cpp/ggml-quants.c +296 -10818
  65. package/cpp/ggml-quants.h +78 -125
  66. package/cpp/ggml-threading.cpp +12 -0
  67. package/cpp/ggml-threading.h +14 -0
  68. package/cpp/ggml-whisper-sim.metallib +0 -0
  69. package/cpp/ggml-whisper.metallib +0 -0
  70. package/cpp/ggml.c +4633 -21450
  71. package/cpp/ggml.h +320 -661
  72. package/cpp/gguf.cpp +1347 -0
  73. package/cpp/gguf.h +202 -0
  74. package/cpp/rn-whisper.cpp +4 -11
  75. package/cpp/whisper-arch.h +197 -0
  76. package/cpp/whisper.cpp +2022 -495
  77. package/cpp/whisper.h +75 -18
  78. package/ios/CMakeLists.txt +95 -0
  79. package/ios/RNWhisper.h +5 -0
  80. package/ios/RNWhisperAudioUtils.m +4 -0
  81. package/ios/RNWhisperContext.h +5 -0
  82. package/ios/RNWhisperContext.mm +4 -2
  83. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  84. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  86. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  89. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  129. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  130. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  131. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  132. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  133. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  134. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  135. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  174. package/jest/mock.js +5 -0
  175. package/lib/commonjs/version.json +1 -1
  176. package/lib/module/version.json +1 -1
  177. package/package.json +10 -6
  178. package/src/version.json +1 -1
  179. package/whisper-rn.podspec +11 -18
  180. package/cpp/README.md +0 -4
  181. package/cpp/ggml-aarch64.c +0 -3209
  182. package/cpp/ggml-aarch64.h +0 -39
  183. package/cpp/ggml-cpu-impl.h +0 -614
@@ -0,0 +1,184 @@
1
+ #pragma once
2
+
3
+ // Rename `_generic` functions if no native implementation is available.
4
+ // This effectively selects the generic implementation.
5
+
6
+ #if defined(WSP_GGML_CPU_GENERIC)
7
+ // quants.c
8
+ #define wsp_quantize_row_q8_0_generic wsp_quantize_row_q8_0
9
+ #define wsp_quantize_row_q8_1_generic wsp_quantize_row_q8_1
10
+ #define wsp_quantize_row_q8_K_generic wsp_quantize_row_q8_K
11
+ #define wsp_ggml_vec_dot_q4_0_q8_0_generic wsp_ggml_vec_dot_q4_0_q8_0
12
+ #define wsp_ggml_vec_dot_q4_1_q8_1_generic wsp_ggml_vec_dot_q4_1_q8_1
13
+ #define wsp_ggml_vec_dot_q5_0_q8_0_generic wsp_ggml_vec_dot_q5_0_q8_0
14
+ #define wsp_ggml_vec_dot_q5_1_q8_1_generic wsp_ggml_vec_dot_q5_1_q8_1
15
+ #define wsp_ggml_vec_dot_q8_0_q8_0_generic wsp_ggml_vec_dot_q8_0_q8_0
16
+ #define wsp_ggml_vec_dot_tq1_0_q8_K_generic wsp_ggml_vec_dot_tq1_0_q8_K
17
+ #define wsp_ggml_vec_dot_tq2_0_q8_K_generic wsp_ggml_vec_dot_tq2_0_q8_K
18
+ #define wsp_ggml_vec_dot_q2_K_q8_K_generic wsp_ggml_vec_dot_q2_K_q8_K
19
+ #define wsp_ggml_vec_dot_q3_K_q8_K_generic wsp_ggml_vec_dot_q3_K_q8_K
20
+ #define wsp_ggml_vec_dot_q4_K_q8_K_generic wsp_ggml_vec_dot_q4_K_q8_K
21
+ #define wsp_ggml_vec_dot_q5_K_q8_K_generic wsp_ggml_vec_dot_q5_K_q8_K
22
+ #define wsp_ggml_vec_dot_q6_K_q8_K_generic wsp_ggml_vec_dot_q6_K_q8_K
23
+ #define wsp_ggml_vec_dot_iq2_xxs_q8_K_generic wsp_ggml_vec_dot_iq2_xxs_q8_K
24
+ #define wsp_ggml_vec_dot_iq2_xs_q8_K_generic wsp_ggml_vec_dot_iq2_xs_q8_K
25
+ #define wsp_ggml_vec_dot_iq2_s_q8_K_generic wsp_ggml_vec_dot_iq2_s_q8_K
26
+ #define wsp_ggml_vec_dot_iq3_xxs_q8_K_generic wsp_ggml_vec_dot_iq3_xxs_q8_K
27
+ #define wsp_ggml_vec_dot_iq3_s_q8_K_generic wsp_ggml_vec_dot_iq3_s_q8_K
28
+ #define wsp_ggml_vec_dot_iq1_s_q8_K_generic wsp_ggml_vec_dot_iq1_s_q8_K
29
+ #define wsp_ggml_vec_dot_iq1_m_q8_K_generic wsp_ggml_vec_dot_iq1_m_q8_K
30
+ #define wsp_ggml_vec_dot_iq4_nl_q8_0_generic wsp_ggml_vec_dot_iq4_nl_q8_0
31
+ #define wsp_ggml_vec_dot_iq4_xs_q8_K_generic wsp_ggml_vec_dot_iq4_xs_q8_K
32
+ // repack.cpp
33
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
34
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic wsp_ggml_wsp_quantize_mat_q8_0_4x8
35
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
36
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
37
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
38
+ #define wsp_ggml_gemv_q4_0_8x8_q8_0_generic wsp_ggml_gemv_q4_0_8x8_q8_0
39
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
40
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
41
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
42
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
43
+ #define wsp_ggml_gemm_q4_0_8x8_q8_0_generic wsp_ggml_gemm_q4_0_8x8_q8_0
44
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
45
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
46
+ #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
47
+ // repack.cpp
48
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
49
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
50
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
51
+ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
52
+ // repack.cpp
53
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
54
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
55
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
56
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
57
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
58
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
59
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
60
+ #elif defined(__POWERPC__) || defined(__powerpc__)
61
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
62
+ // quants.c
63
+ #define wsp_quantize_row_q8_K_generic wsp_quantize_row_q8_K
64
+ #define wsp_ggml_vec_dot_tq1_0_q8_K_generic wsp_ggml_vec_dot_tq1_0_q8_K
65
+ #define wsp_ggml_vec_dot_tq2_0_q8_K_generic wsp_ggml_vec_dot_tq2_0_q8_K
66
+ #define wsp_ggml_vec_dot_iq1_m_q8_K_generic wsp_ggml_vec_dot_iq1_m_q8_K
67
+ // repack.cpp
68
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
69
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic wsp_ggml_wsp_quantize_mat_q8_0_4x8
70
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
71
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
72
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
73
+ #define wsp_ggml_gemv_q4_0_8x8_q8_0_generic wsp_ggml_gemv_q4_0_8x8_q8_0
74
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
75
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
76
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
77
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
78
+ #define wsp_ggml_gemm_q4_0_8x8_q8_0_generic wsp_ggml_gemm_q4_0_8x8_q8_0
79
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
80
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
81
+ #elif defined(__loongarch64)
82
+ // quants.c
83
+ #define wsp_quantize_row_q8_K_generic wsp_quantize_row_q8_K
84
+ #define wsp_ggml_vec_dot_tq1_0_q8_K_generic wsp_ggml_vec_dot_tq1_0_q8_K
85
+ #define wsp_ggml_vec_dot_tq2_0_q8_K_generic wsp_ggml_vec_dot_tq2_0_q8_K
86
+ #define wsp_ggml_vec_dot_iq1_m_q8_K_generic wsp_ggml_vec_dot_iq1_m_q8_K
87
+ // repack.cpp
88
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
89
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic wsp_ggml_wsp_quantize_mat_q8_0_4x8
90
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
91
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
92
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
93
+ #define wsp_ggml_gemv_q4_0_8x8_q8_0_generic wsp_ggml_gemv_q4_0_8x8_q8_0
94
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
95
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
96
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
97
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
98
+ #define wsp_ggml_gemm_q4_0_8x8_q8_0_generic wsp_ggml_gemm_q4_0_8x8_q8_0
99
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
100
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
101
+ #elif defined(__riscv)
102
+ // quants.c
103
+ #define wsp_quantize_row_q8_K_generic wsp_quantize_row_q8_K
104
+ #define wsp_ggml_vec_dot_tq1_0_q8_K_generic wsp_ggml_vec_dot_tq1_0_q8_K
105
+ #define wsp_ggml_vec_dot_tq2_0_q8_K_generic wsp_ggml_vec_dot_tq2_0_q8_K
106
+ #define wsp_ggml_vec_dot_iq2_xxs_q8_K_generic wsp_ggml_vec_dot_iq2_xxs_q8_K
107
+ #define wsp_ggml_vec_dot_iq2_xs_q8_K_generic wsp_ggml_vec_dot_iq2_xs_q8_K
108
+ #define wsp_ggml_vec_dot_iq2_s_q8_K_generic wsp_ggml_vec_dot_iq2_s_q8_K
109
+ #define wsp_ggml_vec_dot_iq3_xxs_q8_K_generic wsp_ggml_vec_dot_iq3_xxs_q8_K
110
+ #define wsp_ggml_vec_dot_iq3_s_q8_K_generic wsp_ggml_vec_dot_iq3_s_q8_K
111
+ #define wsp_ggml_vec_dot_iq1_s_q8_K_generic wsp_ggml_vec_dot_iq1_s_q8_K
112
+ #define wsp_ggml_vec_dot_iq1_m_q8_K_generic wsp_ggml_vec_dot_iq1_m_q8_K
113
+ #define wsp_ggml_vec_dot_iq4_nl_q8_0_generic wsp_ggml_vec_dot_iq4_nl_q8_0
114
+ #define wsp_ggml_vec_dot_iq4_xs_q8_K_generic wsp_ggml_vec_dot_iq4_xs_q8_K
115
+ // repack.cpp
116
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
117
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic wsp_ggml_wsp_quantize_mat_q8_0_4x8
118
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
119
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
120
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
121
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
122
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
123
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
124
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
125
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
126
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
127
+ #elif defined(__s390x__)
128
+ // quants.c
129
+ #define wsp_quantize_row_q8_K_generic wsp_quantize_row_q8_K
130
+ #define wsp_ggml_vec_dot_q5_0_q8_0_generic wsp_ggml_vec_dot_q5_0_q8_0
131
+ #define wsp_ggml_vec_dot_q5_1_q8_1_generic wsp_ggml_vec_dot_q5_1_q8_1
132
+ #define wsp_ggml_vec_dot_tq1_0_q8_K_generic wsp_ggml_vec_dot_tq1_0_q8_K
133
+ #define wsp_ggml_vec_dot_tq2_0_q8_K_generic wsp_ggml_vec_dot_tq2_0_q8_K
134
+ #define wsp_ggml_vec_dot_q2_K_q8_K_generic wsp_ggml_vec_dot_q2_K_q8_K
135
+ #define wsp_ggml_vec_dot_iq2_xxs_q8_K_generic wsp_ggml_vec_dot_iq2_xxs_q8_K
136
+ #define wsp_ggml_vec_dot_iq2_xs_q8_K_generic wsp_ggml_vec_dot_iq2_xs_q8_K
137
+ #define wsp_ggml_vec_dot_iq2_s_q8_K_generic wsp_ggml_vec_dot_iq2_s_q8_K
138
+ #define wsp_ggml_vec_dot_iq3_xxs_q8_K_generic wsp_ggml_vec_dot_iq3_xxs_q8_K
139
+ #define wsp_ggml_vec_dot_iq3_s_q8_K_generic wsp_ggml_vec_dot_iq3_s_q8_K
140
+ #define wsp_ggml_vec_dot_iq1_s_q8_K_generic wsp_ggml_vec_dot_iq1_s_q8_K
141
+ #define wsp_ggml_vec_dot_iq1_m_q8_K_generic wsp_ggml_vec_dot_iq1_m_q8_K
142
+ // repack.cpp
143
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
144
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic wsp_ggml_wsp_quantize_mat_q8_0_4x8
145
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
146
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
147
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
148
+ #define wsp_ggml_gemv_q4_0_8x8_q8_0_generic wsp_ggml_gemv_q4_0_8x8_q8_0
149
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
150
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
151
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
152
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
153
+ #define wsp_ggml_gemm_q4_0_8x8_q8_0_generic wsp_ggml_gemm_q4_0_8x8_q8_0
154
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
155
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
156
+ #elif defined(__wasm__)
157
+ // quants.c
158
+ #define wsp_ggml_vec_dot_q4_1_q8_1_generic wsp_ggml_vec_dot_q4_1_q8_1
159
+ #define wsp_ggml_vec_dot_tq1_0_q8_K_generic wsp_ggml_vec_dot_tq1_0_q8_K
160
+ #define wsp_ggml_vec_dot_tq2_0_q8_K_generic wsp_ggml_vec_dot_tq2_0_q8_K
161
+ #define wsp_ggml_vec_dot_iq2_xxs_q8_K_generic wsp_ggml_vec_dot_iq2_xxs_q8_K
162
+ #define wsp_ggml_vec_dot_iq2_xs_q8_K_generic wsp_ggml_vec_dot_iq2_xs_q8_K
163
+ #define wsp_ggml_vec_dot_iq2_s_q8_K_generic wsp_ggml_vec_dot_iq2_s_q8_K
164
+ #define wsp_ggml_vec_dot_iq3_xxs_q8_K_generic wsp_ggml_vec_dot_iq3_xxs_q8_K
165
+ #define wsp_ggml_vec_dot_iq3_s_q8_K_generic wsp_ggml_vec_dot_iq3_s_q8_K
166
+ #define wsp_ggml_vec_dot_iq1_s_q8_K_generic wsp_ggml_vec_dot_iq1_s_q8_K
167
+ #define wsp_ggml_vec_dot_iq1_m_q8_K_generic wsp_ggml_vec_dot_iq1_m_q8_K
168
+ #define wsp_ggml_vec_dot_iq4_nl_q8_0_generic wsp_ggml_vec_dot_iq4_nl_q8_0
169
+ #define wsp_ggml_vec_dot_iq4_xs_q8_K_generic wsp_ggml_vec_dot_iq4_xs_q8_K
170
+ // repack.cpp
171
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic wsp_ggml_wsp_quantize_mat_q8_0_4x4
172
+ #define wsp_ggml_wsp_quantize_mat_q8_0_4x8_generic wsp_ggml_wsp_quantize_mat_q8_0_4x8
173
+ #define wsp_ggml_wsp_quantize_mat_q8_K_4x8_generic wsp_ggml_wsp_quantize_mat_q8_K_4x8
174
+ #define wsp_ggml_gemv_q4_0_4x4_q8_0_generic wsp_ggml_gemv_q4_0_4x4_q8_0
175
+ #define wsp_ggml_gemv_q4_0_4x8_q8_0_generic wsp_ggml_gemv_q4_0_4x8_q8_0
176
+ #define wsp_ggml_gemv_q4_0_8x8_q8_0_generic wsp_ggml_gemv_q4_0_8x8_q8_0
177
+ #define wsp_ggml_gemv_q4_K_8x8_q8_K_generic wsp_ggml_gemv_q4_K_8x8_q8_K
178
+ #define wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic wsp_ggml_gemv_iq4_nl_4x4_q8_0
179
+ #define wsp_ggml_gemm_q4_0_4x4_q8_0_generic wsp_ggml_gemm_q4_0_4x4_q8_0
180
+ #define wsp_ggml_gemm_q4_0_4x8_q8_0_generic wsp_ggml_gemm_q4_0_4x8_q8_0
181
+ #define wsp_ggml_gemm_q4_0_8x8_q8_0_generic wsp_ggml_gemm_q4_0_8x8_q8_0
182
+ #define wsp_ggml_gemm_q4_K_8x8_q8_K_generic wsp_ggml_gemm_q4_K_8x8_q8_K
183
+ #define wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic wsp_ggml_gemm_iq4_nl_4x4_q8_0
184
+ #endif
@@ -0,0 +1,158 @@
1
+ #include "binary-ops.h"
2
+
3
+ #if defined(WSP_GGML_USE_ACCELERATE)
4
+ #include <Accelerate/Accelerate.h>
5
+
6
+ using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);
7
+ #endif
8
+
9
+ static inline float op_add(float a, float b) {
10
+ return a + b;
11
+ }
12
+
13
+ static inline float op_sub(float a, float b) {
14
+ return a - b;
15
+ }
16
+
17
+ static inline float op_mul(float a, float b) {
18
+ return a * b;
19
+ }
20
+
21
+ static inline float op_div(float a, float b) {
22
+ return a / b;
23
+ }
24
+
25
+ template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
26
+ static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {
27
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
28
+ constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
29
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
30
+
31
+ for (int i = 0; i < n; i++) {
32
+ z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));
33
+ }
34
+ }
35
+
36
+ template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
37
+ static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {
38
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
39
+ constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
40
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
41
+
42
+ for (int i = 0; i < n; i++) {
43
+ int i10 = i % ne10;
44
+ const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);
45
+ z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));
46
+ }
47
+ }
48
+
49
+ template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
50
+ static void apply_binary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
51
+ const wsp_ggml_tensor * src0 = dst->src[0];
52
+ const wsp_ggml_tensor * src1 = dst->src[1];
53
+
54
+ WSP_GGML_ASSERT(wsp_ggml_can_repeat(src1, src0) && wsp_ggml_are_same_shape(src0, dst));
55
+
56
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
57
+
58
+ WSP_GGML_ASSERT( nb0 == sizeof(dst_t));
59
+ WSP_GGML_ASSERT(nb00 == sizeof(src0_t));
60
+
61
+ const auto [ir0, ir1] = get_thread_range(params, src0);
62
+ const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
63
+
64
+ if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
65
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
66
+ }
67
+
68
+ #ifdef WSP_GGML_USE_ACCELERATE
69
+ vDSP_fn_t vDSP_op = nullptr;
70
+ // TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions
71
+ if (src0->type == WSP_GGML_TYPE_F32 && src1->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) {
72
+ if (op == op_add) {
73
+ vDSP_op = vDSP_vadd;
74
+ } else if (op == op_sub) {
75
+ vDSP_op = vDSP_vsub;
76
+ } else if (op == op_mul) {
77
+ vDSP_op = vDSP_vmul;
78
+ } else if (op == op_div) {
79
+ vDSP_op = vDSP_vdiv;
80
+ }
81
+ }
82
+ #endif
83
+
84
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
85
+ const int64_t i03 = ir/(ne02*ne01);
86
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
87
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
88
+
89
+ const int64_t i13 = i03 % ne13;
90
+ const int64_t i12 = i02 % ne12;
91
+ const int64_t i11 = i01 % ne11;
92
+
93
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
94
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
95
+ const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
96
+
97
+ if (is_src1_contiguous) {
98
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
99
+ const int64_t nr0 = ne00 / ne10;
100
+
101
+ for (int64_t r = 0; r < nr0; ++r) {
102
+ #ifdef WSP_GGML_USE_ACCELERATE
103
+ if constexpr (std::is_same_v<src0_t, float> && std::is_same_v<src1_t, float> && std::is_same_v<dst_t, float>) {
104
+ if (vDSP_op != nullptr) {
105
+ vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
106
+ continue;
107
+ }
108
+ }
109
+ #endif
110
+ vec_binary_op_contiguous<op>(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
111
+ }
112
+ } else {
113
+ vec_binary_op_non_contiguous<op>(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);
114
+ }
115
+ }
116
+ }
117
+
118
+ // TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
119
+ template <float (*op)(float, float)>
120
+ static void binary_op(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
121
+ const wsp_ggml_tensor * src0 = dst->src[0];
122
+ const wsp_ggml_tensor * src1 = dst->src[1];
123
+
124
+ /* */ if (src0->type == WSP_GGML_TYPE_F32 && src1->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) { // all f32
125
+ apply_binary_op<op, float, float, float>(params, dst);
126
+ } else if (src0->type == WSP_GGML_TYPE_F16 && src1->type == WSP_GGML_TYPE_F16 && dst->type == WSP_GGML_TYPE_F16) { // all f16
127
+ apply_binary_op<op, wsp_ggml_fp16_t, wsp_ggml_fp16_t, wsp_ggml_fp16_t>(params, dst);
128
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && src1->type == WSP_GGML_TYPE_BF16 && dst->type == WSP_GGML_TYPE_BF16) { // all bf16
129
+ apply_binary_op<op, wsp_ggml_bf16_t, wsp_ggml_bf16_t, wsp_ggml_bf16_t>(params, dst);
130
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && src1->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_BF16) {
131
+ apply_binary_op<op, wsp_ggml_bf16_t, float, wsp_ggml_bf16_t>(params, dst);
132
+ } else if (src0->type == WSP_GGML_TYPE_BF16 && src1->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) {
133
+ apply_binary_op<op, wsp_ggml_bf16_t, float, float>(params, dst);
134
+ } else if (src0->type == WSP_GGML_TYPE_F16 && src1->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F16) {
135
+ apply_binary_op<op, wsp_ggml_fp16_t, float, wsp_ggml_fp16_t>(params, dst);
136
+ } else if (src0->type == WSP_GGML_TYPE_F16 && src1->type == WSP_GGML_TYPE_F32 && dst->type == WSP_GGML_TYPE_F32) {
137
+ apply_binary_op<op, wsp_ggml_fp16_t, float, float>(params, dst);
138
+ } else {
139
+ WSP_GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
140
+ wsp_ggml_type_name(dst->type), wsp_ggml_type_name(src0->type), wsp_ggml_type_name(src1->type));
141
+ }
142
+ }
143
+
144
+ void wsp_ggml_compute_forward_add_non_quantized(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
145
+ binary_op<op_add>(params, dst);
146
+ }
147
+
148
+ void wsp_ggml_compute_forward_sub(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
149
+ binary_op<op_sub>(params, dst);
150
+ }
151
+
152
+ void wsp_ggml_compute_forward_mul(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
153
+ binary_op<op_mul>(params, dst);
154
+ }
155
+
156
+ void wsp_ggml_compute_forward_div(const wsp_ggml_compute_params * params, wsp_ggml_tensor * dst) {
157
+ binary_op<op_div>(params, dst);
158
+ }
@@ -0,0 +1,16 @@
1
+ #pragma once
2
+
3
+ #include "common.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+ void wsp_ggml_compute_forward_add_non_quantized(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
10
+ void wsp_ggml_compute_forward_sub(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
11
+ void wsp_ggml_compute_forward_mul(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
12
+ void wsp_ggml_compute_forward_div(const struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * dst);
13
+
14
+ #ifdef __cplusplus
15
+ }
16
+ #endif
@@ -0,0 +1,72 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "traits.h"
5
+ #include "ggml-cpu-impl.h"
6
+ #include "ggml-impl.h"
7
+
8
+ #ifdef __cplusplus
9
+
10
+ #include <utility>
11
+
12
+ // convenience functions/macros for use in template calls
13
+ // note: these won't be required after the 'traits' lookup table is used.
14
+ static inline wsp_ggml_fp16_t f32_to_f16(float x) {
15
+ return WSP_GGML_FP32_TO_FP16(x);
16
+ }
17
+
18
+ static inline float f16_to_f32(wsp_ggml_fp16_t x) {
19
+ return WSP_GGML_FP16_TO_FP32(x);
20
+ }
21
+
22
+ static inline wsp_ggml_bf16_t f32_to_bf16(float x) {
23
+ return WSP_GGML_FP32_TO_BF16(x);
24
+ }
25
+
26
+ static inline float bf16_to_f32(wsp_ggml_bf16_t x) {
27
+ return WSP_GGML_BF16_TO_FP32(x);
28
+ }
29
+
30
+ static inline float f32_to_f32(float x) {
31
+ return x;
32
+ }
33
+
34
+ // TODO - merge this into the traits table, after using row-based conversions
35
+ template <class T>
36
+ struct type_conversion_table;
37
+
38
+ template <>
39
+ struct type_conversion_table<wsp_ggml_fp16_t> {
40
+ static constexpr float (*to_f32)(wsp_ggml_fp16_t) = f16_to_f32;
41
+ static constexpr wsp_ggml_fp16_t (*from_f32)(float) = f32_to_f16;
42
+ };
43
+
44
+ template <>
45
+ struct type_conversion_table<float> {
46
+ static constexpr float (*to_f32)(float) = f32_to_f32;
47
+ static constexpr float (*from_f32)(float) = f32_to_f32;
48
+ };
49
+
50
+ template <>
51
+ struct type_conversion_table<wsp_ggml_bf16_t> {
52
+ static constexpr float (*to_f32)(wsp_ggml_bf16_t) = bf16_to_f32;
53
+ static constexpr wsp_ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
54
+ };
55
+
56
+ static std::pair<int64_t, int64_t> get_thread_range(const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0) {
57
+ const int64_t ith = params->ith;
58
+ const int64_t nth = params->nth;
59
+
60
+ const int64_t nr = wsp_ggml_nrows(src0);
61
+
62
+ // rows per thread
63
+ const int64_t dr = (nr + nth - 1)/nth;
64
+
65
+ // row range for this thread
66
+ const int64_t ir0 = dr*ith;
67
+ const int64_t ir1 = MIN(ir0 + dr, nr);
68
+
69
+ return {ir0, ir1};
70
+ }
71
+
72
+ #endif