whisper.rn 0.4.0-rc.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. package/README.md +5 -1
  2. package/android/build.gradle +12 -3
  3. package/android/src/main/CMakeLists.txt +44 -13
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -12
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +75 -34
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -38
  7. package/android/src/main/jni.cpp +38 -1
  8. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  14. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  15. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  16. package/cpp/coreml/whisper-compat.h +10 -0
  17. package/cpp/coreml/whisper-compat.m +35 -0
  18. package/cpp/coreml/whisper-decoder-impl.h +27 -15
  19. package/cpp/coreml/whisper-decoder-impl.m +36 -10
  20. package/cpp/coreml/whisper-encoder-impl.h +21 -9
  21. package/cpp/coreml/whisper-encoder-impl.m +29 -3
  22. package/cpp/ggml-alloc.c +727 -517
  23. package/cpp/ggml-alloc.h +47 -65
  24. package/cpp/ggml-backend-impl.h +196 -57
  25. package/cpp/ggml-backend-reg.cpp +591 -0
  26. package/cpp/ggml-backend.cpp +2016 -0
  27. package/cpp/ggml-backend.h +234 -89
  28. package/cpp/ggml-common.h +1861 -0
  29. package/cpp/ggml-cpp.h +39 -0
  30. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  31. package/cpp/ggml-cpu/amx/amx.h +8 -0
  32. package/cpp/ggml-cpu/amx/common.h +91 -0
  33. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  34. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  35. package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  36. package/cpp/ggml-cpu/arch/arm/quants.c +4113 -0
  37. package/cpp/ggml-cpu/arch/arm/repack.cpp +2162 -0
  38. package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
  39. package/cpp/ggml-cpu/arch/x86/quants.c +4310 -0
  40. package/cpp/ggml-cpu/arch/x86/repack.cpp +3284 -0
  41. package/cpp/ggml-cpu/arch-fallback.h +184 -0
  42. package/cpp/ggml-cpu/binary-ops.cpp +158 -0
  43. package/cpp/ggml-cpu/binary-ops.h +16 -0
  44. package/cpp/ggml-cpu/common.h +72 -0
  45. package/cpp/ggml-cpu/ggml-cpu-impl.h +511 -0
  46. package/cpp/ggml-cpu/ggml-cpu.c +3473 -0
  47. package/cpp/ggml-cpu/ggml-cpu.cpp +671 -0
  48. package/cpp/ggml-cpu/ops.cpp +9085 -0
  49. package/cpp/ggml-cpu/ops.h +111 -0
  50. package/cpp/ggml-cpu/quants.c +1157 -0
  51. package/cpp/ggml-cpu/quants.h +89 -0
  52. package/cpp/ggml-cpu/repack.cpp +1570 -0
  53. package/cpp/ggml-cpu/repack.h +98 -0
  54. package/cpp/ggml-cpu/simd-mappings.h +1006 -0
  55. package/cpp/ggml-cpu/traits.cpp +36 -0
  56. package/cpp/ggml-cpu/traits.h +38 -0
  57. package/cpp/ggml-cpu/unary-ops.cpp +186 -0
  58. package/cpp/ggml-cpu/unary-ops.h +28 -0
  59. package/cpp/ggml-cpu/vec.cpp +321 -0
  60. package/cpp/ggml-cpu/vec.h +973 -0
  61. package/cpp/ggml-cpu.h +143 -0
  62. package/cpp/ggml-impl.h +525 -168
  63. package/cpp/ggml-metal-impl.h +622 -0
  64. package/cpp/ggml-metal.h +16 -14
  65. package/cpp/ggml-metal.m +5289 -1859
  66. package/cpp/ggml-opt.cpp +1037 -0
  67. package/cpp/ggml-opt.h +237 -0
  68. package/cpp/ggml-quants.c +2916 -6877
  69. package/cpp/ggml-quants.h +87 -249
  70. package/cpp/ggml-threading.cpp +12 -0
  71. package/cpp/ggml-threading.h +14 -0
  72. package/cpp/ggml-whisper-sim.metallib +0 -0
  73. package/cpp/ggml-whisper.metallib +0 -0
  74. package/cpp/ggml.c +3293 -16770
  75. package/cpp/ggml.h +778 -835
  76. package/cpp/gguf.cpp +1347 -0
  77. package/cpp/gguf.h +202 -0
  78. package/cpp/rn-whisper.cpp +84 -0
  79. package/cpp/rn-whisper.h +2 -0
  80. package/cpp/whisper-arch.h +197 -0
  81. package/cpp/whisper.cpp +3240 -944
  82. package/cpp/whisper.h +144 -31
  83. package/ios/CMakeLists.txt +95 -0
  84. package/ios/RNWhisper.h +5 -0
  85. package/ios/RNWhisper.mm +124 -37
  86. package/ios/RNWhisperAudioUtils.h +1 -0
  87. package/ios/RNWhisperAudioUtils.m +24 -13
  88. package/ios/RNWhisperContext.h +8 -2
  89. package/ios/RNWhisperContext.mm +42 -8
  90. package/ios/rnwhisper.xcframework/Info.plist +74 -0
  91. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  92. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  93. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  94. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  95. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  96. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  97. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  98. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  99. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  100. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  101. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  102. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  103. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  104. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  105. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  106. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  107. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  108. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  109. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  110. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  111. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  112. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  113. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  114. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  115. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  116. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  117. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  118. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  119. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  120. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  121. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  122. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  123. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  124. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  125. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  126. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  127. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  128. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  129. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  130. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  131. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  132. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  133. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  134. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  135. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  136. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  137. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  138. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  139. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  140. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  141. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  142. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  143. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  144. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  145. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  146. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  147. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  148. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +2221 -0
  149. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/gguf.h +202 -0
  150. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  151. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  152. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  153. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  154. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +739 -0
  155. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  156. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  157. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  158. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-alloc.h +76 -0
  159. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +255 -0
  160. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +354 -0
  161. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +1861 -0
  162. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpp.h +39 -0
  163. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +143 -0
  164. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +603 -0
  165. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +622 -0
  166. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +66 -0
  167. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +237 -0
  168. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +100 -0
  169. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-threading.h +14 -0
  170. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +2221 -0
  171. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/gguf.h +202 -0
  172. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-audioutils.h +14 -0
  173. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper-log.h +11 -0
  174. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +52 -0
  175. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper-arch.h +197 -0
  176. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +739 -0
  177. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  178. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +101 -0
  179. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  180. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  181. package/jest/mock.js +14 -1
  182. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  183. package/lib/commonjs/index.js +48 -19
  184. package/lib/commonjs/index.js.map +1 -1
  185. package/lib/commonjs/version.json +1 -1
  186. package/lib/module/NativeRNWhisper.js.map +1 -1
  187. package/lib/module/index.js +48 -19
  188. package/lib/module/index.js.map +1 -1
  189. package/lib/module/version.json +1 -1
  190. package/lib/typescript/NativeRNWhisper.d.ts +6 -3
  191. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  192. package/lib/typescript/index.d.ts +25 -3
  193. package/lib/typescript/index.d.ts.map +1 -1
  194. package/package.json +15 -10
  195. package/src/NativeRNWhisper.ts +12 -3
  196. package/src/index.ts +63 -24
  197. package/src/version.json +1 -1
  198. package/whisper-rn.podspec +18 -18
  199. package/cpp/README.md +0 -4
  200. package/cpp/ggml-backend.c +0 -1718
  201. package/cpp/ggml-metal-whisper.metal +0 -5820
package/cpp/ggml.h CHANGED
@@ -176,25 +176,15 @@
176
176
  #ifdef WSP_GGML_SHARED
177
177
  # if defined(_WIN32) && !defined(__MINGW32__)
178
178
  # ifdef WSP_GGML_BUILD
179
- # define WSP_GGML_API __declspec(dllexport)
179
+ # define WSP_GGML_API __declspec(dllexport) extern
180
180
  # else
181
- # define WSP_GGML_API __declspec(dllimport)
181
+ # define WSP_GGML_API __declspec(dllimport) extern
182
182
  # endif
183
183
  # else
184
- # define WSP_GGML_API __attribute__ ((visibility ("default")))
184
+ # define WSP_GGML_API __attribute__ ((visibility ("default"))) extern
185
185
  # endif
186
186
  #else
187
- # define WSP_GGML_API
188
- #endif
189
-
190
- #ifdef WSP_GGML_MULTIPLATFORM
191
- # if defined(_WIN32)
192
- # define WSP_GGML_CALL
193
- # else
194
- # define WSP_GGML_CALL __attribute__((__ms_abi__))
195
- # endif
196
- #else
197
- # define WSP_GGML_CALL
187
+ # define WSP_GGML_API extern
198
188
  #endif
199
189
 
200
190
  // TODO: support for clang
@@ -208,32 +198,36 @@
208
198
 
209
199
  #ifndef __GNUC__
210
200
  # define WSP_GGML_ATTRIBUTE_FORMAT(...)
211
- #elif defined(__MINGW32__)
201
+ #elif defined(__MINGW32__) && !defined(__clang__)
212
202
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
213
203
  #else
214
204
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
215
205
  #endif
216
206
 
217
- #include <stdint.h>
218
- #include <stddef.h>
219
207
  #include <stdbool.h>
208
+ #include <stddef.h>
209
+ #include <stdint.h>
210
+ #include <stdio.h>
220
211
 
221
212
  #define WSP_GGML_FILE_MAGIC 0x67676d6c // "ggml"
222
- #define WSP_GGML_FILE_VERSION 1
213
+ #define WSP_GGML_FILE_VERSION 2
223
214
 
224
215
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
225
216
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
226
217
 
227
218
  #define WSP_GGML_MAX_DIMS 4
228
219
  #define WSP_GGML_MAX_PARAMS 2048
229
- #define WSP_GGML_MAX_CONTEXTS 64
230
220
  #define WSP_GGML_MAX_SRC 10
221
+ #define WSP_GGML_MAX_N_THREADS 512
222
+ #define WSP_GGML_MAX_OP_PARAMS 64
223
+
231
224
  #ifndef WSP_GGML_MAX_NAME
232
- #define WSP_GGML_MAX_NAME 64
225
+ # define WSP_GGML_MAX_NAME 64
233
226
  #endif
234
- #define WSP_GGML_MAX_OP_PARAMS 64
227
+
235
228
  #define WSP_GGML_DEFAULT_N_THREADS 4
236
229
  #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
230
+
237
231
  #if UINTPTR_MAX == 0xFFFFFFFF
238
232
  #define WSP_GGML_MEM_ALIGN 4
239
233
  #else
@@ -243,36 +237,35 @@
243
237
  #define WSP_GGML_EXIT_SUCCESS 0
244
238
  #define WSP_GGML_EXIT_ABORTED 1
245
239
 
246
- #define WSP_GGUF_MAGIC "GGUF"
247
-
248
- #define WSP_GGUF_VERSION 3
249
-
250
- #define WSP_GGUF_DEFAULT_ALIGNMENT 32
240
+ #define WSP_GGML_ROPE_TYPE_NEOX 2
241
+ #define WSP_GGML_ROPE_TYPE_MROPE 8
242
+ #define WSP_GGML_ROPE_TYPE_VISION 24
251
243
 
252
244
  #define WSP_GGML_UNUSED(x) (void)(x)
253
245
 
254
246
  #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
255
247
 
256
- #define WSP_GGML_ASSERT(x) \
257
- do { \
258
- if (!(x)) { \
259
- fflush(stdout); \
260
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
261
- wsp_ggml_print_backtrace(); \
262
- abort(); \
263
- } \
264
- } while (0)
265
-
266
248
  #ifndef NDEBUG
267
- #define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached")
249
+ # define WSP_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
268
250
  #elif defined(__GNUC__)
269
- #define WSP_GGML_UNREACHABLE() __builtin_unreachable()
251
+ # define WSP_GGML_UNREACHABLE() __builtin_unreachable()
252
+ #elif defined(_MSC_VER)
253
+ # define WSP_GGML_UNREACHABLE() __assume(0)
254
+ #else
255
+ # define WSP_GGML_UNREACHABLE() ((void) 0)
256
+ #endif
257
+
258
+ #ifdef __cplusplus
259
+ # define WSP_GGML_NORETURN [[noreturn]]
270
260
  #elif defined(_MSC_VER)
271
- #define WSP_GGML_UNREACHABLE() __assume(0)
261
+ # define WSP_GGML_NORETURN __declspec(noreturn)
272
262
  #else
