whisper.rn 0.4.0-rc.9 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +43 -13
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +33 -35
  5. package/android/src/main/jni.cpp +9 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  12. package/cpp/coreml/whisper-compat.h +10 -0
  13. package/cpp/coreml/whisper-compat.m +35 -0
  14. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  15. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  16. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  17. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  18. package/cpp/ggml-alloc.c +39 -37
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +55 -27
  21. package/cpp/ggml-backend-reg.cpp +591 -0
  22. package/cpp/ggml-backend.cpp +336 -955
  23. package/cpp/ggml-backend.h +70 -42
  24. package/cpp/ggml-common.h +57 -49
  25. package/cpp/ggml-cpp.h +39 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  27. package/cpp/ggml-cpu/amx/amx.h +8 -0
  28. package/cpp/ggml-cpu/amx/common.h +91 -0
  29. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  30. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  31. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  32. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  33. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  34. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  35. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  36. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  37. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  38. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  39. package/cpp/ggml-cpu/binary-ops.h +16 -0
  40. package/cpp/ggml-cpu/common.h +72 -0
  41. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  42. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  43. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  44. package/cpp/ggml-cpu/ops.cpp +9085 -0
  45. package/cpp/ggml-cpu/ops.h +111 -0
  46. package/cpp/ggml-cpu/quants.c +1157 -0
  47. package/cpp/ggml-cpu/quants.h +89 -0
  48. package/cpp/ggml-cpu/repack.cpp +1570 -0
  49. package/cpp/ggml-cpu/repack.h +98 -0
  50. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  51. package/cpp/ggml-cpu/traits.cpp +36 -0
  52. package/cpp/ggml-cpu/traits.h +38 -0
  53. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  54. package/cpp/ggml-cpu/unary-ops.h +28 -0
  55. package/cpp/ggml-cpu/vec.cpp +321 -0
  56. package/cpp/ggml-cpu/vec.h +973 -0
  57. package/cpp/ggml-cpu.h +143 -0
  58. package/cpp/ggml-impl.h +417 -23
  59. package/cpp/ggml-metal-impl.h +622 -0
  60. package/cpp/ggml-metal.h +9 -9
  61. package/cpp/ggml-metal.m +3451 -1344
  62. package/cpp/ggml-opt.cpp +1037 -0
  63. package/cpp/ggml-opt.h +237 -0
  64. package/cpp/ggml-quants.c +296 -10818
  65. package/cpp/ggml-quants.h +78 -125
  66. package/cpp/ggml-threading.cpp +12 -0
  67. package/cpp/ggml-threading.h +14 -0
  68. package/cpp/ggml-whisper-sim.metallib +0 -0
  69. package/cpp/ggml-whisper.metallib +0 -0
  70. package/cpp/ggml.c +4633 -21450
  71. package/cpp/ggml.h +320 -661
  72. package/cpp/gguf.cpp +1347 -0
  73. package/cpp/gguf.h +202 -0
  74. package/cpp/rn-whisper.cpp +4 -11
  75. package/cpp/whisper-arch.h +197 -0
  76. package/cpp/whisper.cpp +2022 -495
  77. package/cpp/whisper.h +75 -18
  78. package/ios/CMakeLists.txt +95 -0
  79. package/ios/RNWhisper.h +5 -0
  80. package/ios/RNWhisperAudioUtils.m +4 -0
  81. package/ios/RNWhisperContext.h +5 -0
  82. package/ios/RNWhisperContext.mm +4 -2
  83. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  84. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  86. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  89. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  129. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  130. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  131. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  132. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  133. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  134. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  135. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  174. package/jest/mock.js +5 -0
  175. package/lib/commonjs/version.json +1 -1
  176. package/lib/module/version.json +1 -1
  177. package/package.json +10 -6
  178. package/src/version.json +1 -1
  179. package/whisper-rn.podspec +11 -18
  180. package/cpp/README.md +0 -4
  181. package/cpp/ggml-aarch64.c +0 -3209
  182. package/cpp/ggml-aarch64.h +0 -39
  183. package/cpp/ggml-cpu-impl.h +0 -614
