whisper.rn 0.4.0-rc.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +44 -13
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -38
  7. package/android/src/main/jni.cpp +38 -1
  8. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  15. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  16. package/cpp/coreml/whisper-compat.h +10 -0
  17. package/cpp/coreml/whisper-compat.m +35 -0
  18. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  19. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  20. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  21. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  22. package/cpp/ggml-alloc.c +727 -517
  23. package/cpp/ggml-alloc.h +47 -65
  24. package/cpp/ggml-backend-impl.h +196 -57
  25. package/cpp/ggml-backend-reg.cpp +591 -0
  26. package/cpp/ggml-backend.cpp +2016 -0
  27. package/cpp/ggml-backend.h +234 -89
  28. package/cpp/ggml-common.h +1861 -0
  29. package/cpp/ggml-cpp.h +39 -0
  30. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  31. package/cpp/ggml-cpu/amx/amx.h +8 -0
  32. package/cpp/ggml-cpu/amx/common.h +91 -0
  33. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  34. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  35. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  36. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  37. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  38. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  39. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  40. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  41. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  42. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  43. package/cpp/ggml-cpu/binary-ops.h +16 -0
  44. package/cpp/ggml-cpu/common.h +72 -0
  45. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  46. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  47. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  48. package/cpp/ggml-cpu/ops.cpp +9085 -0
  49. package/cpp/ggml-cpu/ops.h +111 -0
  50. package/cpp/ggml-cpu/quants.c +1157 -0
  51. package/cpp/ggml-cpu/quants.h +89 -0
  52. package/cpp/ggml-cpu/repack.cpp +1570 -0
  53. package/cpp/ggml-cpu/repack.h +98 -0
  54. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  55. package/cpp/ggml-cpu/traits.cpp +36 -0
  56. package/cpp/ggml-cpu/traits.h +38 -0
  57. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  58. package/cpp/ggml-cpu/unary-ops.h +28 -0
  59. package/cpp/ggml-cpu/vec.cpp +321 -0
  60. package/cpp/ggml-cpu/vec.h +973 -0
  61. package/cpp/ggml-cpu.h +143 -0
  62. package/cpp/ggml-impl.h +525 -168
  63. package/cpp/ggml-metal-impl.h +622 -0
  64. package/cpp/ggml-metal.h +16 -14
  65. package/cpp/ggml-metal.m +5289 -1859
  66. package/cpp/ggml-opt.cpp +1037 -0
  67. package/cpp/ggml-opt.h +237 -0
  68. package/cpp/ggml-quants.c +2916 -6877
  69. package/cpp/ggml-quants.h +87 -249
  70. package/cpp/ggml-threading.cpp +12 -0
  71. package/cpp/ggml-threading.h +14 -0
  72. package/cpp/ggml-whisper-sim.metallib +0 -0
  73. package/cpp/ggml-whisper.metallib +0 -0
  74. package/cpp/ggml.c +3293 -16770
  75. package/cpp/ggml.h +778 -835
  76. package/cpp/gguf.cpp +1347 -0
  77. package/cpp/gguf.h +202 -0
  78. package/cpp/rn-whisper.cpp +84 -0
  79. package/cpp/rn-whisper.h +2 -0
  80. package/cpp/whisper-arch.h +197 -0
  81. package/cpp/whisper.cpp +3240 -944
  82. package/cpp/whisper.h +144 -31
  83. package/ios/CMakeLists.txt +95 -0
  84. package/ios/RNWhisper.h +5 -0
  85. package/ios/RNWhisper.mm +124 -37
  86. package/ios/RNWhisperAudioUtils.h +1 -0
  87. package/ios/RNWhisperAudioUtils.m +24 -13
  88. package/ios/RNWhisperContext.h +8 -2
  89. package/ios/RNWhisperContext.mm +42 -8
  90. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  129. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  130. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  131. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  132. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  133. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  134. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  135. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  174. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  175. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  176. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  177. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  178. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  179. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  180. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  181. package/jest/mock.js +14 -1
  182. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  183. package/lib/commonjs/index.js +48 -19
  184. package/lib/commonjs/index.js.map +1 -1
  185. package/lib/commonjs/version.json +1 -1
  186. package/lib/module/NativeRNWhisper.js.map +1 -1
  187. package/lib/module/index.js +48 -19
  188. package/lib/module/index.js.map +1 -1
  189. package/lib/module/version.json +1 -1
  190. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  191. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  192. package/lib/typescript/index.d.ts +25 -3
  193. package/lib/typescript/index.d.ts.map +1 -1
  194. package/package.json +15 -10
  195. package/src/NativeRNWhisper.ts +12 -3
  196. package/src/index.ts +63 -24
  197. package/src/version.json +1 -1
  198. package/whisper-rn.podspec +18 -18
  199. package/cpp/README.md +0 -4
  200. package/cpp/ggml-backend.c +0 -1718
  201. package/cpp/ggml-metal-whisper.metal +0 -5820
package/cpp/ggml-alloc.c CHANGED
@@ -14,76 +14,143 @@
14
14
 
15
15
  //#define WSP_GGML_ALLOCATOR_DEBUG
16
16
 
17
- //#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
17
+ //#define AT_PRINTF(...) WSP_GGML_LOG_DEBUG(__VA_ARGS__)
18
18
  #define AT_PRINTF(...)
19
19
 
20
- // TODO: WSP_GGML_PAD ?
20
+
21
+ static bool wsp_ggml_is_view(const struct wsp_ggml_tensor * t) {
22
+ return t->view_src != NULL;
23
+ }
24
+
25
+ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
26
+ if (a->type != b->type) {
27
+ return false;
28
+ }
29
+ for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
30
+ if (a->ne[i] != b->ne[i]) {
31
+ return false;
32
+ }
33
+ if (a->nb[i] != b->nb[i]) {
34
+ return false;
35
+ }
36
+ }
37
+ return true;
38
+ }
39
+
40
+ // ops that return true for this function must not use restrict pointers for their backend implementations
41
+ static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) {
42
+ switch (op) {
43
+ case WSP_GGML_OP_SCALE:
44
+ case WSP_GGML_OP_DIAG_MASK_ZERO:
45
+ case WSP_GGML_OP_DIAG_MASK_INF:
46
+ case WSP_GGML_OP_ADD:
47
+ case WSP_GGML_OP_ADD1:
48
+ case WSP_GGML_OP_SUB:
49
+ case WSP_GGML_OP_MUL:
50
+ case WSP_GGML_OP_DIV:
51
+ case WSP_GGML_OP_SQR:
52
+ case WSP_GGML_OP_SQRT:
53
+ case WSP_GGML_OP_LOG:
54
+ case WSP_GGML_OP_UNARY:
55
+ case WSP_GGML_OP_ROPE:
56
+ case WSP_GGML_OP_ROPE_BACK:
57
+ case WSP_GGML_OP_SILU_BACK:
58
+ case WSP_GGML_OP_RMS_NORM:
59
+ case WSP_GGML_OP_RMS_NORM_BACK:
60
+ case WSP_GGML_OP_SOFT_MAX:
61
+ case WSP_GGML_OP_SOFT_MAX_BACK:
62
+ return true;
63
+
64
+ default:
65
+ return false;
66
+ }
67
+ }
68
+
21
69
  static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
22
70
  assert(alignment && !(alignment & (alignment - 1))); // power of 2
23
71
  size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
24
72
  return offset + align;
25
73
  }
26
74
 
75
+ // tallocr
76
+
77
+ struct wsp_ggml_tallocr wsp_ggml_tallocr_new(wsp_ggml_backend_buffer_t buffer) {
78
+ void * base = wsp_ggml_backend_buffer_get_base(buffer);
79
+ size_t align = wsp_ggml_backend_buffer_get_alignment(buffer);
80
+
81
+ assert(align && !(align & (align - 1))); // power of 2
82
+
83
+ struct wsp_ggml_tallocr talloc = (struct wsp_ggml_tallocr) {
84
+ /*.buffer = */ buffer,
85
+ /*.base = */ base,
86
+ /*.alignment = */ align,
87
+ /*.offset = */ aligned_offset(base, 0, align),
88
+ };
89
+ return talloc;
90
+ }
91
+
92
+ enum wsp_ggml_status wsp_ggml_tallocr_alloc(struct wsp_ggml_tallocr * talloc, struct wsp_ggml_tensor * tensor) {
93
+ size_t size = wsp_ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
94
+ size = WSP_GGML_PAD(size, talloc->alignment);
95
+
96
+ if (talloc->offset + size > wsp_ggml_backend_buffer_get_size(talloc->buffer)) {
97
+ WSP_GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
98
+ __func__, tensor->name, size, wsp_ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
99
+ WSP_GGML_ABORT("not enough space in the buffer");
100
+ }
101
+
102
+ void * addr = (char *)wsp_ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
103
+ talloc->offset += size;
104
+
105
+ assert(((uintptr_t)addr % talloc->alignment) == 0);
106
+
107
+ return wsp_ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
108
+ }
109
+
110
+ // dynamic tensor allocator
111
+
27
112
  struct free_block {
28
- void * addr;
113
+ size_t offset;
29
114
  size_t size;
30
115
  };
