whisper.rn 0.4.0-rc.9 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (183) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +43 -13
  4. package/android/src/main/java/com/rnwhisper/WhisperContext.java +33 -35
  5. package/android/src/main/jni.cpp +9 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  12. package/cpp/coreml/whisper-compat.h +10 -0
  13. package/cpp/coreml/whisper-compat.m +35 -0
  14. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  15. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  16. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  17. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  18. package/cpp/ggml-alloc.c +39 -37
  19. package/cpp/ggml-alloc.h +1 -1
  20. package/cpp/ggml-backend-impl.h +55 -27
  21. package/cpp/ggml-backend-reg.cpp +591 -0
  22. package/cpp/ggml-backend.cpp +336 -955
  23. package/cpp/ggml-backend.h +70 -42
  24. package/cpp/ggml-common.h +57 -49
  25. package/cpp/ggml-cpp.h +39 -0
  26. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  27. package/cpp/ggml-cpu/amx/amx.h +8 -0
  28. package/cpp/ggml-cpu/amx/common.h +91 -0
  29. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  30. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  31. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  32. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  33. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  34. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  35. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  36. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  37. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  38. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  39. package/cpp/ggml-cpu/binary-ops.h +16 -0
  40. package/cpp/ggml-cpu/common.h +72 -0
  41. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  42. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  43. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  44. package/cpp/ggml-cpu/ops.cpp +9085 -0
  45. package/cpp/ggml-cpu/ops.h +111 -0
  46. package/cpp/ggml-cpu/quants.c +1157 -0
  47. package/cpp/ggml-cpu/quants.h +89 -0
  48. package/cpp/ggml-cpu/repack.cpp +1570 -0
  49. package/cpp/ggml-cpu/repack.h +98 -0
  50. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  51. package/cpp/ggml-cpu/traits.cpp +36 -0
  52. package/cpp/ggml-cpu/traits.h +38 -0
  53. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  54. package/cpp/ggml-cpu/unary-ops.h +28 -0
  55. package/cpp/ggml-cpu/vec.cpp +321 -0
  56. package/cpp/ggml-cpu/vec.h +973 -0
  57. package/cpp/ggml-cpu.h +143 -0
  58. package/cpp/ggml-impl.h +417 -23
  59. package/cpp/ggml-metal-impl.h +622 -0
  60. package/cpp/ggml-metal.h +9 -9
  61. package/cpp/ggml-metal.m +3451 -1344
  62. package/cpp/ggml-opt.cpp +1037 -0
  63. package/cpp/ggml-opt.h +237 -0
  64. package/cpp/ggml-quants.c +296 -10818
  65. package/cpp/ggml-quants.h +78 -125
  66. package/cpp/ggml-threading.cpp +12 -0
  67. package/cpp/ggml-threading.h +14 -0
  68. package/cpp/ggml-whisper-sim.metallib +0 -0
  69. package/cpp/ggml-whisper.metallib +0 -0
  70. package/cpp/ggml.c +4633 -21450
  71. package/cpp/ggml.h +320 -661
  72. package/cpp/gguf.cpp +1347 -0
  73. package/cpp/gguf.h +202 -0
  74. package/cpp/rn-whisper.cpp +4 -11
  75. package/cpp/whisper-arch.h +197 -0
  76. package/cpp/whisper.cpp +2022 -495
  77. package/cpp/whisper.h +75 -18
  78. package/ios/CMakeLists.txt +95 -0
  79. package/ios/RNWhisper.h +5 -0
  80. package/ios/RNWhisperAudioUtils.m +4 -0
  81. package/ios/RNWhisperContext.h +5 -0
  82. package/ios/RNWhisperContext.mm +4 -2
  83. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  84. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  85. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  86. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  87. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  88. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  89. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  90. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  129. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  130. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  131. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  132. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  133. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  134. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  135. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  174. package/jest/mock.js +5 -0
  175. package/lib/commonjs/version.json +1 -1
  176. package/lib/module/version.json +1 -1
  177. package/package.json +10 -6
  178. package/src/version.json +1 -1
  179. package/whisper-rn.podspec +11 -18
  180. package/cpp/README.md +0 -4
  181. package/cpp/ggml-aarch64.c +0 -3209
  182. package/cpp/ggml-aarch64.h +0 -39
  183. package/cpp/ggml-cpu-impl.h +0 -614
package/cpp/ggml.h CHANGED
@@ -176,15 +176,15 @@
176
176
  #ifdef WSP_GGML_SHARED
177
177
  # if defined(_WIN32) && !defined(__MINGW32__)
178
178
  # ifdef WSP_GGML_BUILD
179
- # define WSP_GGML_API __declspec(dllexport)
179
+ # define WSP_GGML_API __declspec(dllexport) extern
180
180
  # else
181
- # define WSP_GGML_API __declspec(dllimport)
181
+ # define WSP_GGML_API __declspec(dllimport) extern
182
182
  # endif
183
183
  # else
184
- # define WSP_GGML_API __attribute__ ((visibility ("default")))
184
+ # define WSP_GGML_API __attribute__ ((visibility ("default"))) extern
185
185
  # endif
186
186
  #else
187
- # define WSP_GGML_API
187
+ # define WSP_GGML_API extern
188
188
  #endif
189
189
 
190
190
  // TODO: support for clang
@@ -198,7 +198,7 @@
198
198
 
199
199
  #ifndef __GNUC__
200
200
  # define WSP_GGML_ATTRIBUTE_FORMAT(...)
201
- #elif defined(__MINGW32__)
201
+ #elif defined(__MINGW32__) && !defined(__clang__)
202
202
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
203
203
  #else
204
204
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
@@ -237,13 +237,9 @@
237
237
  #define WSP_GGML_EXIT_SUCCESS 0
238
238
  #define WSP_GGML_EXIT_ABORTED 1
239
239
 
240
- #define WSP_GGML_ROPE_TYPE_NEOX 2
241
-
242
- #define WSP_GGUF_MAGIC "GGUF"
243
-
244
- #define WSP_GGUF_VERSION 3
245
-
246
- #define WSP_GGUF_DEFAULT_ALIGNMENT 32
240
+ #define WSP_GGML_ROPE_TYPE_NEOX 2
241
+ #define WSP_GGML_ROPE_TYPE_MROPE 8
242
+ #define WSP_GGML_ROPE_TYPE_VISION 24
247
243
 
248
244
  #define WSP_GGML_UNUSED(x) (void)(x)
249
245
 
@@ -384,24 +380,21 @@ extern "C" {
384
380
  WSP_GGML_TYPE_F64 = 28,
385
381
  WSP_GGML_TYPE_IQ1_M = 29,
386
382
  WSP_GGML_TYPE_BF16 = 30,
387
- WSP_GGML_TYPE_Q4_0_4_4 = 31,
388
- WSP_GGML_TYPE_Q4_0_4_8 = 32,
389
- WSP_GGML_TYPE_Q4_0_8_8 = 33,
383
+ // WSP_GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
384
+ // WSP_GGML_TYPE_Q4_0_4_8 = 32,
385
+ // WSP_GGML_TYPE_Q4_0_8_8 = 33,
390
386
  WSP_GGML_TYPE_TQ1_0 = 34,
391
387
  WSP_GGML_TYPE_TQ2_0 = 35,
392
- WSP_GGML_TYPE_COUNT,
388
+ // WSP_GGML_TYPE_IQ4_NL_4_4 = 36,
389
+ // WSP_GGML_TYPE_IQ4_NL_4_8 = 37,
390
+ // WSP_GGML_TYPE_IQ4_NL_8_8 = 38,
391
+ WSP_GGML_TYPE_COUNT = 39,
393
392
  };
394
393
 
395
394
  // precision
396
395
  enum wsp_ggml_prec {
397
- WSP_GGML_PREC_DEFAULT,
398
- WSP_GGML_PREC_F32,
399
- };
400
-
401
- enum wsp_ggml_backend_type {
402
- WSP_GGML_BACKEND_TYPE_CPU = 0,
403
- WSP_GGML_BACKEND_TYPE_GPU = 10,
404
- WSP_GGML_BACKEND_TYPE_GPU_SPLIT = 20,
396
+ WSP_GGML_PREC_DEFAULT = 0, // stored as wsp_ggml_tensor.op_params, 0 by default
397
+ WSP_GGML_PREC_F32 = 10,
405
398
  };