273
- #define WSP_GGML_UNREACHABLE() ((void) 0)
263
+ # define WSP_GGML_NORETURN _Noreturn
274
264
  #endif
275
265
 
266
+ #define WSP_GGML_ABORT(...) wsp_ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
267
+ #define WSP_GGML_ASSERT(x) if (!(x)) WSP_GGML_ABORT("WSP_GGML_ASSERT(%s) failed", #x)
268
+
276
269
  // used to copy the number of elements and stride in bytes of tensors into local variables.
277
270
  // main purpose is to reduce code duplication and improve readability.
278
271
  //
@@ -311,84 +304,125 @@
311
304
  WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
312
305
  WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
313
306
 
307
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS01 \
308
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
310
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
311
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
312
+
314
313
  #ifdef __cplusplus
315
314
  extern "C" {
316
315
  #endif
317
316
 
318
- #if defined(__ARM_NEON) && defined(__CUDACC__)
319
- typedef half wsp_ggml_fp16_t;
320
- #elif defined(__ARM_NEON) && !defined(_MSC_VER)
321
- typedef __fp16 wsp_ggml_fp16_t;
322
- #else
323
- typedef uint16_t wsp_ggml_fp16_t;
324
- #endif
317
+ WSP_GGML_NORETURN WSP_GGML_ATTRIBUTE_FORMAT(3, 4)
318
+ WSP_GGML_API void wsp_ggml_abort(const char * file, int line, const char * fmt, ...);
319
+
320
+ enum wsp_ggml_status {
321
+ WSP_GGML_STATUS_ALLOC_FAILED = -2,
322
+ WSP_GGML_STATUS_FAILED = -1,
323
+ WSP_GGML_STATUS_SUCCESS = 0,
324
+ WSP_GGML_STATUS_ABORTED = 1,
325
+ };
325
326
 
326
- // convert FP16 <-> FP32
327
- WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x);
328
- WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x);
327
+ // get wsp_ggml_status name string
328
+ WSP_GGML_API const char * wsp_ggml_status_to_string(enum wsp_ggml_status status);
329
329
 
330
- WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, int n);
331
- WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, int n);
330
+ // ieee 754-2008 half-precision float16
331
+ // todo: make this not an integral type
332
+ typedef uint16_t wsp_ggml_fp16_t;
333
+ WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t);
334
+ WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float);
335
+ WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t *, float *, int64_t);
336
+ WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float *, wsp_ggml_fp16_t *, int64_t);
337
+
338
+ // google brain half-precision bfloat16
339
+ typedef struct { uint16_t bits; } wsp_ggml_bf16_t;
340
+ WSP_GGML_API wsp_ggml_bf16_t wsp_ggml_fp32_to_bf16(float);
341
+ WSP_GGML_API float wsp_ggml_bf16_to_fp32(wsp_ggml_bf16_t); // consider just doing << 16
342
+ WSP_GGML_API void wsp_ggml_bf16_to_fp32_row(const wsp_ggml_bf16_t *, float *, int64_t);
343
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row_ref(const float *, wsp_ggml_bf16_t *, int64_t);
344
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row(const float *, wsp_ggml_bf16_t *, int64_t);
332
345
 
333
346
  struct wsp_ggml_object;
334
347
  struct wsp_ggml_context;
348
+ struct wsp_ggml_cgraph;
335
349
 
350
+ // NOTE: always add types at the end of the enum to keep backward compatibility
336
351
  enum wsp_ggml_type {
337
- WSP_GGML_TYPE_F32 = 0,
338
- WSP_GGML_TYPE_F16 = 1,
339
- WSP_GGML_TYPE_Q4_0 = 2,
340
- WSP_GGML_TYPE_Q4_1 = 3,
352
+ WSP_GGML_TYPE_F32 = 0,
353
+ WSP_GGML_TYPE_F16 = 1,
354
+ WSP_GGML_TYPE_Q4_0 = 2,
355
+ WSP_GGML_TYPE_Q4_1 = 3,
341
356
  // WSP_GGML_TYPE_Q4_2 = 4, support has been removed
342
- // WSP_GGML_TYPE_Q4_3 (5) support has been removed
343
- WSP_GGML_TYPE_Q5_0 = 6,
344
- WSP_GGML_TYPE_Q5_1 = 7,
345
- WSP_GGML_TYPE_Q8_0 = 8,
346
- WSP_GGML_TYPE_Q8_1 = 9,
347
- // k-quantizations
348
- WSP_GGML_TYPE_Q2_K = 10,
349
- WSP_GGML_TYPE_Q3_K = 11,
350
- WSP_GGML_TYPE_Q4_K = 12,
351
- WSP_GGML_TYPE_Q5_K = 13,
352
- WSP_GGML_TYPE_Q6_K = 14,
353
- WSP_GGML_TYPE_Q8_K = 15,
357
+ // WSP_GGML_TYPE_Q4_3 = 5, support has been removed
358
+ WSP_GGML_TYPE_Q5_0 = 6,
359
+ WSP_GGML_TYPE_Q5_1 = 7,
360
+ WSP_GGML_TYPE_Q8_0 = 8,
361
+ WSP_GGML_TYPE_Q8_1 = 9,
362
+ WSP_GGML_TYPE_Q2_K = 10,
363
+ WSP_GGML_TYPE_Q3_K = 11,
364
+ WSP_GGML_TYPE_Q4_K = 12,
365
+ WSP_GGML_TYPE_Q5_K = 13,
366
+ WSP_GGML_TYPE_Q6_K = 14,
367
+ WSP_GGML_TYPE_Q8_K = 15,
354
368
  WSP_GGML_TYPE_IQ2_XXS = 16,
355
369
  WSP_GGML_TYPE_IQ2_XS = 17,
356
- WSP_GGML_TYPE_I8,
357
- WSP_GGML_TYPE_I16,
358
- WSP_GGML_TYPE_I32,
359
- WSP_GGML_TYPE_COUNT,
370
+ WSP_GGML_TYPE_IQ3_XXS = 18,
371
+ WSP_GGML_TYPE_IQ1_S = 19,
372
+ WSP_GGML_TYPE_IQ4_NL = 20,
373
+ WSP_GGML_TYPE_IQ3_S = 21,
374
+ WSP_GGML_TYPE_IQ2_S = 22,
375
+ WSP_GGML_TYPE_IQ4_XS = 23,
376
+ WSP_GGML_TYPE_I8 = 24,
377
+ WSP_GGML_TYPE_I16 = 25,
378
+ WSP_GGML_TYPE_I32 = 26,
379
+ WSP_GGML_TYPE_I64 = 27,
380
+ WSP_GGML_TYPE_F64 = 28,
381
+ WSP_GGML_TYPE_IQ1_M = 29,
382
+ WSP_GGML_TYPE_BF16 = 30,
383
+ // WSP_GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
384
+ // WSP_GGML_TYPE_Q4_0_4_8 = 32,
385
+ // WSP_GGML_TYPE_Q4_0_8_8 = 33,
386
+ WSP_GGML_TYPE_TQ1_0 = 34,
387
+ WSP_GGML_TYPE_TQ2_0 = 35,
388
+ // WSP_GGML_TYPE_IQ4_NL_4_4 = 36,
389
+ // WSP_GGML_TYPE_IQ4_NL_4_8 = 37,
390
+ // WSP_GGML_TYPE_IQ4_NL_8_8 = 38,
391
+ WSP_GGML_TYPE_COUNT = 39,
360
392
  };
361
393
 
362
394
  // precision
363
395
  enum wsp_ggml_prec {
364
- WSP_GGML_PREC_DEFAULT,
365
- WSP_GGML_PREC_F32,
366
- };
367
-
368
- enum wsp_ggml_backend_type {
369
- WSP_GGML_BACKEND_CPU = 0,
370
- WSP_GGML_BACKEND_GPU = 10,
371
- WSP_GGML_BACKEND_GPU_SPLIT = 20,
396
+ WSP_GGML_PREC_DEFAULT = 0, // stored as wsp_ggml_tensor.op_params, 0 by default
397
+ WSP_GGML_PREC_F32 = 10,
372
398
  };
373
399
 
374
400
  // model file types
375
401
  enum wsp_ggml_ftype {
376
- WSP_GGML_FTYPE_UNKNOWN = -1,
377
- WSP_GGML_FTYPE_ALL_F32 = 0,
378
- WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
379
- WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
380
- WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
402
+ WSP_GGML_FTYPE_UNKNOWN = -1,
403
+ WSP_GGML_FTYPE_ALL_F32 = 0,
404
+ WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
405
+ WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
406
+ WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
381
407
  WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
382
- WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
383
- WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
384
- WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
385
- WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
386
- WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
387
- WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
388
- WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
389
- WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
408
+ WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
409
+ WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
410
+ WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
411
+ WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
412
+ WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
413
+ WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
414
+ WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
415
+ WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
390
416
  WSP_GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
391
417
  WSP_GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
418
+ WSP_GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
419
+ WSP_GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
420
+ WSP_GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
421
+ WSP_GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
422
+ WSP_GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
423
+ WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
424
+ WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
425
+ WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
392
426
  };
393
427
 
394
428
  // available tensor operations:
@@ -405,10 +439,13 @@ extern "C" {
405
439
  WSP_GGML_OP_SQR,
406
440
  WSP_GGML_OP_SQRT,
407
441
  WSP_GGML_OP_LOG,
442
+ WSP_GGML_OP_SIN,
443
+ WSP_GGML_OP_COS,
408
444
  WSP_GGML_OP_SUM,
409
445
  WSP_GGML_OP_SUM_ROWS,
410
446
  WSP_GGML_OP_MEAN,
411
447
  WSP_GGML_OP_ARGMAX,
448
+ WSP_GGML_OP_COUNT_EQUAL,
412
449
  WSP_GGML_OP_REPEAT,
413
450
  WSP_GGML_OP_REPEAT_BACK,
414
451
  WSP_GGML_OP_CONCAT,
@@ -417,6 +454,7 @@ extern "C" {
417
454
  WSP_GGML_OP_RMS_NORM,
418
455
  WSP_GGML_OP_RMS_NORM_BACK,
419
456
  WSP_GGML_OP_GROUP_NORM,
457
+ WSP_GGML_OP_L2_NORM,
420
458
 
421
459
  WSP_GGML_OP_MUL_MAT,
422
460
  WSP_GGML_OP_MUL_MAT_ID,
@@ -439,41 +477,47 @@ extern "C" {
439
477
  WSP_GGML_OP_SOFT_MAX_BACK,
440
478
  WSP_GGML_OP_ROPE,
441
479
  WSP_GGML_OP_ROPE_BACK,
442
- WSP_GGML_OP_ALIBI,
443
480
  WSP_GGML_OP_CLAMP,
444
481
  WSP_GGML_OP_CONV_TRANSPOSE_1D,
445
482
  WSP_GGML_OP_IM2COL,
483
+ WSP_GGML_OP_IM2COL_BACK,
484
+ WSP_GGML_OP_CONV_2D_DW,
446
485
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
447
486
  WSP_GGML_OP_POOL_1D,
448
487
  WSP_GGML_OP_POOL_2D,
488
+ WSP_GGML_OP_POOL_2D_BACK,
449
489
  WSP_GGML_OP_UPSCALE, // nearest interpolate
450
490
  WSP_GGML_OP_PAD,
491
+ WSP_GGML_OP_PAD_REFLECT_1D,
492
+ WSP_GGML_OP_ROLL,
493
+ WSP_GGML_OP_ARANGE,
494
+ WSP_GGML_OP_TIMESTEP_EMBEDDING,
451
495
  WSP_GGML_OP_ARGSORT,
452
496
  WSP_GGML_OP_LEAKY_RELU,
453
497
 
454
- WSP_GGML_OP_FLASH_ATTN,
455
- WSP_GGML_OP_FLASH_FF,
498
+ WSP_GGML_OP_FLASH_ATTN_EXT,
456
499
  WSP_GGML_OP_FLASH_ATTN_BACK,
500
+ WSP_GGML_OP_SSM_CONV,
501
+ WSP_GGML_OP_SSM_SCAN,
457
502
  WSP_GGML_OP_WIN_PART,
458
503
  WSP_GGML_OP_WIN_UNPART,
459
504
  WSP_GGML_OP_GET_REL_POS,
460
505
  WSP_GGML_OP_ADD_REL_POS,
506
+ WSP_GGML_OP_RWKV_WKV6,
507
+ WSP_GGML_OP_GATED_LINEAR_ATTN,
508
+ WSP_GGML_OP_RWKV_WKV7,
461
509
 
462
510
  WSP_GGML_OP_UNARY,
463
511
 
464
- WSP_GGML_OP_MAP_UNARY,
465
- WSP_GGML_OP_MAP_BINARY,
466
-
467
- WSP_GGML_OP_MAP_CUSTOM1_F32,
468
- WSP_GGML_OP_MAP_CUSTOM2_F32,
469
- WSP_GGML_OP_MAP_CUSTOM3_F32,
470
-
471
512
  WSP_GGML_OP_MAP_CUSTOM1,
472
513
  WSP_GGML_OP_MAP_CUSTOM2,
473
514
  WSP_GGML_OP_MAP_CUSTOM3,
474
515
 
516
+ WSP_GGML_OP_CUSTOM,
517
+
475
518
  WSP_GGML_OP_CROSS_ENTROPY_LOSS,
476
519
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
520
+ WSP_GGML_OP_OPT_STEP_ADAMW,
477
521
 
478
522
  WSP_GGML_OP_COUNT,
479
523
  };
@@ -486,44 +530,51 @@ extern "C" {
486
530
  WSP_GGML_UNARY_OP_TANH,
487
531
  WSP_GGML_UNARY_OP_ELU,
488
532
  WSP_GGML_UNARY_OP_RELU,
533
+ WSP_GGML_UNARY_OP_SIGMOID,
489
534
  WSP_GGML_UNARY_OP_GELU,
490
535
  WSP_GGML_UNARY_OP_GELU_QUICK,
491
536
  WSP_GGML_UNARY_OP_SILU,
537
+ WSP_GGML_UNARY_OP_HARDSWISH,
538
+ WSP_GGML_UNARY_OP_HARDSIGMOID,
539
+ WSP_GGML_UNARY_OP_EXP,
540
+ WSP_GGML_UNARY_OP_GELU_ERF,
492
541
 
493
542
  WSP_GGML_UNARY_OP_COUNT,
494
543
  };
495
544
 
496
545
  enum wsp_ggml_object_type {
497
- WSP_GGML_OBJECT_TENSOR,
498
- WSP_GGML_OBJECT_GRAPH,
499
- WSP_GGML_OBJECT_WORK_BUFFER
546
+ WSP_GGML_OBJECT_TYPE_TENSOR,
547
+ WSP_GGML_OBJECT_TYPE_GRAPH,
548
+ WSP_GGML_OBJECT_TYPE_WORK_BUFFER
500
549
  };
501
550
 
502
551
  enum wsp_ggml_log_level {
503
- WSP_GGML_LOG_LEVEL_ERROR = 2,
504
- WSP_GGML_LOG_LEVEL_WARN = 3,
505
- WSP_GGML_LOG_LEVEL_INFO = 4,
506
- WSP_GGML_LOG_LEVEL_DEBUG = 5
552
+ WSP_GGML_LOG_LEVEL_NONE = 0,
553
+ WSP_GGML_LOG_LEVEL_DEBUG = 1,
554
+ WSP_GGML_LOG_LEVEL_INFO = 2,
555
+ WSP_GGML_LOG_LEVEL_WARN = 3,
556
+ WSP_GGML_LOG_LEVEL_ERROR = 4,
557
+ WSP_GGML_LOG_LEVEL_CONT = 5, // continue previous log
507
558
  };
508
559
 
509
- // ggml object
510
- struct wsp_ggml_object {
511
- size_t offs;
512
- size_t size;
513
-
514
- struct wsp_ggml_object * next;
515
-
516
- enum wsp_ggml_object_type type;
517
-
518
- char padding[4];
560
+ // this tensor...
561
+ enum wsp_ggml_tensor_flag {
562
+ WSP_GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
563
+ WSP_GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
564
+ WSP_GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
565
+ WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
519
566
  };
520
567
 
521
- static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object);
568
+ struct wsp_ggml_init_params {
569
+ // memory pool
570
+ size_t mem_size; // bytes
571
+ void * mem_buffer; // if NULL, memory will be allocated internally
572
+ bool no_alloc; // don't allocate memory for the tensor data
573
+ };
522
574
 
523
575
  // n-dimensional tensor
524
576
  struct wsp_ggml_tensor {
525
- enum wsp_ggml_type type;
526
- enum wsp_ggml_backend_type backend;
577
+ enum wsp_ggml_type type;
527
578
 
528
579
  struct wsp_ggml_backend_buffer * buffer;
529
580
 
@@ -539,16 +590,11 @@ extern "C" {
539
590
  // op params - allocated as int32_t for alignment
540
591
  int32_t op_params[WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t)];
541
592
 
542
- bool is_param;
593
+ int32_t flags;
543
594
 
544
- struct wsp_ggml_tensor * grad;
545
595
  struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
546
596
 
547
- // performance
548
- int perf_runs;
549
- int64_t perf_cycles;
550
- int64_t perf_time_us;
551
-
597
+ // source tensor and offset for views
552
598
  struct wsp_ggml_tensor * view_src;
553
599
  size_t view_offs;
554
600
 
@@ -563,85 +609,21 @@ extern "C" {
563
609
 
564
610
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
565
611
 
566
- // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
567
- // since https://github.com/ggerganov/ggml/issues/287
568
- struct wsp_ggml_cplan {
569
- size_t work_size; // size of work buffer, calculated by `wsp_ggml_graph_plan()`
570
- uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
612
+ // Abort callback
613
+ // If not NULL, called before ggml computation
614
+ // If it returns true, the computation is aborted
615
+ typedef bool (*wsp_ggml_abort_callback)(void * data);
571
616
 
572
- int n_threads;
573
617
 
574
- // abort wsp_ggml_graph_compute when true
575
- bool (*abort_callback)(void * data);
576
- void * abort_callback_data;
577
- };
578
-
579
- enum wsp_ggml_cgraph_eval_order {
580
- WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
581
- WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
582
- WSP_GGML_CGRAPH_EVAL_ORDER_COUNT
583
- };
584
-
585
- struct wsp_ggml_hash_set {
586
- size_t size;
587
- struct wsp_ggml_tensor ** keys;
588
- };
589
-
590
- // computation graph
591
- struct wsp_ggml_cgraph {
592
- int size;
593
- int n_nodes;
594
- int n_leafs;
595
-
596
- struct wsp_ggml_tensor ** nodes;
597
- struct wsp_ggml_tensor ** grads;
598
- struct wsp_ggml_tensor ** leafs;
599
-
600
- struct wsp_ggml_hash_set visited_hash_table;
601
-
602
- enum wsp_ggml_cgraph_eval_order order;
603
-
604
- // performance
605
- int perf_runs;
606
- int64_t perf_cycles;
607
- int64_t perf_time_us;
608
- };
609
-
610
- // scratch buffer
611
- struct wsp_ggml_scratch {
612
- size_t offs;
613
- size_t size;
614
- void * data;
615
- };
616
-
617
- struct wsp_ggml_init_params {
618
- // memory pool
619
- size_t mem_size; // bytes
620
- void * mem_buffer; // if NULL, memory will be allocated internally
621
- bool no_alloc; // don't allocate memory for the tensor data
622
- };
623
-
624
-
625
- // compute types
626
-
627
- // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
628
- // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
629
- enum wsp_ggml_task_type {
630
- WSP_GGML_TASK_INIT = 0,
631
- WSP_GGML_TASK_COMPUTE,
632
- WSP_GGML_TASK_FINALIZE,
633
- };
634
-
635
- struct wsp_ggml_compute_params {
636
- enum wsp_ggml_task_type type;
618
+ //
619
+ // GUID
620
+ //
637
621
 
638
- // ith = thread index, nth = number of threads
639
- int ith, nth;
622
+ // GUID types
623
+ typedef uint8_t wsp_ggml_guid[16];
624
+ typedef wsp_ggml_guid * wsp_ggml_guid_t;
640
625
 
641
- // work buffer for all threads
642
- size_t wsize;
643
- void * wdata;
644
- };
626
+ WSP_GGML_API bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b);
645
627
 
646
628
  // misc
647
629
 
@@ -651,63 +633,78 @@ extern "C" {
651
633
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
652
634
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
653
635
 
654
- WSP_GGML_API void wsp_ggml_print_backtrace(void);
655
-
656
- WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
657
- WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
636
+ // accepts a UTF-8 path, even on Windows
637
+ WSP_GGML_API FILE * wsp_ggml_fopen(const char * fname, const char * mode);
658
638
 
659
639
  WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj);
660
640
  WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx);
661
641
 
662
- WSP_GGML_API WSP_GGML_CALL int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
663
- WSP_GGML_API WSP_GGML_CALL int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
664
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
665
- WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
642
+ WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
643
+ WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
644
+ WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
645
+ WSP_GGML_API size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
666
646
 
667
- WSP_GGML_API WSP_GGML_CALL int wsp_ggml_blck_size(enum wsp_ggml_type type);
668
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
669
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
647
+ WSP_GGML_API int64_t wsp_ggml_blck_size(enum wsp_ggml_type type);
648
+ WSP_GGML_API size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
649
+ WSP_GGML_API size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
670
650
 
671
651
  WSP_GGML_DEPRECATED(
672
652
  WSP_GGML_API double wsp_ggml_type_sizef(enum wsp_ggml_type type), // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
673
653
  "use wsp_ggml_row_size() instead");
674
654
 
675
- WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_type_name(enum wsp_ggml_type type);
676
- WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_op_name (enum wsp_ggml_op op);
677
- WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
655
+ WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type);
656
+ WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
657
+ WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
678
658
 
679
- WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
680
- WSP_GGML_API WSP_GGML_CALL const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
659
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
660
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
681
661
 
682
- WSP_GGML_API WSP_GGML_CALL size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
662
+ WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
683
663
 
684
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
664
+ WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
685
665
 
686
666
  // TODO: temporary until model loading of ggml examples is refactored
687
667
  WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype);
688
668
 
689
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
690
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor);
691
- WSP_GGML_API WSP_GGML_CALL bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
692
- WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
693
- WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
694
- WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
695
- WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
696
- WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
669
+ WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
670
+ WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
671
+ WSP_GGML_API bool wsp_ggml_is_empty (const struct wsp_ggml_tensor * tensor);
672
+ WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
673
+ WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
674
+ WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
675
+ WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
676
+ WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
697
677
 
698
- WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
678
+ // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
679
+ WSP_GGML_API bool wsp_ggml_is_contiguous (const struct wsp_ggml_tensor * tensor);
680
+ WSP_GGML_API bool wsp_ggml_is_contiguous_0(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_is_contiguous()
681
+ WSP_GGML_API bool wsp_ggml_is_contiguous_1(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 1
682
+ WSP_GGML_API bool wsp_ggml_is_contiguous_2(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 2
683
+
684
+ // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
685
+ WSP_GGML_API bool wsp_ggml_is_contiguously_allocated(const struct wsp_ggml_tensor * tensor);
686
+
687
+ // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
688
+ WSP_GGML_API bool wsp_ggml_is_contiguous_channels(const struct wsp_ggml_tensor * tensor);
689
+
690
+ WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
691
+ WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
692
+
693
+ WSP_GGML_API bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
699
694
 
700
695
  // use this to compute the memory overhead of a tensor
701
696
  WSP_GGML_API size_t wsp_ggml_tensor_overhead(void);
702
697
 
698
+ WSP_GGML_API bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size_t nbytes);
699
+
703
700
  // main
704
701
 
705
- WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params);
706
- WSP_GGML_API void wsp_ggml_free(struct wsp_ggml_context * ctx);
702
+ WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init (struct wsp_ggml_init_params params);
703
+ WSP_GGML_API void wsp_ggml_reset(struct wsp_ggml_context * ctx);
704
+ WSP_GGML_API void wsp_ggml_free (struct wsp_ggml_context * ctx);
707
705
 
708
706
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
709
707
 
710
- WSP_GGML_API size_t wsp_ggml_set_scratch (struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch);
711
708
  WSP_GGML_API bool wsp_ggml_get_no_alloc(struct wsp_ggml_context * ctx);
712
709
  WSP_GGML_API void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc);
713
710
 
@@ -747,8 +744,7 @@ extern "C" {
747
744
  int64_t ne2,
748
745
  int64_t ne3);
749
746
 