31
116
 
32
- struct wsp_ggml_tallocr {
33
- struct wsp_ggml_backend_buffer * buffer;
34
- bool buffer_owned;
35
- void * base;
117
+ struct wsp_ggml_dyn_tallocr {
36
118
  size_t alignment;
37
-
38
119
  int n_free_blocks;
39
120
  struct free_block free_blocks[MAX_FREE_BLOCKS];
40
-
41
121
  size_t max_size;
42
122
 
43
- bool measure;
44
-
45
123
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
46
- struct wsp_ggml_tensor * allocated_tensors[1024];
124
+ struct {
125
+ const struct wsp_ggml_tensor * tensor;
126
+ size_t offset;
127
+ } allocated_tensors[1024];
47
128
  #endif
48
129
  };
49
130
 
50
131
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
51
- static void add_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
132
+ static void add_allocated_tensor(struct wsp_ggml_dyn_tallocr * alloc, size_t offset, const struct wsp_ggml_tensor * tensor) {
52
133
  for (int i = 0; i < 1024; i++) {
53
- if (alloc->allocated_tensors[i] == NULL) {
54
- alloc->allocated_tensors[i] = tensor;
134
+ if (alloc->allocated_tensors[i].tensor == NULL) {
135
+ alloc->allocated_tensors[i].tensor = tensor;
136
+ alloc->allocated_tensors[i].offset = offset;
55
137
  return;
56
138
  }
57
139
  }
58
- WSP_GGML_ASSERT(!"out of allocated_tensors");
140
+ WSP_GGML_ABORT("out of allocated_tensors");
59
141
  }
60
- static void remove_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
142
+ static void remove_allocated_tensor(struct wsp_ggml_dyn_tallocr * alloc, size_t offset, const struct wsp_ggml_tensor * tensor) {
61
143
  for (int i = 0; i < 1024; i++) {
62
- if (alloc->allocated_tensors[i] == tensor ||
63
- (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
64
- alloc->allocated_tensors[i] = NULL;
144
+ if (alloc->allocated_tensors[i].offset == offset) {
145
+ alloc->allocated_tensors[i].tensor = NULL;
65
146
  return;
66
147
  }
67
148
  }
68
- printf("tried to free tensor %s not found\n", tensor->name);
69
- WSP_GGML_ASSERT(!"tensor not found");
149
+ WSP_GGML_ABORT("tried to free tensor %s not found\n", tensor->name);
70
150
  }
71
151
  #endif
72
152
 
73
- // check if a tensor is allocated by this buffer
74
- static bool wsp_ggml_tallocr_is_own(wsp_ggml_tallocr_t alloc, const struct wsp_ggml_tensor * tensor) {
75
- return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
76
- }
77
-
78
- static bool wsp_ggml_is_view(struct wsp_ggml_tensor * t) {
79
- return t->view_src != NULL;
80
- }
81
-
82
- void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
83
- WSP_GGML_ASSERT(!wsp_ggml_is_view(tensor)); // views generally get data pointer from one of their sources
84
- WSP_GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
85
-
86
- size_t size = wsp_ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
153
+ static size_t wsp_ggml_dyn_tallocr_alloc(struct wsp_ggml_dyn_tallocr * alloc, size_t size, const struct wsp_ggml_tensor * tensor) {
87
154
  size = aligned_offset(NULL, size, alloc->alignment);
88
155
 
89
156
  AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
@@ -109,16 +176,16 @@ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * t
109
176
  if (block->size >= size) {
110
177
  best_fit_block = alloc->n_free_blocks - 1;
111
178
  } else {
112
- fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
179
+ // this should never happen
180
+ WSP_GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
113
181
  __func__, size, max_avail);
114
- WSP_GGML_ASSERT(!"not enough space in the buffer");
115
- return;
182
+ WSP_GGML_ABORT("not enough space in the buffer");
116
183
  }
117
184
  }
118
185
 
119
186
  struct free_block * block = &alloc->free_blocks[best_fit_block];
120
- void * addr = block->addr;
121
- block->addr = (char*)block->addr + size;
187
+ size_t offset = block->offset;
188
+ block->offset = offset + size;
122
189
  block->size -= size;
123
190
  if (block->size == 0) {
124
191
  // remove block if empty
@@ -128,59 +195,63 @@ void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * t
128
195
  }
129
196
  }
130
197
 
131
- AT_PRINTF("block %d, addr %p\n", best_fit_block, addr);
132
-
133
- tensor->data = addr;
134
- tensor->buffer = alloc->buffer;
135
- if (!alloc->measure) {
136
- wsp_ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
137
- }
198
+ AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset);
138
199
 
139
200
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
140
- add_allocated_tensor(alloc, tensor);
141
- size_t cur_max = (char*)addr - (char*)alloc->base + size;
201
+ add_allocated_tensor(alloc, offset, tensor);
202
+ size_t cur_max = offset + size;
142
203
  if (cur_max > alloc->max_size) {
143
- printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
204
+ // sort allocated_tensors by offset
205
+ for (int i = 0; i < 1024; i++) {
206
+ for (int j = i + 1; j < 1024; j++) {
207
+ if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) {
208
+ const struct wsp_ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor;
209
+ size_t tmp_offset = alloc->allocated_tensors[i].offset;
210
+ alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor;
211
+ alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset;
212
+ alloc->allocated_tensors[j].tensor = tmp_tensor;
213
+ alloc->allocated_tensors[j].offset = tmp_offset;
214
+ }
215
+ }
216
+ }
217
+ WSP_GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
144
218
  for (int i = 0; i < 1024; i++) {
145
- if (alloc->allocated_tensors[i]) {
146
- printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, wsp_ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
219
+ if (alloc->allocated_tensors[i].tensor) {
220
+ WSP_GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
221
+ alloc->allocated_tensors[i].offset,
222
+ alloc->allocated_tensors[i].offset + wsp_ggml_nbytes(alloc->allocated_tensors[i].tensor),
223
+ wsp_ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
147
224
  }
148
225
  }
149
- printf("\n");
226
+ WSP_GGML_LOG_DEBUG("\n");
150
227
  }
151
228
  #endif
152
229
 
153
- alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
154
- }
230
+ alloc->max_size = MAX(alloc->max_size, offset + size);
155
231
 
156
- // this is a very naive implementation, but for our case the number of free blocks should be very small
157
- static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) {
158
- if (wsp_ggml_tallocr_is_own(alloc, tensor) == false) {
159
- // the tensor was not allocated in this buffer
160
- // this can happen because the graph allocator will try to free weights and other tensors from different buffers
161
- // the easiest way to deal with this is just to ignore it
162
- // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
163
- return;
164
- }
232
+ return offset;
165
233
 
166
- void * ptr = tensor->data;
234
+ WSP_GGML_UNUSED(tensor);
235
+ }
167
236
 
168
- size_t size = wsp_ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor);
237
+ // this is a very naive implementation, but for our case the number of free blocks should be very small
238
+ static void wsp_ggml_dyn_tallocr_free_tensor(struct wsp_ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct wsp_ggml_tensor * tensor) {
169
239
  size = aligned_offset(NULL, size, alloc->alignment);
170
- AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
240
+
241
+ AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks);
171
242
 
172
243
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
173
- remove_allocated_tensor(alloc, tensor);
244
+ remove_allocated_tensor(alloc, offset, tensor);
174
245
  #endif
175
246
 
176
247
  // see if we can merge with an existing block