406
399
 
407
400
  // model file types
@@ -430,9 +423,6 @@ extern "C" {
430
423
  WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431
424
  WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432
425
  WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
433
- WSP_GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
434
- WSP_GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
435
- WSP_GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
436
426
  };
437
427
 
438
428
  // available tensor operations:
@@ -464,6 +454,7 @@ extern "C" {
464
454
  WSP_GGML_OP_RMS_NORM,
465
455
  WSP_GGML_OP_RMS_NORM_BACK,
466
456
  WSP_GGML_OP_GROUP_NORM,
457
+ WSP_GGML_OP_L2_NORM,
467
458
 
468
459
  WSP_GGML_OP_MUL_MAT,
469
460
  WSP_GGML_OP_MUL_MAT_ID,
@@ -490,12 +481,15 @@ extern "C" {
490
481
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
491
482
  WSP_GGML_OP_IM2COL,
492
483
  WSP_GGML_OP_IM2COL_BACK,
484
+ WSP_GGML_OP_CONV_2D_DW,
493
485
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
494
486
  WSP_GGML_OP_POOL_1D,
495
487
  WSP_GGML_OP_POOL_2D,
496
488
  WSP_GGML_OP_POOL_2D_BACK,
497
489
  WSP_GGML_OP_UPSCALE, // nearest interpolate
498
490
  WSP_GGML_OP_PAD,
491
+ WSP_GGML_OP_PAD_REFLECT_1D,
492
+ WSP_GGML_OP_ROLL,
499
493
  WSP_GGML_OP_ARANGE,
500
494
  WSP_GGML_OP_TIMESTEP_EMBEDDING,
501
495
  WSP_GGML_OP_ARGSORT,
@@ -509,21 +503,18 @@ extern "C" {
509
503
  WSP_GGML_OP_WIN_UNPART,
510
504
  WSP_GGML_OP_GET_REL_POS,
511
505
  WSP_GGML_OP_ADD_REL_POS,
512
- WSP_GGML_OP_RWKV_WKV,
506
+ WSP_GGML_OP_RWKV_WKV6,
507
+ WSP_GGML_OP_GATED_LINEAR_ATTN,
508
+ WSP_GGML_OP_RWKV_WKV7,
513
509
 
514
510
  WSP_GGML_OP_UNARY,
515
511
 
516
- WSP_GGML_OP_MAP_UNARY,
517
- WSP_GGML_OP_MAP_BINARY,
518
-
519
- WSP_GGML_OP_MAP_CUSTOM1_F32,
520
- WSP_GGML_OP_MAP_CUSTOM2_F32,
521
- WSP_GGML_OP_MAP_CUSTOM3_F32,
522
-
523
512
  WSP_GGML_OP_MAP_CUSTOM1,
524
513
  WSP_GGML_OP_MAP_CUSTOM2,
525
514
  WSP_GGML_OP_MAP_CUSTOM3,
526
515
 
516
+ WSP_GGML_OP_CUSTOM,
517
+
527
518
  WSP_GGML_OP_CROSS_ENTROPY_LOSS,
528
519
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
529
520
  WSP_GGML_OP_OPT_STEP_ADAMW,
@@ -546,6 +537,7 @@ extern "C" {
546
537
  WSP_GGML_UNARY_OP_HARDSWISH,
547
538
  WSP_GGML_UNARY_OP_HARDSIGMOID,
548
539
  WSP_GGML_UNARY_OP_EXP,
540
+ WSP_GGML_UNARY_OP_GELU_ERF,
549
541
 
550
542
  WSP_GGML_UNARY_OP_COUNT,
551
543
  };
@@ -558,10 +550,10 @@ extern "C" {
558
550
 
559
551
  enum wsp_ggml_log_level {
560
552
  WSP_GGML_LOG_LEVEL_NONE = 0,
561
- WSP_GGML_LOG_LEVEL_INFO = 1,
562
- WSP_GGML_LOG_LEVEL_WARN = 2,
563
- WSP_GGML_LOG_LEVEL_ERROR = 3,
564
- WSP_GGML_LOG_LEVEL_DEBUG = 4,
553
+ WSP_GGML_LOG_LEVEL_DEBUG = 1,
554
+ WSP_GGML_LOG_LEVEL_INFO = 2,
555
+ WSP_GGML_LOG_LEVEL_WARN = 3,
556
+ WSP_GGML_LOG_LEVEL_ERROR = 4,
565
557
  WSP_GGML_LOG_LEVEL_CONT = 5, // continue previous log
566
558
  };
567
559
 
@@ -573,12 +565,17 @@ extern "C" {
573
565
  WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
574
566
  };
575
567
 
568
+ struct wsp_ggml_init_params {
569
+ // memory pool
570
+ size_t mem_size; // bytes
571
+ void * mem_buffer; // if NULL, memory will be allocated internally
572
+ bool no_alloc; // don't allocate memory for the tensor data
573
+ };
574
+
576
575
  // n-dimensional tensor
577
576
  struct wsp_ggml_tensor {
578
577
  enum wsp_ggml_type type;
579
578
 
580
- WSP_GGML_DEPRECATED(enum wsp_ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
581
-
582
579
  struct wsp_ggml_backend_buffer * buffer;
583
580
 
584
581
  int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements
@@ -595,7 +592,6 @@ extern "C" {
595
592
 
596
593
  int32_t flags;
597
594
 
598
- struct wsp_ggml_tensor * grad;
599
595
  struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
600
596
 
601
597
  // source tensor and offset for views
@@ -608,7 +604,7 @@ extern "C" {
608
604
 
609
605
  void * extra; // extra things e.g. for ggml-cuda.cu
610
606
 
611
- // char padding[4];
607
+ char padding[8];
612
608
  };
613
609
 
614
610
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
@@ -618,67 +614,6 @@ extern "C" {
618
614
  // If it returns true, the computation is aborted
619
615
  typedef bool (*wsp_ggml_abort_callback)(void * data);
620
616
 
621
- // Scheduling priorities
622
- enum wsp_ggml_sched_priority {
623
- WSP_GGML_SCHED_PRIO_NORMAL,
624
- WSP_GGML_SCHED_PRIO_MEDIUM,
625
- WSP_GGML_SCHED_PRIO_HIGH,
626
- WSP_GGML_SCHED_PRIO_REALTIME
627
- };
628
-
629
- // Threadpool params
630
- // Use wsp_ggml_threadpool_params_default() or wsp_ggml_threadpool_params_init() to populate the defaults
631
- struct wsp_ggml_threadpool_params {
632
- bool cpumask[WSP_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
633
- int n_threads; // number of threads
634
- enum wsp_ggml_sched_priority prio; // thread priority
635
- uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
636
- bool strict_cpu; // strict cpu placement
637
- bool paused; // start in paused state
638
- };
639
-
640
- struct wsp_ggml_threadpool; // forward declaration, see ggml.c
641
-
642
- typedef struct wsp_ggml_threadpool * wsp_ggml_threadpool_t;
643
-
644
- // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
645
- // since https://github.com/ggerganov/ggml/issues/287
646
- struct wsp_ggml_cplan {
647
- size_t work_size; // size of work buffer, calculated by `wsp_ggml_graph_plan()`
648
- uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
649
-
650
- int n_threads;
651
- struct wsp_ggml_threadpool * threadpool;
652
-
653
- // abort wsp_ggml_graph_compute when true
654
- wsp_ggml_abort_callback abort_callback;
655
- void * abort_callback_data;
656
- };
657
-
658
- // scratch buffer
659
- // TODO: deprecate and remove
660
- struct wsp_ggml_scratch {
661
- size_t offs;
662
- size_t size;
663
- void * data;
664
- };
665
-
666
- struct wsp_ggml_init_params {
667
- // memory pool
668
- size_t mem_size; // bytes
669
- void * mem_buffer; // if NULL, memory will be allocated internally
670
- bool no_alloc; // don't allocate memory for the tensor data
671
- };
672
-
673
- // numa strategies
674
- enum wsp_ggml_numa_strategy {
675
- WSP_GGML_NUMA_STRATEGY_DISABLED = 0,
676
- WSP_GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
677
- WSP_GGML_NUMA_STRATEGY_ISOLATE = 2,
678
- WSP_GGML_NUMA_STRATEGY_NUMACTL = 3,
679
- WSP_GGML_NUMA_STRATEGY_MIRROR = 4,
680
- WSP_GGML_NUMA_STRATEGY_COUNT
681
- };
682
617
 
683
618
  //
684
619
  // GUID
@@ -701,9 +636,6 @@ extern "C" {
701
636
  // accepts a UTF-8 path, even on Windows
702
637
  WSP_GGML_API FILE * wsp_ggml_fopen(const char * fname, const char * mode);
703
638
 
704
- WSP_GGML_API void wsp_ggml_numa_init(enum wsp_ggml_numa_strategy numa); // call once for better performance on NUMA systems
705
- WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
706
-
707
639
  WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj);
708
640
  WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx);
709
641
 
@@ -743,11 +675,18 @@ extern "C" {
743
675
  WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
744
676
  WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
745
677
 
678
+ // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
746
679
  WSP_GGML_API bool wsp_ggml_is_contiguous (const struct wsp_ggml_tensor * tensor);
747
680
  WSP_GGML_API bool wsp_ggml_is_contiguous_0(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_is_contiguous()
748
681
  WSP_GGML_API bool wsp_ggml_is_contiguous_1(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 1
749
682
  WSP_GGML_API bool wsp_ggml_is_contiguous_2(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 2
750
683
 
684
+ // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
685
+ WSP_GGML_API bool wsp_ggml_is_contiguously_allocated(const struct wsp_ggml_tensor * tensor);
686
+
687
+ // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
688
+ WSP_GGML_API bool wsp_ggml_is_contiguous_channels(const struct wsp_ggml_tensor * tensor);
689
+
751
690
  WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
752
691
  WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
753
692
 
@@ -766,7 +705,6 @@ extern "C" {
766
705
 
767
706
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
768
707
 
769
- WSP_GGML_API size_t wsp_ggml_set_scratch (struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch);
770
708
  WSP_GGML_API bool wsp_ggml_get_no_alloc(struct wsp_ggml_context * ctx);
771
709
  WSP_GGML_API void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc);
772
710
 
@@ -806,8 +744,7 @@ extern "C" {
806
744
  int64_t ne2,
807
745
  int64_t ne3);
808
746
 
809
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value);
810
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value);
747
+ WSP_GGML_API void * wsp_ggml_new_buffer(struct wsp_ggml_context * ctx, size_t nbytes);
811
748
 
812
749
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
813
750
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
@@ -817,35 +754,25 @@ extern "C" {
817
754
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
818
755
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
819
756
 
820
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
821
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value);
822
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value);
823
-
824
757
  // Converts a flat index into coordinates
825
- WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
826
-
827
- WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i);
828
- WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value);
829
-
830
- WSP_GGML_API int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
831
- WSP_GGML_API void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
758
+ WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
832
759
 
833
- WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i);
834
- WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value);
835
-
836
- WSP_GGML_API float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
837
- WSP_GGML_API void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
760
+ WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
838
761
 
839
762
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
840
763
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
841
764
 
842
- WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
843
-
844
765
  WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor);
845
766
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name);
846
767
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
847
768
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name( struct wsp_ggml_tensor * tensor, const char * fmt, ...);
848
769
 
770
+ // Tensor flags
771
+ WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
772
+ WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
773
+ WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_tensor * tensor);
774
+ WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
775
+
849
776
  //
850
777
  // operations on tensors with backpropagation
851
778
  //
@@ -1009,11 +936,20 @@ extern "C" {
1009
936
  struct wsp_ggml_tensor * a,
1010
937
  struct wsp_ggml_tensor * b);
1011
938
 
939
+ // repeat a to the specified shape
940
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_4d(
941
+ struct wsp_ggml_context * ctx,
942
+ struct wsp_ggml_tensor * a,
943
+ int64_t ne0,
944
+ int64_t ne1,
945
+ int64_t ne2,
946
+ int64_t ne3);
947
+
1012
948
  // sums repetitions in a into shape of b
1013
949
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back(
1014
950
  struct wsp_ggml_context * ctx,
1015
951
  struct wsp_ggml_tensor * a,
1016
- struct wsp_ggml_tensor * b);
952
+ struct wsp_ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
1017
953
 
1018
954
  // concat a and b along dim
1019
955
  // used in stable-diffusion
@@ -1099,6 +1035,16 @@ extern "C" {
1099
1035
  struct wsp_ggml_context * ctx,
1100
1036
  struct wsp_ggml_tensor * a);
1101
1037
 
1038
+ // GELU using erf (error function) when possible
1039
+ // some backends may fallback to approximation based on Abramowitz and Stegun formula
1040
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_erf(
1041
+ struct wsp_ggml_context * ctx,
1042
+ struct wsp_ggml_tensor * a);
1043
+
1044
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_erf_inplace(
1045
+ struct wsp_ggml_context * ctx,
1046
+ struct wsp_ggml_tensor * a);
1047
+
1102
1048
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_quick(
1103
1049
  struct wsp_ggml_context * ctx,
1104
1050
  struct wsp_ggml_tensor * a);
@@ -1175,6 +1121,18 @@ extern "C" {
1175
1121
  int n_groups,
1176
1122
  float eps);
1177
1123
 
1124
+ // l2 normalize along rows
1125
+ // used in rwkv v7
1126
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_l2_norm(
1127
+ struct wsp_ggml_context * ctx,
1128
+ struct wsp_ggml_tensor * a,
1129
+ float eps);
1130
+
1131
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_l2_norm_inplace(
1132
+ struct wsp_ggml_context * ctx,
1133
+ struct wsp_ggml_tensor * a,
1134
+ float eps);
1135
+
1178
1136
  // a - x
1179
1137
  // b - dy
1180
1138
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rms_norm_back(
@@ -1464,16 +1422,20 @@ extern "C" {
1464
1422
  float scale,
1465
1423
  float max_bias);
1466
1424
 
1467
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1425
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back(
1468
1426
  struct wsp_ggml_context * ctx,
1469
1427
  struct wsp_ggml_tensor * a,
1470
- struct wsp_ggml_tensor * b);
1428
+ struct wsp_ggml_tensor * b,
1429
+ float scale,
1430
+ float max_bias);
1471
1431
 
1472
1432
  // in-place, returns view(a)
1473
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace(
1433
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_inplace(
1474
1434
  struct wsp_ggml_context * ctx,
1475
1435
  struct wsp_ggml_tensor * a,
1476
- struct wsp_ggml_tensor * b);
1436
+ struct wsp_ggml_tensor * b,
1437
+ float scale,
1438
+ float max_bias);
1477
1439
 
1478
1440
  // rotary position embedding
1479
1441
  // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
@@ -1512,6 +1474,22 @@ extern "C" {
1512
1474
  float beta_fast,
1513
1475
  float beta_slow);
1514
1476
 
1477
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi(
1478
+ struct wsp_ggml_context * ctx,
1479
+ struct wsp_ggml_tensor * a,
1480
+ struct wsp_ggml_tensor * b,
1481
+ struct wsp_ggml_tensor * c,
1482
+ int n_dims,
1483
+ int sections[4],
1484
+ int mode,
1485
+ int n_ctx_orig,
1486
+ float freq_base,
1487
+ float freq_scale,
1488
+ float ext_factor,
1489
+ float attn_factor,
1490
+ float beta_fast,
1491
+ float beta_slow);
1492
+
1515
1493
  // in-place, returns view(a)
1516
1494
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
1517
1495
  struct wsp_ggml_context * ctx,
@@ -1559,12 +1537,12 @@ extern "C" {
1559
1537
  "use wsp_ggml_rope_ext_inplace instead");
1560
1538
 
1561
1539
  // compute correction dims for YaRN RoPE scaling
1562
- void wsp_ggml_rope_yarn_corr_dims(
1540
+ WSP_GGML_API void wsp_ggml_rope_yarn_corr_dims(
1563
1541
  int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1564
1542
 
1565
1543
  // rotary position embedding backward, i.e compute dx from dy
1566
1544
  // a - dy
1567
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1545
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_back(
1568
1546
  struct wsp_ggml_context * ctx,
1569
1547
  struct wsp_ggml_tensor * a, // gradients of wsp_ggml_rope result
1570
1548
  struct wsp_ggml_tensor * b, // positions
@@ -1579,6 +1557,23 @@ extern "C" {
1579
1557
  float beta_fast,
1580
1558
  float beta_slow);
1581
1559
 
1560
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi_back(
1561
+ struct wsp_ggml_context * ctx,
1562
+ struct wsp_ggml_tensor * a,
1563
+ struct wsp_ggml_tensor * b,
1564
+ struct wsp_ggml_tensor * c,
1565
+ int n_dims,
1566
+ int sections[4],
1567
+ int mode,
1568
+ int n_ctx_orig,
1569
+ float freq_base,
1570
+ float freq_scale,
1571
+ float ext_factor,
1572
+ float attn_factor,
1573
+ float beta_fast,
1574
+ float beta_slow);
1575
+
1576
+
1582
1577
  // clamp
1583
1578
  // in-place, returns view(a)
1584
1579
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_clamp(
@@ -1615,17 +1610,6 @@ extern "C" {
1615
1610
  int d1, // dilation dimension 1
1616
1611
  bool is_2D);
1617
1612
 
1618
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_depthwise_2d(
1619
- struct wsp_ggml_context * ctx,
1620
- struct wsp_ggml_tensor * a, // convolution kernel
1621
- struct wsp_ggml_tensor * b, // data
1622
- int s0, // stride dimension 0
1623
- int s1, // stride dimension 1
1624
- int p0, // padding dimension 0
1625
- int p1, // padding dimension 1
1626
- int d0, // dilation dimension 0
1627
- int d1); // dilation dimension 1
1628
-
1629
1613
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1630
1614
  struct wsp_ggml_context * ctx,
1631
1615
  struct wsp_ggml_tensor * a, // convolution kernel
@@ -1643,6 +1627,23 @@ extern "C" {
1643
1627
  int s, // stride
1644
1628
  int d); // dilation
1645
1629
 
1630
+ // depthwise
1631
+ // TODO: this is very likely wrong for some cases! - needs more testing
1632
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw(
1633
+ struct wsp_ggml_context * ctx,
1634
+ struct wsp_ggml_tensor * a, // convolution kernel
1635
+ struct wsp_ggml_tensor * b, // data
1636
+ int s0, // stride
1637
+ int p0, // padding
1638
+ int d0); // dilation
1639
+
1640
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw_ph(
1641
+ struct wsp_ggml_context * ctx,
1642
+ struct wsp_ggml_tensor * a, // convolution kernel
1643
+ struct wsp_ggml_tensor * b, // data
1644
+ int s0, // stride
1645
+ int d0); // dilation
1646
+
1646
1647
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1647
1648
  struct wsp_ggml_context * ctx,
1648
1649
  struct wsp_ggml_tensor * a, // convolution kernel
@@ -1662,7 +1663,6 @@ extern "C" {
1662
1663
  int d0, // dilation dimension 0
1663
1664
  int d1); // dilation dimension 1
1664
1665
 
1665
-
1666
1666
  // kernel size is a->ne[0] x a->ne[1]
1667
1667
  // stride is equal to kernel size
1668
1668
  // padding is zero
@@ -1689,6 +1689,34 @@ extern "C" {
1689
1689
  struct wsp_ggml_tensor * a,
1690
1690
  struct wsp_ggml_tensor * b);
1691
1691
 
1692
+ // depthwise (via im2col and mul_mat)
1693
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw(
1694
+ struct wsp_ggml_context * ctx,
1695
+ struct wsp_ggml_tensor * a, // convolution kernel
1696
+ struct wsp_ggml_tensor * b, // data
1697
+ int s0, // stride dimension 0
1698
+ int s1, // stride dimension 1
1699
+ int p0, // padding dimension 0
1700
+ int p1, // padding dimension 1
1701
+ int d0, // dilation dimension 0
1702
+ int d1); // dilation dimension 1
1703
+
1704
+ // Depthwise 2D convolution
1705
+ // may be faster than wsp_ggml_conv_2d_dw, but not available in all backends
1706
+ // a: KW KH 1 C convolution kernel
1707
+ // b: W H C N input data
1708
+ // res: W_out H_out C N
1709
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw_direct(
1710
+ struct wsp_ggml_context * ctx,
1711
+ struct wsp_ggml_tensor * a,
1712
+ struct wsp_ggml_tensor * b,
1713
+ int stride0,
1714
+ int stride1,
1715
+ int pad0,
1716
+ int pad1,
1717
+ int dilation0,
1718
+ int dilation1);
1719
+
1692
1720
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_2d_p0(
1693
1721
  struct wsp_ggml_context * ctx,
1694
1722
  struct wsp_ggml_tensor * a,
@@ -1734,24 +1762,29 @@ extern "C" {
1734
1762
  float p0,
1735
1763
  float p1);
1736
1764
 
1737
- // nearest interpolate
1765
+ enum wsp_ggml_scale_mode {
1766
+ WSP_GGML_SCALE_MODE_NEAREST = 0,
1767
+ WSP_GGML_SCALE_MODE_BILINEAR = 1,
1768
+ };
1769
+
1770
+ // interpolate
1738
1771
  // multiplies ne0 and ne1 by scale factor
1739
- // used in stable-diffusion
1740
1772
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1741
1773
  struct wsp_ggml_context * ctx,
1742
1774
  struct wsp_ggml_tensor * a,
1743
- int scale_factor);
1775
+ int scale_factor,
1776
+ enum wsp_ggml_scale_mode mode);
1744
1777
 
1745
- // nearest interpolate
1746
- // nearest interpolate to specified dimensions
1747
- // used in tortoise.cpp
1778
+ // interpolate
1779
+ // interpolate scale to specified dimensions
1748
1780
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1749
1781
  struct wsp_ggml_context * ctx,
1750
1782
  struct wsp_ggml_tensor * a,
1751
1783
  int ne0,
1752
1784
  int ne1,
1753
1785
  int ne2,
1754
- int ne3);
1786
+ int ne3,
1787
+ enum wsp_ggml_scale_mode mode);
1755
1788
 
1756
1789
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1757
1790
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
@@ -1762,6 +1795,24 @@ extern "C" {
1762
1795
  int p2,
1763
1796
  int p3);
1764
1797
 
1798
+ // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
1799
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad_reflect_1d(
1800
+ struct wsp_ggml_context * ctx,
1801
+ struct wsp_ggml_tensor * a,
1802
+ int p0,
1803
+ int p1);
1804
+
1805
+ // Move tensor elements by an offset given for each dimension. Elements that
1806
+ // are shifted beyond the last position are wrapped around to the beginning.
1807
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_roll(
1808
+ struct wsp_ggml_context * ctx,
1809
+ struct wsp_ggml_tensor * a,
1810
+ int shift0,
1811
+ int shift1,
1812
+ int shift2,
1813
+ int shift3);
1814
+
1815
+
1765
1816
  // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1766
1817
  // timesteps: [N,]
1767
1818
  // return: [N, dim]
@@ -1794,13 +1845,13 @@ extern "C" {
1794
1845
  struct wsp_ggml_tensor * a,
1795
1846
  int k);
1796
1847
 
1797
- #define WSP_GGML_KQ_MASK_PAD 32
1848
+ #define WSP_GGML_KQ_MASK_PAD 64
1798
1849
 
1799
- // q: [n_embd, n_batch, n_head, 1]
1800
- // k: [n_embd, n_kv, n_head_kv, 1]
1801
- // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1802
- // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1803
- // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1850
+ // q: [n_embd_k, n_batch, n_head, 1]
1851
+ // k: [n_embd_k, n_kv, n_head_kv, 1]
1852
+ // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1853
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1854
+ // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
1804
1855
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
1805
1856
  struct wsp_ggml_context * ctx,
1806
1857
  struct wsp_ggml_tensor * q,
@@ -1815,6 +1866,9 @@ extern "C" {
1815
1866
  struct wsp_ggml_tensor * a,
1816
1867
  enum wsp_ggml_prec prec);
1817
1868
 
1869
+ WSP_GGML_API enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
1870
+ const struct wsp_ggml_tensor * a);
1871
+
1818
1872
  // TODO: needs to be adapted to wsp_ggml_flash_attn_ext
1819
1873
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
1820
1874
  struct wsp_ggml_context * ctx,
@@ -1888,7 +1942,7 @@ extern "C" {
1888
1942
  struct wsp_ggml_tensor * pw,
1889
1943
  struct wsp_ggml_tensor * ph);
1890
1944
 
1891
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv(
1945
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv6(
1892
1946
  struct wsp_ggml_context * ctx,
1893
1947
  struct wsp_ggml_tensor * k,
1894
1948
  struct wsp_ggml_tensor * v,
@@ -1897,84 +1951,26 @@ extern "C" {
1897
1951
  struct wsp_ggml_tensor * td,
1898
1952
  struct wsp_ggml_tensor * state);
1899
1953
 
1900
- // custom operators
1954
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gated_linear_attn(
1955
+ struct wsp_ggml_context * ctx,
1956
+ struct wsp_ggml_tensor * k,
1957
+ struct wsp_ggml_tensor * v,
1958
+ struct wsp_ggml_tensor * q,
1959
+ struct wsp_ggml_tensor * g,
1960
+ struct wsp_ggml_tensor * state,
1961
+ float scale);
1901
1962
 
1902
- typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
1903
- typedef void (*wsp_ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
1904
-
1905
- typedef void (*wsp_ggml_custom1_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1906
- typedef void (*wsp_ggml_custom2_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1907
- typedef void (*wsp_ggml_custom3_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1908
-
1909
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_f32(
1910
- struct wsp_ggml_context * ctx,
1911
- struct wsp_ggml_tensor * a,
1912
- wsp_ggml_unary_op_f32_t fun),
1913
- "use wsp_ggml_map_custom1 instead");
1914
-
1915
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_inplace_f32(
1916
- struct wsp_ggml_context * ctx,
1917
- struct wsp_ggml_tensor * a,
1918
- wsp_ggml_unary_op_f32_t fun),
1919
- "use wsp_ggml_map_custom1_inplace instead");
1920
-
1921
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_f32(
1922
- struct wsp_ggml_context * ctx,
1923
- struct wsp_ggml_tensor * a,
1924
- struct wsp_ggml_tensor * b,
1925
- wsp_ggml_binary_op_f32_t fun),
1926
- "use wsp_ggml_map_custom2 instead");
1927
-
1928
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_inplace_f32(
1929
- struct wsp_ggml_context * ctx,
1930
- struct wsp_ggml_tensor * a,
1931
- struct wsp_ggml_tensor * b,
1932
- wsp_ggml_binary_op_f32_t fun),
1933
- "use wsp_ggml_map_custom2_inplace instead");
1934
-
1935
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_f32(
1936
- struct wsp_ggml_context * ctx,
1937
- struct wsp_ggml_tensor * a,
1938
- wsp_ggml_custom1_op_f32_t fun),
1939
- "use wsp_ggml_map_custom1 instead");
1940
-
1941
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace_f32(
1942
- struct wsp_ggml_context * ctx,
1943
- struct wsp_ggml_tensor * a,
1944
- wsp_ggml_custom1_op_f32_t fun),
1945
- "use wsp_ggml_map_custom1_inplace instead");
1946
-
1947
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_f32(
1948
- struct wsp_ggml_context * ctx,
1949
- struct wsp_ggml_tensor * a,
1950
- struct wsp_ggml_tensor * b,
1951
- wsp_ggml_custom2_op_f32_t fun),
1952
- "use wsp_ggml_map_custom2 instead");
1953
-
1954
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace_f32(
1955
- struct wsp_ggml_context * ctx,
1956
- struct wsp_ggml_tensor * a,
1957
- struct wsp_ggml_tensor * b,
1958
- wsp_ggml_custom2_op_f32_t fun),
1959
- "use wsp_ggml_map_custom2_inplace instead");
1960
-
1961
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_f32(
1962
- struct wsp_ggml_context * ctx,
1963
- struct wsp_ggml_tensor * a,
1964
- struct wsp_ggml_tensor * b,
1965
- struct wsp_ggml_tensor * c,
1966
- wsp_ggml_custom3_op_f32_t fun),
1967
- "use wsp_ggml_map_custom3 instead");
1968
-
1969
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace_f32(
1970
- struct wsp_ggml_context * ctx,
1971
- struct wsp_ggml_tensor * a,
1972
- struct wsp_ggml_tensor * b,
1973
- struct wsp_ggml_tensor * c,
1974
- wsp_ggml_custom3_op_f32_t fun),
1975
- "use wsp_ggml_map_custom3_inplace instead");
1976
-
1977
- // custom operators v2
1963
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv7(
1964
+ struct wsp_ggml_context * ctx,
1965
+ struct wsp_ggml_tensor * r,
1966
+ struct wsp_ggml_tensor * w,
1967
+ struct wsp_ggml_tensor * k,
1968
+ struct wsp_ggml_tensor * v,
1969
+ struct wsp_ggml_tensor * a,
1970
+ struct wsp_ggml_tensor * b,
1971
+ struct wsp_ggml_tensor * state);
1972
+
1973
+ // custom operators
1978
1974
 