@@ -8,6 +8,7 @@
8
8
  #error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
9
9
  #endif
10
10
 
11
+ #import "whisper-compat.h"
11
12
  #import "whisper-encoder-impl.h"
12
13
 
13
14
  @implementation whisper_encoder_implInput
@@ -76,10 +77,13 @@
76
77
  Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
77
78
  */
78
79
  - (instancetype)initWithMLModel:(MLModel *)model {
80
+ if (model == nil) {
81
+ return nil;
82
+ }
79
83
  self = [super init];
80
- if (!self) { return nil; }
81
- _model = model;
82
- if (_model == nil) { return nil; }
84
+ if (self != nil) {
85
+ _model = model;
86
+ }
83
87
  return self;
84
88
  }
85
89
 
@@ -176,6 +180,28 @@
176
180
  return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
177
181
  }
178
182
 
183
+ - (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
184
+ [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
185
+ if (prediction != nil) {
186
+ whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
187
+ completionHandler(output, predictionError);
188
+ } else {
189
+ completionHandler(nil, predictionError);
190
+ }
191
+ }];
192
+ }
193
+
194
+ - (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
195
+ [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
196
+ if (prediction != nil) {
197
+ whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
198
+ completionHandler(output, predictionError);
199
+ } else {
200
+ completionHandler(nil, predictionError);
201
+ }
202
+ }];
203
+ }
204
+
179
205
  - (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
180
206
  whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
181
207
  return [self predictionFromFeatures:input_ error:error];
package/cpp/ggml-alloc.c CHANGED
@@ -37,6 +37,7 @@ static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const str
37
37
  return true;
38
38
  }
39
39
 