177
248
  for (int i = 0; i < alloc->n_free_blocks; i++) {
178
249
  struct free_block * block = &alloc->free_blocks[i];
179
250
  // check if ptr is at the end of the block
180
- if ((char*)block->addr + block->size == ptr) {
251
+ if (block->offset + block->size == offset) {
181
252
  block->size += size;
182
253
  // check if we can merge with the next block
183
- if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
254
+ if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) {
184
255
  block->size += alloc->free_blocks[i+1].size;
185
256
  alloc->n_free_blocks--;
186
257
  for (int j = i+1; j < alloc->n_free_blocks; j++) {
@@ -190,11 +261,11 @@ static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_gg
190
261
  return;
191
262
  }
192
263
  // check if ptr is at the beginning of the block
193
- if ((char*)ptr + size == block->addr) {
194
- block->addr = ptr;
264
+ if (offset + size == block->offset) {
265
+ block->offset = offset;
195
266
  block->size += size;
196
267
  // check if we can merge with the previous block
197
- if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
268
+ if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) {
198
269
  alloc->free_blocks[i-1].size += block->size;
199
270
  alloc->n_free_blocks--;
200
271
  for (int j = i; j < alloc->n_free_blocks; j++) {
@@ -208,7 +279,7 @@ static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_gg
208
279
  WSP_GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
209
280
  // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
210
281
  int insert_pos = 0;
211
- while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
282
+ while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) {
212
283
  insert_pos++;
213
284
  }
214
285
  // shift all blocks from insert_pos onward to make room for the new block
@@ -216,614 +287,753 @@ static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_gg
216
287
  alloc->free_blocks[i] = alloc->free_blocks[i-1];
217
288
  }
218
289
  // insert the new block
219
- alloc->free_blocks[insert_pos].addr = ptr;
290
+ alloc->free_blocks[insert_pos].offset = offset;
220
291
  alloc->free_blocks[insert_pos].size = size;
221
292
  alloc->n_free_blocks++;
293
+
294
+ WSP_GGML_UNUSED(tensor);
222
295
  }
223
296
 
224
- void wsp_ggml_tallocr_reset(wsp_ggml_tallocr_t alloc) {
297
+ static void wsp_ggml_dyn_tallocr_reset(struct wsp_ggml_dyn_tallocr * alloc) {
225
298
  alloc->n_free_blocks = 1;
226
- size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
227
- alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
299
+ alloc->free_blocks[0].offset = 0;
300
+ alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
301
+ alloc->max_size = 0;
228
302
 
229
- if (alloc->measure) {
230
- alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
231
- } else {
232
- alloc->free_blocks[0].size = wsp_ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
233
- wsp_ggml_backend_buffer_reset(alloc->buffer);
303
+ #ifdef WSP_GGML_ALLOCATOR_DEBUG
304
+ for (int i = 0; i < 1024; i++) {
305
+ alloc->allocated_tensors[i].tensor = NULL;
234
306
  }
307
+ #endif
235
308
  }
236
309
 
237
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new(void * data, size_t size, size_t alignment) {
238
- struct wsp_ggml_backend_buffer * buffer = wsp_ggml_backend_cpu_buffer_from_ptr(data, size);
310
+ static struct wsp_ggml_dyn_tallocr * wsp_ggml_dyn_tallocr_new(size_t alignment) {
311
+ struct wsp_ggml_dyn_tallocr * alloc = (struct wsp_ggml_dyn_tallocr *)malloc(sizeof(struct wsp_ggml_dyn_tallocr));
239
312
 
240
- wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr));
241
-
242
- *alloc = (struct wsp_ggml_tallocr) {
243
- /*.buffer = */ buffer,
244
- /*.buffer_owned = */ true,
245
- /*.base = */ wsp_ggml_backend_buffer_get_base(buffer),
313
+ *alloc = (struct wsp_ggml_dyn_tallocr) {
246
314
  /*.alignment = */ alignment,
247
315
  /*.n_free_blocks = */ 0,
248
316
  /*.free_blocks = */ {{0}},
249
317
  /*.max_size = */ 0,
250
- /*.measure = */ false,
251
318
  #ifdef WSP_GGML_ALLOCATOR_DEBUG
252
- /*.allocated_tensors = */ {0},
319
+ /*.allocated_tensors = */ {{0}},
253
320
  #endif
254
321
  };
255
322
 
256
- wsp_ggml_tallocr_reset(alloc);
323
+ wsp_ggml_dyn_tallocr_reset(alloc);
257
324
 
258
325
  return alloc;
259
326
  }
260
327
 
261
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure(size_t alignment) {
262
- wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
263
- alloc->measure = true;
328
+ static void wsp_ggml_dyn_tallocr_free(struct wsp_ggml_dyn_tallocr * alloc) {
329
+ free(alloc);
330
+ }
264
331
 
265
- return alloc;
332
+ static size_t wsp_ggml_dyn_tallocr_max_size(struct wsp_ggml_dyn_tallocr * alloc) {
333
+ return alloc->max_size;
266
334
  }
267
335
 
268
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_buft(struct wsp_ggml_backend_buffer_type * buft) {
269
- // create a backend buffer to get the correct tensor allocation sizes
270
- wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, 1);
271
336
 
272
- // TODO: move alloc initialization to a common wsp_ggml_tallocr_new_impl function
273
- wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer);
274
- alloc->buffer_owned = true;
275
- alloc->measure = true;
276
- wsp_ggml_tallocr_reset(alloc);
277
- return alloc;
278
- }
337
+ /////////////////////////////////////
279
338
 
280
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend) {
281
- return wsp_ggml_tallocr_new_measure_from_buft(wsp_ggml_backend_get_default_buffer_type(backend));
282
- }
339
+ // graph allocator
283
340
 
284
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buft(struct wsp_ggml_backend_buffer_type * buft, size_t size) {
285
- // create a backend buffer to get the correct tensor allocation sizes
286
- wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, size);
287
- wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer);
288
- alloc->buffer_owned = true;
289
- return alloc;
290
- }
341
+ struct hash_node {
342
+ int n_children;
343
+ int n_views;
344
+ int buffer_id;
345
+ size_t offset; // offset within the buffer
346
+ bool allocated;
347
+ };
291
348
 
292
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) {
293
- return wsp_ggml_tallocr_new_from_buft(wsp_ggml_backend_get_default_buffer_type(backend), size);
294
- }
349
+ struct tensor_alloc {
350
+ int buffer_id;
351
+ size_t offset;
352
+ size_t size_max; // 0 = pre-allocated, unused, or view
353
+ };
295
354
 
296
- wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) {
297
- wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr));
355
+ struct leaf_alloc {
356
+ struct tensor_alloc leaf;
357
+ };
298
358
 
299
- *alloc = (struct wsp_ggml_tallocr) {
300
- /*.buffer = */ buffer,
301
- /*.buffer_owned = */ false,
302
- /*.base = */ wsp_ggml_backend_buffer_get_base(buffer),
303
- /*.alignment = */ wsp_ggml_backend_buffer_get_alignment(buffer),
304
- /*.n_free_blocks = */ 0,
305
- /*.free_blocks = */ {{0}},
306
- /*.max_size = */ 0,
307
- /*.measure = */ false,
308
- #ifdef WSP_GGML_ALLOCATOR_DEBUG
309
- /*.allocated_tensors = */ {0},
310
- #endif
311
- };
359
+ struct node_alloc {
360
+ struct tensor_alloc dst;
361
+ struct tensor_alloc src[WSP_GGML_MAX_SRC];
362
+ };
312
363
 
313
- wsp_ggml_tallocr_reset(alloc);
364
+ struct wsp_ggml_gallocr {
365
+ wsp_ggml_backend_buffer_type_t * bufts; // [n_buffers]
366
+ wsp_ggml_backend_buffer_t * buffers; // [n_buffers]
367
+ struct wsp_ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
368
+ int n_buffers;
314
369
 
315
- return alloc;
316
- }
370
+ struct wsp_ggml_hash_set hash_set;
371
+ struct hash_node * hash_values; // [hash_set.size]
317
372
 
318
- struct wsp_ggml_backend_buffer * wsp_ggml_tallocr_get_buffer(wsp_ggml_tallocr_t alloc) {
319
- return alloc->buffer;
320
- }
373
+ struct node_alloc * node_allocs; // [n_nodes]
374
+ int n_nodes;
321
375
 
322
- void wsp_ggml_tallocr_free(wsp_ggml_tallocr_t alloc) {
323
- if (alloc == NULL) {
324
- return;
325
- }
376
+ struct leaf_alloc * leaf_allocs; // [n_leafs]
377
+ int n_leafs;
378
+ };
326
379
 