1979
1975
  typedef void (*wsp_ggml_custom1_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int nth, void * userdata);
1980
1976
  typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
@@ -2031,6 +2027,30 @@ extern "C" {
2031
2027
  int n_tasks,
2032
2028
  void * userdata);
2033
2029
 
2030
+ typedef void (*wsp_ggml_custom_op_t)(struct wsp_ggml_tensor * dst , int ith, int nth, void * userdata);
2031
+
2032
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_custom_4d(
2033
+ struct wsp_ggml_context * ctx,
2034
+ enum wsp_ggml_type type,
2035
+ int64_t ne0,
2036
+ int64_t ne1,
2037
+ int64_t ne2,
2038
+ int64_t ne3,
2039
+ struct wsp_ggml_tensor ** args,
2040
+ int n_args,
2041
+ wsp_ggml_custom_op_t fun,
2042
+ int n_tasks,
2043
+ void * userdata);
2044
+
2045
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_custom_inplace(
2046
+ struct wsp_ggml_context * ctx,
2047
+ struct wsp_ggml_tensor * a,
2048
+ struct wsp_ggml_tensor ** args,
2049
+ int n_args,
2050
+ wsp_ggml_custom_op_t fun,
2051
+ int n_tasks,
2052
+ void * userdata);
2053
+
2034
2054
  // loss function
