whisper.rn 0.4.0-rc.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +44 -13
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -38
  7. package/android/src/main/jni.cpp +38 -1
  8. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  15. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  16. package/cpp/coreml/whisper-compat.h +10 -0
  17. package/cpp/coreml/whisper-compat.m +35 -0
  18. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  19. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  20. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  21. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  22. package/cpp/ggml-alloc.c +727 -517
  23. package/cpp/ggml-alloc.h +47 -65
  24. package/cpp/ggml-backend-impl.h +196 -57
  25. package/cpp/ggml-backend-reg.cpp +591 -0
  26. package/cpp/ggml-backend.cpp +2016 -0
  27. package/cpp/ggml-backend.h +234 -89
  28. package/cpp/ggml-common.h +1861 -0
  29. package/cpp/ggml-cpp.h +39 -0
  30. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  31. package/cpp/ggml-cpu/amx/amx.h +8 -0
  32. package/cpp/ggml-cpu/amx/common.h +91 -0
  33. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  34. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  35. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  36. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  37. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  38. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  39. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  40. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  41. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  42. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  43. package/cpp/ggml-cpu/binary-ops.h +16 -0
  44. package/cpp/ggml-cpu/common.h +72 -0
  45. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  46. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  47. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  48. package/cpp/ggml-cpu/ops.cpp +9085 -0
  49. package/cpp/ggml-cpu/ops.h +111 -0
  50. package/cpp/ggml-cpu/quants.c +1157 -0
  51. package/cpp/ggml-cpu/quants.h +89 -0
  52. package/cpp/ggml-cpu/repack.cpp +1570 -0
  53. package/cpp/ggml-cpu/repack.h +98 -0
  54. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  55. package/cpp/ggml-cpu/traits.cpp +36 -0
  56. package/cpp/ggml-cpu/traits.h +38 -0
  57. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  58. package/cpp/ggml-cpu/unary-ops.h +28 -0
  59. package/cpp/ggml-cpu/vec.cpp +321 -0
  60. package/cpp/ggml-cpu/vec.h +973 -0
  61. package/cpp/ggml-cpu.h +143 -0
  62. package/cpp/ggml-impl.h +525 -168
  63. package/cpp/ggml-metal-impl.h +622 -0
  64. package/cpp/ggml-metal.h +16 -14
  65. package/cpp/ggml-metal.m +5289 -1859
  66. package/cpp/ggml-opt.cpp +1037 -0
  67. package/cpp/ggml-opt.h +237 -0
  68. package/cpp/ggml-quants.c +2916 -6877
  69. package/cpp/ggml-quants.h +87 -249
  70. package/cpp/ggml-threading.cpp +12 -0
  71. package/cpp/ggml-threading.h +14 -0
  72. package/cpp/ggml-whisper-sim.metallib +0 -0
  73. package/cpp/ggml-whisper.metallib +0 -0
  74. package/cpp/ggml.c +3293 -16770
  75. package/cpp/ggml.h +778 -835
  76. package/cpp/gguf.cpp +1347 -0
  77. package/cpp/gguf.h +202 -0
  78. package/cpp/rn-whisper.cpp +84 -0
  79. package/cpp/rn-whisper.h +2 -0
  80. package/cpp/whisper-arch.h +197 -0
  81. package/cpp/whisper.cpp +3240 -944
  82. package/cpp/whisper.h +144 -31
  83. package/ios/CMakeLists.txt +95 -0
  84. package/ios/RNWhisper.h +5 -0
  85. package/ios/RNWhisper.mm +124 -37
  86. package/ios/RNWhisperAudioUtils.h +1 -0
  87. package/ios/RNWhisperAudioUtils.m +24 -13
  88. package/ios/RNWhisperContext.h +8 -2
  89. package/ios/RNWhisperContext.mm +42 -8
  90. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  129. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  130. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  131. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  132. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  133. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  134. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  135. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  174. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  175. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  176. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  177. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  178. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  179. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  180. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  181. package/jest/mock.js +14 -1
  182. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  183. package/lib/commonjs/index.js +48 -19
  184. package/lib/commonjs/index.js.map +1 -1
  185. package/lib/commonjs/version.json +1 -1
  186. package/lib/module/NativeRNWhisper.js.map +1 -1
  187. package/lib/module/index.js +48 -19
  188. package/lib/module/index.js.map +1 -1
  189. package/lib/module/version.json +1 -1
  190. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  191. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  192. package/lib/typescript/index.d.ts +25 -3
  193. package/lib/typescript/index.d.ts.map +1 -1
  194. package/package.json +15 -10
  195. package/src/NativeRNWhisper.ts +12 -3
  196. package/src/index.ts +63 -24
  197. package/src/version.json +1 -1
  198. package/whisper-rn.podspec +18 -18
  199. package/cpp/README.md +0 -4
  200. package/cpp/ggml-backend.c +0 -1718
  201. package/cpp/ggml-metal-whisper.metal +0 -5820