327
- if (alloc->buffer_owned) {
328
- wsp_ggml_backend_buffer_free(alloc->buffer);
329
- }
330
- free(alloc);
331
- }
380
+ wsp_ggml_gallocr_t wsp_ggml_gallocr_new_n(wsp_ggml_backend_buffer_type_t * bufts, int n_bufs) {
381
+ wsp_ggml_gallocr_t galloc = (wsp_ggml_gallocr_t)calloc(1, sizeof(struct wsp_ggml_gallocr));
382
+ WSP_GGML_ASSERT(galloc != NULL);
332
383
 
333
- bool wsp_ggml_tallocr_is_measure(wsp_ggml_tallocr_t alloc) {
334
- return alloc->measure;
335
- }
384
+ galloc->bufts = calloc(n_bufs, sizeof(wsp_ggml_backend_buffer_type_t));
385
+ WSP_GGML_ASSERT(galloc->bufts != NULL);
336
386
 
337
- size_t wsp_ggml_tallocr_max_size(wsp_ggml_tallocr_t alloc) {
338
- return alloc->max_size;
339
- }
387
+ galloc->buffers = calloc(n_bufs, sizeof(wsp_ggml_backend_buffer_t));
388
+ WSP_GGML_ASSERT(galloc->buffers != NULL);
340
389
 
341
- // graph allocator
390
+ galloc->buf_tallocs = calloc(n_bufs, sizeof(struct wsp_ggml_dyn_tallocr *));
391
+ WSP_GGML_ASSERT(galloc->buf_tallocs != NULL);
342
392
 
343
- struct hash_node {
344
- int n_children;
345
- int n_views;
346
- };
393
+ for (int i = 0; i < n_bufs; i++) {
394
+ galloc->bufts[i] = bufts[i];
395
+ galloc->buffers[i] = NULL;
347
396
 
348
- struct wsp_ggml_gallocr {
349
- wsp_ggml_tallocr_t talloc;
350
- struct wsp_ggml_hash_set hash_set;
351
- struct hash_node * hash_values;
352
- size_t hash_values_size;
353
- wsp_ggml_tallocr_t * hash_allocs;
354
- int * parse_seq;
355
- int parse_seq_len;
356
- };
397
+ // check if the same buffer type is used multiple times and reuse the same allocator
398
+ for (int j = 0; j < i; j++) {
399
+ if (bufts[i] == bufts[j]) {
400
+ galloc->buf_tallocs[i] = galloc->buf_tallocs[j];
401
+ break;
402
+ }
403
+ }
357
404
 
358
- wsp_ggml_gallocr_t wsp_ggml_gallocr_new(void) {
359
- wsp_ggml_gallocr_t galloc = (wsp_ggml_gallocr_t)malloc(sizeof(struct wsp_ggml_gallocr));
360
-
361
- *galloc = (struct wsp_ggml_gallocr) {
362
- /*.talloc = */ NULL,
363
- /*.hash_set = */ {0},
364
- /*.hash_values = */ NULL,
365
- /*.hash_values_size = */ 0,
366
- /*.hash_allocs = */ NULL,
367
- /*.parse_seq = */ NULL,
368
- /*.parse_seq_len = */ 0,
369
- };
405
+ if (galloc->buf_tallocs[i] == NULL) {
406
+ size_t alignment = wsp_ggml_backend_buft_get_alignment(bufts[i]);
407
+ galloc->buf_tallocs[i] = wsp_ggml_dyn_tallocr_new(alignment);
408
+ }
409
+ }
410
+ galloc->n_buffers = n_bufs;
370
411
 
371
412
  return galloc;
372
413
  }
373
414
 
415
+ wsp_ggml_gallocr_t wsp_ggml_gallocr_new(wsp_ggml_backend_buffer_type_t buft) {
416
+ return wsp_ggml_gallocr_new_n(&buft, 1);
417
+ }
418
+
374
419
  void wsp_ggml_gallocr_free(wsp_ggml_gallocr_t galloc) {
375
420
  if (galloc == NULL) {
376
421
  return;
377
422
  }
378
423
 
379
- if (galloc->hash_set.keys != NULL) {
380
- free(galloc->hash_set.keys);
381
- }
382
- if (galloc->hash_values != NULL) {
383
- free(galloc->hash_values);
384
- }
385
- if (galloc->hash_allocs != NULL) {
386
- free(galloc->hash_allocs);
387
- }
388
- if (galloc->parse_seq != NULL) {
389
- free(galloc->parse_seq);
424
+ for (int i = 0; i < galloc->n_buffers; i++) {
425
+ if (galloc->buffers != NULL) {
426
+ // skip if already freed
427
+ bool freed = false;
428
+ for (int j = 0; j < i; j++) {
429
+ if (galloc->buffers[j] == galloc->buffers[i]) {
430
+ freed = true;
431
+ break;
432
+ }
433
+ }
434
+ if (!freed) {
435
+ wsp_ggml_backend_buffer_free(galloc->buffers[i]);
436
+ }
437
+ }
438
+ if (galloc->buf_tallocs != NULL) {
439
+ // skip if already freed
440
+ bool freed = false;
441
+ for (int j = 0; j < i; j++) {
442
+ if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
443
+ freed = true;
444
+ break;
445
+ }
446
+ }
447
+ if (!freed) {
448
+ wsp_ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
449
+ }
450
+ }
390
451
  }
452
+
453
+ wsp_ggml_hash_set_free(&galloc->hash_set);
454
+ free(galloc->hash_values);
455
+ free(galloc->bufts);
456
+ free(galloc->buffers);
457
+ free(galloc->buf_tallocs);
458
+ free(galloc->node_allocs);
459
+ free(galloc->leaf_allocs);
391
460
  free(galloc);
392
461
  }
393
462
 
394
- void wsp_ggml_gallocr_set_parse_seq(wsp_ggml_gallocr_t galloc, const int * list, int n) {
395
- free(galloc->parse_seq);
396
- galloc->parse_seq = malloc(sizeof(int) * n);
397
-
398
- for (int i = 0; i < n; i++) {
399
- galloc->parse_seq[i] = list[i];
400
- }
401
- galloc->parse_seq_len = n;
402
- }
463
+ typedef struct wsp_ggml_gallocr * wsp_ggml_gallocr_t;
403
464
 
404
- static struct hash_node * hash_get(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) {
405
- size_t i = wsp_ggml_hash_find_or_insert(galloc->hash_set, t);
465
+ static struct hash_node * wsp_ggml_gallocr_hash_get(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) {
466
+ size_t i = wsp_ggml_hash_find_or_insert(&galloc->hash_set, t);
406
467
  return &galloc->hash_values[i];
407
468
  }
408
469
 
409
- static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
410
- if (a->type != b->type) {
411
- return false;
412
- }
413
- for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
414
- if (a->ne[i] != b->ne[i]) {
415
- return false;
416
- }
417
- if (a->nb[i] != b->nb[i]) {
418
- return false;
419
- }
420
- }
421
- return true;
470
+ static bool wsp_ggml_gallocr_is_own(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) {
471
+ return wsp_ggml_gallocr_hash_get(galloc, t)->allocated;
422
472
  }
423
473
 
424
- static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) {
425
- switch (op) {
426
- case WSP_GGML_OP_SCALE:
427
- case WSP_GGML_OP_DIAG_MASK_ZERO:
428
- case WSP_GGML_OP_DIAG_MASK_INF:
429
- case WSP_GGML_OP_ADD:
430
- case WSP_GGML_OP_ADD1:
431
- case WSP_GGML_OP_SUB:
432
- case WSP_GGML_OP_MUL:
433
- case WSP_GGML_OP_DIV:
434
- case WSP_GGML_OP_SQR:
435
- case WSP_GGML_OP_SQRT:
436
- case WSP_GGML_OP_LOG:
437
- case WSP_GGML_OP_UNARY:
438
- case WSP_GGML_OP_ROPE:
439
- case WSP_GGML_OP_RMS_NORM:
440
- case WSP_GGML_OP_SOFT_MAX:
441
- return true;
442
-
443
- default:
444
- return false;
445
- }
446
- }
447
-
448
- static wsp_ggml_tallocr_t node_tallocr(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
449
- if (galloc->talloc != NULL) {
450
- return galloc->talloc;
451
- }
452
-
453
- return galloc->hash_allocs[wsp_ggml_hash_find_or_insert(galloc->hash_set, node)];
474
+ static bool wsp_ggml_gallocr_is_allocated(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) {
475
+ return t->data != NULL || wsp_ggml_gallocr_hash_get(galloc, t)->allocated;
454
476
  }
455
477
 
456
- static void init_view(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * view, bool update_backend) {
457
- wsp_ggml_tallocr_t alloc = node_tallocr(galloc, view);
458
-
459
- WSP_GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
460
- if (update_backend) {
461
- view->backend = view->view_src->backend;
462
- }
463
- // views are initialized in the alloc buffer rather than the view_src buffer
464
- view->buffer = alloc->buffer;
465
- view->data = (char *)view->view_src->data + view->view_offs;
478
+ static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, int buffer_id) {
479
+ WSP_GGML_ASSERT(buffer_id >= 0);
480
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
466
481
 
467
- assert(wsp_ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
482
+ if (!wsp_ggml_gallocr_is_allocated(galloc, node) && !wsp_ggml_is_view(node)) {
483
+ hn->allocated = true;
484
+ assert(hn->offset == 0);
468
485
 
469
- if (!alloc->measure) {
470
- wsp_ggml_backend_buffer_init_tensor(alloc->buffer, view);
471
- }
472
- }
486
+ // try to reuse a parent's buffer (inplace)
487
+ if (wsp_ggml_op_can_inplace(node->op)) {
488
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
489
+ struct wsp_ggml_tensor * parent = node->src[i];
490
+ if (parent == NULL) {
491
+ continue;
492
+ }
473
493
 
474
- static void allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
475
- wsp_ggml_tallocr_t alloc = node_tallocr(galloc, node);
494
+ // if the node's data is external, then we cannot re-use it
495
+ if (!wsp_ggml_gallocr_is_own(galloc, parent)) {
496
+ AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
497
+ continue;
498
+ }
476
499
 
477
- if (node->data == NULL) {
478
- if (wsp_ggml_is_view(node)) {
479
- init_view(galloc, node, true);
480
- } else {
481
- // see if we can reuse a parent's buffer (inplace)
482
- if (wsp_ggml_op_can_inplace(node->op)) {
483
- for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
484
- struct wsp_ggml_tensor * parent = node->src[i];
485
- if (parent == NULL) {
486
- break;
487
- }
500
+ // outputs cannot be reused
501
+ if (parent->flags & WSP_GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & WSP_GGML_TENSOR_FLAG_OUTPUT)) {
502
+ AT_PRINTF("not reusing parent %s for %s as it is an output\n", parent->name, node->name);
503
+ continue;
504
+ }
488
505
 
489
- // if the node's data is external, then we cannot re-use it
490
- if (wsp_ggml_tallocr_is_own(alloc, parent) == false) {
491
- AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
492
- continue;
493
- }
506
+ if (!wsp_ggml_are_same_layout(node, parent)) {
507
+ AT_PRINTF("not reusing parent %s for %s as layouts are different\n", parent->name, node->name);
508
+ continue;
509
+ }
494
510
 
495
- struct hash_node * p_hn = hash_get(galloc, parent);
496
- if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && wsp_ggml_are_same_layout(node, parent)) {
497
- if (wsp_ggml_is_view(parent)) {
498
- struct wsp_ggml_tensor * view_src = parent->view_src;
499
- struct hash_node * view_src_hn = hash_get(galloc, view_src);
500
- if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
501
- // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
502
- // the parent's data that it will need later (same layout requirement). the problem is that then
503
- // we cannot free the tensor because the original address of the allocation is lost.
504
- // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
505
- // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
506
- AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
507
- node->view_src = view_src;
508
- view_src_hn->n_views += 1;
509
- init_view(galloc, node, false);
510
- return;
511
- }
512
- } else {
513
- AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
514
- node->view_src = parent;
515
- p_hn->n_views += 1;
516
- init_view(galloc, node, false);
511
+ struct hash_node * p_hn = wsp_ggml_gallocr_hash_get(galloc, parent);
512
+ if (p_hn->n_children == 1 && p_hn->n_views == 0) {
513
+ if (wsp_ggml_is_view(parent)) {
514
+ struct wsp_ggml_tensor * view_src = parent->view_src;
515
+ struct hash_node * view_src_hn = wsp_ggml_gallocr_hash_get(galloc, view_src);
516
+ if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
517
+ AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
518
+ assert(view_src_hn->offset == p_hn->offset);
519
+ hn->buffer_id = p_hn->buffer_id;
520
+ hn->offset = p_hn->offset;
521
+ p_hn->allocated = false; // avoid freeing the parent
522
+ view_src_hn->allocated = false;
517
523
  return;
518
524
  }
525
+ } else {
526
+ AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
527
+ hn->buffer_id = p_hn->buffer_id;
528
+ hn->offset = p_hn->offset;
529
+ p_hn->allocated = false; // avoid freeing the parent
530
+ return;
519
531
  }
520
532
  }
521
533
  }
522
- wsp_ggml_tallocr_alloc(alloc, node);
523
534
  }
535
+ // allocate tensor from the buffer
536
+ struct wsp_ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
537
+ wsp_ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
538
+ size_t size = wsp_ggml_backend_buft_get_alloc_size(buft, node);
539
+ size_t offset = wsp_ggml_dyn_tallocr_alloc(alloc, size, node);
540
+ hn->buffer_id = buffer_id;
541
+ hn->offset = offset;
524
542
  }
525
543
  }
526
544
 
527
- static void free_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
528
- wsp_ggml_tallocr_t alloc = node_tallocr(galloc, node);
545
+ static void wsp_ggml_gallocr_free_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) {
546
+ // graph outputs are never freed
547
+ if (node->flags & WSP_GGML_TENSOR_FLAG_OUTPUT) {
548
+ AT_PRINTF("not freeing output %s\n", node->name);
549
+ return;
550
+ }
551
+
552
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
553
+ size_t offset = hn->offset;
554
+ int buffer_id = hn->buffer_id;
555
+ struct wsp_ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
556
+ wsp_ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
557
+ size_t size = wsp_ggml_backend_buft_get_alloc_size(buft, node);
558
+ wsp_ggml_dyn_tallocr_free_tensor(alloc, offset, size, node);
559
+ hn->allocated = false;
560
+ }
529
561
 