2035
2055
 
2036
2056
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss(
@@ -2051,36 +2071,24 @@ extern "C" {
2051
2071
  struct wsp_ggml_context * ctx,
2052
2072
  struct wsp_ggml_tensor * a,
2053
2073
  struct wsp_ggml_tensor * grad,
2054
- float alpha,
2055
- float beta1,
2056
- float beta2,
2057
- float eps,
2058
- float wd); // weight decay
2074
+ struct wsp_ggml_tensor * m,
2075
+ struct wsp_ggml_tensor * v,
2076
+ struct wsp_ggml_tensor * adamw_params); // parameters such a the learning rate
2059
2077
 
2060
2078
  //
2061
2079
  // automatic differentiation
2062
2080
  //
2063
2081
 
2064
- WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
2065
- WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
2066
-
2067
- WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
2068
- WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool accumulate);
2069
-
2070
- WSP_GGML_API void wsp_ggml_build_opt_adamw(
2071
- struct wsp_ggml_context * ctx,
2072
- struct wsp_ggml_cgraph * gf,
2073
- struct wsp_ggml_cgraph * gb,
2074
- float alpha,
2075
- float beta1,
2076
- float beta2,
2077
- float eps,
2078
- float wd); // weight decay
2082
+ WSP_GGML_API void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
2083
+ WSP_GGML_API void wsp_ggml_build_backward_expand(
2084
+ struct wsp_ggml_context * ctx, // context for gradient computation
2085
+ struct wsp_ggml_cgraph * cgraph,
2086
+ struct wsp_ggml_tensor ** grad_accs);
2079
2087
 