750
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value);
751
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value);
747
+ WSP_GGML_API void * wsp_ggml_new_buffer(struct wsp_ggml_context * ctx, size_t nbytes);
752
748
 
753
749
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
754
750
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
@@ -758,35 +754,25 @@ extern "C" {
758
754
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
759
755
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
760
756
 
761
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
762
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value);
763
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value);
764
-
765
757
  // Converts a flat index into coordinates
766
- WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
767
-
768
- WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i);
769
- WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value);
770
-
771
- WSP_GGML_API int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
772
- WSP_GGML_API void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
758
+ WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
773
759
 
774
- WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i);
775
- WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value);
776
-
777
- WSP_GGML_API float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3);
778
- WSP_GGML_API void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
760
+ WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
779
761
 
780
762
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
781
763
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
782
764
 
783
- WSP_GGML_API WSP_GGML_CALL enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
784
-
785
765
  WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor);
786
766
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name);
787
767
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
788
768
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name( struct wsp_ggml_tensor * tensor, const char * fmt, ...);
789
769
 
770
+ // Tensor flags
771
+ WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
772
+ WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
773
+ WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_tensor * tensor);
774
+ WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
775
+
790
776
  //
791
777
  // operations on tensors with backpropagation
792
778
  //
@@ -901,6 +887,22 @@ extern "C" {
901
887
  struct wsp_ggml_context * ctx,
902
888
  struct wsp_ggml_tensor * a);
903
889
 
890
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin(
891
+ struct wsp_ggml_context * ctx,
892
+ struct wsp_ggml_tensor * a);
893
+
894
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin_inplace(
895
+ struct wsp_ggml_context * ctx,
896
+ struct wsp_ggml_tensor * a);
897
+
898
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos(
899
+ struct wsp_ggml_context * ctx,
900
+ struct wsp_ggml_tensor * a);
901
+
902
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos_inplace(
903
+ struct wsp_ggml_context * ctx,
904
+ struct wsp_ggml_tensor * a);
905
+
904
906
  // return scalar
905
907
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum(
906
908
  struct wsp_ggml_context * ctx,
@@ -921,6 +923,12 @@ extern "C" {
921
923
  struct wsp_ggml_context * ctx,
922
924
  struct wsp_ggml_tensor * a);
923
925
 
926
+ // count number of equal elements in a and b
927
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_count_equal(
928
+ struct wsp_ggml_context * ctx,
929
+ struct wsp_ggml_tensor * a,
930
+ struct wsp_ggml_tensor * b);
931
+
924
932
  // if a is the same shape as b, and a is not parameter, return a
925
933
  // otherwise, return a new tensor: repeat(a) to fit in b
926
934
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat(
@@ -928,18 +936,28 @@ extern "C" {
928
936
  struct wsp_ggml_tensor * a,
929
937
  struct wsp_ggml_tensor * b);
930
938
 
939
+ // repeat a to the specified shape
940
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_4d(
941
+ struct wsp_ggml_context * ctx,
942
+ struct wsp_ggml_tensor * a,
943
+ int64_t ne0,
944
+ int64_t ne1,
945
+ int64_t ne2,
946
+ int64_t ne3);
947
+
931
948
  // sums repetitions in a into shape of b
932
949
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back(
933
950
  struct wsp_ggml_context * ctx,
934
951
  struct wsp_ggml_tensor * a,
935
- struct wsp_ggml_tensor * b);
952
+ struct wsp_ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
936
953
 
937
- // concat a and b on dim 2
954
+ // concat a and b along dim
938
955
  // used in stable-diffusion
939
956
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_concat(
940
957
  struct wsp_ggml_context * ctx,
941
958
  struct wsp_ggml_tensor * a,
942
- struct wsp_ggml_tensor * b);
959
+ struct wsp_ggml_tensor * b,
960
+ int dim);
943
961
 
944
962
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs(
945
963
  struct wsp_ggml_context * ctx,
@@ -1001,6 +1019,14 @@ extern "C" {
1001
1019
  struct wsp_ggml_context * ctx,
1002
1020
  struct wsp_ggml_tensor * a);
1003
1021
 
1022
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid(
1023
+ struct wsp_ggml_context * ctx,
1024
+ struct wsp_ggml_tensor * a);
1025
+
1026
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid_inplace(
1027
+ struct wsp_ggml_context * ctx,
1028
+ struct wsp_ggml_tensor * a);
1029
+
1004
1030
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu(
1005
1031
  struct wsp_ggml_context * ctx,
1006
1032
  struct wsp_ggml_tensor * a);
@@ -1009,6 +1035,16 @@ extern "C" {
1009
1035
  struct wsp_ggml_context * ctx,
1010
1036
  struct wsp_ggml_tensor * a);
1011
1037
 
1038
+ // GELU using erf (error function) when possible
1039
+ // some backends may fallback to approximation based on Abramowitz and Stegun formula
1040
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_erf(
1041
+ struct wsp_ggml_context * ctx,
1042
+ struct wsp_ggml_tensor * a);
1043
+
1044
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_erf_inplace(
1045
+ struct wsp_ggml_context * ctx,
1046
+ struct wsp_ggml_tensor * a);
1047
+
1012
1048
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu_quick(
1013
1049
  struct wsp_ggml_context * ctx,
1014
1050
  struct wsp_ggml_tensor * a);
@@ -1032,6 +1068,24 @@ extern "C" {
1032
1068
  struct wsp_ggml_tensor * a,
1033
1069
  struct wsp_ggml_tensor * b);
1034
1070
 
1071
+ // hardswish(x) = x * relu6(x + 3) / 6
1072
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardswish(
1073
+ struct wsp_ggml_context * ctx,
1074
+ struct wsp_ggml_tensor * a);
1075
+
1076
+ // hardsigmoid(x) = relu6(x + 3) / 6
1077
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardsigmoid(
1078
+ struct wsp_ggml_context * ctx,
1079
+ struct wsp_ggml_tensor * a);
1080
+
1081
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp(
1082
+ struct wsp_ggml_context * ctx,
1083
+ struct wsp_ggml_tensor * a);
1084
+
1085
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
1086
+ struct wsp_ggml_context * ctx,
1087
+ struct wsp_ggml_tensor * a);
1088
+
1035
1089
  // normalize along rows
1036
1090
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
1037
1091
  struct wsp_ggml_context * ctx,
@@ -1055,16 +1109,29 @@ extern "C" {
1055
1109
 
1056
1110
  // group normalize along ne0*ne1*n_groups
1057
1111
  // used in stable-diffusion
1058
- // TODO: eps is hardcoded to 1e-6 for now
1059
1112
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm(
1060
1113
  struct wsp_ggml_context * ctx,
1061
1114
  struct wsp_ggml_tensor * a,
1062
- int n_groups);
1115
+ int n_groups,
1116
+ float eps);
1063
1117
 
1064
1118
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm_inplace(
1065
1119
  struct wsp_ggml_context * ctx,
1066
1120
  struct wsp_ggml_tensor * a,
1067
- int n_groups);
1121
+ int n_groups,
1122
+ float eps);
1123
+
1124
+ // l2 normalize along rows
1125
+ // used in rwkv v7
1126
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_l2_norm(
1127
+ struct wsp_ggml_context * ctx,
1128
+ struct wsp_ggml_tensor * a,
1129
+ float eps);
1130
+
1131
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_l2_norm_inplace(
1132
+ struct wsp_ggml_context * ctx,
1133
+ struct wsp_ggml_tensor * a,
1134
+ float eps);
1068
1135
 
1069
1136
  // a - x
1070
1137
  // b - dy
@@ -1089,14 +1156,11 @@ extern "C" {
1089
1156
  enum wsp_ggml_prec prec);
1090
1157
 
1091
1158
  // indirect matrix multiplication
1092
- // wsp_ggml_mul_mat_id(ctx, as, ids, id, b) ~= wsp_ggml_mul_mat(as[ids[id]], b)
1093
1159
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1094
1160
  struct wsp_ggml_context * ctx,
1095
- struct wsp_ggml_tensor * const as[],
1096
- int n_as,
1097
- struct wsp_ggml_tensor * ids,
1098
- int id,
1099
- struct wsp_ggml_tensor * b);
1161
+ struct wsp_ggml_tensor * as,
1162
+ struct wsp_ggml_tensor * b,
1163
+ struct wsp_ggml_tensor * ids);
1100
1164
 
1101
1165
  // A: m columns, n rows,
1102
1166
  // B: p columns, n rows,
@@ -1129,7 +1193,7 @@ extern "C" {
1129
1193
  size_t nb1,
1130
1194
  size_t nb2,
1131
1195
  size_t nb3,
1132
- size_t offset);
1196
+ size_t offset); // in bytes
1133
1197
 
1134
1198
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1135
1199
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_inplace(
@@ -1139,19 +1203,19 @@ extern "C" {
1139
1203
  size_t nb1,
1140
1204
  size_t nb2,
1141
1205
  size_t nb3,
1142
- size_t offset);
1206
+ size_t offset); // in bytes
1143
1207
 
1144
1208
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d(
1145
1209
  struct wsp_ggml_context * ctx,
1146
1210
  struct wsp_ggml_tensor * a,
1147
1211
  struct wsp_ggml_tensor * b,
1148
- size_t offset);
1212
+ size_t offset); // in bytes
1149
1213
 
1150
1214
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace(
1151
1215
  struct wsp_ggml_context * ctx,
1152
1216
  struct wsp_ggml_tensor * a,
1153
1217
  struct wsp_ggml_tensor * b,
1154
- size_t offset);
1218
+ size_t offset); // in bytes
1155
1219
 
1156
1220
  // b -> view(a,offset,nb1,nb2,3), return modified a
1157
1221
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d(
@@ -1159,7 +1223,7 @@ extern "C" {
1159
1223
  struct wsp_ggml_tensor * a,
1160
1224
  struct wsp_ggml_tensor * b,
1161
1225
  size_t nb1,
1162
- size_t offset);
1226
+ size_t offset); // in bytes
1163
1227
 
1164
1228
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1165
1229
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
@@ -1167,7 +1231,7 @@ extern "C" {
1167
1231
  struct wsp_ggml_tensor * a,
1168
1232
  struct wsp_ggml_tensor * b,
1169
1233
  size_t nb1,
1170
- size_t offset);
1234
+ size_t offset); // in bytes
1171
1235
 
1172
1236
  // a -> b, return view(b)
1173
1237
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
@@ -1302,14 +1366,14 @@ extern "C" {
1302
1366
  // supports 3D: a->ne[2] == b->ne[1]
1303
1367
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1304
1368
  struct wsp_ggml_context * ctx,
1305
- struct wsp_ggml_tensor * a,
1306
- struct wsp_ggml_tensor * b);
1369
+ struct wsp_ggml_tensor * a, // data
1370
+ struct wsp_ggml_tensor * b); // row indices
1307
1371
 
1308
1372
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
1309
1373
  struct wsp_ggml_context * ctx,
1310
- struct wsp_ggml_tensor * a,
1311
- struct wsp_ggml_tensor * b,
1312
- struct wsp_ggml_tensor * c);
1374
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_get_rows result
1375
+ struct wsp_ggml_tensor * b, // row indices
1376
+ struct wsp_ggml_tensor * c); // data for wsp_ggml_get_rows, only used for its shape
1313
1377
 
1314
1378
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag(
1315
1379
  struct wsp_ggml_context * ctx,
@@ -1348,29 +1412,34 @@ extern "C" {
1348
1412
  struct wsp_ggml_context * ctx,
1349
1413
  struct wsp_ggml_tensor * a);
1350
1414
 
1351
- // fused soft_max(a*scale + mask)
1415
+ // fused soft_max(a*scale + mask*(ALiBi slope))
1352
1416
  // mask is optional
1417
+ // max_bias = 0.0f for no ALiBi
1353
1418
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1354
1419
  struct wsp_ggml_context * ctx,
1355
1420
  struct wsp_ggml_tensor * a,
1356
1421
  struct wsp_ggml_tensor * mask,
1357
- float scale);
1422
+ float scale,
1423
+ float max_bias);
1358
1424
 
1359
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1425
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back(
1360
1426
  struct wsp_ggml_context * ctx,
1361
1427
  struct wsp_ggml_tensor * a,
1362
- struct wsp_ggml_tensor * b);
1428
+ struct wsp_ggml_tensor * b,
1429
+ float scale,
1430
+ float max_bias);
1363
1431
 