530
- wsp_ggml_tallocr_free_tensor(alloc, node);
562
+ static int get_node_buffer_id(const int * node_buffer_ids, int i) {
563
+ return node_buffer_ids ? node_buffer_ids[i] : 0;
531
564
  }
532
565
 
533
- static void wsp_ggml_tallocr_alloc_graph_impl(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * gf) {
534
- const int * parse_seq = galloc->parse_seq;
535
- int parse_seq_len = galloc->parse_seq_len;
566
+ static void wsp_ggml_gallocr_alloc_graph_impl(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
567
+ // clear hash tables
568
+ wsp_ggml_hash_set_reset(&galloc->hash_set);
569
+ memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);
536
570
 
537
- // count number of children and views
538
- for (int i = 0; i < gf->n_nodes; i++) {
539
- struct wsp_ggml_tensor * node = gf->nodes[i];
571
+ // allocate leafs
572
+ // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes
573
+ for (int i = 0; i < graph->n_leafs; i++) {
574
+ struct wsp_ggml_tensor * leaf = graph->leafs[i];
575
+ wsp_ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i));
576
+ }
540
577
 
541
- if (wsp_ggml_is_view(node)) {
578
+ // count number of children and views
579
+ // allocate other graph inputs and leafs first to avoid overwriting them
580
+ for (int i = 0; i < graph->n_nodes; i++) {
581
+ struct wsp_ggml_tensor * node = graph->nodes[i];
582
+
583
+ // TODO: better way to add external dependencies
584
+ // WSP_GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
585
+ // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
586
+ // itself is never used and should not be considered a dependency
587
+ if (wsp_ggml_is_view(node) && node->op != WSP_GGML_OP_NONE) {
542
588
  struct wsp_ggml_tensor * view_src = node->view_src;
543
- hash_get(galloc, view_src)->n_views += 1;
544
- if (node->buffer == NULL && node->data != NULL) {
545
- // view of a pre-allocated tensor, didn't call init_view() yet
546
- init_view(galloc, node, true);
547
- }
589
+ wsp_ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
590
+ }
591
+
592
+ if (node->flags & WSP_GGML_TENSOR_FLAG_INPUT) {
593
+ wsp_ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i));
548
594
  }
549
595
 
550
596
  for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
551
- struct wsp_ggml_tensor * parent = node->src[j];
552
- if (parent == NULL) {
553
- break;
597
+ struct wsp_ggml_tensor * src = node->src[j];
598
+ if (src == NULL) {
599
+ continue;
554
600
  }
555
- hash_get(galloc, parent)->n_children += 1;
556
- if (wsp_ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
557
- init_view(galloc, parent, true);
601
+
602
+ wsp_ggml_gallocr_hash_get(galloc, src)->n_children += 1;
603
+
604
+ // allocate explicit inputs
605
+ if (src->flags & WSP_GGML_TENSOR_FLAG_INPUT) {
606
+ wsp_ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i));
558
607
  }
559
608
  }
560
- }
609
+ }
561
610
 
562
611
  // allocate tensors