2080
2088
  // graph allocation in a context
2081
2089
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
2082
2090
  WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads);
2083
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
2091
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, bool force_grads);
2084
2092
  WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
2085
2093
  WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2086
2094
  WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
@@ -2095,31 +2103,9 @@ extern "C" {
2095
2103
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
2096
2104
  WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
2097
2105
 
2098
- WSP_GGML_API struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads);
2099
- WSP_GGML_API void wsp_ggml_threadpool_params_init (struct wsp_ggml_threadpool_params * p, int n_threads);
2100
- WSP_GGML_API bool wsp_ggml_threadpool_params_match (const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1);
2101
- WSP_GGML_API struct wsp_ggml_threadpool * wsp_ggml_threadpool_new (struct wsp_ggml_threadpool_params * params);
2102
- WSP_GGML_API void wsp_ggml_threadpool_free (struct wsp_ggml_threadpool * threadpool);
2103
- WSP_GGML_API int wsp_ggml_threadpool_get_n_threads(struct wsp_ggml_threadpool * threadpool);
2104
- WSP_GGML_API void wsp_ggml_threadpool_pause (struct wsp_ggml_threadpool * threadpool);
2105
- WSP_GGML_API void wsp_ggml_threadpool_resume (struct wsp_ggml_threadpool * threadpool);
2106
-
2107
- // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
2108
- // when plan.work_size > 0, caller must allocate memory for plan.work_data
2109
- WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan(
2110
- const struct wsp_ggml_cgraph * cgraph,
2111
- int n_threads, /* = WSP_GGML_DEFAULT_N_THREADS */
2112
- struct wsp_ggml_threadpool * threadpool /* = NULL */ );
2113
- WSP_GGML_API enum wsp_ggml_status wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
2114
-
2115
- // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
2116
- // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
2117
- WSP_GGML_API enum wsp_ggml_status wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
2118
-
2119
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
2120
-
2121
- WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
2122
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
2106
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor (const struct wsp_ggml_cgraph * cgraph, const char * name);
2107
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad (const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
2108
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad_acc(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
2123
2109
 
2124
2110
  // print info and performance information for the graph
2125
2111
  WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph);
@@ -2127,201 +2113,14 @@ extern "C" {
2127
2113
  // dump the graph into a file using the dot format
2128
2114
  WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename);
2129
2115
 
2130
- // build gradient checkpointing backward graph gb for gf using provided checkpoints
2131
- // gb_tmp will contain original backward graph with rewritten backward process nodes,
2132
- // but without the second forward pass nodes.
2133
- WSP_GGML_API void wsp_ggml_build_backward_gradient_checkpointing(
2134
- struct wsp_ggml_context * ctx,
2135
- struct wsp_ggml_cgraph * gf,
2136
- struct wsp_ggml_cgraph * gb,
2137
- struct wsp_ggml_cgraph * gb_tmp,
2138
- struct wsp_ggml_tensor * * checkpoints,
2139
- int n_checkpoints);
2140
- //
2141
- // optimization
2142
- //
2143
-
2144
- // optimization methods
2145
- enum wsp_ggml_opt_type {
2146
- WSP_GGML_OPT_TYPE_ADAM,
2147
- WSP_GGML_OPT_TYPE_LBFGS,
2148
- };
2149
-
2150
- // linesearch methods
2151
- enum wsp_ggml_linesearch {
2152
- WSP_GGML_LINESEARCH_DEFAULT = 1,
2153
-
2154
- WSP_GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
2155
- WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
2156
- WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
2157
- };
2158
-
2159
- // optimization return values
2160
- enum wsp_ggml_opt_result {
2161
- WSP_GGML_OPT_RESULT_OK = 0,
2162
- WSP_GGML_OPT_RESULT_DID_NOT_CONVERGE,
2163
- WSP_GGML_OPT_RESULT_NO_CONTEXT,
2164
- WSP_GGML_OPT_RESULT_INVALID_WOLFE,
2165
- WSP_GGML_OPT_RESULT_FAIL,
2166
- WSP_GGML_OPT_RESULT_CANCEL,
2167
-
2168
- WSP_GGML_LINESEARCH_FAIL = -128,
2169
- WSP_GGML_LINESEARCH_MINIMUM_STEP,
2170
- WSP_GGML_LINESEARCH_MAXIMUM_STEP,
2171
- WSP_GGML_LINESEARCH_MAXIMUM_ITERATIONS,
2172
- WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
2173
- };
2174
-
2175
- typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
2116
+ // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
2176
2117
  typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
2177
2118
 
2178
2119
  // Set callback for all future logging events.
2179
2120
  // If this is not called, or NULL is supplied, everything is output on stderr.
2180
2121
  WSP_GGML_API void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data);