1364
1432
  // in-place, returns view(a)
1365
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace(
1433
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_inplace(
1366
1434
  struct wsp_ggml_context * ctx,
1367
1435
  struct wsp_ggml_tensor * a,
1368
- struct wsp_ggml_tensor * b);
1436
+ struct wsp_ggml_tensor * b,
1437
+ float scale,
1438
+ float max_bias);
1369
1439
 
1370
1440
  // rotary position embedding
1371
- // if mode & 1 == 1, skip n_past elements (DEPRECATED)
1372
- // if mode & 2 == 1, GPT-NeoX style
1373
- // if mode & 4 == 1, ChatGLM style
1441
+ // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1442
+ // if (mode & WSP_GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1374
1443
  //
1375
1444
  // b is an int32 vector with size a->ne[2], it contains the positions
1376
1445
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
@@ -1378,8 +1447,7 @@ extern "C" {
1378
1447
  struct wsp_ggml_tensor * a,
1379
1448
  struct wsp_ggml_tensor * b,
1380
1449
  int n_dims,
1381
- int mode,
1382
- int n_ctx);
1450
+ int mode);
1383
1451
 
1384
1452
  // in-place, returns view(a)
1385
1453
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
@@ -1387,18 +1455,34 @@ extern "C" {
1387
1455
  struct wsp_ggml_tensor * a,
1388
1456
  struct wsp_ggml_tensor * b,
1389
1457
  int n_dims,
1390
- int mode,
1391
- int n_ctx);
1458
+ int mode);
1392
1459
 
1393
1460
  // custom RoPE
1394
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1461
+ // c is freq factors (e.g. phi3-128k), (optional)
1462
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext(
1395
1463
  struct wsp_ggml_context * ctx,
1396
1464
  struct wsp_ggml_tensor * a,
1397
1465
  struct wsp_ggml_tensor * b,
1466
+ struct wsp_ggml_tensor * c,
1398
1467
  int n_dims,
1399
1468
  int mode,
1400
- int n_ctx,
1401
- int n_orig_ctx,
1469
+ int n_ctx_orig,
1470
+ float freq_base,
1471
+ float freq_scale,
1472
+ float ext_factor,
1473
+ float attn_factor,
1474
+ float beta_fast,
1475
+ float beta_slow);
1476
+
1477
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi(
1478
+ struct wsp_ggml_context * ctx,
1479
+ struct wsp_ggml_tensor * a,
1480
+ struct wsp_ggml_tensor * b,
1481
+ struct wsp_ggml_tensor * c,
1482
+ int n_dims,
1483
+ int sections[4],
1484
+ int mode,
1485
+ int n_ctx_orig,
1402
1486
  float freq_base,
1403
1487
  float freq_scale,
1404
1488
  float ext_factor,
@@ -1407,14 +1491,14 @@ extern "C" {
1407
1491
  float beta_slow);
1408
1492
 
1409
1493
  // in-place, returns view(a)
1410
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1494
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
1411
1495
  struct wsp_ggml_context * ctx,
1412
1496
  struct wsp_ggml_tensor * a,
1413
1497
  struct wsp_ggml_tensor * b,
1498
+ struct wsp_ggml_tensor * c,
1414
1499
  int n_dims,
1415
1500
  int mode,
1416
- int n_ctx,
1417
- int n_orig_ctx,
1501
+ int n_ctx_orig,
1418
1502
  float freq_base,
1419
1503
  float freq_scale,
1420
1504
  float ext_factor,
@@ -1422,46 +1506,73 @@ extern "C" {
1422
1506
  float beta_fast,
1423
1507
  float beta_slow);
1424
1508
 
1425
- // compute correction dims for YaRN RoPE scaling
1426
- WSP_GGML_CALL void wsp_ggml_rope_yarn_corr_dims(
1427
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1509
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1510
+ struct wsp_ggml_context * ctx,
1511
+ struct wsp_ggml_tensor * a,
1512
+ struct wsp_ggml_tensor * b,
1513
+ int n_dims,
1514
+ int mode,
1515
+ int n_ctx_orig,
1516
+ float freq_base,
1517
+ float freq_scale,
1518
+ float ext_factor,
1519
+ float attn_factor,
1520
+ float beta_fast,
1521
+ float beta_slow),
1522
+ "use wsp_ggml_rope_ext instead");
1428
1523
 
1429
- // xPos RoPE, in-place, returns view(a)
1430
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1524
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1431
1525
  struct wsp_ggml_context * ctx,
1432
1526
  struct wsp_ggml_tensor * a,
1433
1527
  struct wsp_ggml_tensor * b,
1434
1528
  int n_dims,
1435
- float base,
1436
- bool down);
1529
+ int mode,
1530
+ int n_ctx_orig,
1531
+ float freq_base,
1532
+ float freq_scale,
1533
+ float ext_factor,
1534
+ float attn_factor,
1535
+ float beta_fast,
1536
+ float beta_slow),
1537
+ "use wsp_ggml_rope_ext_inplace instead");
1538
+
1539
+ // compute correction dims for YaRN RoPE scaling
1540
+ WSP_GGML_API void wsp_ggml_rope_yarn_corr_dims(
1541
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1437
1542
 
1438
1543
  // rotary position embedding backward, i.e compute dx from dy
1439
1544
  // a - dy
1440
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1545
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_back(
1441
1546
  struct wsp_ggml_context * ctx,
1442
- struct wsp_ggml_tensor * a,
1443
- struct wsp_ggml_tensor * b,
1547
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_rope result
1548
+ struct wsp_ggml_tensor * b, // positions
1549
+ struct wsp_ggml_tensor * c, // freq factors
1444
1550
  int n_dims,
1445
1551
  int mode,
1446
- int n_ctx,
1447
- int n_orig_ctx,
1552
+ int n_ctx_orig,
1448
1553
  float freq_base,
1449
1554
  float freq_scale,
1450
1555
  float ext_factor,
1451
1556
  float attn_factor,
1452
1557
  float beta_fast,
1453
- float beta_slow,
1454
- float xpos_base,
1455
- bool xpos_down);
1558
+ float beta_slow);
1456
1559
 
1457
- // alibi position embedding
1458
- // in-place, returns view(a)
1459
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_alibi(
1560
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi_back(
1460
1561
  struct wsp_ggml_context * ctx,
1461
1562
  struct wsp_ggml_tensor * a,
1462
- int n_past,
1463
- int n_head,
1464
- float bias_max);
1563
+ struct wsp_ggml_tensor * b,
1564
+ struct wsp_ggml_tensor * c,
1565
+ int n_dims,
1566
+ int sections[4],
1567
+ int mode,
1568
+ int n_ctx_orig,
1569
+ float freq_base,
1570
+ float freq_scale,
1571
+ float ext_factor,
1572
+ float attn_factor,
1573
+ float beta_fast,
1574
+ float beta_slow);
1575
+
1465
1576
 
1466
1577
  // clamp
1467
1578
  // in-place, returns view(a)
@@ -1471,22 +1582,38 @@ extern "C" {
1471
1582
  float min,
1472
1583
  float max);
1473
1584
 
1585
+ // im2col
1586
+ // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1474
1587
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1475
1588
  struct wsp_ggml_context * ctx,
1476
- struct wsp_ggml_tensor * a,
1477
- struct wsp_ggml_tensor * b,
1478
- int s0,
1479
- int s1,
1480
- int p0,
1481
- int p1,
1482
- int d0,
1483
- int d1,
1484
- bool is_2D);
1589
+ struct wsp_ggml_tensor * a, // convolution kernel
1590
+ struct wsp_ggml_tensor * b, // data
1591
+ int s0, // stride dimension 0
1592
+ int s1, // stride dimension 1
1593
+ int p0, // padding dimension 0
1594
+ int p1, // padding dimension 1
1595
+ int d0, // dilation dimension 0
1596
+ int d1, // dilation dimension 1
1597
+ bool is_2D,
1598
+ enum wsp_ggml_type dst_type);
1599
+
1600
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col_back(
1601
+ struct wsp_ggml_context * ctx,
1602
+ struct wsp_ggml_tensor * a, // convolution kernel
1603
+ struct wsp_ggml_tensor * b, // gradient of im2col output
1604
+ int64_t * ne, // shape of im2col input
1605
+ int s0, // stride dimension 0
1606
+ int s1, // stride dimension 1
1607
+ int p0, // padding dimension 0
1608
+ int p1, // padding dimension 1
1609
+ int d0, // dilation dimension 0
1610
+ int d1, // dilation dimension 1
1611
+ bool is_2D);
1485
1612
 
1486
1613
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1487
1614
  struct wsp_ggml_context * ctx,
1488
- struct wsp_ggml_tensor * a,
1489
- struct wsp_ggml_tensor * b,
1615
+ struct wsp_ggml_tensor * a, // convolution kernel
1616
+ struct wsp_ggml_tensor * b, // data
1490
1617
  int s0, // stride
1491
1618
  int p0, // padding
1492
1619
  int d0); // dilation
@@ -1495,30 +1622,46 @@ extern "C" {
1495
1622
  // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1496
1623
  WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1497
1624
  struct wsp_ggml_context * ctx,
1498
- struct wsp_ggml_tensor * a,
1499
- struct wsp_ggml_tensor * b,
1500
- int s,
1501
- int d);
1625
+ struct wsp_ggml_tensor * a, // convolution kernel
1626
+ struct wsp_ggml_tensor * b, // data
1627
+ int s, // stride
1628
+ int d); // dilation
1629
+
1630
+ // depthwise
1631
+ // TODO: this is very likely wrong for some cases! - needs more testing
1632
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw(
1633
+ struct wsp_ggml_context * ctx,
1634
+ struct wsp_ggml_tensor * a, // convolution kernel
1635
+ struct wsp_ggml_tensor * b, // data
1636
+ int s0, // stride
1637
+ int p0, // padding
1638
+ int d0); // dilation
1639
+
1640
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw_ph(
1641
+ struct wsp_ggml_context * ctx,
1642
+ struct wsp_ggml_tensor * a, // convolution kernel
1643
+ struct wsp_ggml_tensor * b, // data
1644
+ int s0, // stride
1645
+ int d0); // dilation
1502
1646
 
1503
1647
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1504
1648
  struct wsp_ggml_context * ctx,
1505
- struct wsp_ggml_tensor * a,
1506
- struct wsp_ggml_tensor * b,
1507
- int s0,
1508
- int p0,
1509
- int d0);
1649
+ struct wsp_ggml_tensor * a, // convolution kernel
1650
+ struct wsp_ggml_tensor * b, // data
1651
+ int s0, // stride
1652
+ int p0, // padding
1653
+ int d0); // dilation
1510
1654
 
1511
1655
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1512
1656
  struct wsp_ggml_context * ctx,
1513
- struct wsp_ggml_tensor * a,
1514
- struct wsp_ggml_tensor * b,
1515
- int s0,
1516
- int s1,
1517
- int p0,
1518
- int p1,
1519
- int d0,
1520
- int d1);
1521
-
1657
+ struct wsp_ggml_tensor * a, // convolution kernel
1658
+ struct wsp_ggml_tensor * b, // data
1659
+ int s0, // stride dimension 0
1660
+ int s1, // stride dimension 1
1661
+ int p0, // padding dimension 0
1662
+ int p1, // padding dimension 1
1663
+ int d0, // dilation dimension 0
1664
+ int d1); // dilation dimension 1
1522
1665
 
1523
1666
  // kernel size is a->ne[0] x a->ne[1]
1524
1667
  // stride is equal to kernel size
@@ -1546,6 +1689,34 @@ extern "C" {
1546
1689
  struct wsp_ggml_tensor * a,
1547
1690
  struct wsp_ggml_tensor * b);
1548
1691
 