563
- // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
564
- int last_barrier_pos = 0;
565
- int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
566
-
567
- for (int ind = 0; ind < n_nodes; ind++) {
568
- // allocate a node if there is no parse_seq or this is not a barrier
569
- if (parse_seq_len == 0 || parse_seq[ind] != -1) {
570
- int i = parse_seq_len ? parse_seq[ind] : ind;
571
- struct wsp_ggml_tensor * node = gf->nodes[i];
572
-
573
- // allocate parents (leafs)
574
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
575
- struct wsp_ggml_tensor * parent = node->src[j];
576
- if (parent == NULL) {
577
- break;
578
- }
579
- allocate_node(galloc, parent);
612
+ for (int i = 0; i < graph->n_nodes; i++) {
613
+ struct wsp_ggml_tensor * node = graph->nodes[i];
614
+ int buffer_id = get_node_buffer_id(node_buffer_ids, i);
615
+
616
+ // allocate parents (only leafs need to be allocated at this point)
617
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
618
+ struct wsp_ggml_tensor * parent = node->src[j];
619
+ if (parent == NULL) {
620
+ continue;
580
621
  }
622
+ wsp_ggml_gallocr_allocate_node(galloc, parent, buffer_id);
623
+ }
581
624
 
582
- // allocate node
583
- allocate_node(galloc, node);
625
+ // allocate node
626
+ wsp_ggml_gallocr_allocate_node(galloc, node, buffer_id);
584
627
 
585
- AT_PRINTF("exec: %s (%s) <= ", wsp_ggml_op_name(node->op), node->name);
586
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
587
- struct wsp_ggml_tensor * parent = node->src[j];
588
- if (parent == NULL) {
589
- break;
590
- }
591
- AT_PRINTF("%s", parent->name);
592
- if (j < WSP_GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
593
- AT_PRINTF(", ");
594
- }
628
+ AT_PRINTF("exec: %s (%s) <= ", wsp_ggml_op_desc(node), node->name);
629
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
630
+ struct wsp_ggml_tensor * parent = node->src[j];
631
+ if (parent == NULL) {
632
+ continue;
633
+ }
634
+ AT_PRINTF("%s", parent->name);
635
+ if (j < WSP_GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
636
+ AT_PRINTF(", ");
595
637
  }
596
- AT_PRINTF("\n");
597
638
  }
639
+ AT_PRINTF("\n");
598
640
 
599
641
  // update parents
600
- // update immediately if there is no parse_seq
601
- // update only at barriers if there is parse_seq
602
- if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
603
- int update_start = parse_seq_len ? last_barrier_pos : ind;
604
- int update_end = parse_seq_len ? ind : ind + 1;
605
- for (int i = update_start; i < update_end; i++) {
606
- int node_i = parse_seq_len ? parse_seq[i] : i;
607
- struct wsp_ggml_tensor * node = gf->nodes[node_i];
608
-
609
- for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
610
- struct wsp_ggml_tensor * parent = node->src[j];
611
- if (parent == NULL) {
612
- break;
613
- }
614
- struct hash_node * p_hn = hash_get(galloc, parent);
615
- p_hn->n_children -= 1;
616
-
617
- //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
618
-
619
- if (p_hn->n_children == 0 && p_hn->n_views == 0) {
620
- if (wsp_ggml_is_view(parent)) {
621
- struct wsp_ggml_tensor * view_src = parent->view_src;
622
- struct hash_node * view_src_hn = hash_get(galloc, view_src);
623
- view_src_hn->n_views -= 1;
624
- AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
625
- if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
626
- free_node(galloc, view_src);
627
- }
628
- }
629
- else {
630
- free_node(galloc, parent);
631
- }
642
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
643
+ struct wsp_ggml_tensor * parent = node->src[j];
644
+ if (parent == NULL) {
645
+ continue;
646
+ }
647
+ struct hash_node * p_hn = wsp_ggml_gallocr_hash_get(galloc, parent);
648
+ p_hn->n_children -= 1;
649
+
650
+ AT_PRINTF("parent %s: %d children, %d views, allocated: %d\n",
651
+ parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
652
+
653
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
654
+ if (wsp_ggml_is_view(parent)) {
655
+ struct wsp_ggml_tensor * view_src = parent->view_src;
656
+ struct hash_node * view_src_hn = wsp_ggml_gallocr_hash_get(galloc, view_src);
657
+ view_src_hn->n_views -= 1;
658
+ AT_PRINTF("view_src %s: %d children, %d views\n",
659
+ view_src->name, view_src_hn->n_children, view_src_hn->n_views);
660
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {
661
+ wsp_ggml_gallocr_free_node(galloc, view_src);
632
662
  }
633
663
  }
664
+ else if (p_hn->allocated) {
665
+ wsp_ggml_gallocr_free_node(galloc, parent);
666
+ }
634
667
  }
635
668
  AT_PRINTF("\n");
636
- if (parse_seq_len) {
637
- last_barrier_pos = ind + 1;
638
- }
639
669
  }
640
670
  }
641
671
  }
642
672
 
643
- size_t wsp_ggml_gallocr_alloc_graph(wsp_ggml_gallocr_t galloc, wsp_ggml_tallocr_t talloc, struct wsp_ggml_cgraph * graph) {
644
- size_t hash_size = graph->visited_hash_table.size;
673
+ bool wsp_ggml_gallocr_reserve_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
674
+ size_t min_hash_size = graph->n_nodes + graph->n_leafs;
675
+ // add 25% margin to avoid hash collisions
676
+ min_hash_size += min_hash_size / 4;
677
+
678
+ // initialize hash table
679
+ if (galloc->hash_set.size < min_hash_size) {
680
+ wsp_ggml_hash_set_free(&galloc->hash_set);
681
+ galloc->hash_set = wsp_ggml_hash_set_new(min_hash_size);
682
+ WSP_GGML_ASSERT(galloc->hash_set.keys != NULL);
645
683
 
646
- // check if the hash table is initialized and large enough
647
- if (galloc->hash_set.size < hash_size) {
648
- if (galloc->hash_set.keys != NULL) {
649
- free(galloc->hash_set.keys);
684
+ free(galloc->hash_values);
685
+ galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size);
686
+ WSP_GGML_ASSERT(galloc->hash_values != NULL);
687
+ }
688
+
689
+ // reset allocators
690
+ for (int i = 0; i < galloc->n_buffers; i++) {
691
+ wsp_ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]);
692
+ }
693
+
694
+ // allocate in hash table
695
+ wsp_ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids);
696
+
697
+ // set the node_allocs from the hash table
698
+ if (galloc->n_nodes < graph->n_nodes) {
699
+ free(galloc->node_allocs);
700
+ galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc));
701
+ WSP_GGML_ASSERT(galloc->node_allocs != NULL);
702
+ }
703
+ galloc->n_nodes = graph->n_nodes;
704
+ for (int i = 0; i < graph->n_nodes; i++) {
705
+ struct wsp_ggml_tensor * node = graph->nodes[i];
706
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
707
+ if (node->view_src || node->data) {
708
+ node_alloc->dst.buffer_id = -1;
709
+ node_alloc->dst.offset = SIZE_MAX;
710
+ node_alloc->dst.size_max = 0;
711
+ } else {
712
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
713
+ node_alloc->dst.buffer_id = hn->buffer_id;
714
+ node_alloc->dst.offset = hn->offset;
715
+ node_alloc->dst.size_max = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
716
+ }
717
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
718
+ struct wsp_ggml_tensor * src = node->src[j];
719
+ if (!src || src->view_src || src->data) {
720
+ node_alloc->src[j].buffer_id = -1;
721
+ node_alloc->src[j].offset = SIZE_MAX;
722
+ node_alloc->src[j].size_max = 0;
723
+ } else {
724
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, src);
725
+ node_alloc->src[j].buffer_id = hn->buffer_id;
726
+ node_alloc->src[j].offset = hn->offset;
727
+ node_alloc->src[j].size_max = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);
728
+ }
650
729
  }
651
- if (galloc->hash_values != NULL) {
652
- free(galloc->hash_values);
730
+ }
731
+ if (galloc->n_leafs < graph->n_leafs) {
732
+ free(galloc->leaf_allocs);
733
+ galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0]));
734
+ WSP_GGML_ASSERT(galloc->leaf_allocs != NULL);
735
+ }
736
+ galloc->n_leafs = graph->n_leafs;
737
+ for (int i = 0; i < graph->n_leafs; i++) {
738
+ struct wsp_ggml_tensor * leaf = graph->leafs[i];
739
+ struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, leaf);
740
+ if (leaf->view_src || leaf->data) {
741
+ galloc->leaf_allocs[i].leaf.buffer_id = -1;
742
+ galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
743
+ galloc->leaf_allocs[i].leaf.size_max = 0;
744
+ } else {
745
+ galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id;
746
+ galloc->leaf_allocs[i].leaf.offset = hn->offset;
747
+ galloc->leaf_allocs[i].leaf.size_max = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
653
748
  }