2181
2122
 
2182
- // optimization parameters
2183
- //
2184
- // see ggml.c (wsp_ggml_opt_default_params) for default values
2185
- //
2186
- struct wsp_ggml_opt_params {
2187
- enum wsp_ggml_opt_type type;
2188
-
2189
- size_t graph_size;
2190
-
2191
- int n_threads;
2192
-
2193
- // delta-based convergence test
2194
- //
2195
- // if past == 0 - disabled
2196
- // if past > 0:
2197
- // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
2198
- //
2199
- int past;
2200
- float delta;
2201
-
2202
- // maximum number of iterations without improvement
2203
- //
2204
- // if 0 - disabled
2205
- // if > 0:
2206
- // assume convergence if no cost improvement in this number of iterations
2207
- //
2208
- int max_no_improvement;
2209
-
2210
- bool print_forward_graph;
2211
- bool print_backward_graph;
2212
-
2213
- int n_gradient_accumulation;
2214
-
2215
- // ADAM parameters
2216
- struct {
2217
- int n_iter;
2218
-
2219
- float sched; // schedule multiplier (fixed, decay or warmup)
2220
- float decay; // weight decay for AdamW, use 0.0f to disable
2221
- int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
2222
- float alpha; // learning rate
2223
- float beta1;
2224
- float beta2;
2225
- float eps; // epsilon for numerical stability
2226
- float eps_f; // epsilon for convergence test
2227
- float eps_g; // epsilon for convergence test
2228
- float gclip; // gradient clipping
2229
- } adam;
2230
-
2231
- // LBFGS parameters
2232
- struct {
2233
- int m; // number of corrections to approximate the inv. Hessian
2234
- int n_iter;
2235
- int max_linesearch;
2236
-
2237
- float eps; // convergence tolerance
2238
- float ftol; // line search tolerance
2239
- float wolfe;
2240
- float min_step;
2241
- float max_step;
2242
-
2243
- enum wsp_ggml_linesearch linesearch;
2244
- } lbfgs;
2245
- };
2246
-
2247
- struct wsp_ggml_opt_context {
2248
- struct wsp_ggml_context * ctx;
2249
- struct wsp_ggml_opt_params params;
2250
-
2251
- int iter;
2252
- int64_t nx; // number of parameter elements
2253
-
2254
- bool just_initialized;
2255
-
2256
- float loss_before;
2257
- float loss_after;
2258
-
2259
- struct {
2260
- struct wsp_ggml_tensor * g; // current gradient
2261
- struct wsp_ggml_tensor * m; // first moment
2262
- struct wsp_ggml_tensor * v; // second moment
2263
- struct wsp_ggml_tensor * pf; // past function values
2264
- float fx_best;
2265
- float fx_prev;
2266
- int n_no_improvement;
2267
- } adam;
2268
-
2269
- struct {
2270
- struct wsp_ggml_tensor * x; // current parameters
2271
- struct wsp_ggml_tensor * xp; // previous parameters
2272
- struct wsp_ggml_tensor * g; // current gradient
2273
- struct wsp_ggml_tensor * gp; // previous gradient
2274
- struct wsp_ggml_tensor * d; // search direction
2275
- struct wsp_ggml_tensor * pf; // past function values
2276
- struct wsp_ggml_tensor * lmal; // the L-BFGS memory alpha
2277
- struct wsp_ggml_tensor * lmys; // the L-BFGS memory ys
2278
- struct wsp_ggml_tensor * lms; // the L-BFGS memory s
2279
- struct wsp_ggml_tensor * lmy; // the L-BFGS memory y
2280
- float fx_best;
2281
- float step;
2282
- int j;
2283
- int k;
2284
- int end;
2285
- int n_no_improvement;
2286
- } lbfgs;
2287
- };
2288
-
2289
- WSP_GGML_API struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type);
2290
-
2291
- // optimize the function defined by the tensor f
2292
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt(
2293
- struct wsp_ggml_context * ctx,
2294
- struct wsp_ggml_opt_params params,
2295
- struct wsp_ggml_tensor * f);
2296
-
2297
- // initialize optimizer context
2298
- WSP_GGML_API void wsp_ggml_opt_init(
2299
- struct wsp_ggml_context * ctx,
2300
- struct wsp_ggml_opt_context * opt,
2301
- struct wsp_ggml_opt_params params,
2302
- int64_t nx);
2303
-
2304
- // continue optimizing the function defined by the tensor f
2305
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume(
2306
- struct wsp_ggml_context * ctx,
2307
- struct wsp_ggml_opt_context * opt,
2308
- struct wsp_ggml_tensor * f);
2309
-
2310
- // continue optimizing the function defined by the tensor f
2311
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
2312
- struct wsp_ggml_context * ctx,
2313
- struct wsp_ggml_opt_context * opt,
2314
- struct wsp_ggml_tensor * f,
2315
- struct wsp_ggml_cgraph * gf,
2316
- struct wsp_ggml_cgraph * gb,
2317
- wsp_ggml_opt_callback callback,
2318
- void * callback_data);
2319
-
2320
- //
2321
- // tensor flags
2322
- //
2323
- WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
2324
- WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
2123
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
2325
2124
 