1692
+ // depthwise (via im2col and mul_mat)
1693
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw(
1694
+ struct wsp_ggml_context * ctx,
1695
+ struct wsp_ggml_tensor * a, // convolution kernel
1696
+ struct wsp_ggml_tensor * b, // data
1697
+ int s0, // stride dimension 0
1698
+ int s1, // stride dimension 1
1699
+ int p0, // padding dimension 0
1700
+ int p1, // padding dimension 1
1701
+ int d0, // dilation dimension 0
1702
+ int d1); // dilation dimension 1
1703
+
1704
+ // Depthwise 2D convolution
1705
+ // may be faster than wsp_ggml_conv_2d_dw, but not available in all backends
1706
+ // a: KW KH 1 C convolution kernel
1707
+ // b: W H C N input data
1708
+ // res: W_out H_out C N
1709
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw_direct(
1710
+ struct wsp_ggml_context * ctx,
1711
+ struct wsp_ggml_tensor * a,
1712
+ struct wsp_ggml_tensor * b,
1713
+ int stride0,
1714
+ int stride1,
1715
+ int pad0,
1716
+ int pad1,
1717
+ int dilation0,
1718
+ int dilation1);
1719
+
1549
1720
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_2d_p0(
1550
1721
  struct wsp_ggml_context * ctx,
1551
1722
  struct wsp_ggml_tensor * a,
@@ -1579,12 +1750,41 @@ extern "C" {
1579
1750
  float p0,
1580
1751
  float p1);
1581
1752
 
1582
- // nearest interpolate
1583
- // used in stable-diffusion
1753
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
1754
+ struct wsp_ggml_context * ctx,
1755
+ struct wsp_ggml_tensor * a,
1756
+ struct wsp_ggml_tensor * af, // "a"/input used in forward pass
1757
+ enum wsp_ggml_op_pool op,
1758
+ int k0,
1759
+ int k1,
1760
+ int s0,
1761
+ int s1,
1762
+ float p0,
1763
+ float p1);
1764
+
1765
+ enum wsp_ggml_scale_mode {
1766
+ WSP_GGML_SCALE_MODE_NEAREST = 0,
1767
+ WSP_GGML_SCALE_MODE_BILINEAR = 1,
1768
+ };
1769
+
1770
+ // interpolate
1771
+ // multiplies ne0 and ne1 by scale factor
1584
1772
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1585
1773
  struct wsp_ggml_context * ctx,
1586
1774
  struct wsp_ggml_tensor * a,
1587
- int scale_factor);
1775
+ int scale_factor,
1776
+ enum wsp_ggml_scale_mode mode);
1777
+
1778
+ // interpolate
1779
+ // interpolate scale to specified dimensions
1780
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1781
+ struct wsp_ggml_context * ctx,
1782
+ struct wsp_ggml_tensor * a,
1783
+ int ne0,
1784
+ int ne1,
1785
+ int ne2,
1786
+ int ne3,
1787
+ enum wsp_ggml_scale_mode mode);
1588
1788
 
1589
1789
  // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1590
1790
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
@@ -1595,10 +1795,37 @@ extern "C" {
1595
1795
  int p2,
1596
1796
  int p3);
1597
1797
 
1798
+ // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
1799
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad_reflect_1d(
1800
+ struct wsp_ggml_context * ctx,
1801
+ struct wsp_ggml_tensor * a,
1802
+ int p0,
1803
+ int p1);
1804
+
1805
+ // Move tensor elements by an offset given for each dimension. Elements that
1806
+ // are shifted beyond the last position are wrapped around to the beginning.
1807
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_roll(
1808
+ struct wsp_ggml_context * ctx,
1809
+ struct wsp_ggml_tensor * a,
1810
+ int shift0,
1811
+ int shift1,
1812
+ int shift2,
1813
+ int shift3);
1814
+
1815
+
1816
+ // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1817
+ // timesteps: [N,]
1818
+ // return: [N, dim]
1819
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
1820
+ struct wsp_ggml_context * ctx,
1821
+ struct wsp_ggml_tensor * timesteps,
1822
+ int dim,
1823
+ int max_period);
1824
+
1598
1825
  // sort rows
1599
1826
  enum wsp_ggml_sort_order {
1600
- WSP_GGML_SORT_ASC,
1601
- WSP_GGML_SORT_DESC,
1827
+ WSP_GGML_SORT_ORDER_ASC,
1828
+ WSP_GGML_SORT_ORDER_DESC,
1602
1829
  };
1603
1830
 
1604
1831
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
@@ -1606,19 +1833,43 @@ extern "C" {
1606
1833
  struct wsp_ggml_tensor * a,
1607
1834
  enum wsp_ggml_sort_order order);
1608
1835
 
1836
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_arange(
1837
+ struct wsp_ggml_context * ctx,
1838
+ float start,
1839
+ float stop,
1840
+ float step);
1841
+
1609
1842
  // top k elements per row
1610
1843
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1611
1844
  struct wsp_ggml_context * ctx,
1612
1845
  struct wsp_ggml_tensor * a,
1613
1846
  int k);
1614
1847
 
1615
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1848
+ #define WSP_GGML_KQ_MASK_PAD 64
1849
+
1850
+ // q: [n_embd_k, n_batch, n_head, 1]
1851
+ // k: [n_embd_k, n_kv, n_head_kv, 1]
1852
+ // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
1853
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1854
+ // res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
1855
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
1616
1856
  struct wsp_ggml_context * ctx,
1617
1857
  struct wsp_ggml_tensor * q,
1618
1858
  struct wsp_ggml_tensor * k,
1619
1859
  struct wsp_ggml_tensor * v,
1620
- bool masked);
1860
+ struct wsp_ggml_tensor * mask,
1861
+ float scale,
1862
+ float max_bias,
1863
+ float logit_softcap);
1621
1864
 
1865
+ WSP_GGML_API void wsp_ggml_flash_attn_ext_set_prec(
1866
+ struct wsp_ggml_tensor * a,
1867
+ enum wsp_ggml_prec prec);
1868
+
1869
+ WSP_GGML_API enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
1870
+ const struct wsp_ggml_tensor * a);
1871
+
1872
+ // TODO: needs to be adapted to wsp_ggml_flash_attn_ext
1622
1873
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
1623
1874
  struct wsp_ggml_context * ctx,
1624
1875
  struct wsp_ggml_tensor * q,
@@ -1627,13 +1878,19 @@ extern "C" {
1627
1878
  struct wsp_ggml_tensor * d,
1628
1879
  bool masked);
1629
1880
 
1630
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_ff(
1881
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
1631
1882
  struct wsp_ggml_context * ctx,
1632
- struct wsp_ggml_tensor * a,
1633
- struct wsp_ggml_tensor * b0,
1634
- struct wsp_ggml_tensor * b1,
1635
- struct wsp_ggml_tensor * c0,
1636
- struct wsp_ggml_tensor * c1);
1883
+ struct wsp_ggml_tensor * sx,
1884
+ struct wsp_ggml_tensor * c);
1885
+
1886
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
1887
+ struct wsp_ggml_context * ctx,
1888
+ struct wsp_ggml_tensor * s,
1889
+ struct wsp_ggml_tensor * x,
1890
+ struct wsp_ggml_tensor * dt,
1891
+ struct wsp_ggml_tensor * A,
1892
+ struct wsp_ggml_tensor * B,
1893
+ struct wsp_ggml_tensor * C);
1637
1894
 
1638
1895
  // partition into non-overlapping windows with padding if needed
1639
1896
  // example:
@@ -1685,90 +1942,42 @@ extern "C" {
1685
1942
  struct wsp_ggml_tensor * pw,
1686
1943
  struct wsp_ggml_tensor * ph);
1687
1944
 
1688
- // custom operators
1945
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv6(
1946
+ struct wsp_ggml_context * ctx,
1947
+ struct wsp_ggml_tensor * k,
1948
+ struct wsp_ggml_tensor * v,
1949
+ struct wsp_ggml_tensor * r,
1950
+ struct wsp_ggml_tensor * tf,
1951
+ struct wsp_ggml_tensor * td,
1952
+ struct wsp_ggml_tensor * state);
1953
+
1954
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gated_linear_attn(
1955
+ struct wsp_ggml_context * ctx,
1956
+ struct wsp_ggml_tensor * k,
1957
+ struct wsp_ggml_tensor * v,
1958
+ struct wsp_ggml_tensor * q,
1959
+ struct wsp_ggml_tensor * g,
1960
+ struct wsp_ggml_tensor * state,
1961
+ float scale);
1689
1962
 
1690
- typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
1691
- typedef void (*wsp_ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
1692
-
1693
- typedef void (*wsp_ggml_custom1_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1694
- typedef void (*wsp_ggml_custom2_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1695
- typedef void (*wsp_ggml_custom3_op_f32_t)(struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *, const struct wsp_ggml_tensor *);
1696
-
1697
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_f32(
1698
- struct wsp_ggml_context * ctx,
1699
- struct wsp_ggml_tensor * a,
1700
- wsp_ggml_unary_op_f32_t fun),
1701
- "use wsp_ggml_map_custom1 instead");
1702
-
1703
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_unary_inplace_f32(
1704
- struct wsp_ggml_context * ctx,
1705
- struct wsp_ggml_tensor * a,
1706
- wsp_ggml_unary_op_f32_t fun),
1707
- "use wsp_ggml_map_custom1_inplace instead");
1708
-
1709
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_f32(
1710
- struct wsp_ggml_context * ctx,
1711
- struct wsp_ggml_tensor * a,
1712
- struct wsp_ggml_tensor * b,
1713
- wsp_ggml_binary_op_f32_t fun),
1714
- "use wsp_ggml_map_custom2 instead");
1715
-
1716
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_binary_inplace_f32(
1717
- struct wsp_ggml_context * ctx,
1718
- struct wsp_ggml_tensor * a,
1719
- struct wsp_ggml_tensor * b,
1720
- wsp_ggml_binary_op_f32_t fun),
1721
- "use wsp_ggml_map_custom2_inplace instead");
1722
-
1723
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_f32(
1724
- struct wsp_ggml_context * ctx,
1725
- struct wsp_ggml_tensor * a,
1726
- wsp_ggml_custom1_op_f32_t fun),
1727
- "use wsp_ggml_map_custom1 instead");
1728
-
1729
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1_inplace_f32(
1730
- struct wsp_ggml_context * ctx,
1731
- struct wsp_ggml_tensor * a,
1732
- wsp_ggml_custom1_op_f32_t fun),
1733
- "use wsp_ggml_map_custom1_inplace instead");
1734
-
1735
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_f32(
1736
- struct wsp_ggml_context * ctx,
1737
- struct wsp_ggml_tensor * a,
1738
- struct wsp_ggml_tensor * b,
1739
- wsp_ggml_custom2_op_f32_t fun),
1740
- "use wsp_ggml_map_custom2 instead");
1741
-
1742
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom2_inplace_f32(
1743
- struct wsp_ggml_context * ctx,
1744
- struct wsp_ggml_tensor * a,
1745
- struct wsp_ggml_tensor * b,
1746
- wsp_ggml_custom2_op_f32_t fun),
1747
- "use wsp_ggml_map_custom2_inplace instead");
1748
-
1749
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_f32(
1750
- struct wsp_ggml_context * ctx,
1751
- struct wsp_ggml_tensor * a,
1752
- struct wsp_ggml_tensor * b,
1753
- struct wsp_ggml_tensor * c,
1754
- wsp_ggml_custom3_op_f32_t fun),
1755
- "use wsp_ggml_map_custom3 instead");
1756
-
1757
- WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace_f32(
1758
- struct wsp_ggml_context * ctx,
1759
- struct wsp_ggml_tensor * a,
1760
- struct wsp_ggml_tensor * b,
1761
- struct wsp_ggml_tensor * c,
1762
- wsp_ggml_custom3_op_f32_t fun),
1763
- "use wsp_ggml_map_custom3_inplace instead");
1764
-
1765
- // custom operators v2
1963
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv7(
1964
+ struct wsp_ggml_context * ctx,
1965
+ struct wsp_ggml_tensor * r,
1966
+ struct wsp_ggml_tensor * w,
1967
+ struct wsp_ggml_tensor * k,
1968
+ struct wsp_ggml_tensor * v,
1969
+ struct wsp_ggml_tensor * a,
1970
+ struct wsp_ggml_tensor * b,
1971
+ struct wsp_ggml_tensor * state);
1972
+
1973
+ // custom operators
1766
1974
 
1767
1975
  typedef void (*wsp_ggml_custom1_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, int ith, int nth, void * userdata);
1768
1976
  typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
1769
1977
  typedef void (*wsp_ggml_custom3_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, const struct wsp_ggml_tensor * c, int ith, int nth, void * userdata);
1770
1978
 
1771
- #define WSP_GGML_N_TASKS_MAX -1
1979
+ #define WSP_GGML_N_TASKS_MAX (-1)
1980
+ // n_tasks == WSP_GGML_N_TASKS_MAX means to use max number of tasks
1772
1981
 