654
- galloc->hash_set.keys = malloc(sizeof(struct wsp_ggml_tensor *) * hash_size);
655
- galloc->hash_set.size = hash_size;
656
- galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
657
749
  }
658
750
 
659
- // reset hash table
660
- memset(galloc->hash_set.keys, 0, sizeof(struct wsp_ggml_tensor *) * hash_size);
661
- memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
751
+ // reallocate buffers if needed
752
+ for (int i = 0; i < galloc->n_buffers; i++) {
753
+ // if the buffer type is used multiple times, we reuse the same buffer
754
+ for (int j = 0; j < i; j++) {
755
+ if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
756
+ galloc->buffers[i] = galloc->buffers[j];
757
+ break;
758
+ }
759
+ }
760
+
761
+ size_t cur_size = galloc->buffers[i] ? wsp_ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
762
+ size_t new_size = wsp_ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
662
763
 
663
- galloc->talloc = talloc;
664
- wsp_ggml_tallocr_alloc_graph_impl(galloc, graph);
665
- galloc->talloc = NULL;
764
+ // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
765
+ if (new_size > cur_size || galloc->buffers[i] == NULL) {
766
+ #ifndef NDEBUG
767
+ WSP_GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, wsp_ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
768
+ #endif
666
769
 
667
- size_t max_size = wsp_ggml_tallocr_max_size(talloc);
770
+ wsp_ggml_backend_buffer_free(galloc->buffers[i]);
771
+ galloc->buffers[i] = wsp_ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
772
+ if (galloc->buffers[i] == NULL) {
773
+ WSP_GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, wsp_ggml_backend_buft_name(galloc->bufts[i]), new_size);
774
+ return false;
775
+ }
776
+ wsp_ggml_backend_buffer_set_usage(galloc->buffers[i], WSP_GGML_BACKEND_BUFFER_USAGE_COMPUTE);
777
+ }
778
+ }
668
779
 
669
- return max_size;
780
+ return true;
670
781
  }
671
782
 
672
- void wsp_ggml_gallocr_alloc_graph_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, struct wsp_ggml_hash_set hash_set, wsp_ggml_tallocr_t * hash_node_talloc) {
673
- const size_t hash_size = hash_set.size;
783
+ bool wsp_ggml_gallocr_reserve(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph *graph) {
784
+ return wsp_ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
785
+ }
674
786
 
675
- WSP_GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
787
+ static void wsp_ggml_gallocr_init_tensor(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) {
788
+ int buffer_id = tensor_alloc->buffer_id;
789
+ assert(tensor->data || tensor->view_src || wsp_ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
676
790
 
677
- galloc->talloc = NULL;
791
+ if (tensor->view_src != NULL) {
792
+ if (tensor->buffer == NULL) {
793
+ assert(tensor_alloc->offset == SIZE_MAX);
794
+ if (tensor->view_src->buffer == NULL) {
795
+ // this tensor was allocated without ggml-backend
796
+ return;
797
+ }
798
+ wsp_ggml_backend_view_init(tensor);
799
+ }
800
+ } else {
801
+ if (tensor->data == NULL) {
802
+ assert(tensor_alloc->offset != SIZE_MAX);
803
+ assert(wsp_ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
804
+ void * base = wsp_ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
805
+ void * addr = (char *)base + tensor_alloc->offset;
806
+ wsp_ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr);
807
+ } else {
808
+ if (tensor->buffer == NULL) {
809
+ // this tensor was allocated without ggml-backend
810
+ return;
811
+ }
812
+ }
813
+ }
814
+ }
678
815
 
679
- // alloc hash_values if needed
680
- if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
681
- free(galloc->hash_values);
682
- galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
683
- galloc->hash_values_size = hash_size;
816
+ static bool wsp_ggml_gallocr_node_needs_realloc(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, struct tensor_alloc * talloc) {
817
+ size_t node_size = 0;
818
+ if (!node->data && !node->view_src) {
819
+ // If we previously had data but don't now then reallocate
820
+ if (talloc->buffer_id < 0) {
821
+ return false;
822
+ }
823
+ node_size = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
684
824
  }
825
+ return talloc->size_max >= node_size;
826
+ }
685
827
 
686
- // free hash_set.keys if needed
687
- if (galloc->hash_set.keys != NULL) {
688
- free(galloc->hash_set.keys);
828
+ static bool wsp_ggml_gallocr_needs_realloc(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph) {
829
+ if (galloc->n_nodes != graph->n_nodes) {
830
+ #ifndef NDEBUG
831
+ WSP_GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__);
832
+ #endif
833
+ return true;
689
834
  }
690
- galloc->hash_set = hash_set;
691
835
 
692
- // reset hash values
693
- memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
836
+ if (galloc->n_leafs != graph->n_leafs) {
837
+ #ifndef NDEBUG
838
+ WSP_GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__);
839
+ #endif
840
+ return true;
841
+ }
694
842
 
695
- galloc->hash_allocs = hash_node_talloc;
843
+ for (int i = 0; i < graph->n_nodes; i++) {
844
+ struct wsp_ggml_tensor * node = graph->nodes[i];
845
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
696
846
 
697
- wsp_ggml_tallocr_alloc_graph_impl(galloc, graph);
847
+ if (!wsp_ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
848
+ #ifndef NDEBUG
849
+ WSP_GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name);
850
+ #endif
851
+ return true;
852
+ }
698
853
 
699
- // remove unowned resources
700
- galloc->hash_set.keys = NULL;
701
- galloc->hash_allocs = NULL;
702
- }
854
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
855
+ struct wsp_ggml_tensor * src = node->src[j];
856
+ if (src == NULL) {
857
+ continue;
858
+ }
859
+ if (!wsp_ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
860
+ #ifndef NDEBUG
861
+ WSP_GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
862
+ #endif
863
+ return true;
864
+ }
865
+ }
866
+ }
703
867
 
704
- // legacy API wrapper
868
+ return false;
869
+ }
705
870
 
706
- struct wsp_ggml_allocr {
707
- wsp_ggml_tallocr_t talloc;
708
- wsp_ggml_gallocr_t galloc;
709
- };
871
+ bool wsp_ggml_gallocr_alloc_graph(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph) {
872
+ if (wsp_ggml_gallocr_needs_realloc(galloc, graph)) {
873
+ if (galloc->n_buffers == 1) {
874
+ #ifndef NDEBUG
875
+ WSP_GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__);
876
+ #endif
877
+ if (!wsp_ggml_gallocr_reserve(galloc, graph)) {
878
+ return false;
879
+ }
880
+ } else {
881
+ #ifndef NDEBUG
882
+ WSP_GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
883
+ #endif
884
+ return false;
885
+ }
886
+ }
710
887
 
711
- static wsp_ggml_allocr_t wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_t talloc) {
712
- wsp_ggml_allocr_t alloc = (wsp_ggml_allocr_t)malloc(sizeof(struct wsp_ggml_allocr));
713
- *alloc = (struct wsp_ggml_allocr) {
714
- /*.talloc = */ talloc,
715
- /*.galloc = */ wsp_ggml_gallocr_new(),
716
- };
717
- return alloc;
718
- }
888
+ // reset buffers
889
+ for (int i = 0; i < galloc->n_buffers; i++) {
890
+ if (galloc->buffers[i] != NULL) {
891
+ wsp_ggml_backend_buffer_reset(galloc->buffers[i]);
892
+ }
893
+ }
719
894
 
720
- wsp_ggml_allocr_t wsp_ggml_allocr_new(void * data, size_t size, size_t alignment) {
721
- return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new(data, size, alignment));
722
- }
895
+ // allocate the graph tensors from the previous assignments
896
+ // leafs
897
+ for (int i = 0; i < graph->n_leafs; i++) {
898
+ struct wsp_ggml_tensor * leaf = graph->leafs[i];
899
+ struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];
900
+ wsp_ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf);
901
+ }
902
+ // nodes
903
+ for (int i = 0; i < graph->n_nodes; i++) {
904
+ struct wsp_ggml_tensor * node = graph->nodes[i];
905
+ struct node_alloc * node_alloc = &galloc->node_allocs[i];
906
+ for (int j = 0; j < WSP_GGML_MAX_SRC; j++) {
907
+ struct wsp_ggml_tensor * src = node->src[j];
908
+ if (src == NULL) {
909
+ continue;
910
+ }
911
+ wsp_ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]);
912
+ }
913
+ wsp_ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst);
914
+ }
723
915
 
724
- wsp_ggml_allocr_t wsp_ggml_allocr_new_measure(size_t alignment) {
725
- return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_measure(alignment));
916
+ return true;
726
917
  }
727
918
 