2326
2125
  //
2327
2126
  // quantization
@@ -2352,190 +2151,26 @@ extern "C" {
2352
2151
  int64_t n_per_row,
2353
2152
  const float * imatrix);
2354
2153
 
2355
- //
2356
- // gguf
2357
- //
2358
-
2359
- enum wsp_gguf_type {
2360
- WSP_GGUF_TYPE_UINT8 = 0,
2361
- WSP_GGUF_TYPE_INT8 = 1,
2362
- WSP_GGUF_TYPE_UINT16 = 2,
2363
- WSP_GGUF_TYPE_INT16 = 3,
2364
- WSP_GGUF_TYPE_UINT32 = 4,
2365
- WSP_GGUF_TYPE_INT32 = 5,
2366
- WSP_GGUF_TYPE_FLOAT32 = 6,
2367
- WSP_GGUF_TYPE_BOOL = 7,
2368
- WSP_GGUF_TYPE_STRING = 8,
2369
- WSP_GGUF_TYPE_ARRAY = 9,
2370
- WSP_GGUF_TYPE_UINT64 = 10,
2371
- WSP_GGUF_TYPE_INT64 = 11,
2372
- WSP_GGUF_TYPE_FLOAT64 = 12,
2373
- WSP_GGUF_TYPE_COUNT, // marks the end of the enum
2374
- };
2375
-
2376
- struct wsp_gguf_context;
2377
-
2378
- struct wsp_gguf_init_params {
2379
- bool no_alloc;
2380
-
2381
- // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
2382
- struct wsp_ggml_context ** ctx;
2383
- };
2384
-
2385
- WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_empty(void);
2386
- WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp_gguf_init_params params);
2387
- //WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_buffer(..);
2388
-
2389
- WSP_GGML_API void wsp_gguf_free(struct wsp_gguf_context * ctx);
2390
-
2391
- WSP_GGML_API const char * wsp_gguf_type_name(enum wsp_gguf_type type);
2392
-
2393
- WSP_GGML_API int wsp_gguf_get_version (const struct wsp_gguf_context * ctx);
2394
- WSP_GGML_API size_t wsp_gguf_get_alignment (const struct wsp_gguf_context * ctx);
2395
- WSP_GGML_API size_t wsp_gguf_get_data_offset(const struct wsp_gguf_context * ctx);
2396
- WSP_GGML_API void * wsp_gguf_get_data (const struct wsp_gguf_context * ctx);
2397
-
2398
- WSP_GGML_API int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx);
2399
- WSP_GGML_API int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key);
2400
- WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int key_id);
2401
-
2402
- WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int key_id);
2403
- WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id);
2404
-
2405
- // will abort if the wrong type is used for the key
2406
- WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int key_id);
2407
- WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int key_id);
2408
- WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int key_id);
2409
- WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int key_id);
2410
- WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int key_id);
2411
- WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int key_id);
2412
- WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int key_id);
2413
- WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int key_id);
2414
- WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int key_id);
2415
- WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2416
- WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2417
- WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2418
- WSP_GGML_API const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id);
2419
- WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2420
- WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2421
- WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
2422
-
2423
- WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2424
- WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2425
- WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2426
- WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
2427
- WSP_GGML_API enum wsp_ggml_type wsp_gguf_get_tensor_type (const struct wsp_gguf_context * ctx, int i);
2428
-
2429
- // removes key if it exists
2430
- WSP_GGML_API void wsp_gguf_remove_key(struct wsp_gguf_context * ctx, const char * key);
2431
-
2432
- // overrides existing values or adds a new one
2433
- WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
2434
- WSP_GGML_API void wsp_gguf_set_val_i8 (struct wsp_gguf_context * ctx, const char * key, int8_t val);
2435
- WSP_GGML_API void wsp_gguf_set_val_u16 (struct wsp_gguf_context * ctx, const char * key, uint16_t val);
2436
- WSP_GGML_API void wsp_gguf_set_val_i16 (struct wsp_gguf_context * ctx, const char * key, int16_t val);
2437
- WSP_GGML_API void wsp_gguf_set_val_u32 (struct wsp_gguf_context * ctx, const char * key, uint32_t val);
2438
- WSP_GGML_API void wsp_gguf_set_val_i32 (struct wsp_gguf_context * ctx, const char * key, int32_t val);
2439
- WSP_GGML_API void wsp_gguf_set_val_f32 (struct wsp_gguf_context * ctx, const char * key, float val);
2440
- WSP_GGML_API void wsp_gguf_set_val_u64 (struct wsp_gguf_context * ctx, const char * key, uint64_t val);
2441
- WSP_GGML_API void wsp_gguf_set_val_i64 (struct wsp_gguf_context * ctx, const char * key, int64_t val);
2442
- WSP_GGML_API void wsp_gguf_set_val_f64 (struct wsp_gguf_context * ctx, const char * key, double val);
2443
- WSP_GGML_API void wsp_gguf_set_val_bool(struct wsp_gguf_context * ctx, const char * key, bool val);
2444
- WSP_GGML_API void wsp_gguf_set_val_str (struct wsp_gguf_context * ctx, const char * key, const char * val);
2445
- WSP_GGML_API void wsp_gguf_set_arr_data(struct wsp_gguf_context * ctx, const char * key, enum wsp_gguf_type type, const void * data, int n);
2446
- WSP_GGML_API void wsp_gguf_set_arr_str (struct wsp_gguf_context * ctx, const char * key, const char ** data, int n);
2447
-
2448
- // set or add KV pairs from another context
2449
- WSP_GGML_API void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * src);
2450
-
2451
- // manage tensor info
2452
- WSP_GGML_API void wsp_gguf_add_tensor(struct wsp_gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
2453
- WSP_GGML_API void wsp_gguf_set_tensor_type(struct wsp_gguf_context * ctx, const char * name, enum wsp_ggml_type type);
2454
- WSP_GGML_API void wsp_gguf_set_tensor_data(struct wsp_gguf_context * ctx, const char * name, const void * data, size_t size);
2455
-
2456
- // writing gguf files can be done in 2 ways:
2457
- //
2458
- // - write the entire wsp_gguf_context to a binary file in a single pass:
2459
- //
2460
- // wsp_gguf_write_to_file(ctx, fname);
2461
- //
2462
- // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
2463
- //
2464
- // FILE * f = fopen(fname, "wb");
2465
- // fseek(f, wsp_gguf_get_meta_size(ctx), SEEK_SET);
2466
- // fwrite(f, ...);
2467
- // void * data = wsp_gguf_meta_get_meta_data(ctx);
2468
- // fseek(f, 0, SEEK_SET);
2469
- // fwrite(f, data, wsp_gguf_get_meta_size(ctx));
2470
- // free(data);
2471
- // fclose(f);
2472
- //
2473
-
2474
- // write the entire context to a binary file
2475
- WSP_GGML_API void wsp_gguf_write_to_file(const struct wsp_gguf_context * ctx, const char * fname, bool only_meta);
2476
-
2477
- // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
2478
- WSP_GGML_API size_t wsp_gguf_get_meta_size(const struct wsp_gguf_context * ctx);
2479
- WSP_GGML_API void wsp_gguf_get_meta_data(const struct wsp_gguf_context * ctx, void * data);
2480
-
2481
- //
2482
- // system info
2483
- //
2484
-
2485
- WSP_GGML_API int wsp_ggml_cpu_has_avx (void);
2486
- WSP_GGML_API int wsp_ggml_cpu_has_avx_vnni (void);
2487
- WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void);
2488
- WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void);
2489
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void);
2490
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void);
2491
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_bf16(void);
2492
- WSP_GGML_API int wsp_ggml_cpu_has_amx_int8 (void);
2493
- WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
2494
- WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
2495
- WSP_GGML_API int wsp_ggml_cpu_has_sve (void);
2496
- WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
2497
- WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
2498
- WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
2499
- WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
2500
- WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
2501
- WSP_GGML_API int wsp_ggml_cpu_has_blas (void);
2502
- WSP_GGML_API int wsp_ggml_cpu_has_cuda (void);
2503
- WSP_GGML_API int wsp_ggml_cpu_has_vulkan (void);
2504
- WSP_GGML_API int wsp_ggml_cpu_has_kompute (void);
2505
- WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void);
2506
- WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void);
2507
- WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void);
2508
- WSP_GGML_API int wsp_ggml_cpu_has_riscv_v (void);
2509
- WSP_GGML_API int wsp_ggml_cpu_has_sycl (void);
2510
- WSP_GGML_API int wsp_ggml_cpu_has_rpc (void);
2511
- WSP_GGML_API int wsp_ggml_cpu_has_vsx (void);
2512
- WSP_GGML_API int wsp_ggml_cpu_has_matmul_int8(void);
2513
- WSP_GGML_API int wsp_ggml_cpu_has_cann (void);
2514
- WSP_GGML_API int wsp_ggml_cpu_has_llamafile (void);
2515
-
2516
- // get the sve vector length in bytes
2517
- WSP_GGML_API int wsp_ggml_cpu_get_sve_cnt(void);
2518
-
2519
- //
2520
- // Internal types and functions exposed for tests and benchmarks
2521
- //
2522
-
2523
- #ifdef __cplusplus
2524
- // restrict not standard in C++
2525
- #define WSP_GGML_RESTRICT
2154
+ #ifdef __cplusplus
2155
+ // restrict not standard in C++
2156
+ # if defined(__GNUC__)
2157
+ # define WSP_GGML_RESTRICT __restrict__
2158
+ # elif defined(__clang__)
2159
+ # define WSP_GGML_RESTRICT __restrict
2160
+ # elif defined(_MSC_VER)
2161
+ # define WSP_GGML_RESTRICT __restrict
2162
+ # else
2163
+ # define WSP_GGML_RESTRICT
2164
+ # endif
2526
2165
  #else