1773
1982
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1(
1774
1983
  struct wsp_ggml_context * ctx,
@@ -1818,56 +2027,85 @@ extern "C" {
1818
2027
  int n_tasks,
1819
2028
  void * userdata);
1820
2029
 
2030
+ typedef void (*wsp_ggml_custom_op_t)(struct wsp_ggml_tensor * dst , int ith, int nth, void * userdata);
2031
+
2032
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_custom_4d(
2033
+ struct wsp_ggml_context * ctx,
2034
+ enum wsp_ggml_type type,
2035
+ int64_t ne0,
2036
+ int64_t ne1,
2037
+ int64_t ne2,
2038
+ int64_t ne3,
2039
+ struct wsp_ggml_tensor ** args,
2040
+ int n_args,
2041
+ wsp_ggml_custom_op_t fun,
2042
+ int n_tasks,
2043
+ void * userdata);
2044
+
2045
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_custom_inplace(
2046
+ struct wsp_ggml_context * ctx,
2047
+ struct wsp_ggml_tensor * a,
2048
+ struct wsp_ggml_tensor ** args,
2049
+ int n_args,
2050
+ wsp_ggml_custom_op_t fun,
2051
+ int n_tasks,
2052
+ void * userdata);
2053
+
1821
2054
  // loss function
1822
2055
 
1823
2056
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss(
1824
- struct wsp_ggml_context * ctx,
1825
- struct wsp_ggml_tensor * a,
1826
- struct wsp_ggml_tensor * b);
2057
+ struct wsp_ggml_context * ctx,
2058
+ struct wsp_ggml_tensor * a, // logits
2059
+ struct wsp_ggml_tensor * b); // labels
1827
2060
 
1828
2061
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back(
1829
- struct wsp_ggml_context * ctx,
1830
- struct wsp_ggml_tensor * a,
1831
- struct wsp_ggml_tensor * b,
1832
- struct wsp_ggml_tensor * c);
2062
+ struct wsp_ggml_context * ctx,
2063
+ struct wsp_ggml_tensor * a, // logits
2064
+ struct wsp_ggml_tensor * b, // labels
2065
+ struct wsp_ggml_tensor * c); // gradients of cross_entropy_loss result
2066
+
2067
+ // AdamW optimizer step
2068
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2069
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2070
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
2071
+ struct wsp_ggml_context * ctx,
2072
+ struct wsp_ggml_tensor * a,
2073
+ struct wsp_ggml_tensor * grad,
2074
+ struct wsp_ggml_tensor * m,
2075
+ struct wsp_ggml_tensor * v,
2076
+ struct wsp_ggml_tensor * adamw_params); // parameters such a the learning rate
1833
2077
 
1834
2078
  //
1835
2079
  // automatic differentiation
1836
2080
  //
1837
2081
 
1838
- WSP_GGML_API void wsp_ggml_set_param(
1839
- struct wsp_ggml_context * ctx,
1840
- struct wsp_ggml_tensor * tensor);
2082
+ WSP_GGML_API void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
2083
+ WSP_GGML_API void wsp_ggml_build_backward_expand(
2084
+ struct wsp_ggml_context * ctx, // context for gradient computation
2085
+ struct wsp_ggml_cgraph * cgraph,
2086
+ struct wsp_ggml_tensor ** grad_accs);
1841
2087
 
2088
+ // graph allocation in a context
2089
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
2090
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads);
2091
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, bool force_grads);
2092
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
2093
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2094
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
1842
2095
 
1843
- WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1844
- WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
2096
+ WSP_GGML_API int wsp_ggml_graph_size (struct wsp_ggml_cgraph * cgraph);
2097
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_node (struct wsp_ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2098
+ WSP_GGML_API struct wsp_ggml_tensor ** wsp_ggml_graph_nodes (struct wsp_ggml_cgraph * cgraph);
2099
+ WSP_GGML_API int wsp_ggml_graph_n_nodes(struct wsp_ggml_cgraph * cgraph);
1845
2100
 
1846
- // graph allocation in a context
1847
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
1848
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads);
1849
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
1850
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_view (struct wsp_ggml_cgraph * cgraph, int i0, int i1);
1851
- WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
1852
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads
1853
- WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
2101
+ WSP_GGML_API void wsp_ggml_graph_add_node(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1854
2102
 
1855
2103
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
1856
2104
  WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1857
2105
 
1858
- // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1859
- // when plan.work_size > 0, caller must allocate memory for plan.work_data
1860
- WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (const struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1861
- WSP_GGML_API int wsp_ggml_graph_compute( struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1862
-
1863
- // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1864
- // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1865
- WSP_GGML_API void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
1866
-
1867
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1868
-
1869
- WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1870
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
2106
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor (const struct wsp_ggml_cgraph * cgraph, const char * name);
2107
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad (const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
2108
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad_acc(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
1871
2109
 
1872
2110
  // print info and performance information for the graph
1873
2111
  WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph);
@@ -1875,191 +2113,14 @@ extern "C" {
1875
2113
  // dump the graph into a file using the dot format
1876
2114
  WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename);
1877
2115
 
1878
- // build gradient checkpointing backward graph gb for gf using provided checkpoints
1879
- // gb_tmp will contain original backward graph with rewritten backward process nodes,
1880
- // but without the second forward pass nodes.
1881
- WSP_GGML_API void wsp_ggml_build_backward_gradient_checkpointing(
1882
- struct wsp_ggml_context * ctx,
1883
- struct wsp_ggml_cgraph * gf,
1884
- struct wsp_ggml_cgraph * gb,
1885
- struct wsp_ggml_cgraph * gb_tmp,
1886
- struct wsp_ggml_tensor * * checkpoints,
1887
- int n_checkpoints);
1888
- //
1889
- // optimization
1890
- //
1891
-
1892
- // optimization methods
1893
- enum wsp_ggml_opt_type {
1894
- WSP_GGML_OPT_ADAM,
1895
- WSP_GGML_OPT_LBFGS,
1896
- };
1897
-
1898
- // linesearch methods
1899
- enum wsp_ggml_linesearch {
1900
- WSP_GGML_LINESEARCH_DEFAULT = 1,
1901
-
1902
- WSP_GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
1903
- WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
1904
- WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
1905
- };
1906
-
1907
- // optimization return values
1908
- enum wsp_ggml_opt_result {
1909
- WSP_GGML_OPT_OK = 0,
1910
- WSP_GGML_OPT_DID_NOT_CONVERGE,
1911
- WSP_GGML_OPT_NO_CONTEXT,
1912
- WSP_GGML_OPT_INVALID_WOLFE,
1913
- WSP_GGML_OPT_FAIL,
1914
- WSP_GGML_OPT_CANCEL,
1915
-
1916
- WSP_GGML_LINESEARCH_FAIL = -128,
1917
- WSP_GGML_LINESEARCH_MINIMUM_STEP,
1918
- WSP_GGML_LINESEARCH_MAXIMUM_STEP,
1919
- WSP_GGML_LINESEARCH_MAXIMUM_ITERATIONS,
1920
- WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
1921
- };
1922
-
1923
- typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
2116
+ // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
1924
2117
  typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1925
2118
 
1926
- // optimization parameters
1927
- //
1928
- // see ggml.c (wsp_ggml_opt_default_params) for default values
1929
- //
1930
- struct wsp_ggml_opt_params {
1931
- enum wsp_ggml_opt_type type;
1932
-
1933
- size_t graph_size;
1934
-
1935
- int n_threads;
1936
-
1937
- // delta-based convergence test
1938
- //
1939
- // if past == 0 - disabled
1940
- // if past > 0:
1941
- // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
1942
- //
1943
- int past;
1944
- float delta;
1945
-
1946
- // maximum number of iterations without improvement
1947
- //
1948
- // if 0 - disabled
1949
- // if > 0:
1950
- // assume convergence if no cost improvement in this number of iterations
1951
- //
1952
- int max_no_improvement;
1953
-
1954
- bool print_forward_graph;
1955
- bool print_backward_graph;
1956
-
1957
- int n_gradient_accumulation;
1958
-
1959
- // ADAM parameters
1960
- struct {
1961
- int n_iter;
1962
-
1963
- float sched; // schedule multiplier (fixed, decay or warmup)
1964
- float decay; // weight decay for AdamW, use 0.0f to disable
1965
- int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
1966
- float alpha; // learning rate
1967
- float beta1;
1968
- float beta2;
1969
- float eps; // epsilon for numerical stability
1970
- float eps_f; // epsilon for convergence test
1971
- float eps_g; // epsilon for convergence test
1972
- float gclip; // gradient clipping
1973
- } adam;
1974
-
1975
- // LBFGS parameters
1976
- struct {
1977
- int m; // number of corrections to approximate the inv. Hessian
1978
- int n_iter;
1979
- int max_linesearch;
1980
-
1981
- float eps; // convergence tolerance
1982
- float ftol; // line search tolerance
1983
- float wolfe;
1984
- float min_step;
1985
- float max_step;
1986
-
1987
- enum wsp_ggml_linesearch linesearch;
1988
- } lbfgs;
1989
- };
1990
-
1991
- struct wsp_ggml_opt_context {
1992
- struct wsp_ggml_context * ctx;
1993
- struct wsp_ggml_opt_params params;
1994
-
1995
- int iter;
1996
- int64_t nx; // number of parameter elements
1997
-
1998
- bool just_initialized;
1999
-
2000
- float loss_before;
2001
- float loss_after;
2002
-
2003
- struct {
2004
- struct wsp_ggml_tensor * g; // current gradient
2005
- struct wsp_ggml_tensor * m; // first moment
2006
- struct wsp_ggml_tensor * v; // second moment
2007
- struct wsp_ggml_tensor * pf; // past function values
2008
- float fx_best;
2009
- float fx_prev;
2010
- int n_no_improvement;
2011
- } adam;
2012
-
2013
- struct {
2014
- struct wsp_ggml_tensor * x; // current parameters
2015
- struct wsp_ggml_tensor * xp; // previous parameters
2016
- struct wsp_ggml_tensor * g; // current gradient
2017
- struct wsp_ggml_tensor * gp; // previous gradient
2018
- struct wsp_ggml_tensor * d; // search direction
2019
- struct wsp_ggml_tensor * pf; // past function values
2020
- struct wsp_ggml_tensor * lmal; // the L-BFGS memory alpha
2021
- struct wsp_ggml_tensor * lmys; // the L-BFGS memory ys
2022
- struct wsp_ggml_tensor * lms; // the L-BFGS memory s
2023
- struct wsp_ggml_tensor * lmy; // the L-BFGS memory y
2024
- float fx_best;
2025
- float step;
2026
- int j;
2027
- int k;
2028
- int end;
2029
- int n_no_improvement;
2030
- } lbfgs;
2031
- };
2032
-
2033
- WSP_GGML_API struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type);
2119
+ // Set callback for all future logging events.
2120
+ // If this is not called, or NULL is supplied, everything is output on stderr.
2121
+ WSP_GGML_API void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data);
2034
2122
 
2035
- // optimize the function defined by the tensor f
2036
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt(
2037
- struct wsp_ggml_context * ctx,
2038
- struct wsp_ggml_opt_params params,
2039
- struct wsp_ggml_tensor * f);
2040
-
2041
- // initialize optimizer context
2042
- WSP_GGML_API void wsp_ggml_opt_init(
2043
- struct wsp_ggml_context * ctx,
2044
- struct wsp_ggml_opt_context * opt,
2045
- struct wsp_ggml_opt_params params,
2046
- int64_t nx);
2047
-
2048
- // continue optimizing the function defined by the tensor f
2049
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume(
2050
- struct wsp_ggml_context * ctx,
2051
- struct wsp_ggml_opt_context * opt,
2052
- struct wsp_ggml_tensor * f);
2053
-
2054
- // continue optimizing the function defined by the tensor f
2055
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
2056
- struct wsp_ggml_context * ctx,
2057
- struct wsp_ggml_opt_context * opt,
2058
- struct wsp_ggml_tensor * f,
2059
- struct wsp_ggml_cgraph * gf,
2060
- struct wsp_ggml_cgraph * gb,
2061
- wsp_ggml_opt_callback callback,
2062
- void * callback_data);
2123
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
2063
2124
 
2064
2125
  //
2065
2126
  // quantization
@@ -2077,201 +2138,83 @@ extern "C" {
2077
2138
  WSP_GGML_API void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type);
2078
2139
  WSP_GGML_API void wsp_ggml_wsp_quantize_free(void);
2079
2140
 
2080
- // TODO: these would probably get removed in favor of the more general wsp_ggml_wsp_quantize_chunk
2081
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
2082
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
2083
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
2084
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
2085
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
2086
-
2087
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
2088
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
2089
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
2090
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
2091
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
2092
-
2093
2141
  // some quantization type cannot be used without an importance matrix
2094
2142
  WSP_GGML_API bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type);