728
- wsp_ggml_allocr_t wsp_ggml_allocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) {
729
- return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_from_buffer(buffer));
730
- }
919
+ size_t wsp_ggml_gallocr_get_buffer_size(wsp_ggml_gallocr_t galloc, int buffer_id) {
920
+ WSP_GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
731
921
 
732
- wsp_ggml_allocr_t wsp_ggml_allocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) {
733
- return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_from_backend(backend, size));
734
- }
922
+ if (galloc->buffers[buffer_id] == NULL) {
923
+ return 0;
924
+ }
735
925
 
736
- wsp_ggml_allocr_t wsp_ggml_allocr_new_measure_from_backend(struct wsp_ggml_backend * backend) {
737
- return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_measure_from_backend(backend));
738
- }
926
+ for (int i = 0; i < buffer_id; i++) {
927
+ if (galloc->buffers[i] == galloc->buffers[buffer_id]) {
928
+ // this buffer is the same as a previous one due to the same buffer type being used multiple times
929
+ // only return the buffer size the first time it appears to avoid double counting
930
+ return 0;
931
+ }
932
+ }
739
933
 
740
- struct wsp_ggml_backend_buffer * wsp_ggml_allocr_get_buffer(wsp_ggml_allocr_t alloc) {
741
- return wsp_ggml_tallocr_get_buffer(alloc->talloc);
934
+ return wsp_ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
742
935
  }
743
936
 
744
- void wsp_ggml_allocr_set_parse_seq(wsp_ggml_allocr_t alloc, const int * list, int n) {
745
- wsp_ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
746
- }
937
+ // utils
747
938
 
748
- void wsp_ggml_allocr_free(wsp_ggml_allocr_t alloc) {
749
- if (alloc == NULL) {
750
- return;
939
+ static void free_buffers(wsp_ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
940
+ for (size_t i = 0; i < *n_buffers; i++) {
941
+ wsp_ggml_backend_buffer_free((*buffers)[i]);
751
942
  }
752
-
753
- wsp_ggml_gallocr_free(alloc->galloc);
754
- wsp_ggml_tallocr_free(alloc->talloc);
755
- free(alloc);
943
+ free(*buffers);
756
944
  }
757
945
 
758
- bool wsp_ggml_allocr_is_measure(wsp_ggml_allocr_t alloc) {
759
- return wsp_ggml_tallocr_is_measure(alloc->talloc);
760
- }
946
+ static bool alloc_tensor_range(struct wsp_ggml_context * ctx,
947
+ struct wsp_ggml_tensor * first, struct wsp_ggml_tensor * last,
948
+ wsp_ggml_backend_buffer_type_t buft, size_t size,
949
+ wsp_ggml_backend_buffer_t ** buffers, size_t * n_buffers) {
761
950
 
762
- void wsp_ggml_allocr_reset(wsp_ggml_allocr_t alloc) {
763
- wsp_ggml_tallocr_reset(alloc->talloc);
764
- }
951
+ wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, size);
952
+ if (buffer == NULL) {
953
+ WSP_GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, wsp_ggml_backend_buft_name(buft), size);
954
+ free_buffers(buffers, n_buffers);
955
+ return false;
956
+ }
765
957
 
766
- void wsp_ggml_allocr_alloc(wsp_ggml_allocr_t alloc, struct wsp_ggml_tensor * tensor) {
767
- wsp_ggml_tallocr_alloc(alloc->talloc, tensor);
768
- }
958
+ *buffers = realloc(*buffers, sizeof(wsp_ggml_backend_buffer_t) * (*n_buffers + 1));
959
+ (*buffers)[(*n_buffers)++] = buffer;
769
960
 
770
- size_t wsp_ggml_allocr_max_size(wsp_ggml_allocr_t alloc) {
771
- return wsp_ggml_tallocr_max_size(alloc->talloc);
772
- }
961
+ struct wsp_ggml_tallocr tallocr = wsp_ggml_tallocr_new(buffer);
962
+
963
+ for (struct wsp_ggml_tensor * t = first; t != last; t = wsp_ggml_get_next_tensor(ctx, t)) {
964
+ enum wsp_ggml_status status = WSP_GGML_STATUS_SUCCESS;
965
+ if (t->data == NULL) {
966
+ if (t->view_src == NULL) {
967
+ status = wsp_ggml_tallocr_alloc(&tallocr, t);
968
+ } else if (t->buffer == NULL) {
969
+ status = wsp_ggml_backend_view_init(t);
970
+ }
971
+ } else {
972
+ if (t->view_src != NULL && t->buffer == NULL) {
973
+ // view of a pre-allocated tensor
974
+ status = wsp_ggml_backend_view_init(t);
975
+ }
976
+ }
977
+ if (status != WSP_GGML_STATUS_SUCCESS) {
978
+ WSP_GGML_LOG_ERROR("%s: failed to initialize tensor %s\n", __func__, t->name);
979
+ free_buffers(buffers, n_buffers);
980
+ return false;
981
+ }
982
+ }
773
983
 
774
- size_t wsp_ggml_allocr_alloc_graph(wsp_ggml_allocr_t alloc, struct wsp_ggml_cgraph * graph) {
775
- return wsp_ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
984
+ return true;
776
985
  }
777
986
 
778
- // utils
779
987
  wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct wsp_ggml_context * ctx, wsp_ggml_backend_buffer_type_t buft) {
780
988
  WSP_GGML_ASSERT(wsp_ggml_get_no_alloc(ctx) == true);
781
989
 
782
990
  size_t alignment = wsp_ggml_backend_buft_get_alignment(buft);
991
+ size_t max_size = wsp_ggml_backend_buft_get_max_size(buft);
992
+
993
+ wsp_ggml_backend_buffer_t * buffers = NULL;
994
+ size_t n_buffers = 0;
783
995
 
784
- size_t nbytes = 0;
785
- for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
996
+ size_t cur_buf_size = 0;
997
+ struct wsp_ggml_tensor * first = wsp_ggml_get_first_tensor(ctx);
998
+ for (struct wsp_ggml_tensor * t = first; t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
999
+ size_t this_size = 0;
786
1000
  if (t->data == NULL && t->view_src == NULL) {
787
- nbytes += WSP_GGML_PAD(wsp_ggml_backend_buft_get_alloc_size(buft, t), alignment);
1001
+ this_size = WSP_GGML_PAD(wsp_ggml_backend_buft_get_alloc_size(buft, t), alignment);
1002
+ }
1003
+
1004
+ if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
1005
+ // allocate tensors in the current buffer
1006
+ if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
1007
+ return NULL;
1008
+ }
1009
+ first = t;
1010
+ cur_buf_size = this_size;
1011
+ } else {
1012
+ cur_buf_size += this_size;
788
1013
  }
789
1014
  }
790
1015
 
791
- if (nbytes == 0) {
792
- // all the tensors in the context are already allocated
793
- #ifndef NDEBUG
794
- fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
795
- #endif
796
- return NULL;
1016
+ // allocate remaining tensors
1017
+ if (cur_buf_size > 0) {
1018
+ if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {
1019
+ return NULL;
1020
+ }
797
1021
  }
798
1022
 
799
- wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, nbytes);
800
- if (buffer == NULL) {
801
- // failed to allocate buffer
1023
+ if (n_buffers == 0) {
802
1024
  #ifndef NDEBUG
803
- fprintf(stderr, "%s: failed to allocate buffer\n", __func__);
1025
+ WSP_GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__);
804
1026
  #endif
805
1027
  return NULL;
806
1028
  }
807
1029
 
808
- wsp_ggml_tallocr_t tallocr = wsp_ggml_tallocr_new_from_buffer(buffer);
809
-
810
- for (struct wsp_ggml_tensor * t = wsp_ggml_get_first_tensor(ctx); t != NULL; t = wsp_ggml_get_next_tensor(ctx, t)) {
811
- if (t->data == NULL) {
812
- if (t->view_src == NULL) {
813
- wsp_ggml_tallocr_alloc(tallocr, t);
814
- } else {
815
- wsp_ggml_backend_view_init(buffer, t);
816
- }
817
- } else {
818
- if (t->view_src != NULL) {
819
- // view of a pre-allocated tensor
820
- wsp_ggml_backend_view_init(buffer, t);
821
- }
822
- }
1030
+ wsp_ggml_backend_buffer_t buffer;
1031
+ if (n_buffers == 1) {
1032
+ buffer = buffers[0];
1033
+ } else {
1034
+ buffer = wsp_ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
823
1035
  }
824
-
825
- wsp_ggml_tallocr_free(tallocr);
826
-
1036
+ free(buffers);
827
1037
  return buffer;
828
1038
  }
829
1039