@@ -1,1718 +0,0 @@
1
- #include "ggml-backend-impl.h"
2
- #include "ggml-alloc.h"
3
- #include "ggml-impl.h"
4
-
5
- #include <assert.h>
6
- #include <limits.h>
7
- #include <stdarg.h>
8
- #include <stdio.h>
9
- #include <stdlib.h>
10
- #include <string.h>
11
-
12
-
13
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
14
-
15
-
16
- // backend buffer type
17
-
18
- const char * wsp_ggml_backend_buft_name(wsp_ggml_backend_buffer_type_t buft) {
19
- return buft->iface.get_name(buft);
20
- }
21
-
22
- WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
23
- return buft->iface.alloc_buffer(buft, size);
24
- }
25
-
26
- size_t wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
27
- return buft->iface.get_alignment(buft);
28
- }
29
-
30
- WSP_GGML_CALL size_t wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_type_t buft, struct wsp_ggml_tensor * tensor) {
31
- // get_alloc_size is optional, defaults to wsp_ggml_nbytes
32
- if (buft->iface.get_alloc_size) {
33
- return buft->iface.get_alloc_size(buft, tensor);
34
- }
35
- return wsp_ggml_nbytes(tensor);
36
- }
37
-
38
- bool wsp_ggml_backend_buft_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
39
- return buft->iface.supports_backend(buft, backend);
40
- }
41
-
42
- bool wsp_ggml_backend_buft_is_host(wsp_ggml_backend_buffer_type_t buft) {
43
- if (buft->iface.is_host) {
44
- return buft->iface.is_host(buft);
45
- }
46
- return false;
47
- }
48
-
49
- // backend buffer
50
-
51
- WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
52
- wsp_ggml_backend_buffer_type_t buft,
53
- struct wsp_ggml_backend_buffer_i iface,
54
- wsp_ggml_backend_buffer_context_t context,
55
- size_t size) {
56
- wsp_ggml_backend_buffer_t buffer = malloc(sizeof(struct wsp_ggml_backend_buffer));
57
-
58
- WSP_GGML_ASSERT(iface.get_base != NULL);
59
-
60
- (*buffer) = (struct wsp_ggml_backend_buffer) {
61
- /* .interface = */ iface,
62
- /* .buft = */ buft,
63
- /* .context = */ context,
64
- /* .size = */ size,
65
- /* .usage = */ WSP_GGML_BACKEND_BUFFER_USAGE_ANY
66
- };
67
-
68
- return buffer;
69
- }
70
-
71
- const char * wsp_ggml_backend_buffer_name(wsp_ggml_backend_buffer_t buffer) {
72
- return buffer->iface.get_name(buffer);
73
- }
74
-
75
- void wsp_ggml_backend_buffer_free(wsp_ggml_backend_buffer_t buffer) {
76
- if (buffer == NULL) {
77
- return;
78
- }
79
-
80
- if (buffer->iface.free_buffer != NULL) {
81
- buffer->iface.free_buffer(buffer);
82
- }
83
- free(buffer);
84
- }
85
-
86
- size_t wsp_ggml_backend_buffer_get_size(wsp_ggml_backend_buffer_t buffer) {
87
- return buffer->size;
88
- }
89
-
90
- void * wsp_ggml_backend_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
91
- void * base = buffer->iface.get_base(buffer);
92
-
93
- WSP_GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
94
-
95
- return base;
96
- }
97
-
98
- WSP_GGML_CALL void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
99
- // init_tensor is optional
100
- if (buffer->iface.init_tensor) {
101
- buffer->iface.init_tensor(buffer, tensor);
102
- }
103
- }
104
-
105
- size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer) {
106
- return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_buffer_get_type(buffer));
107
- }
108
-
109
- size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
110
- return wsp_ggml_backend_buft_get_alloc_size(wsp_ggml_backend_buffer_get_type(buffer), tensor);
111
- }
112
-
113
- void wsp_ggml_backend_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
114
- buffer->iface.clear(buffer, value);
115
- }
116
-
117
- bool wsp_ggml_backend_buffer_is_host(wsp_ggml_backend_buffer_t buffer) {
118
- return wsp_ggml_backend_buft_is_host(wsp_ggml_backend_buffer_get_type(buffer));
119
- }
120
-
121
- void wsp_ggml_backend_buffer_set_usage(wsp_ggml_backend_buffer_t buffer, enum wsp_ggml_backend_buffer_usage usage) {
122
- buffer->usage = usage;
123
- }
124
-
125
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_buffer_get_type(wsp_ggml_backend_buffer_t buffer) {
126
- return buffer->buft;
127
- }
128
-
129
- void wsp_ggml_backend_buffer_reset(wsp_ggml_backend_buffer_t buffer) {
130
- if (buffer->iface.reset) {
131
- buffer->iface.reset(buffer);
132
- }
133
- }
134
-
135
- bool wsp_ggml_backend_buffer_copy_tensor(const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
136
- wsp_ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
137
- if (dst_buf->iface.cpy_tensor) {
138
- return src->buffer->iface.cpy_tensor(dst_buf, src, dst);
139
- }
140
- return false;
141
- }
142
-
143
- // backend
144
-
145
- const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend) {
146
- if (backend == NULL) {
147
- return "NULL";
148
- }
149
- return backend->iface.get_name(backend);
150
- }
151
-
152
- void wsp_ggml_backend_free(wsp_ggml_backend_t backend) {
153
- if (backend == NULL) {
154
- return;
155
- }
156
-
157
- backend->iface.free(backend);
158
- }
159
-
160
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_get_default_buffer_type(wsp_ggml_backend_t backend) {
161
- return backend->iface.get_default_buffer_type(backend);
162
- }
163
-
164
- wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
165
- return wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_get_default_buffer_type(backend), size);
166
- }
167
-
168
- size_t wsp_ggml_backend_get_alignment(wsp_ggml_backend_t backend) {
169
- return wsp_ggml_backend_buft_get_alignment(wsp_ggml_backend_get_default_buffer_type(backend));
170
- }
171
-
172
- void wsp_ggml_backend_tensor_set_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
173
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
174
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
175
-
176
- if (backend->iface.set_tensor_async == NULL) {
177
- wsp_ggml_backend_tensor_set(tensor, data, offset, size);
178
- } else {
179
- backend->iface.set_tensor_async(backend, tensor, data, offset, size);
180
- }
181
- }
182
-
183
- void wsp_ggml_backend_tensor_get_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
184
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
185
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
186
-
187
- if (backend->iface.get_tensor_async == NULL) {
188
- wsp_ggml_backend_tensor_get(tensor, data, offset, size);
189
- } else {
190
- backend->iface.get_tensor_async(backend, tensor, data, offset, size);
191
- }
192
- }
193
-
194
- WSP_GGML_CALL void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
195
- wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
196
-
197
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
198
- WSP_GGML_ASSERT(buf != NULL && "tensor buffer not set");
199
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
200
-
201
- tensor->buffer->iface.set_tensor(buf, tensor, data, offset, size);
202
- }
203
-
204
- WSP_GGML_CALL void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
205
- wsp_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
206
-
207
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
208
- WSP_GGML_ASSERT(tensor->buffer != NULL && "tensor buffer not set");
209
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
210
-
211
- tensor->buffer->iface.get_tensor(buf, tensor, data, offset, size);
212
- }
213
-
214
- void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend) {
215
- if (backend->iface.synchronize == NULL) {
216
- return;
217
- }
218
-
219
- backend->iface.synchronize(backend);
220
- }
221
-
222
- wsp_ggml_backend_graph_plan_t wsp_ggml_backend_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
223
- return backend->iface.graph_plan_create(backend, cgraph);
224
- }
225
-
226
- void wsp_ggml_backend_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
227
- backend->iface.graph_plan_free(backend, plan);
228
- }
229
-
230
- void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
231
- backend->iface.graph_plan_compute(backend, plan);
232
- }
233
-
234
- bool wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
235
- return backend->iface.graph_compute(backend, cgraph);
236
- }
237
-
238
- bool wsp_ggml_backend_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
239
- return backend->iface.supports_op(backend, op);
240
- }
241
-
242
- // backend copy
243
-
244
- static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
245
- if (a->type != b->type) {
246
- return false;
247
- }
248
- for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
249
- if (a->ne[i] != b->ne[i]) {
250
- return false;
251
- }
252
- if (a->nb[i] != b->nb[i]) {
253
- return false;
254
- }
255
- }
256
- return true;
257
- }
258
-
259
- void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
260
- WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
261
-
262
- if (src == dst) {
263
- return;
264
- }
265
-
266
- if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
267
- wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
268
- } else if (wsp_ggml_backend_buffer_is_host(dst->buffer)) {
269
- wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
270
- } else if (!wsp_ggml_backend_buffer_copy_tensor(src, dst)) {
271
- #ifndef NDEBUG
272
- fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, wsp_ggml_backend_buffer_name(src->buffer), wsp_ggml_backend_buffer_name(dst->buffer));
273
- #endif
274
- size_t nbytes = wsp_ggml_nbytes(src);
275
- void * data = malloc(nbytes);
276
- wsp_ggml_backend_tensor_get(src, data, 0, nbytes);
277
- wsp_ggml_backend_tensor_set(dst, data, 0, nbytes);
278
- free(data);
279
- }
280
- }
281
-
282
- void wsp_ggml_backend_tensor_copy_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
283
- WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
284
-
285
- if (src == dst) {
286
- return;
287
- }
288
-
289
- if (wsp_ggml_backend_buft_supports_backend(src->buffer->buft, backend) && wsp_ggml_backend_buft_supports_backend(dst->buffer->buft, backend)) {
290
- if (backend->iface.cpy_tensor_async != NULL) {
291
- if (backend->iface.cpy_tensor_async(backend, src, dst)) {
292
- return;
293
- }
294
- }
295
- }
296
-
297
- size_t nbytes = wsp_ggml_nbytes(src);
298
- if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
299
- wsp_ggml_backend_tensor_set_async(backend, dst, src->data, 0, nbytes);
300
- }
301
- else {
302
- wsp_ggml_backend_tensor_copy(src, dst);
303
- }
304
- }
305
-
306
-
307
- // backend registry
308
-
309
- #define WSP_GGML_MAX_BACKENDS_REG 16
310
-
311
- struct wsp_ggml_backend_reg {
312
- char name[128];
313
- wsp_ggml_backend_init_fn init_fn;
314
- wsp_ggml_backend_buffer_type_t default_buffer_type;
315
- void * user_data;
316
- };
317
-
318
- static struct wsp_ggml_backend_reg wsp_ggml_backend_registry[WSP_GGML_MAX_BACKENDS_REG];
319
- static size_t wsp_ggml_backend_registry_count = 0;
320
-
321
- WSP_GGML_CALL static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data);
322
-
323
- WSP_GGML_CALL static void wsp_ggml_backend_registry_init(void) {
324
- static bool initialized = false;
325
-
326
- if (initialized) {
327
- return;
328
- }
329
-
330
- initialized = true;
331
-
332
- wsp_ggml_backend_register("CPU", wsp_ggml_backend_reg_cpu_init, wsp_ggml_backend_cpu_buffer_type(), NULL);
333
-
334
- // add forward decls here to avoid including the backend headers
335
- #ifdef WSP_GGML_USE_CUBLAS
336
- extern WSP_GGML_CALL void wsp_ggml_backend_cuda_reg_devices(void);
337
- wsp_ggml_backend_cuda_reg_devices();
338
- #endif
339
-
340
- #ifdef WSP_GGML_USE_METAL
341
- extern WSP_GGML_CALL wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data);
342
- extern WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void);
343
- wsp_ggml_backend_register("Metal", wsp_ggml_backend_reg_metal_init, wsp_ggml_backend_metal_buffer_type(), NULL);
344
- #endif
345
- }
346
-
347
- WSP_GGML_CALL void wsp_ggml_backend_register(const char * name, wsp_ggml_backend_init_fn init_fn, wsp_ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
348
- WSP_GGML_ASSERT(wsp_ggml_backend_registry_count < WSP_GGML_MAX_BACKENDS_REG);
349
-
350
- size_t id = wsp_ggml_backend_registry_count;
351
-
352
- wsp_ggml_backend_registry[id] = (struct wsp_ggml_backend_reg) {
353
- /* .name = */ {0},
354
- /* .fn = */ init_fn,
355
- /* .default_buffer_type = */ default_buffer_type,
356
- /* .user_data = */ user_data,
357
- };
358
-
359
- snprintf(wsp_ggml_backend_registry[id].name, sizeof(wsp_ggml_backend_registry[id].name), "%s", name);
360
-
361
- #ifndef NDEBUG
362
- fprintf(stderr, "%s: registered backend %s\n", __func__, name);
363
- #endif
364
-
365
- wsp_ggml_backend_registry_count++;
366
- }
367
-
368
- size_t wsp_ggml_backend_reg_get_count(void) {
369
- wsp_ggml_backend_registry_init();
370
-
371
- return wsp_ggml_backend_registry_count;
372
- }
373
-
374
- size_t wsp_ggml_backend_reg_find_by_name(const char * name) {
375
- wsp_ggml_backend_registry_init();
376
-
377
- for (size_t i = 0; i < wsp_ggml_backend_registry_count; i++) {
378
- // TODO: case insensitive in a portable way
379
- if (strcmp(wsp_ggml_backend_registry[i].name, name) == 0) {
380
- return i;
381
- }
382
- }
383
-
384
- // not found
385
- return SIZE_MAX;
386
- }
387
-
388
- // init from backend:params string
389
- wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend_from_str(const char * backend_str) {
390
- wsp_ggml_backend_registry_init();
391
-
392
- const char * params = strchr(backend_str, ':');
393
- char backend_name[128];
394
- if (params == NULL) {
395
- snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
396
- params = "";
397
- } else {
398
- snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
399
- params++;
400
- }
401
-
402
- size_t backend_i = wsp_ggml_backend_reg_find_by_name(backend_name);
403
-
404
- if (backend_i == SIZE_MAX) {
405
- fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
406
- return NULL;
407
- }
408
-
409
- return wsp_ggml_backend_reg_init_backend(backend_i, params);
410
- }
411
-
412
- const char * wsp_ggml_backend_reg_get_name(size_t i) {
413
- wsp_ggml_backend_registry_init();
414
-
415
- WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
416
- return wsp_ggml_backend_registry[i].name;
417
- }
418
-
419
- wsp_ggml_backend_t wsp_ggml_backend_reg_init_backend(size_t i, const char * params) {
420
- wsp_ggml_backend_registry_init();
421
-
422
- WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
423
- return wsp_ggml_backend_registry[i].init_fn(params, wsp_ggml_backend_registry[i].user_data);
424
- }
425
-
426
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_reg_get_default_buffer_type(size_t i) {
427
- wsp_ggml_backend_registry_init();
428
-
429
- WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
430
- return wsp_ggml_backend_registry[i].default_buffer_type;
431
- }
432
-
433
- wsp_ggml_backend_buffer_t wsp_ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
434
- wsp_ggml_backend_registry_init();
435
-
436
- WSP_GGML_ASSERT(i < wsp_ggml_backend_registry_count);
437
- return wsp_ggml_backend_buft_alloc_buffer(wsp_ggml_backend_registry[i].default_buffer_type, size);
438
- }
439
-
440
- // backend CPU
441
-
442
- WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_buffer_name(wsp_ggml_backend_buffer_t buffer) {
443
- return "CPU";
444
-
445
- WSP_GGML_UNUSED(buffer);
446
- }
447
-
448
- WSP_GGML_CALL static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
449
- return (void *)buffer->context;
450
- }
451
-
452
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
453
- free(buffer->context);
454
- }
455
-
456
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
457
- memcpy((char *)tensor->data + offset, data, size);
458
-
459
- WSP_GGML_UNUSED(buffer);
460
- }
461
-
462
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
463
- memcpy(data, (const char *)tensor->data + offset, size);
464
-
465
- WSP_GGML_UNUSED(buffer);
466
- }
467
-
468
- WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_cpy_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
469
- if (wsp_ggml_backend_buffer_is_host(src->buffer)) {
470
- memcpy(dst->data, src->data, wsp_ggml_nbytes(src));
471
- return true;
472
- }
473
- return false;
474
-
475
- WSP_GGML_UNUSED(buffer);
476
- }
477
-
478
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_buffer_clear(wsp_ggml_backend_buffer_t buffer, uint8_t value) {
479
- memset(buffer->context, value, buffer->size);
480
- }
481
-
482
- static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i = {
483
- /* .get_name = */ wsp_ggml_backend_cpu_buffer_name,
484
- /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer,
485
- /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
486
- /* .init_tensor = */ NULL, // no initialization required
487
- /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
488
- /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
489
- /* .cpy_tensor = */ wsp_ggml_backend_cpu_buffer_cpy_tensor,
490
- /* .clear = */ wsp_ggml_backend_cpu_buffer_clear,
491
- /* .reset = */ NULL,
492
- };
493
-
494
- // for buffers from ptr, free is not called
495
- static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
496
- /* .get_name = */ wsp_ggml_backend_cpu_buffer_name,
497
- /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
498
- /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base,
499
- /* .init_tensor = */ NULL, // no initialization required
500
- /* .set_tensor = */ wsp_ggml_backend_cpu_buffer_set_tensor,
501
- /* .get_tensor = */ wsp_ggml_backend_cpu_buffer_get_tensor,
502
- /* .cpy_tensor = */ wsp_ggml_backend_cpu_buffer_cpy_tensor,
503
- /* .clear = */ wsp_ggml_backend_cpu_buffer_clear,
504
- /* .reset = */ NULL,
505
- };
506
-
507
- static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
508
-
509
- WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
510
- return "CPU";
511
-
512
- WSP_GGML_UNUSED(buft);
513
- }
514
-
515
- WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
516
- size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
517
- void * data = malloc(size); // TODO: maybe use WSP_GGML_ALIGNED_MALLOC?
518
-
519
- WSP_GGML_ASSERT(data != NULL && "failed to allocate buffer");
520
-
521
- return wsp_ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size);
522
- }
523
-
524
- WSP_GGML_CALL static size_t wsp_ggml_backend_cpu_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
525
- return TENSOR_ALIGNMENT;
526
-
527
- WSP_GGML_UNUSED(buft);
528
- }
529
-
530
- WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
531
- return wsp_ggml_backend_is_cpu(backend);
532
-
533
- WSP_GGML_UNUSED(buft);
534
- }
535
-
536
- WSP_GGML_CALL static bool wsp_ggml_backend_cpu_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
537
- return true;
538
-
539
- WSP_GGML_UNUSED(buft);
540
- }
541
-
542
- WSP_GGML_CALL wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_buffer_type(void) {
543
- static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type = {
544
- /* .iface = */ {
545
- /* .get_name = */ wsp_ggml_backend_cpu_buffer_type_get_name,
546
- /* .alloc_buffer = */ wsp_ggml_backend_cpu_buffer_type_alloc_buffer,
547
- /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment,
548
- /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
549
- /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend,
550
- /* .is_host = */ wsp_ggml_backend_cpu_buffer_type_is_host,
551
- },
552
- /* .context = */ NULL,
553
- };
554
-
555
- return &wsp_ggml_backend_cpu_buffer_type;
556
- }
557
-
558
- #ifdef WSP_GGML_USE_CPU_HBM
559
-
560
- // buffer type HBM
561
-
562
- #include <hbwmalloc.h>
563
-
564
- WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_hbm_buffer_type_get_name(wsp_ggml_backend_buffer_type_t buft) {
565
- return "CPU_HBM";
566
-
567
- WSP_GGML_UNUSED(buft);
568
- }
569
-
570
- WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_hbm_buffer_get_name(wsp_ggml_backend_buffer_t buf) {
571
- return "CPU_HBM";
572
-
573
- WSP_GGML_UNUSED(buf);
574
- }
575
-
576
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_hbm_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
577
- hbw_free(buffer->context);
578
- }
579
-
580
- WSP_GGML_CALL static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
581
- //void * ptr = hbw_malloc(size);
582
- void * ptr;
583
- int result = hbw_posix_memalign(&ptr, wsp_ggml_backend_cpu_buffer_type_get_alignment(buft), size);
584
- if (result != 0) {
585
- fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size);
586
- return NULL;
587
- }
588
-
589
- wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_cpu_buffer_from_ptr(ptr, size);
590
- buffer->buft = buft;
591
- buffer->iface.get_name = wsp_ggml_backend_cpu_hbm_buffer_get_name;
592
- buffer->iface.free_buffer = wsp_ggml_backend_cpu_hbm_buffer_free_buffer;
593
-
594
- return buffer;
595
- }
596
-
597
- wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_hbm_buffer_type(void) {
598
- static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_cpu_buffer_type_hbm = {
599
- /* .iface = */ {
600
- /* .get_name = */ wsp_ggml_backend_cpu_hbm_buffer_type_get_name,
601
- /* .alloc_buffer = */ wsp_ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
602
- /* .get_alignment = */ wsp_ggml_backend_cpu_buffer_type_get_alignment,
603
- /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
604
- /* .supports_backend = */ wsp_ggml_backend_cpu_buffer_type_supports_backend,
605
- /* .is_host = */ wsp_ggml_backend_cpu_buffer_type_is_host,
606
- },
607
- /* .context = */ NULL,
608
- };
609
-
610
- return &wsp_ggml_backend_cpu_buffer_type_hbm;
611
- }
612
- #endif
613
-
614
- struct wsp_ggml_backend_cpu_context {
615
- int n_threads;
616
- void * work_data;
617
- size_t work_size;
618
- };
619
-
620
- WSP_GGML_CALL static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) {
621
- return "CPU";
622
-
623
- WSP_GGML_UNUSED(backend);
624
- }
625
-
626
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) {
627
- struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
628
- free(cpu_ctx->work_data);
629
- free(cpu_ctx);
630
- free(backend);
631
- }
632
-
633
- WSP_GGML_CALL static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_cpu_get_default_buffer_type(wsp_ggml_backend_t backend) {
634
- return wsp_ggml_backend_cpu_buffer_type();
635
-
636
- WSP_GGML_UNUSED(backend);
637
- }
638
-
639
- struct wsp_ggml_backend_plan_cpu {
640
- struct wsp_ggml_cplan cplan;
641
- struct wsp_ggml_cgraph cgraph;
642
- };
643
-
644
- WSP_GGML_CALL static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, const struct wsp_ggml_cgraph * cgraph) {
645
- struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
646
-
647
- struct wsp_ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct wsp_ggml_backend_plan_cpu));
648
-
649
- cpu_plan->cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads);
650
- cpu_plan->cgraph = *cgraph; // FIXME: deep copy
651
-
652
- if (cpu_plan->cplan.work_size > 0) {
653
- cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
654
- }
655
-
656
- return cpu_plan;
657
- }
658
-
659
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
660
- struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan;
661
-
662
- free(cpu_plan->cplan.work_data);
663
- free(cpu_plan);
664
-
665
- WSP_GGML_UNUSED(backend);
666
- }
667
-
668
- WSP_GGML_CALL static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) {
669
- struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan;
670
-
671
- wsp_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
672
-
673
- WSP_GGML_UNUSED(backend);
674
- }
675
-
676
- WSP_GGML_CALL static bool wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
677
- struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context;
678
-
679
- struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads);
680
-
681
- if (cpu_ctx->work_size < cplan.work_size) {
682
- // TODO: may be faster to free and use malloc to avoid the copy
683
- cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
684
- cpu_ctx->work_size = cplan.work_size;
685
- }
686
-
687
- cplan.work_data = cpu_ctx->work_data;
688
-
689
- wsp_ggml_graph_compute(cgraph, &cplan);
690
- return true;
691
- }
692
-
693
- WSP_GGML_CALL static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
694
- switch (op->op) {
695
- case WSP_GGML_OP_CPY:
696
- return op->type != WSP_GGML_TYPE_IQ2_XXS && op->type != WSP_GGML_TYPE_IQ2_XS; // missing type_traits.from_float
697
- case WSP_GGML_OP_MUL_MAT:
698
- return op->src[1]->type == WSP_GGML_TYPE_F32 || op->src[1]->type == wsp_ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
699
- default:
700
- return true;
701
- }
702
-
703
- WSP_GGML_UNUSED(backend);
704
- }
705
-
706
- static struct wsp_ggml_backend_i cpu_backend_i = {
707
- /* .get_name = */ wsp_ggml_backend_cpu_name,
708
- /* .free = */ wsp_ggml_backend_cpu_free,
709
- /* .get_default_buffer_type = */ wsp_ggml_backend_cpu_get_default_buffer_type,
710
- /* .set_tensor_async = */ NULL,
711
- /* .get_tensor_async = */ NULL,
712
- /* .cpy_tensor_async = */ NULL,
713
- /* .synchronize = */ NULL,
714
- /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create,
715
- /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free,
716
- /* .graph_plan_compute = */ wsp_ggml_backend_cpu_graph_plan_compute,
717
- /* .graph_compute = */ wsp_ggml_backend_cpu_graph_compute,
718
- /* .supports_op = */ wsp_ggml_backend_cpu_supports_op,
719
- };
720
-
721
- wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void) {
722
- struct wsp_ggml_backend_cpu_context * ctx = malloc(sizeof(struct wsp_ggml_backend_cpu_context));
723
-
724
- ctx->n_threads = WSP_GGML_DEFAULT_N_THREADS;
725
- ctx->work_data = NULL;
726
- ctx->work_size = 0;
727
-
728
- wsp_ggml_backend_t cpu_backend = malloc(sizeof(struct wsp_ggml_backend));
729
-
730
- *cpu_backend = (struct wsp_ggml_backend) {
731
- /* .interface = */ cpu_backend_i,
732
- /* .context = */ ctx
733
- };
734
- return cpu_backend;
735
- }
736
-
737
- WSP_GGML_CALL bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) {
738
- return backend && backend->iface.get_name == wsp_ggml_backend_cpu_name;
739
- }
740
-
741
- void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads) {
742
- WSP_GGML_ASSERT(wsp_ggml_backend_is_cpu(backend_cpu));
743
-
744
- struct wsp_ggml_backend_cpu_context * ctx = (struct wsp_ggml_backend_cpu_context *)backend_cpu->context;
745
- ctx->n_threads = n_threads;
746
- }
747
-
748
- WSP_GGML_CALL wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
749
- return wsp_ggml_backend_buffer_init(wsp_ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
750
- }
751
-
752
- WSP_GGML_CALL static wsp_ggml_backend_t wsp_ggml_backend_reg_cpu_init(const char * params, void * user_data) {
753
- return wsp_ggml_backend_cpu_init();
754
-
755
- WSP_GGML_UNUSED(params);
756
- WSP_GGML_UNUSED(user_data);
757
- }
758
-
759
-
760
- // scheduler
761
-
762
- #define WSP_GGML_MAX_BACKENDS 16
763
- #define WSP_GGML_MAX_SPLITS 256
764
- #define WSP_GGML_MAX_SPLIT_INPUTS 16
765
-
766
- struct wsp_ggml_backend_sched_split {
767
- wsp_ggml_tallocr_t tallocr;
768
- int i_start;
769
- int i_end;
770
- struct wsp_ggml_tensor * inputs[WSP_GGML_MAX_SPLIT_INPUTS];
771
- int n_inputs;
772
- // graph view of this split
773
- struct wsp_ggml_cgraph graph;
774
- };
775
-
776
- struct wsp_ggml_backend_sched {
777
- bool is_reset; // true if the scheduler has been reset since the last graph split
778
-
779
- int n_backends;
780
- wsp_ggml_backend_t backends[WSP_GGML_MAX_BACKENDS];
781
- wsp_ggml_backend_buffer_type_t bufts[WSP_GGML_MAX_BACKENDS];
782
- wsp_ggml_tallocr_t tallocs[WSP_GGML_MAX_BACKENDS];
783
-
784
- wsp_ggml_gallocr_t galloc;
785
-
786
- // hash keys of the nodes in the graph
787
- struct wsp_ggml_hash_set hash_set;
788
- // hash values (arrays of [hash_set.size])
789
- wsp_ggml_tallocr_t * node_talloc; // tallocr assigned to each node (indirectly this is the backend)
790
- struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // copies of each node for each destination backend
791
-
792
- // copy of the graph with modified inputs
793
- struct wsp_ggml_cgraph * graph;
794
-
795
- struct wsp_ggml_backend_sched_split splits[WSP_GGML_MAX_SPLITS];
796
- int n_splits;
797
-
798
- struct wsp_ggml_context * ctx;
799
-
800
- // align context_buffer to WSP_GGML_MEM_ALIGN
801
- #ifdef _MSC_VER
802
- __declspec(align(WSP_GGML_MEM_ALIGN))
803
- #else
804
- __attribute__((aligned(WSP_GGML_MEM_ALIGN)))
805
- #endif
806
- char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + sizeof(struct wsp_ggml_cgraph)];
807
-
808
- wsp_ggml_backend_sched_eval_callback callback_eval;
809
- void * callback_eval_user_data;
810
- };
811
-
812
- #define hash_id(node) wsp_ggml_hash_find_or_insert(sched->hash_set, node)
813
- #define node_allocr(node) sched->node_talloc[hash_id(node)]
814
-
815
- static bool wsp_ggml_is_view_op(enum wsp_ggml_op op) {
816
- return op == WSP_GGML_OP_VIEW || op == WSP_GGML_OP_RESHAPE || op == WSP_GGML_OP_PERMUTE || op == WSP_GGML_OP_TRANSPOSE;
817
- }
818
-
819
- // returns the priority of the backend, lower is better
820
- static int sched_backend_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
821
- for (int i = 0; i < sched->n_backends; i++) {
822
- if (sched->backends[i] == backend) {
823
- return i;
824
- }
825
- }
826
- return INT_MAX;
827
- }
828
-
829
- static int sched_allocr_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t allocr) {
830
- for (int i = 0; i < sched->n_backends; i++) {
831
- if (sched->tallocs[i] == allocr) {
832
- return i;
833
- }
834
- }
835
- return INT_MAX;
836
- }
837
-
838
- static wsp_ggml_tallocr_t sched_allocr_from_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_buffer_t buffer) {
839
- if (buffer == NULL) {
840
- return NULL;
841
- }
842
-
843
- // check if this is already allocate in a allocr buffer (from user manual allocations)
844
- for (int i = 0; i < sched->n_backends; i++) {
845
- if (wsp_ggml_tallocr_get_buffer(sched->tallocs[i]) == buffer) {
846
- return sched->tallocs[i];
847
- }
848
- }
849
-
850
- // find highest prio backend that supports the buffer type
851
- for (int i = 0; i < sched->n_backends; i++) {
852
- if (wsp_ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) {
853
- return sched->tallocs[i];
854
- }
855
- }
856
- WSP_GGML_ASSERT(false && "tensor buffer type not supported by any backend");
857
- }
858
-
859
- static wsp_ggml_backend_t get_allocr_backend(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t allocr) {
860
- if (allocr == NULL) {
861
- return NULL;
862
- }
863
- for (int i = 0; i < sched->n_backends; i++) {
864
- if (sched->tallocs[i] == allocr) {
865
- return sched->backends[i];
866
- }
867
- }
868
- WSP_GGML_UNREACHABLE();
869
- }
870
-
871
- #if 0
872
- static char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*16 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug only
873
- #define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
874
- #define GET_CAUSE(node) causes[hash_id(node)]
875
- #else
876
- #define SET_CAUSE(node, ...)
877
- #define GET_CAUSE(node) ""
878
- #endif
879
-
880
- // returns the backend that should be used for the node based on the current locations
881
- static wsp_ggml_tallocr_t sched_allocr_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
882
- // assign pre-allocated nodes to their backend
883
- // dst
884
- wsp_ggml_tallocr_t cur_allocr = sched_allocr_from_buffer(sched, node->buffer);
885
- if (cur_allocr != NULL) {
886
- SET_CAUSE(node, "1.dst");
887
- return cur_allocr;
888
- }
889
- // view_src
890
- if (node->view_src != NULL) {
891
- cur_allocr = sched_allocr_from_buffer(sched, node->view_src->buffer);
892
- if (cur_allocr != NULL) {
893
- SET_CAUSE(node, "1.vsrc");
894
- return cur_allocr;
895
- }
896
- }
897
- // assign nodes that use weights to the backend of the weights
898
- for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
899
- const struct wsp_ggml_tensor * src = node->src[i];
900
- if (src == NULL) {
901
- break;
902
- }
903
- if (src->buffer != NULL && src->buffer->usage == WSP_GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
904
- wsp_ggml_tallocr_t src_allocr = sched_allocr_from_buffer(sched, src->buffer);
905
- // operations with weights are always run on the same backend as the weights
906
- SET_CAUSE(node, "1.wgt%d", i);
907
- return src_allocr;
908
- }
909
- }
910
-
911
- return NULL;
912
- }
913
-
914
- static char * fmt_size(size_t size) {
915
- static char buffer[128];
916
- if (size >= 1024*1024) {
917
- sprintf(buffer, "%zuM", size/1024/1024);
918
- } else {
919
- sprintf(buffer, "%zuK", size/1024);
920
- }
921
- return buffer;
922
- }
923
-
924
- static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
925
- int cur_split = 0;
926
- for (int i = 0; i < graph->n_nodes; i++) {
927
- if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
928
- wsp_ggml_backend_t split_backend = get_allocr_backend(sched, sched->splits[cur_split].tallocr);
929
- fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, wsp_ggml_backend_name(split_backend),
930
- sched->splits[cur_split].n_inputs);
931
- for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
932
- fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
933
- fmt_size(wsp_ggml_nbytes(sched->splits[cur_split].inputs[j])));
934
- }
935
- fprintf(stderr, "\n");
936
- cur_split++;
937
- }
938
- struct wsp_ggml_tensor * node = graph->nodes[i];
939
- if (wsp_ggml_is_view_op(node->op)) {
940
- continue;
941
- }
942
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
943
- wsp_ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
944
- fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name,
945
- fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
946
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
947
- struct wsp_ggml_tensor * src = node->src[j];
948
- if (src == NULL) {
949
- break;
950
- }
951
- wsp_ggml_tallocr_t src_allocr = node_allocr(src);
952
- wsp_ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
953
- fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
954
- fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
955
- }
956
- fprintf(stderr, "\n");
957
- }
958
- }
959
-
960
- // creates a copy of the tensor with the same memory layout
961
- static struct wsp_ggml_tensor * wsp_ggml_dup_tensor_layout(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * tensor) {
962
- struct wsp_ggml_tensor * dup = wsp_ggml_dup_tensor(ctx, tensor);
963
- for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
964
- dup->nb[i] = tensor->nb[i];
965
- }
966
- return dup;
967
- }
968
-
969
-
970
- //#define DEBUG_PASS1
971
- //#define DEBUG_PASS2
972
- //#define DEBUG_PASS3
973
- //#define DEBUG_PASS4
974
-
975
- // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
976
- static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
977
- // reset splits
978
- sched->n_splits = 0;
979
- sched->is_reset = false;
980
-
981
- struct wsp_ggml_init_params params = {
982
- /* .mem_size = */ sizeof(sched->context_buffer),
983
- /* .mem_buffer = */ sched->context_buffer,
984
- /* .no_alloc = */ true
985
- };
986
-
987
- wsp_ggml_free(sched->ctx);
988
-
989
- sched->ctx = wsp_ggml_init(params);
990
- if (sched->ctx == NULL) {
991
- fprintf(stderr, "%s: failed to initialize context\n", __func__);
992
- WSP_GGML_ASSERT(false);
993
- }
994
-
995
- // pass 1: assign backends to ops with pre-allocated inputs
996
- for (int i = 0; i < graph->n_leafs; i++) {
997
- struct wsp_ggml_tensor * leaf = graph->leafs[i];
998
- if (node_allocr(leaf) != NULL) {
999
- // do not overwrite user assignments
1000
- continue;
1001
- }
1002
- node_allocr(leaf) = sched_allocr_from_cur(sched, leaf);
1003
- }
1004
-
1005
- for (int i = 0; i < graph->n_nodes; i++) {
1006
- struct wsp_ggml_tensor * node = graph->nodes[i];
1007
- if (node_allocr(node) != NULL) {
1008
- // do not overwrite user assignments
1009
- continue;
1010
- }
1011
- node_allocr(node) = sched_allocr_from_cur(sched, node);
1012
- // src
1013
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
1014
- struct wsp_ggml_tensor * src = node->src[j];
1015
- if (src == NULL) {
1016
- break;
1017
- }
1018
- if (node_allocr(src) == NULL) {
1019
- node_allocr(src) = sched_allocr_from_cur(sched, src);
1020
- }
1021
- }
1022
- }
1023
- #ifdef DEBUG_PASS1
1024
- fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1025
- #endif
1026
-
1027
- // pass 2: expand current backend assignments
1028
- // assign the same backend to adjacent nodes
1029
- // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
1030
- // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
1031
-
1032
- // pass 2.1 expand gpu up
1033
- {
1034
- wsp_ggml_tallocr_t cur_allocr = NULL;
1035
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
1036
- struct wsp_ggml_tensor * node = graph->nodes[i];
1037
- if (wsp_ggml_is_view_op(node->op)) {
1038
- continue;
1039
- }
1040
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1041
- if (node_allocr != NULL) {
1042
- if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
1043
- // skip cpu (lowest prio backend)
1044
- cur_allocr = NULL;
1045
- } else {
1046
- cur_allocr = node_allocr;
1047
- }
1048
- } else {
1049
- node_allocr(node) = cur_allocr;
1050
- SET_CAUSE(node, "2.1");
1051
- }
1052
- }
1053
- }
1054
-
1055
- // pass 2.2 expand gpu down
1056
- {
1057
- wsp_ggml_tallocr_t cur_allocr = NULL;
1058
- for (int i = 0; i < graph->n_nodes; i++) {
1059
- struct wsp_ggml_tensor * node = graph->nodes[i];
1060
- if (wsp_ggml_is_view_op(node->op)) {
1061
- continue;
1062
- }
1063
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1064
- if (node_allocr != NULL) {
1065
- if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
1066
- // skip cpu (lowest prio backend)
1067
- cur_allocr = NULL;
1068
- } else {
1069
- cur_allocr = node_allocr;
1070
- }
1071
- } else {
1072
- node_allocr(node) = cur_allocr;
1073
- SET_CAUSE(node, "2.2");
1074
- }
1075
- }
1076
- }
1077
-
1078
- // pass 2.3 expand rest up
1079
- {
1080
- wsp_ggml_tallocr_t cur_allocr = NULL;
1081
- for (int i = graph->n_nodes - 1; i >= 0; i--) {
1082
- struct wsp_ggml_tensor * node = graph->nodes[i];
1083
- if (wsp_ggml_is_view_op(node->op)) {
1084
- continue;
1085
- }
1086
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1087
- if (node_allocr != NULL) {
1088
- cur_allocr = node_allocr;
1089
- } else {
1090
- node_allocr(node) = cur_allocr;
1091
- SET_CAUSE(node, "2.3");
1092
- }
1093
- }
1094
- }
1095
-
1096
- // pass 2.4 expand rest down
1097
- {
1098
- wsp_ggml_tallocr_t cur_allocr = NULL;
1099
- for (int i = 0; i < graph->n_nodes; i++) {
1100
- struct wsp_ggml_tensor * node = graph->nodes[i];
1101
- if (wsp_ggml_is_view_op(node->op)) {
1102
- continue;
1103
- }
1104
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1105
- if (node_allocr != NULL) {
1106
- cur_allocr = node_allocr;
1107
- } else {
1108
- node_allocr(node) = cur_allocr;
1109
- SET_CAUSE(node, "2.4");
1110
- }
1111
- }
1112
- }
1113
- #ifdef DEBUG_PASS2
1114
- fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1115
- #endif
1116
-
1117
- // pass 3: assign backends to remaining src from dst and view_src
1118
- for (int i = 0; i < graph->n_nodes; i++) {
1119
- struct wsp_ggml_tensor * node = graph->nodes[i];
1120
- wsp_ggml_tallocr_t cur_allocr = node_allocr(node);
1121
- if (node->view_src != NULL && cur_allocr == NULL) {
1122
- cur_allocr = node_allocr(node) = node_allocr(node->view_src);
1123
- SET_CAUSE(node, "3.vsrc");
1124
- }
1125
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
1126
- struct wsp_ggml_tensor * src = node->src[j];
1127
- if (src == NULL) {
1128
- break;
1129
- }
1130
- wsp_ggml_tallocr_t src_allocr = node_allocr(src);
1131
- if (src_allocr == NULL) {
1132
- if (src->view_src != NULL) {
1133
- // views are always on the same backend as the source
1134
- node_allocr(src) = node_allocr(src->view_src);
1135
- SET_CAUSE(src, "3.vsrc");
1136
- } else {
1137
- node_allocr(src) = cur_allocr;
1138
- SET_CAUSE(src, "3.cur");
1139
- }
1140
- }
1141
- }
1142
- }
1143
- #ifdef DEBUG_PASS3
1144
- fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1145
- #endif
1146
-
1147
- // pass 4: split graph, find tensors that need to be copied
1148
- {
1149
- int cur_split = 0;
1150
- // find the backend of the first split, skipping view ops
1151
- for (int i = 0; i < graph->n_nodes; i++) {
1152
- struct wsp_ggml_tensor * node = graph->nodes[i];
1153
- if (!wsp_ggml_is_view_op(node->op)) {
1154
- sched->splits[0].tallocr = node_allocr(node);
1155
- break;
1156
- }
1157
- }
1158
- sched->splits[0].i_start = 0;
1159
- sched->splits[0].n_inputs = 0;
1160
- memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
1161
- wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
1162
- size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
1163
- for (int i = 0; i < graph->n_nodes; i++) {
1164
- struct wsp_ggml_tensor * node = graph->nodes[i];
1165
-
1166
- if (wsp_ggml_is_view_op(node->op)) {
1167
- continue;
1168
- }
1169
-
1170
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1171
-
1172
- WSP_GGML_ASSERT(node_allocr != NULL); // all nodes should be assigned by now
1173
-
1174
- if (node_allocr != cur_allocr) {
1175
- sched->splits[cur_split].i_end = i;
1176
- cur_split++;
1177
- WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS);
1178
- sched->splits[cur_split].tallocr = node_allocr;
1179
- sched->splits[cur_split].i_start = i;
1180
- sched->splits[cur_split].n_inputs = 0;
1181
- cur_allocr = node_allocr;
1182
- cur_backend_id = sched_allocr_prio(sched, cur_allocr);
1183
- }
1184
-
1185
- // find inputs that are not on the same backend
1186
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
1187
- struct wsp_ggml_tensor * src = node->src[j];
1188
- if (src == NULL) {
1189
- break;
1190
- }
1191
- wsp_ggml_tallocr_t src_allocr = node_allocr(src);
1192
- WSP_GGML_ASSERT(src_allocr != NULL); // all inputs should be assigned by now
1193
- if (src_allocr != node_allocr) {
1194
- // check if the input is already in the split
1195
- bool found = false;
1196
- for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
1197
- if (sched->splits[cur_split].inputs[k] == src) {
1198
- found = true;
1199
- break;
1200
- }
1201
- }
1202
-
1203
- if (!found) {
1204
- int n_inputs = sched->splits[cur_split].n_inputs++;
1205
- //printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)));
1206
- WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS);
1207
- sched->splits[cur_split].inputs[n_inputs] = src;
1208
- }
1209
-
1210
- // create a copy of the input in the split's backend
1211
- size_t id = hash_id(src);
1212
- if (sched->node_copies[id][cur_backend_id] == NULL) {
1213
- wsp_ggml_backend_t backend = get_allocr_backend(sched, cur_allocr);
1214
- struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src);
1215
- wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name);
1216
-
1217
- sched->node_copies[id][cur_backend_id] = tensor_copy;
1218
- node_allocr(tensor_copy) = cur_allocr;
1219
- SET_CAUSE(tensor_copy, "4.cpy");
1220
- }
1221
- node->src[j] = sched->node_copies[id][cur_backend_id];
1222
- }
1223
- }
1224
- }
1225
- sched->splits[cur_split].i_end = graph->n_nodes;
1226
- sched->n_splits = cur_split + 1;
1227
- }
1228
- #ifdef DEBUG_PASS4
1229
- fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1230
- #endif
1231
-
1232
- #ifndef NDEBUG
1233
- // sanity check: all sources should have the same backend as the node
1234
- for (int i = 0; i < graph->n_nodes; i++) {
1235
- struct wsp_ggml_tensor * node = graph->nodes[i];
1236
- wsp_ggml_tallocr_t node_allocr = node_allocr(node);
1237
- if (node_allocr == NULL) {
1238
- fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
1239
- }
1240
- if (node->view_src != NULL && node_allocr != node_allocr(node->view_src)) {
1241
- fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n",
1242
- node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
1243
- node->view_src->name, node_allocr(node->view_src) ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr(node->view_src))) : "NULL");
1244
- }
1245
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
1246
- struct wsp_ggml_tensor * src = node->src[j];
1247
- if (src == NULL) {
1248
- break;
1249
- }
1250
- wsp_ggml_tallocr_t src_allocr = node_allocr(src);
1251
- if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
1252
- fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
1253
- node->name, node_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
1254
- j, src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
1255
- }
1256
- if (src->view_src != NULL && src_allocr != node_allocr(src->view_src)) {
1257
- fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n",
1258
- src->name, src_allocr ? wsp_ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL",
1259
- src->view_src->name, node_allocr(src->view_src) ? wsp_ggml_backend_name(get_allocr_backend(sched, node_allocr(src->view_src))) : "NULL");
1260
- }
1261
- }
1262
- }
1263
- fflush(stderr);
1264
- #endif
1265
-
1266
- // create copies of the graph for each split
1267
- // FIXME: avoid this copy, pass split inputs to wsp_ggml_gallocr_alloc_graph_n in some other way
1268
- struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*WSP_GGML_MAX_SPLIT_INPUTS, false);
1269
- for (int i = 0; i < sched->n_splits; i++) {
1270
- struct wsp_ggml_backend_sched_split * split = &sched->splits[i];
1271
- split->graph = wsp_ggml_graph_view(graph, split->i_start, split->i_end);
1272
-
1273
- // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
1274
- for (int j = 0; j < split->n_inputs; j++) {
1275
- struct wsp_ggml_tensor * input = split->inputs[j];
1276
- struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
1277
- // add a dependency to the input source so that it is not freed before the copy is done
1278
- WSP_GGML_ASSERT(input_cpy->src[0] == NULL || input_cpy->src[0] == input);
1279
- input_cpy->src[0] = input;
1280
- graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
1281
- }
1282
-
1283
- for (int j = split->i_start; j < split->i_end; j++) {
1284
- graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
1285
- }
1286
- }
1287
- sched->graph = graph_copy;
1288
- }
1289
-
1290
- static void sched_alloc_splits(wsp_ggml_backend_sched_t sched) {
1291
- wsp_ggml_gallocr_alloc_graph_n(
1292
- sched->galloc,
1293
- sched->graph,
1294
- sched->hash_set,
1295
- sched->node_talloc);
1296
- }
1297
-
1298
- static void sched_compute_splits(wsp_ggml_backend_sched_t sched) {
1299
- uint64_t copy_us[WSP_GGML_MAX_BACKENDS] = {0};
1300
- uint64_t compute_us[WSP_GGML_MAX_BACKENDS] = {0};
1301
-
1302
- struct wsp_ggml_backend_sched_split * splits = sched->splits;
1303
-
1304
- for (int i = 0; i < sched->n_splits; i++) {
1305
- struct wsp_ggml_backend_sched_split * split = &splits[i];
1306
- wsp_ggml_backend_t split_backend = get_allocr_backend(sched, split->tallocr);
1307
- int split_backend_id = sched_backend_prio(sched, split_backend);
1308
-
1309
- // copy the input tensors to the split backend
1310
- uint64_t copy_start_us = wsp_ggml_time_us();
1311
- for (int j = 0; j < split->n_inputs; j++) {
1312
- struct wsp_ggml_tensor * input = split->inputs[j];
1313
- struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][split_backend_id];
1314
-
1315
- WSP_GGML_ASSERT(input->buffer != NULL);
1316
- WSP_GGML_ASSERT(input_cpy->buffer != NULL);
1317
-
1318
- // TODO: avoid this copy if it was already copied in a previous split, and the input didn't change
1319
- // this is important to avoid copying constants such as KQ_mask and inp_pos multiple times
1320
- wsp_ggml_backend_tensor_copy_async(split_backend, input, input_cpy);
1321
- }
1322
- //wsp_ggml_backend_synchronize(split_backend); // necessary to measure copy time
1323
- int64_t copy_end_us = wsp_ggml_time_us();
1324
- copy_us[split_backend_id] += copy_end_us - copy_start_us;
1325
-
1326
- #if 0
1327
- char split_filename[WSP_GGML_MAX_NAME];
1328
- snprintf(split_filename, WSP_GGML_MAX_NAME, "split_%i_%s.dot", i, wsp_ggml_backend_name(split_backend));
1329
- wsp_ggml_graph_dump_dot(split->graph, NULL, split_filename);
1330
- #endif
1331
-
1332
-
1333
- uint64_t compute_start_us = wsp_ggml_time_us();
1334
- if (!sched->callback_eval) {
1335
- wsp_ggml_backend_graph_compute(split_backend, &split->graph);
1336
- //wsp_ggml_backend_synchronize(split_backend); // necessary to measure compute time
1337
- } else {
1338
- // similar to wsp_ggml_backend_compare_graph_backend
1339
- for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
1340
- struct wsp_ggml_tensor * t = split->graph.nodes[j0];
1341
-
1342
- // check if the user needs data from this node
1343
- bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
1344
-
1345
- int j1 = j0;
1346
-
1347
- // determine the range [j0, j1] of nodes that can be computed together
1348
- while (!need && j1 < split->graph.n_nodes - 1) {
1349
- t = split->graph.nodes[++j1];
1350
- need = sched->callback_eval(t, true, sched->callback_eval_user_data);
1351
- }
1352
-
1353
- struct wsp_ggml_cgraph gv = wsp_ggml_graph_view(&split->graph, j0, j1 + 1);
1354
-
1355
- wsp_ggml_backend_graph_compute(split_backend, &gv);
1356
-
1357
- if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
1358
- break;
1359
- }
1360
-
1361
- j0 = j1;
1362
- }
1363
- }
1364
- uint64_t compute_end_us = wsp_ggml_time_us();
1365
- compute_us[split_backend_id] += compute_end_us - compute_start_us;
1366
- }
1367
-
1368
- #if 0
1369
- // per-backend timings
1370
- fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
1371
- for (int i = 0; i < sched->n_backends; i++) {
1372
- if (copy_us[i] > 0 || compute_us[i] > 0) {
1373
- fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", wsp_ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
1374
- }
1375
- }
1376
- #endif
1377
- }
1378
-
1379
- static void sched_reset(wsp_ggml_backend_sched_t sched) {
1380
- for (int i = 0; i < sched->n_backends; i++) {
1381
- wsp_ggml_tallocr_reset(sched->tallocs[i]);
1382
- }
1383
- // reset state for the next run
1384
- size_t hash_size = sched->hash_set.size;
1385
- memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
1386
- memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
1387
- memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
1388
-
1389
- sched->is_reset = true;
1390
- }
1391
-
1392
- wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, wsp_ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size) {
1393
- WSP_GGML_ASSERT(n_backends > 0);
1394
- WSP_GGML_ASSERT(n_backends <= WSP_GGML_MAX_BACKENDS);
1395
-
1396
- struct wsp_ggml_backend_sched * sched = calloc(sizeof(struct wsp_ggml_backend_sched), 1);
1397
-
1398
- // initialize hash table
1399
- sched->hash_set = wsp_ggml_hash_set_new(graph_size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS);
1400
- sched->node_talloc = calloc(sizeof(sched->node_talloc[0]) * sched->hash_set.size, 1);
1401
- sched->node_copies = calloc(sizeof(sched->node_copies[0]) * sched->hash_set.size, 1);
1402
-
1403
- sched->n_backends = n_backends;
1404
- for (int i = 0; i < n_backends; i++) {
1405
- sched->backends[i] = backends[i];
1406
- sched->bufts[i] = bufts ? bufts[i] : wsp_ggml_backend_get_default_buffer_type(backends[i]);
1407
- }
1408
-
1409
- sched->galloc = wsp_ggml_gallocr_new();
1410
-
1411
- // init measure allocs for each backend
1412
- for (int i = 0; i < n_backends; i++) {
1413
- sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_buft(sched->bufts[i]);
1414
- }
1415
-
1416
- sched_reset(sched);
1417
-
1418
- return sched;
1419
- }
1420
-
1421
- void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) {
1422
- if (sched == NULL) {
1423
- return;
1424
- }
1425
- for (int i = 0; i < sched->n_backends; i++) {
1426
- wsp_ggml_tallocr_free(sched->tallocs[i]);
1427
- }
1428
- wsp_ggml_gallocr_free(sched->galloc);
1429
- wsp_ggml_free(sched->ctx);
1430
- free(sched->hash_set.keys);
1431
- free(sched->node_talloc);
1432
- free(sched->node_copies);
1433
- free(sched);
1434
- }
1435
-
1436
- void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph) {
1437
- WSP_GGML_ASSERT(wsp_ggml_tallocr_is_measure(sched->tallocs[0])); // can only be initialized once
1438
-
1439
- sched_split_graph(sched, measure_graph);
1440
- sched_alloc_splits(sched);
1441
-
1442
- // allocate buffers and reset allocators
1443
- for (int i = 0; i < sched->n_backends; i++) {
1444
- size_t size = wsp_ggml_tallocr_max_size(sched->tallocs[i]);
1445
- wsp_ggml_tallocr_free(sched->tallocs[i]);
1446
- sched->tallocs[i] = wsp_ggml_tallocr_new_from_buft(sched->bufts[i], size);
1447
- }
1448
-
1449
- sched_reset(sched);
1450
- }
1451
-
1452
- void wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) {
1453
- WSP_GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS);
1454
-
1455
- if (!sched->is_reset) {
1456
- sched_reset(sched);
1457
- }
1458
-
1459
- sched_split_graph(sched, graph);
1460
- sched_alloc_splits(sched);
1461
- sched_compute_splits(sched);
1462
- }
1463
-
1464
- void wsp_ggml_backend_sched_reset(wsp_ggml_backend_sched_t sched) {
1465
- sched_reset(sched);
1466
- }
1467
-
1468
-
1469
- void wsp_ggml_backend_sched_set_eval_callback(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_sched_eval_callback callback, void * user_data) {
1470
- sched->callback_eval = callback;
1471
- sched->callback_eval_user_data = user_data;
1472
- }
1473
-
1474
- int wsp_ggml_backend_sched_get_n_splits(wsp_ggml_backend_sched_t sched) {
1475
- return sched->n_splits;
1476
- }
1477
-
1478
- wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
1479
- int backend_index = sched_backend_prio(sched, backend);
1480
- WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1481
- return sched->tallocs[backend_index];
1482
- }
1483
-
1484
- wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) {
1485
- int backend_index = sched_backend_prio(sched, backend);
1486
- WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1487
- return wsp_ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
1488
- }
1489
-
1490
- void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend) {
1491
- int backend_index = sched_backend_prio(sched, backend);
1492
- WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
1493
- node_allocr(node) = sched->tallocs[backend_index];
1494
- }
1495
-
1496
- wsp_ggml_backend_t wsp_ggml_backend_sched_get_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) {
1497
- wsp_ggml_tallocr_t allocr = node_allocr(node);
1498
- if (allocr == NULL) {
1499
- return NULL;
1500
- }
1501
- return get_allocr_backend(sched, allocr);
1502
- }
1503
-
1504
- // utils
1505
-
1506
- void wsp_ggml_backend_view_init(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) {
1507
- WSP_GGML_ASSERT(tensor->buffer == NULL);
1508
- //WSP_GGML_ASSERT(tensor->data == NULL); // views of pre-allocated tensors may have the data set in wsp_ggml_new_tensor, but still need to be initialized by the backend
1509
- WSP_GGML_ASSERT(tensor->view_src != NULL);
1510
- WSP_GGML_ASSERT(tensor->view_src->buffer != NULL);
1511
- WSP_GGML_ASSERT(tensor->view_src->data != NULL);
1512
-
1513
- tensor->buffer = buffer;
1514
- tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1515
- tensor->backend = tensor->view_src->backend;
1516
- wsp_ggml_backend_buffer_init_tensor(buffer, tensor);
1517
- }
1518
-
1519
- void wsp_ggml_backend_tensor_alloc(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, void * addr) {
1520
- WSP_GGML_ASSERT(tensor->buffer == NULL);
1521
- WSP_GGML_ASSERT(tensor->data == NULL);
1522
- WSP_GGML_ASSERT(tensor->view_src == NULL);
1523
- WSP_GGML_ASSERT(addr >= wsp_ggml_backend_buffer_get_base(buffer));
1524
- WSP_GGML_ASSERT((char *)addr + wsp_ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
1525
- (char *)wsp_ggml_backend_buffer_get_base(buffer) + wsp_ggml_backend_buffer_get_size(buffer));
1526
-
1527
- tensor->buffer = buffer;
1528
- tensor->data = addr;
1529
- wsp_ggml_backend_buffer_init_tensor(buffer, tensor);
1530
- }
1531
-
1532
- static struct wsp_ggml_tensor * graph_dup_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor ** node_copies,
1533
- struct wsp_ggml_context * ctx_allocated, struct wsp_ggml_context * ctx_unallocated, struct wsp_ggml_tensor * src) {
1534
-
1535
- WSP_GGML_ASSERT(src != NULL);
1536
- WSP_GGML_ASSERT(src->data && "graph must be allocated");
1537
-
1538
- size_t id = wsp_ggml_hash_insert(hash_set, src);
1539
- if (id == WSP_GGML_HASHTABLE_ALREADY_EXISTS) {
1540
- return node_copies[wsp_ggml_hash_find(hash_set, src)];
1541
- }
1542
-
1543
- struct wsp_ggml_tensor * dst = wsp_ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
1544
- if (src->view_src != NULL) {
1545
- dst->view_src = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
1546
- dst->view_offs = src->view_offs;
1547
- }
1548
- dst->op = src->op;
1549
- memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
1550
- wsp_ggml_set_name(dst, src->name);
1551
-
1552
- // copy src
1553
- for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
1554
- struct wsp_ggml_tensor * s = src->src[i];
1555
- if (s == NULL) {
1556
- break;
1557
- }
1558
- dst->src[i] = graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
1559
- }
1560
-
1561
- node_copies[id] = dst;
1562
- return dst;
1563
- }
1564
-
1565
- static void graph_init_tensor(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor ** node_copies, bool * node_init, struct wsp_ggml_tensor * src) {
1566
- size_t id = wsp_ggml_hash_find(hash_set, src);
1567
- if (node_init[id]) {
1568
- return;
1569
- }
1570
- node_init[id] = true;
1571
-
1572
- struct wsp_ggml_tensor * dst = node_copies[id];
1573
- if (dst->view_src != NULL) {
1574
- graph_init_tensor(hash_set, node_copies, node_init, src->view_src);
1575
- wsp_ggml_backend_view_init(dst->view_src->buffer, dst);
1576
- }
1577
- else {
1578
- wsp_ggml_backend_tensor_copy(src, dst);
1579
- }
1580
-
1581
- // init src
1582
- for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
1583
- struct wsp_ggml_tensor * s = src->src[i];
1584
- if (s == NULL) {
1585
- break;
1586
- }
1587
- graph_init_tensor(hash_set, node_copies, node_init, s);
1588
- }
1589
- }
1590
-
1591
- struct wsp_ggml_backend_graph_copy wsp_ggml_backend_graph_copy(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * graph) {
1592
- struct wsp_ggml_hash_set hash_set = {
1593
- /* .size = */ graph->visited_hash_table.size,
1594
- /* .keys = */ calloc(sizeof(hash_set.keys[0]) * graph->visited_hash_table.size, 1)
1595
- };
1596
- struct wsp_ggml_tensor ** node_copies = calloc(sizeof(node_copies[0]) * hash_set.size, 1);
1597
- bool * node_init = calloc(sizeof(node_init[0]) * hash_set.size, 1);
1598
-
1599
- struct wsp_ggml_init_params params = {
1600
- /* .mem_size = */ wsp_ggml_tensor_overhead()*hash_set.size + wsp_ggml_graph_overhead_custom(graph->size, false),
1601
- /* .mem_buffer = */ NULL,
1602
- /* .no_alloc = */ true
1603
- };
1604
-
1605
- struct wsp_ggml_context * ctx_allocated = wsp_ggml_init(params);
1606
- struct wsp_ggml_context * ctx_unallocated = wsp_ggml_init(params);
1607
-
1608
- if (ctx_allocated == NULL || ctx_unallocated == NULL) {
1609
- fprintf(stderr, "failed to allocate context for graph copy\n");
1610
- free(hash_set.keys);
1611
- free(node_copies);
1612
- free(node_init);
1613
- wsp_ggml_free(ctx_allocated);
1614
- wsp_ggml_free(ctx_unallocated);
1615
- return (struct wsp_ggml_backend_graph_copy) {
1616
- /* .buffer = */ NULL,
1617
- /* .ctx_allocated = */ NULL,
1618
- /* .ctx_unallocated = */ NULL,
1619
- /* .graph = */ NULL,
1620
- };
1621
- }
1622
-
1623
- // dup nodes
1624
- for (int i = 0; i < graph->n_nodes; i++) {
1625
- struct wsp_ggml_tensor * node = graph->nodes[i];
1626
- graph_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
1627
- }
1628
-
1629
- // allocate nodes
1630
- wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
1631
- if (buffer == NULL) {
1632
- fprintf(stderr, "failed to allocate buffer for graph copy\n");
1633
- free(hash_set.keys);
1634
- free(node_copies);
1635
- free(node_init);
1636
- wsp_ggml_free(ctx_allocated);
1637
- wsp_ggml_free(ctx_unallocated);
1638
- return (struct wsp_ggml_backend_graph_copy) {
1639
- /* .buffer = */ NULL,
1640
- /* .ctx_allocated = */ NULL,
1641
- /* .ctx_unallocated = */ NULL,
1642
- /* .graph = */ NULL,
1643
- };
1644
- }
1645
-
1646
- //printf("copy buffer size: %zu MB\n", wsp_ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
1647
-
1648
- // copy data and init views
1649
- for (int i = 0; i < graph->n_nodes; i++) {
1650
- struct wsp_ggml_tensor * node = graph->nodes[i];
1651
- graph_init_tensor(hash_set, node_copies, node_init, node);
1652
- }
1653
-
1654
- // build graph copy
1655
- struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(ctx_allocated, graph->size, false);
1656
- for (int i = 0; i < graph->n_nodes; i++) {
1657
- struct wsp_ggml_tensor * node = graph->nodes[i];
1658
- struct wsp_ggml_tensor * node_copy = node_copies[wsp_ggml_hash_find(hash_set, node)];
1659
- graph_copy->nodes[i] = node_copy;
1660
- }
1661
- graph_copy->n_nodes = graph->n_nodes;
1662
-
1663
- free(hash_set.keys);
1664
- free(node_copies);
1665
- free(node_init);
1666
-
1667
- return (struct wsp_ggml_backend_graph_copy) {
1668
- /* .buffer = */ buffer,
1669
- /* .ctx_allocated = */ ctx_allocated,
1670
- /* .ctx_unallocated = */ ctx_unallocated,
1671
- /* .graph = */ graph_copy,
1672
- };
1673
- }
1674
-
1675
- void wsp_ggml_backend_graph_copy_free(struct wsp_ggml_backend_graph_copy copy) {
1676
- wsp_ggml_backend_buffer_free(copy.buffer);
1677
- wsp_ggml_free(copy.ctx_allocated);
1678
- wsp_ggml_free(copy.ctx_unallocated);
1679
- }
1680
-
1681
- bool wsp_ggml_backend_compare_graph_backend(wsp_ggml_backend_t backend1, wsp_ggml_backend_t backend2, struct wsp_ggml_cgraph * graph, wsp_ggml_backend_eval_callback callback, void * user_data) {
1682
- struct wsp_ggml_backend_graph_copy copy = wsp_ggml_backend_graph_copy(backend2, graph);
1683
- if (copy.buffer == NULL) {
1684
- return false;
1685
- }
1686
-
1687
- struct wsp_ggml_cgraph * g1 = graph;
1688
- struct wsp_ggml_cgraph * g2 = copy.graph;
1689
-
1690
- assert(g1->n_nodes == g2->n_nodes);
1691
-
1692
- for (int i = 0; i < g1->n_nodes; i++) {
1693
- //printf("eval %d/%d\n", i, g1->n_nodes);
1694
- struct wsp_ggml_tensor * t1 = g1->nodes[i];
1695
- struct wsp_ggml_tensor * t2 = g2->nodes[i];
1696
-
1697
- assert(t1->op == t2->op && wsp_ggml_are_same_layout(t1, t2));
1698
-
1699
- struct wsp_ggml_cgraph g1v = wsp_ggml_graph_view(g1, i, i + 1);
1700
- struct wsp_ggml_cgraph g2v = wsp_ggml_graph_view(g2, i, i + 1);
1701
-
1702
- wsp_ggml_backend_graph_compute(backend1, &g1v);
1703
- wsp_ggml_backend_graph_compute(backend2, &g2v);
1704
-
1705
- if (wsp_ggml_is_view_op(t1->op)) {
1706
- continue;
1707
- }
1708
-
1709
- // compare results, calculate rms etc
1710
- if (!callback(i, t1, t2, user_data)) {
1711
- break;
1712
- }
1713
- }
1714
-
1715
- wsp_ggml_backend_graph_copy_free(copy);
1716
-
1717
- return true;
1718
- }