2095
2143
 
2096
2144
  // calls wsp_ggml_wsp_quantize_init internally (i.e. can allocate memory)
2097
- WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst,
2098
- int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
2099
-
2100
- //
2101
- // gguf
2102
- //
2103
-
2104
- enum wsp_gguf_type {
2105
- WSP_GGUF_TYPE_UINT8 = 0,
2106
- WSP_GGUF_TYPE_INT8 = 1,
2107
- WSP_GGUF_TYPE_UINT16 = 2,
2108
- WSP_GGUF_TYPE_INT16 = 3,
2109
- WSP_GGUF_TYPE_UINT32 = 4,
2110
- WSP_GGUF_TYPE_INT32 = 5,
2111
- WSP_GGUF_TYPE_FLOAT32 = 6,
2112
- WSP_GGUF_TYPE_BOOL = 7,
2113
- WSP_GGUF_TYPE_STRING = 8,
2114
- WSP_GGUF_TYPE_ARRAY = 9,
2115
- WSP_GGUF_TYPE_UINT64 = 10,
2116
- WSP_GGUF_TYPE_INT64 = 11,
2117
- WSP_GGUF_TYPE_FLOAT64 = 12,
2118
- WSP_GGUF_TYPE_COUNT, // marks the end of the enum
2145
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(
2146
+ enum wsp_ggml_type type,
2147
+ const float * src,
2148
+ void * dst,
2149
+ int64_t start,
2150
+ int64_t nrows,
2151
+ int64_t n_per_row,
2152
+ const float * imatrix);
2153
+
2154
+ #ifdef __cplusplus
2155
+ // restrict not standard in C++
2156
+ # if defined(__GNUC__)
2157
+ # define WSP_GGML_RESTRICT __restrict__
2158
+ # elif defined(__clang__)
2159
+ # define WSP_GGML_RESTRICT __restrict
2160
+ # elif defined(_MSC_VER)
2161
+ # define WSP_GGML_RESTRICT __restrict
2162
+ # else
2163
+ # define WSP_GGML_RESTRICT
2164
+ # endif
2165
+ #else
2166
+ # if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L)
2167
+ # define WSP_GGML_RESTRICT __restrict
2168
+ # else
2169
+ # define WSP_GGML_RESTRICT restrict
2170
+ # endif
2171
+ #endif
2172
+ typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
2173
+ typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t k);
2174
+
2175
+ struct wsp_ggml_type_traits {
2176
+ const char * type_name;
2177
+ int64_t blck_size;
2178
+ int64_t blck_size_interleave; // interleave elements in blocks
2179
+ size_t type_size;
2180
+ bool is_quantized;
2181
+ wsp_ggml_to_float_t to_float;
2182
+ wsp_ggml_from_float_t from_float_ref;
2119
2183
  };
2120
2184
 
2121
- struct wsp_gguf_context;
2185
+ WSP_GGML_API const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type);
2122
2186
 
2123
- struct wsp_gguf_init_params {
2124
- bool no_alloc;
2187
+ // ggml threadpool
2188
+ // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend
2189
+ // the goal should be to create an API that other backends can use move everything to the ggml base
2125
2190
 
2126
- // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
2127
- struct wsp_ggml_context ** ctx;
2191
+ // scheduling priorities
2192
+ enum wsp_ggml_sched_priority {
2193
+ WSP_GGML_SCHED_PRIO_LOW = -1,
2194
+ WSP_GGML_SCHED_PRIO_NORMAL,
2195
+ WSP_GGML_SCHED_PRIO_MEDIUM,
2196
+ WSP_GGML_SCHED_PRIO_HIGH,
2197
+ WSP_GGML_SCHED_PRIO_REALTIME
2128
2198
  };
2129
2199
 
2130
- WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_empty(void);
2131
- WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp_gguf_init_params params);
2132
- //WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_buffer(..);
2133
-
2134
- WSP_GGML_API void wsp_gguf_free(struct wsp_gguf_context * ctx);
2135
-
2136
- WSP_GGML_API const char * wsp_gguf_type_name(enum wsp_gguf_type type);
2137
-
2138
- WSP_GGML_API int wsp_gguf_get_version (const struct wsp_gguf_context * ctx);
2139
- WSP_GGML_API size_t wsp_gguf_get_alignment (const struct wsp_gguf_context * ctx);
2140
- WSP_GGML_API size_t wsp_gguf_get_data_offset(const struct wsp_gguf_context * ctx);
2141
- WSP_GGML_API void * wsp_gguf_get_data (const struct wsp_gguf_context * ctx);
2142
-
2143
- WSP_GGML_API int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx);
2144
- WSP_GGML_API int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key);
2145
- WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int key_id);
2146
-
2147
- WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int key_id);
2148
- WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id);
2149
-
2150
- // will abort if the wrong type is used for the key
2151
- WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int key_id);
2152
- WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int key_id);
2153
- WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int key_id);
2154
- WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int key_id);
2155
- WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int key_id);
2156
- WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int key_id);
2157
- WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int key_id);
2158
- WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int key_id);
2159
- WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int key_id);
2160
- WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2161
- WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2162
- WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2163
- WSP_GGML_API const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id);
2164
- WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2165
- WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2166
- WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
2167
-
2168
- WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2169
- WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2170
- WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2171
- WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
2172
- WSP_GGML_API enum wsp_ggml_type wsp_gguf_get_tensor_type (const struct wsp_gguf_context * ctx, int i);
2173
-
2174
- // overrides existing values or adds a new one
2175
- WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
2176
- WSP_GGML_API void wsp_gguf_set_val_i8 (struct wsp_gguf_context * ctx, const char * key, int8_t val);
2177
- WSP_GGML_API void wsp_gguf_set_val_u16 (struct wsp_gguf_context * ctx, const char * key, uint16_t val);
2178
- WSP_GGML_API void wsp_gguf_set_val_i16 (struct wsp_gguf_context * ctx, const char * key, int16_t val);
2179
- WSP_GGML_API void wsp_gguf_set_val_u32 (struct wsp_gguf_context * ctx, const char * key, uint32_t val);
2180
- WSP_GGML_API void wsp_gguf_set_val_i32 (struct wsp_gguf_context * ctx, const char * key, int32_t val);
2181
- WSP_GGML_API void wsp_gguf_set_val_f32 (struct wsp_gguf_context * ctx, const char * key, float val);
2182
- WSP_GGML_API void wsp_gguf_set_val_u64 (struct wsp_gguf_context * ctx, const char * key, uint64_t val);
2183
- WSP_GGML_API void wsp_gguf_set_val_i64 (struct wsp_gguf_context * ctx, const char * key, int64_t val);
2184
- WSP_GGML_API void wsp_gguf_set_val_f64 (struct wsp_gguf_context * ctx, const char * key, double val);
2185
- WSP_GGML_API void wsp_gguf_set_val_bool(struct wsp_gguf_context * ctx, const char * key, bool val);
2186
- WSP_GGML_API void wsp_gguf_set_val_str (struct wsp_gguf_context * ctx, const char * key, const char * val);
2187
- WSP_GGML_API void wsp_gguf_set_arr_data(struct wsp_gguf_context * ctx, const char * key, enum wsp_gguf_type type, const void * data, int n);
2188
- WSP_GGML_API void wsp_gguf_set_arr_str (struct wsp_gguf_context * ctx, const char * key, const char ** data, int n);
2189
-
2190
- // set or add KV pairs from another context
2191
- WSP_GGML_API void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * src);
2192
-
2193
- // manage tensor info
2194
- WSP_GGML_API void wsp_gguf_add_tensor(struct wsp_gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
2195
- WSP_GGML_API void wsp_gguf_set_tensor_type(struct wsp_gguf_context * ctx, const char * name, enum wsp_ggml_type type);
2196
- WSP_GGML_API void wsp_gguf_set_tensor_data(struct wsp_gguf_context * ctx, const char * name, const void * data, size_t size);
2197
-
2198
- // writing gguf files can be done in 2 ways:
2199
- //
2200
- // - write the entire wsp_gguf_context to a binary file in a single pass:
2201
- //
2202
- // wsp_gguf_write_to_file(ctx, fname);
2203
- //
2204
- // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
2205
- //
2206
- // FILE * f = fopen(fname, "wb");
2207
- // fseek(f, wsp_gguf_get_meta_size(ctx), SEEK_SET);
2208
- // fwrite(f, ...);
2209
- // void * data = wsp_gguf_meta_get_meta_data(ctx);
2210
- // fseek(f, 0, SEEK_SET);
2211
- // fwrite(f, data, wsp_gguf_get_meta_size(ctx));
2212
- // free(data);
2213
- // fclose(f);
2214
- //
2215
-
2216
- // write the entire context to a binary file
2217
- WSP_GGML_API void wsp_gguf_write_to_file(const struct wsp_gguf_context * ctx, const char * fname, bool only_meta);
2218
-
2219
- // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
2220
- WSP_GGML_API size_t wsp_gguf_get_meta_size(const struct wsp_gguf_context * ctx);
2221
- WSP_GGML_API void wsp_gguf_get_meta_data(const struct wsp_gguf_context * ctx, void * data);
2200
+ // threadpool params
2201
+ // Use wsp_ggml_threadpool_params_default() or wsp_ggml_threadpool_params_init() to populate the defaults
2202
+ struct wsp_ggml_threadpool_params {
2203
+ bool cpumask[WSP_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
2204
+ int n_threads; // number of threads
2205
+ enum wsp_ggml_sched_priority prio; // thread priority
2206
+ uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
2207
+ bool strict_cpu; // strict cpu placement
2208
+ bool paused; // start in paused state
2209
+ };
2222
2210
 
2223
- //
2224
- // system info
2225
- //
2211
+ struct wsp_ggml_threadpool; // forward declaration, see ggml.c
2226
2212
 
2227
- WSP_GGML_API int wsp_ggml_cpu_has_avx (void);
2228
- WSP_GGML_API int wsp_ggml_cpu_has_avx_vnni (void);
2229
- WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void);
2230
- WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void);
2231
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void);
2232
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void);
2233
- WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
2234
- WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
2235
- WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
2236
- WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
2237
- WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
2238
- WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
2239
- WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
2240
- WSP_GGML_API int wsp_ggml_cpu_has_blas (void);
2241
- WSP_GGML_API int wsp_ggml_cpu_has_cublas (void);
2242
- WSP_GGML_API int wsp_ggml_cpu_has_clblast (void);
2243
- WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void);
2244
- WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void);
2245
- WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void);
2246
- WSP_GGML_API int wsp_ggml_cpu_has_vsx (void);
2213
+ typedef struct wsp_ggml_threadpool * wsp_ggml_threadpool_t;
2247
2214
 
2248
- //
2249
- // Internal types and functions exposed for tests and benchmarks
2250
- //
2251
-
2252
- #ifdef __cplusplus
2253
- // restrict not standard in C++
2254
- #define WSP_GGML_RESTRICT
2255
- #else
2256
- #define WSP_GGML_RESTRICT restrict
2257
- #endif
2258
- typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
2259
- typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
2260
- typedef void (*wsp_ggml_vec_dot_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
2261
-
2262
- typedef struct {
2263
- const char * type_name;
2264
- int blck_size;
2265
- size_t type_size;
2266
- bool is_quantized;
2267
- wsp_ggml_to_float_t to_float;
2268
- wsp_ggml_from_float_t from_float;
2269
- wsp_ggml_from_float_t from_float_reference;
2270
- wsp_ggml_vec_dot_t vec_dot;
2271
- enum wsp_ggml_type vec_dot_type;
2272
- } wsp_ggml_type_traits_t;
2273
-
2274
- WSP_GGML_API wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2215
+ WSP_GGML_API struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads);
2216
+ WSP_GGML_API void wsp_ggml_threadpool_params_init (struct wsp_ggml_threadpool_params * p, int n_threads);
2217
+ WSP_GGML_API bool wsp_ggml_threadpool_params_match (const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1);
2275
2218
 
2276
2219
  #ifdef __cplusplus
2277
2220
  }