whisper.rn 0.4.0-rc.9 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/README.md +74 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +43 -13
  4. package/android/src/main/java/com/rnwhisper/RNWhisper.java +211 -0
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +64 -36
  6. package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +157 -0
  7. package/android/src/main/jni.cpp +205 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
  15. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +26 -0
  16. package/cpp/coreml/whisper-compat.h +10 -0
  17. package/cpp/coreml/whisper-compat.m +35 -0
  18. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  19. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  20. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  21. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  22. package/cpp/ggml-alloc.c +39 -37
  23. package/cpp/ggml-alloc.h +1 -1
  24. package/cpp/ggml-backend-impl.h +55 -27
  25. package/cpp/ggml-backend-reg.cpp +591 -0
  26. package/cpp/ggml-backend.cpp +336 -955
  27. package/cpp/ggml-backend.h +70 -42
  28. package/cpp/ggml-common.h +57 -49
  29. package/cpp/ggml-cpp.h +39 -0
  30. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  31. package/cpp/ggml-cpu/amx/amx.h +8 -0
  32. package/cpp/ggml-cpu/amx/common.h +91 -0
  33. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  34. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  35. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  36. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  37. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  38. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  39. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  40. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  41. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  42. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  43. package/cpp/ggml-cpu/binary-ops.h +16 -0
  44. package/cpp/ggml-cpu/common.h +72 -0
  45. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  46. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  47. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  48. package/cpp/ggml-cpu/ops.cpp +9085 -0
  49. package/cpp/ggml-cpu/ops.h +111 -0
  50. package/cpp/ggml-cpu/quants.c +1157 -0
  51. package/cpp/ggml-cpu/quants.h +89 -0
  52. package/cpp/ggml-cpu/repack.cpp +1570 -0
  53. package/cpp/ggml-cpu/repack.h +98 -0
  54. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  55. package/cpp/ggml-cpu/traits.cpp +36 -0
  56. package/cpp/ggml-cpu/traits.h +38 -0
  57. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  58. package/cpp/ggml-cpu/unary-ops.h +28 -0
  59. package/cpp/ggml-cpu/vec.cpp +321 -0
  60. package/cpp/ggml-cpu/vec.h +973 -0
  61. package/cpp/ggml-cpu.h +143 -0
  62. package/cpp/ggml-impl.h +417 -23
  63. package/cpp/ggml-metal-impl.h +622 -0
  64. package/cpp/ggml-metal.h +9 -9
  65. package/cpp/ggml-metal.m +3451 -1344
  66. package/cpp/ggml-opt.cpp +1037 -0
  67. package/cpp/ggml-opt.h +237 -0
  68. package/cpp/ggml-quants.c +296 -10818
  69. package/cpp/ggml-quants.h +78 -125
  70. package/cpp/ggml-threading.cpp +12 -0
  71. package/cpp/ggml-threading.h +14 -0
  72. package/cpp/ggml-whisper-sim.metallib +0 -0
  73. package/cpp/ggml-whisper.metallib +0 -0
  74. package/cpp/ggml.c +4633 -21450
  75. package/cpp/ggml.h +320 -661
  76. package/cpp/gguf.cpp +1347 -0
  77. package/cpp/gguf.h +202 -0
  78. package/cpp/rn-whisper.cpp +4 -11
  79. package/cpp/whisper-arch.h +197 -0
  80. package/cpp/whisper.cpp +2022 -495
  81. package/cpp/whisper.h +75 -18
  82. package/ios/CMakeLists.txt +95 -0
  83. package/ios/RNWhisper.h +5 -0
  84. package/ios/RNWhisper.mm +147 -0
  85. package/ios/RNWhisperAudioUtils.m +4 -0
  86. package/ios/RNWhisperContext.h +5 -0
  87. package/ios/RNWhisperContext.mm +22 -26
  88. package/ios/RNWhisperVadContext.h +29 -0
  89. package/ios/RNWhisperVadContext.mm +152 -0
  90. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  129. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  130. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  131. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  132. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  133. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  134. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  135. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  174. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  175. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  176. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  177. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  178. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  179. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  180. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  181. package/jest/mock.js +24 -0
  182. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  183. package/lib/commonjs/index.js +111 -1
  184. package/lib/commonjs/index.js.map +1 -1
  185. package/lib/commonjs/version.json +1 -1
  186. package/lib/module/NativeRNWhisper.js.map +1 -1
  187. package/lib/module/index.js +112 -0
  188. package/lib/module/index.js.map +1 -1
  189. package/lib/module/version.json +1 -1
  190. package/lib/typescript/NativeRNWhisper.d.ts +35 -0
  191. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  192. package/lib/typescript/index.d.ts +39 -3
  193. package/lib/typescript/index.d.ts.map +1 -1
  194. package/package.json +10 -6
  195. package/src/NativeRNWhisper.ts +48 -0
  196. package/src/index.ts +132 -1
  197. package/src/version.json +1 -1
  198. package/whisper-rn.podspec +11 -18
  199. package/cpp/README.md +0 -4
  200. package/cpp/ggml-aarch64.c +0 -3209
  201. package/cpp/ggml-aarch64.h +0 -39
  202. package/cpp/ggml-cpu-impl.h +0 -614