2527
- #define WSP_GGML_RESTRICT restrict
2166
+ # if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L)
2167
+ # define WSP_GGML_RESTRICT __restrict
2168
+ # else
2169
+ # define WSP_GGML_RESTRICT restrict
2170
+ # endif
2528
2171
  #endif
2529
2172
  typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
2530
2173
  typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t k);
2531
- typedef void (*wsp_ggml_from_float_to_mat_t)
2532
- (const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
2533
- typedef void (*wsp_ggml_vec_dot_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x, size_t bx,
2534
- const void * WSP_GGML_RESTRICT y, size_t by, int nrc);
2535
- typedef void (*wsp_ggml_gemv_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x,
2536
- const void * WSP_GGML_RESTRICT y, int nr, int nc);
2537
- typedef void (*wsp_ggml_gemm_t) (int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT x,
2538
- const void * WSP_GGML_RESTRICT y, int nr, int nc);
2539
2174
 
2540
2175
  struct wsp_ggml_type_traits {
2541
2176
  const char * type_name;
@@ -2544,19 +2179,43 @@ extern "C" {
2544
2179
  size_t type_size;
2545
2180
  bool is_quantized;
2546
2181
  wsp_ggml_to_float_t to_float;
2547
- wsp_ggml_from_float_t from_float;
2548
2182
  wsp_ggml_from_float_t from_float_ref;
2549
- wsp_ggml_from_float_to_mat_t from_float_to_mat;
2550
- wsp_ggml_vec_dot_t vec_dot;
2551
- enum wsp_ggml_type vec_dot_type;
2552
- int64_t nrows; // number of rows to process simultaneously
2553
- int64_t ncols; // number of columns to process simultaneously
2554
- wsp_ggml_gemv_t gemv;
2555
- wsp_ggml_gemm_t gemm;
2556
2183
  };
2557
2184
 
2558
2185
  WSP_GGML_API const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type);
2559
2186
 
2187
+ // ggml threadpool
2188
+ // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend
2189
+ // the goal should be to create an API that other backends can use move everything to the ggml base
2190
+
2191
+ // scheduling priorities
2192
+ enum wsp_ggml_sched_priority {
2193
+ WSP_GGML_SCHED_PRIO_LOW = -1,
2194
+ WSP_GGML_SCHED_PRIO_NORMAL,
2195
+ WSP_GGML_SCHED_PRIO_MEDIUM,
2196
+ WSP_GGML_SCHED_PRIO_HIGH,
2197
+ WSP_GGML_SCHED_PRIO_REALTIME
2198
+ };
2199
+
2200
+ // threadpool params
2201
+ // Use wsp_ggml_threadpool_params_default() or wsp_ggml_threadpool_params_init() to populate the defaults
2202
+ struct wsp_ggml_threadpool_params {
2203
+ bool cpumask[WSP_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
2204
+ int n_threads; // number of threads
2205
+ enum wsp_ggml_sched_priority prio; // thread priority
2206
+ uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
2207
+ bool strict_cpu; // strict cpu placement
2208
+ bool paused; // start in paused state
2209
+ };
2210
+
2211
+ struct wsp_ggml_threadpool; // forward declaration, see ggml.c
2212
+
2213
+ typedef struct wsp_ggml_threadpool * wsp_ggml_threadpool_t;
2214
+
2215
+ WSP_GGML_API struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads);
2216
+ WSP_GGML_API void wsp_ggml_threadpool_params_init (struct wsp_ggml_threadpool_params * p, int n_threads);
2217
+ WSP_GGML_API bool wsp_ggml_threadpool_params_match (const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1);
2218
+
2560
2219
  #ifdef __cplusplus
2561
2220
  }
2562
2221
  #endif