40
+ // ops that return true for this function must not use restrict pointers for their backend implementations
40
41
  static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) {
41
42
  switch (op) {
42
43
  case WSP_GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) {
52
53
  case WSP_GGML_OP_LOG:
53
54
  case WSP_GGML_OP_UNARY:
54
55
  case WSP_GGML_OP_ROPE:
56
+ case WSP_GGML_OP_ROPE_BACK:
57
+ case WSP_GGML_OP_SILU_BACK:
55
58
  case WSP_GGML_OP_RMS_NORM:
59
+ case WSP_GGML_OP_RMS_NORM_BACK:
56
60
  case WSP_GGML_OP_SOFT_MAX:
61
+ case WSP_GGML_OP_SOFT_MAX_BACK:
57
62
  return true;
58
63
 
59
64
  default:
@@ -84,7 +89,7 @@ struct wsp_ggml_tallocr wsp_ggml_tallocr_new(wsp_ggml_backend_buffer_t buffer) {
84
89
  return talloc;
85
90
  }
86
91
 
87
- void wsp_ggml_tallocr_alloc(struct wsp_ggml_tallocr * talloc, struct wsp_ggml_tensor * tensor) {
92
+ enum wsp_ggml_status wsp_ggml_tallocr_alloc(struct wsp_ggml_tallocr * talloc, struct wsp_ggml_tensor * tensor) {
88
93
  size_t size = wsp_ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
89
94
  size = WSP_GGML_PAD(size, talloc->alignment);
90
95
 
@@ -99,7 +104,7 @@ void wsp_ggml_tallocr_alloc(struct wsp_ggml_tallocr * talloc, struct wsp_ggml_te
99
104
 
100
105
  assert(((uintptr_t)addr % talloc->alignment) == 0);
101
106
 
102
- wsp_ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
107
+ return wsp_ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
103
108
  }
104
109
 
105
110
  // dynamic tensor allocator
@@ -466,18 +471,12 @@ static bool wsp_ggml_gallocr_is_own(wsp_ggml_gallocr_t galloc, struct wsp_ggml_t
466
471
  return wsp_ggml_gallocr_hash_get(galloc, t)->allocated;
467
472
  }
468
473
 
469
- static void wsp_ggml_gallocr_set_node_offset(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, int buffer_id, size_t offset) {
470
- struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
471
- hn->buffer_id = buffer_id;
472
- hn->offset = offset;
473
- hn->allocated = true;
474
- }
475
-
476
474
  static bool wsp_ggml_gallocr_is_allocated(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) {
477
475
  return t->data != NULL || wsp_ggml_gallocr_hash_get(galloc, t)->allocated;
478
476
  }
479
477
 
480
478
  static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, int buffer_id) {
479
+ WSP_GGML_ASSERT(buffer_id >= 0);
481
480
  struct hash_node * hn = wsp_ggml_gallocr_hash_get(galloc, node);
482
481
 
483
482
  if (!wsp_ggml_gallocr_is_allocated(galloc, node) && !wsp_ggml_is_view(node)) {
@@ -540,7 +539,6 @@ static void wsp_ggml_gallocr_allocate_node(wsp_ggml_gallocr_t galloc, struct wsp
540
539
  size_t offset = wsp_ggml_dyn_tallocr_alloc(alloc, size, node);
541
540
  hn->buffer_id = buffer_id;
542
541
  hn->offset = offset;
543
- return;
544
542
  }
545
543
  }
546
544
 
@@ -816,7 +814,14 @@ static void wsp_ggml_gallocr_init_tensor(wsp_ggml_gallocr_t galloc, struct wsp_g
816
814
  }
817
815
 
818
816
  static bool wsp_ggml_gallocr_node_needs_realloc(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node, struct tensor_alloc * talloc) {
819
- size_t node_size = (node->data || node->view_src) ? 0 : wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
817
+ size_t node_size = 0;
818
+ if (!node->data && !node->view_src) {
819
+ // If we previously had data but don't now then reallocate
820
+ if (talloc->buffer_id < 0) {
821
+ return false;
822
+ }
823
+ node_size = wsp_ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
824
+ }
820
825
  return talloc->size_max >= node_size;
821
826
  }
822
827
 
@@ -931,42 +936,51 @@ size_t wsp_ggml_gallocr_get_buffer_size(wsp_ggml_gallocr_t galloc, int buffer_id
931
936
 
932
937
  // utils
933
938
 
939
+ static void free_buffers(wsp_ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
940
+ for (size_t i = 0; i < *n_buffers; i++) {
941
+ wsp_ggml_backend_buffer_free((*buffers)[i]);
942
+ }
943
+ free(*buffers);
944
+ }
945
+
934
946
  static bool alloc_tensor_range(struct wsp_ggml_context * ctx,
935
947
  struct wsp_ggml_tensor * first, struct wsp_ggml_tensor * last,
936
948
  wsp_ggml_backend_buffer_type_t buft, size_t size,
937
949
  wsp_ggml_backend_buffer_t ** buffers, size_t * n_buffers) {
950
+
938
951
  wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_buft_alloc_buffer(buft, size);
939
952
  if (buffer == NULL) {
940
- #ifndef NDEBUG
941
- WSP_GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, wsp_ggml_backend_buft_name(buft), size);
942
- #endif
943
- for (size_t i = 0; i < *n_buffers; i++) {
944
- wsp_ggml_backend_buffer_free((*buffers)[i]);
945
- }
946
- free(*buffers);
953
+ WSP_GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, wsp_ggml_backend_buft_name(buft), size);
954
+ free_buffers(buffers, n_buffers);
947
955
  return false;
948
956
  }
949
957
 
958
+ *buffers = realloc(*buffers, sizeof(wsp_ggml_backend_buffer_t) * (*n_buffers + 1));
959
+ (*buffers)[(*n_buffers)++] = buffer;
960
+
950
961
  struct wsp_ggml_tallocr tallocr = wsp_ggml_tallocr_new(buffer);
951
962
 
952
963
  for (struct wsp_ggml_tensor * t = first; t != last; t = wsp_ggml_get_next_tensor(ctx, t)) {
964
+ enum wsp_ggml_status status = WSP_GGML_STATUS_SUCCESS;
953
965
  if (t->data == NULL) {
954
966
  if (t->view_src == NULL) {
955
- wsp_ggml_tallocr_alloc(&tallocr, t);
967
+ status = wsp_ggml_tallocr_alloc(&tallocr, t);
956
968
  } else if (t->buffer == NULL) {
957
- wsp_ggml_backend_view_init(t);
969
+ status = wsp_ggml_backend_view_init(t);
958
970
  }
959
971
  } else {
960
972
  if (t->view_src != NULL && t->buffer == NULL) {
961
973
  // view of a pre-allocated tensor
962
- wsp_ggml_backend_view_init(t);
974
+ status = wsp_ggml_backend_view_init(t);
963
975
  }
964
976
  }
977
+ if (status != WSP_GGML_STATUS_SUCCESS) {
978
+ WSP_GGML_LOG_ERROR("%s: failed to initialize tensor %s\n", __func__, t->name);
979
+ free_buffers(buffers, n_buffers);
980
+ return false;
981
+ }
965
982
  }
966
983
 
967
- *buffers = realloc(*buffers, sizeof(wsp_ggml_backend_buffer_t) * (*n_buffers + 1));
968
- (*buffers)[(*n_buffers)++] = buffer;
969
-
970
984
  return true;
971
985
  }
972
986
 
@@ -987,19 +1001,7 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_ctx_tensors_from_buft(struct ws
987
1001
  this_size = WSP_GGML_PAD(wsp_ggml_backend_buft_get_alloc_size(buft, t), alignment);
988
1002
  }
989
1003
 
990
- if (this_size > max_size) {
991
- WSP_GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
992
- __func__, t->name,
993
- wsp_ggml_backend_buft_name(buft),
994
- this_size, max_size);
995
- for (size_t i = 0; i < n_buffers; i++) {
996
- wsp_ggml_backend_buffer_free(buffers[i]);
997
- }
998
- free(buffers);
999
- return NULL;
1000
- }
1001
-
1002
- if ((cur_buf_size + this_size) > max_size) {
1004
+ if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
1003
1005
  // allocate tensors in the current buffer
1004
1006
  if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
1005
1007
  return NULL;
package/cpp/ggml-alloc.h CHANGED
@@ -19,7 +19,7 @@ struct wsp_ggml_tallocr {
19
19
  };
20
20
 
21
21
  WSP_GGML_API struct wsp_ggml_tallocr wsp_ggml_tallocr_new(wsp_ggml_backend_buffer_t buffer);
22
- WSP_GGML_API void wsp_ggml_tallocr_alloc(struct wsp_ggml_tallocr * talloc, struct wsp_ggml_tensor * tensor);
22
+ WSP_GGML_API enum wsp_ggml_status wsp_ggml_tallocr_alloc(struct wsp_ggml_tallocr * talloc, struct wsp_ggml_tensor * tensor);
23
23
 
24
24
  // Graph allocator
25
25
  /*
@@ -8,6 +8,8 @@
8
8
  extern "C" {
9
9
  #endif
10
10
 
11
+ #define WSP_GGML_BACKEND_API_VERSION 1
12
+
11
13
  //
12
14
  // Backend buffer type
13
15
  //
@@ -22,7 +24,7 @@ extern "C" {
22
24
  size_t (*get_max_size) (wsp_ggml_backend_buffer_type_t buft);
23
25
  // (optional) data size needed to allocate the tensor, including padding (defaults to wsp_ggml_nbytes)
24
26
  size_t (*get_alloc_size)(wsp_ggml_backend_buffer_type_t buft, const struct wsp_ggml_tensor * tensor);
25
- // (optional) check if tensor data is in host memory (defaults to false)
27
+ // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
26
28
  bool (*is_host) (wsp_ggml_backend_buffer_type_t buft);
27
29
  };
28
30
 
@@ -37,13 +39,12 @@ extern "C" {
37
39
  //
38
40
 
39
41
  struct wsp_ggml_backend_buffer_i {
40
- const char * (*get_name) (wsp_ggml_backend_buffer_t buffer);
41
42
  // (optional) free the buffer
42
43
  void (*free_buffer) (wsp_ggml_backend_buffer_t buffer);
43
44
  // base address of the buffer
44
45
  void * (*get_base) (wsp_ggml_backend_buffer_t buffer);
45
46
  // (optional) initialize a tensor in the buffer (eg. add tensor extras)
46
- void (*init_tensor) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor);
47
+ enum wsp_ggml_status (*init_tensor)(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor);
47
48
  // tensor data access
48
49
  void (*memset_tensor)(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
49
50
  void (*set_tensor) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size);
@@ -64,20 +65,20 @@ extern "C" {
64
65
  enum wsp_ggml_backend_buffer_usage usage;
65
66
  };
66
67
 
67
- wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
68
+ WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init(
68
69
  wsp_ggml_backend_buffer_type_t buft,
69
70
  struct wsp_ggml_backend_buffer_i iface,
70
71
  void * context,
71
72
  size_t size);
72
73
 
73
74
  // do not use directly, use wsp_ggml_backend_tensor_copy instead
74
- bool wsp_ggml_backend_buffer_copy_tensor(const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst);
75
+ WSP_GGML_API bool wsp_ggml_backend_buffer_copy_tensor(const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst);
75
76
 
76
77
  // multi-buffer
77
78
  // buffer that contains a collection of buffers
78
- wsp_ggml_backend_buffer_t wsp_ggml_backend_multi_buffer_alloc_buffer(wsp_ggml_backend_buffer_t * buffers, size_t n_buffers);
79
- bool wsp_ggml_backend_buffer_is_multi_buffer(wsp_ggml_backend_buffer_t buffer);
80
- void wsp_ggml_backend_multi_buffer_set_usage(wsp_ggml_backend_buffer_t buffer, enum wsp_ggml_backend_buffer_usage usage);
79
+ WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_multi_buffer_alloc_buffer(wsp_ggml_backend_buffer_t * buffers, size_t n_buffers);
80
+ WSP_GGML_API bool wsp_ggml_backend_buffer_is_multi_buffer(wsp_ggml_backend_buffer_t buffer);
81
+ WSP_GGML_API void wsp_ggml_backend_multi_buffer_set_usage(wsp_ggml_backend_buffer_t buffer, enum wsp_ggml_backend_buffer_usage usage);
81
82
 
82
83
  //
83
84
  // Backend (stream)
@@ -88,19 +89,16 @@ extern "C" {
88
89
 
89
90
  void (*free)(wsp_ggml_backend_t backend);
90
91
 
91
- // Will be moved to the device interface
92
- // buffer allocation
93
- wsp_ggml_backend_buffer_type_t (*get_default_buffer_type)(wsp_ggml_backend_t backend);
94
-
95
92
  // (optional) asynchronous tensor data access
96
93
  void (*set_tensor_async)(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size);
97
94
  void (*get_tensor_async)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size);
98
95
  bool (*cpy_tensor_async)(wsp_ggml_backend_t backend_src, wsp_ggml_backend_t backend_dst, const struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst);
99
96
 
100
- // (optional) complete all pending operations
97
+ // (optional) complete all pending operations (required if the backend supports async operations)
101
98
  void (*synchronize)(wsp_ggml_backend_t backend);
102
99
 
103
- // (optional) compute graph with a plan (not used currently)
100
+ // (optional) graph plans (not used currently)
101
+ // compute graph with a plan
104
102
  wsp_ggml_backend_graph_plan_t (*graph_plan_create) (wsp_ggml_backend_t backend, const struct wsp_ggml_cgraph * cgraph);
105
103
  void (*graph_plan_free) (wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan);
106
104
  // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
@@ -111,13 +109,6 @@ extern "C" {
111
109
  // compute graph (always async if supported by the backend)
112
110
  enum wsp_ggml_status (*graph_compute) (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph);
113
111
 
114
- // IMPORTANT: these functions have been moved to the device interface and will be removed from the backend interface
115
- // new backends should implement the device interface instead
116
- // These functions are being moved to the device interface
117
- bool (*supports_op) (wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op);
118
- bool (*supports_buft)(wsp_ggml_backend_t backend, wsp_ggml_backend_buffer_type_t buft);
119
- bool (*offload_op) (wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op);
120
-
121
112
  // (optional) event synchronization
122
113
  // record an event on this stream
123
114
  void (*event_record)(wsp_ggml_backend_t backend, wsp_ggml_backend_event_t event);
@@ -210,17 +201,54 @@ extern "C" {
210
201
  };
211
202
 
212
203
  struct wsp_ggml_backend_reg {
213
- // int api_version; // TODO: for dynamic loading
204
+ int api_version; // initialize to WSP_GGML_BACKEND_API_VERSION
214
205
  struct wsp_ggml_backend_reg_i iface;
215
206
  void * context;
216
207
  };
217
208
 
218
-
219
209
  // Internal backend registry API
220
- void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
221
- void wsp_ggml_backend_device_register(wsp_ggml_backend_dev_t device);
222
- // TODO: backends can be loaded as a dynamic library, in which case it needs to export this function
223
- // typedef wsp_ggml_backend_register_t * (*wsp_ggml_backend_init)(void);
210
+ WSP_GGML_API void wsp_ggml_backend_register(wsp_ggml_backend_reg_t reg);
211
+
212
+ // Add backend dynamic loading support to the backend
213
+
214
+ // Initialize the backend
215
+ typedef wsp_ggml_backend_reg_t (*wsp_ggml_backend_init_t)(void);
216
+ // Optional: obtain a score for the backend based on the system configuration
217
+ // Higher scores are preferred, 0 means the backend is not supported in the current system
218
+ typedef int (*wsp_ggml_backend_score_t)(void);
219
+
220
+ #ifdef WSP_GGML_BACKEND_DL
221
+ # ifdef __cplusplus
222
+ # define WSP_GGML_BACKEND_DL_IMPL(reg_fn) \
223
+ extern "C" { \
224
+ WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_init(void); \
225
+ } \
226
+ wsp_ggml_backend_reg_t wsp_ggml_backend_init(void) { \
227
+ return reg_fn(); \
228
+ }
229
+ # define WSP_GGML_BACKEND_DL_SCORE_IMPL(score_fn) \
230
+ extern "C" { \
231
+ WSP_GGML_BACKEND_API int wsp_ggml_backend_score(void); \
232
+ } \
233
+ int wsp_ggml_backend_score(void) { \
234
+ return score_fn(); \
235
+ }
236
+ # else
237
+ # define WSP_GGML_BACKEND_DL_IMPL(reg_fn) \
238
+ WSP_GGML_BACKEND_API wsp_ggml_backend_reg_t wsp_ggml_backend_init(void); \
239
+ wsp_ggml_backend_reg_t wsp_ggml_backend_init(void) { \
240
+ return reg_fn(); \
241
+ }
242
+ # define WSP_GGML_BACKEND_DL_SCORE_IMPL(score_fn) \
243
+ WSP_GGML_BACKEND_API int wsp_ggml_backend_score(void); \
244
+ int wsp_ggml_backend_score(void) { \
245
+ return score_fn(); \
246
+ }
247
+ # endif
248
+ #else
249
+ # define WSP_GGML_BACKEND_DL_IMPL(reg_fn)
250
+ # define WSP_GGML_BACKEND_DL_SCORE_IMPL(score_fn)
251
+ #endif
224
252
 
225
253
  #ifdef __cplusplus
226
254
  }