package/cpp/ggml-opt.h ADDED
@@ -0,0 +1,237 @@
1
+ // This file contains functionality for training models using GGML.
2
+ // It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.
3
+ // At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.
4
+ //
5
+ // Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
6
+
7
+ #pragma once
8
+
9
+ #include "ggml.h"
10
+ #include "ggml-backend.h"
11
+
12
+ #include <stdint.h>
13
+
14
+ #ifdef __cplusplus
15
+ extern "C" {
16
+ #endif
17
+
18
+ struct wsp_ggml_opt_dataset;
19
+ struct wsp_ggml_opt_context;
20
+ struct wsp_ggml_opt_result;
21
+
22
+ typedef struct wsp_ggml_opt_dataset * wsp_ggml_opt_dataset_t;
23
+ typedef struct wsp_ggml_opt_context * wsp_ggml_opt_context_t;
24
+ typedef struct wsp_ggml_opt_result * wsp_ggml_opt_result_t;
25
+
26
+ // ====== Loss ======
27
+
28
+ // built-in loss types, i.e. the built-in quantities minimized by the optimizer
29
+ // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value
30
+ enum wsp_ggml_opt_loss_type {
31
+ WSP_GGML_OPT_LOSS_TYPE_MEAN,
32
+ WSP_GGML_OPT_LOSS_TYPE_SUM,
33
+ WSP_GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,
34
+ WSP_GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
35
+ };
36
+
37
+ // ====== Dataset ======
38
+
39
+ WSP_GGML_API wsp_ggml_opt_dataset_t wsp_ggml_opt_dataset_init(
40
+ enum wsp_ggml_type type_data, // the type for the internal data tensor
41
+ enum wsp_ggml_type type_label, // the type for the internal labels tensor
42
+ int64_t ne_datapoint, // number of elements per datapoint
43
+ int64_t ne_label, // number of elements per label
44
+ int64_t ndata, // total number of datapoints/labels
45
+ int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
46
+ WSP_GGML_API void wsp_ggml_opt_dataset_free(wsp_ggml_opt_dataset_t dataset);
47
+
48
+ // get underlying tensors that store the data
49
+ WSP_GGML_API int64_t wsp_ggml_opt_dataset_ndata (wsp_ggml_opt_dataset_t dataset);
50
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_dataset_data (wsp_ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
51
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_dataset_labels(wsp_ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
52
+
53
+ // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative
54
+ WSP_GGML_API void wsp_ggml_opt_dataset_shuffle(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_dataset_t dataset, int64_t idata);
55
+
56
+ // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch
57
+ WSP_GGML_API void wsp_ggml_opt_dataset_get_batch(
58
+ wsp_ggml_opt_dataset_t dataset,
59
+ struct wsp_ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
60
+ struct wsp_ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
61
+ int64_t ibatch);
62
+ WSP_GGML_API void wsp_ggml_opt_dataset_get_batch_host(
63
+ wsp_ggml_opt_dataset_t dataset,
64
+ void * data_batch,
65
+ size_t nb_data_batch,
66
+ void * labels_batch,
67
+ int64_t ibatch);
68
+
69
+ // ====== Model / Context ======
70
+
71
+ enum wsp_ggml_opt_build_type {
72
+ WSP_GGML_OPT_BUILD_TYPE_FORWARD = 10,
73
+ WSP_GGML_OPT_BUILD_TYPE_GRAD = 20,
74
+ WSP_GGML_OPT_BUILD_TYPE_OPT = 30,
75
+ };
76
+
77
+ // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
78
+ struct wsp_ggml_opt_optimizer_params {
79
+ // AdamW optimizer parameters
80
+ struct {
81
+ float alpha; // learning rate
82
+ float beta1;
83
+ float beta2;
84
+ float eps; // epsilon for numerical stability
85
+ float wd; // weight decay for AdamW, use 0.0f to disable
86
+ } adamw;
87
+ };
88
+
89
+ // callback to calculate optimizer parameters prior to a backward pass
90
+ // userdata can be used to pass arbitrary data
91
+ typedef struct wsp_ggml_opt_optimizer_params (*wsp_ggml_opt_get_optimizer_params)(void * userdata);
92
+
93
+ // returns the default optimizer params (constant, hard-coded values)
94
+ // userdata is not used
95
+ WSP_GGML_API struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_default_optimizer_params(void * userdata);
96
+
97
+ // casts userdata to wsp_ggml_opt_optimizer_params and returns it
98
+ WSP_GGML_API struct wsp_ggml_opt_optimizer_params wsp_ggml_opt_get_constant_optimizer_params(void * userdata);
99
+
100
+ // parameters for initializing a new optimization context
101
+ struct wsp_ggml_opt_params {
102
+ wsp_ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
103
+
104
+ // by default the forward graph needs to be reconstructed for each eval
105
+ // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
106
+ struct wsp_ggml_context * ctx_compute;
107
+ struct wsp_ggml_tensor * inputs;
108
+ struct wsp_ggml_tensor * outputs;
109
+
110
+ enum wsp_ggml_opt_loss_type loss_type;
111
+ enum wsp_ggml_opt_build_type build_type;
112
+
113
+ int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114
+
115
+ wsp_ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116
+ void * get_opt_pars_ud; // userdata for calculating optimizer parameters
117
+ };
118
+
119
+ // get parameters for an optimization context with defaults set where possible
120
+ // parameters for which no sensible defaults exist are supplied as arguments to this function
121
+ WSP_GGML_API struct wsp_ggml_opt_params wsp_ggml_opt_default_params(
122
+ wsp_ggml_backend_sched_t backend_sched,
123
+ enum wsp_ggml_opt_loss_type loss_type);
124
+
125
+ WSP_GGML_API wsp_ggml_opt_context_t wsp_ggml_opt_init(struct wsp_ggml_opt_params params);
126
+ WSP_GGML_API void wsp_ggml_opt_free(wsp_ggml_opt_context_t opt_ctx);
127
+
128
+ // set gradients to zero, initilize loss, and optionally reset the optimizer
129
+ WSP_GGML_API void wsp_ggml_opt_reset(wsp_ggml_opt_context_t opt_ctx, bool optimizer);
130
+
131
+ WSP_GGML_API bool wsp_ggml_opt_static_graphs(wsp_ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
132
+
133
+ // get underlying tensors that store data
134
+ // if not using static graphs these pointers become invalid with the next call to wsp_ggml_opt_alloc
135
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_inputs( wsp_ggml_opt_context_t opt_ctx); // forward graph input tensor
136
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_outputs( wsp_ggml_opt_context_t opt_ctx); // forward graph output tensor
137
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_labels( wsp_ggml_opt_context_t opt_ctx); // labels to compare outputs against
138
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_loss( wsp_ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
139
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_pred( wsp_ggml_opt_context_t opt_ctx); // predictions made by outputs
140
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_ncorrect(wsp_ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
141
+
142
+ // get the gradient accumulator for a node from the forward graph
143
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_grad_acc(wsp_ggml_opt_context_t opt_ctx, struct wsp_ggml_tensor * node);
144
+
145
+ // ====== Optimization Result ======
146
+
147
+ WSP_GGML_API wsp_ggml_opt_result_t wsp_ggml_opt_result_init(void);
148
+ WSP_GGML_API void wsp_ggml_opt_result_free(wsp_ggml_opt_result_t result);
149
+ WSP_GGML_API void wsp_ggml_opt_result_reset(wsp_ggml_opt_result_t result);
150
+
151
+ // get data from result, uncertainties are optional and can be ignored by passing NULL
152
+ WSP_GGML_API void wsp_ggml_opt_result_ndata( wsp_ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints
153
+ WSP_GGML_API void wsp_ggml_opt_result_loss( wsp_ggml_opt_result_t result, double * loss, double * unc); // writes 1 value
154
+ WSP_GGML_API void wsp_ggml_opt_result_pred( wsp_ggml_opt_result_t result, int32_t * pred); // writes ndata values
155
+ WSP_GGML_API void wsp_ggml_opt_result_accuracy(wsp_ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value
156
+
157
+ // ====== Computation ======
158
+
159
+ // if not using static graphs, this function must be called prior to wsp_ggml_opt_alloc
160
+ WSP_GGML_API void wsp_ggml_opt_prepare_alloc(
161
+ wsp_ggml_opt_context_t opt_ctx,
162
+ struct wsp_ggml_context * ctx_compute,
163
+ struct wsp_ggml_cgraph * gf,
164
+ struct wsp_ggml_tensor * inputs,
165
+ struct wsp_ggml_tensor * outputs);
166
+
167
+ // allocate the next graph for evaluation, either forward or forward + backward
168
+ // must be called exactly once prior to calling wsp_ggml_opt_eval
169
+ WSP_GGML_API void wsp_ggml_opt_alloc(wsp_ggml_opt_context_t opt_ctx, bool backward);
170
+
171
+ // do forward pass, increment result if not NULL, do backward pass if allocated
172
+ WSP_GGML_API void wsp_ggml_opt_eval(wsp_ggml_opt_context_t opt_ctx, wsp_ggml_opt_result_t result);
173
+
174
+ // ############################################################################
175
+ // ## The high-level functions start here. They do not depend on any private ##
176
+ // ## functions or structs and can be copied to and adapted for user code. ##
177
+ // ############################################################################
178
+
179
+ // ====== Intended Usage ======
180
+ //
181
+ // 1. Select the appropriate loss for your problem.
182
+ // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them.
183
+ // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).
184
+ // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.
185
+ // The first context should contain the model parameters and inputs and be allocated statically in user code.
186
+ // The second context should contain all other tensors and will be (re)allocated automatically.
187
+ // Due to this automated allocation the data of the second context is not defined when accessed in user code.
188
+ // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
189
+ // 4. Call wsp_ggml_opt_fit. If you need more control you can use wsp_ggml_opt_epoch instead.
190
+
191
+ // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
192
+ typedef void (*wsp_ggml_opt_epoch_callback)(
193
+ bool train, // true after training evaluation, false after validation evaluation
194
+ wsp_ggml_opt_context_t opt_ctx,
195
+ wsp_ggml_opt_dataset_t dataset,
196
+ wsp_ggml_opt_result_t result, // result associated with the dataset subsection
197
+ int64_t ibatch, // number of batches that have been evaluated so far
198
+ int64_t ibatch_max, // total number of batches in this dataset subsection
199
+ int64_t t_start_us); // time at which the evaluation on the dataset subsection was started
200
+
201
+ // do training on front of dataset, do evaluation only on back of dataset
202
+ WSP_GGML_API void wsp_ggml_opt_epoch(
203
+ wsp_ggml_opt_context_t opt_ctx,
204
+ wsp_ggml_opt_dataset_t dataset,
205
+ wsp_ggml_opt_result_t result_train, // result to increment during training, ignored if NULL
206
+ wsp_ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL
207
+ int64_t idata_split, // data index at which to split training and evaluation
208
+ wsp_ggml_opt_epoch_callback callback_train,
209
+ wsp_ggml_opt_epoch_callback callback_eval);
210
+
211
+ // callback that prints a progress bar on stderr
212
+ WSP_GGML_API void wsp_ggml_opt_epoch_callback_progress_bar(
213
+ bool train,
214
+ wsp_ggml_opt_context_t opt_ctx,
215
+ wsp_ggml_opt_dataset_t dataset,
216
+ wsp_ggml_opt_result_t result,
217
+ int64_t ibatch,
218
+ int64_t ibatch_max,
219
+ int64_t t_start_us);
220
+
221
+ // fit model defined by inputs and outputs to dataset
222
+ WSP_GGML_API void wsp_ggml_opt_fit(
223
+ wsp_ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
224
+ struct wsp_ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
225
+ struct wsp_ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
226
+ struct wsp_ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227
+ wsp_ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228
+ enum wsp_ggml_opt_loss_type loss_type, // loss to minimize
229
+ wsp_ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230
+ int64_t nepoch, // how many times the dataset should be iterated over
231
+ int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232
+ float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233
+ bool silent); // whether or not info prints to stderr should be suppressed
234
+
235
+ #ifdef __cplusplus
236
+ }
237
+ #endif