whisper.rn 0.4.0-rc.1 → 0.4.0-rc.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +21 -1
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -92
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +86 -40
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +85 -131
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +226 -109
  9. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  10. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  11. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  12. package/cpp/coreml/whisper-encoder.h +4 -0
  13. package/cpp/coreml/whisper-encoder.mm +5 -3
  14. package/cpp/ggml-alloc.c +797 -400
  15. package/cpp/ggml-alloc.h +60 -10
  16. package/cpp/ggml-backend-impl.h +255 -0
  17. package/cpp/ggml-backend-reg.cpp +582 -0
  18. package/cpp/ggml-backend.cpp +2002 -0
  19. package/cpp/ggml-backend.h +354 -0
  20. package/cpp/ggml-common.h +1851 -0
  21. package/cpp/ggml-cpp.h +39 -0
  22. package/cpp/ggml-cpu-aarch64.cpp +4247 -0
  23. package/cpp/ggml-cpu-aarch64.h +8 -0
  24. package/cpp/ggml-cpu-impl.h +531 -0
  25. package/cpp/ggml-cpu-quants.c +12245 -0
  26. package/cpp/ggml-cpu-quants.h +63 -0
  27. package/cpp/ggml-cpu-traits.cpp +36 -0
  28. package/cpp/ggml-cpu-traits.h +38 -0
  29. package/cpp/ggml-cpu.c +14792 -0
  30. package/cpp/ggml-cpu.cpp +653 -0
  31. package/cpp/ggml-cpu.h +137 -0
  32. package/cpp/ggml-impl.h +567 -0
  33. package/cpp/ggml-metal-impl.h +288 -0
  34. package/cpp/ggml-metal.h +24 -43
  35. package/cpp/ggml-metal.m +4867 -1080
  36. package/cpp/ggml-opt.cpp +854 -0
  37. package/cpp/ggml-opt.h +216 -0
  38. package/cpp/ggml-quants.c +5238 -0
  39. package/cpp/ggml-quants.h +100 -0
  40. package/cpp/ggml-threading.cpp +12 -0
  41. package/cpp/ggml-threading.h +14 -0
  42. package/cpp/ggml-whisper.metallib +0 -0
  43. package/cpp/ggml.c +5106 -19431
  44. package/cpp/ggml.h +847 -669
  45. package/cpp/gguf.cpp +1329 -0
  46. package/cpp/gguf.h +202 -0
  47. package/cpp/rn-audioutils.cpp +68 -0
  48. package/cpp/rn-audioutils.h +14 -0
  49. package/cpp/rn-whisper-log.h +11 -0
  50. package/cpp/rn-whisper.cpp +221 -52
  51. package/cpp/rn-whisper.h +50 -15
  52. package/cpp/whisper.cpp +3174 -1533
  53. package/cpp/whisper.h +176 -44
  54. package/ios/RNWhisper.mm +139 -46
  55. package/ios/RNWhisperAudioUtils.h +1 -2
  56. package/ios/RNWhisperAudioUtils.m +18 -67
  57. package/ios/RNWhisperContext.h +11 -8
  58. package/ios/RNWhisperContext.mm +195 -150
  59. package/jest/mock.js +15 -2
  60. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  61. package/lib/commonjs/index.js +76 -28
  62. package/lib/commonjs/index.js.map +1 -1
  63. package/lib/commonjs/version.json +1 -1
  64. package/lib/module/NativeRNWhisper.js.map +1 -1
  65. package/lib/module/index.js +76 -28
  66. package/lib/module/index.js.map +1 -1
  67. package/lib/module/version.json +1 -1
  68. package/lib/typescript/NativeRNWhisper.d.ts +13 -4
  69. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  70. package/lib/typescript/index.d.ts +37 -5
  71. package/lib/typescript/index.d.ts.map +1 -1
  72. package/package.json +9 -7
  73. package/src/NativeRNWhisper.ts +20 -4
  74. package/src/index.ts +98 -42
  75. package/src/version.json +1 -1
  76. package/whisper-rn.podspec +13 -20
  77. package/cpp/README.md +0 -4
  78. package/cpp/ggml-metal.metal +0 -2353
package/cpp/ggml.h CHANGED
@@ -58,7 +58,8 @@
58
58
  // {
59
59
  // ...
60
60
  //
61
- // struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(f);
61
+ // struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx);
62
+ // wsp_ggml_build_forward_expand(gf, f);
62
63
  //
63
64
  // // set the input variable and parameter values
64
65
  // wsp_ggml_set_f32(x, 2.0f);
@@ -175,15 +176,15 @@
175
176
  #ifdef WSP_GGML_SHARED
176
177
  # if defined(_WIN32) && !defined(__MINGW32__)
177
178
  # ifdef WSP_GGML_BUILD
178
- # define WSP_GGML_API __declspec(dllexport)
179
+ # define WSP_GGML_API __declspec(dllexport) extern
179
180
  # else
180
- # define WSP_GGML_API __declspec(dllimport)
181
+ # define WSP_GGML_API __declspec(dllimport) extern
181
182
  # endif
182
183
  # else
183
- # define WSP_GGML_API __attribute__ ((visibility ("default")))
184
+ # define WSP_GGML_API __attribute__ ((visibility ("default"))) extern
184
185
  # endif
185
186
  #else
186
- # define WSP_GGML_API
187
+ # define WSP_GGML_API extern
187
188
  #endif
188
189
 
189
190
  // TODO: support for clang
@@ -197,30 +198,35 @@
197
198
 
198
199
  #ifndef __GNUC__
199
200
  # define WSP_GGML_ATTRIBUTE_FORMAT(...)
200
- #elif defined(__MINGW32__)
201
+ #elif defined(__MINGW32__) && !defined(__clang__)
201
202
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
202
203
  #else
203
204
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
204
205
  #endif
205
206
 
206
- #include <stdint.h>
207
- #include <stddef.h>
208
207
  #include <stdbool.h>
208
+ #include <stddef.h>
209
+ #include <stdint.h>
210
+ #include <stdio.h>
209
211
 
210
212
  #define WSP_GGML_FILE_MAGIC 0x67676d6c // "ggml"
211
- #define WSP_GGML_FILE_VERSION 1
213
+ #define WSP_GGML_FILE_VERSION 2
212
214
 
213
215
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
214
216
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
215
217
 
216
- #define WSP_GGML_MAX_DIMS 4
217
- #define WSP_GGML_MAX_NODES 4096
218
- #define WSP_GGML_MAX_PARAMS 256
219
- #define WSP_GGML_MAX_CONTEXTS 64
220
- #define WSP_GGML_MAX_SRC 6
221
- #define WSP_GGML_MAX_NAME 64
222
- #define WSP_GGML_MAX_OP_PARAMS 32
223
- #define WSP_GGML_DEFAULT_N_THREADS 4
218
+ #define WSP_GGML_MAX_DIMS 4
219
+ #define WSP_GGML_MAX_PARAMS 2048
220
+ #define WSP_GGML_MAX_SRC 10
221
+ #define WSP_GGML_MAX_N_THREADS 512
222
+ #define WSP_GGML_MAX_OP_PARAMS 64
223
+
224
+ #ifndef WSP_GGML_MAX_NAME
225
+ # define WSP_GGML_MAX_NAME 64
226
+ #endif
227
+
228
+ #define WSP_GGML_DEFAULT_N_THREADS 4
229
+ #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
224
230
 
225
231
  #if UINTPTR_MAX == 0xFFFFFFFF
226
232
  #define WSP_GGML_MEM_ALIGN 4
@@ -231,22 +237,34 @@
231
237
  #define WSP_GGML_EXIT_SUCCESS 0
232
238
  #define WSP_GGML_EXIT_ABORTED 1
233
239
 
234
- #define GGUF_MAGIC 0x46554747 // "GGUF"
235
- #define GGUF_VERSION 2
236
-
237
- #define GGUF_DEFAULT_ALIGNMENT 32
240
+ #define WSP_GGML_ROPE_TYPE_NEOX 2
241
+ #define WSP_GGML_ROPE_TYPE_MROPE 8
242
+ #define WSP_GGML_ROPE_TYPE_VISION 24
238
243
 
239
244
  #define WSP_GGML_UNUSED(x) (void)(x)
240
245
 
241
246
  #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
242
247
 
243
- #define WSP_GGML_ASSERT(x) \
244
- do { \
245
- if (!(x)) { \
246
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
247
- abort(); \
248
- } \
249
- } while (0)
248
+ #ifndef NDEBUG
249
+ # define WSP_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
250
+ #elif defined(__GNUC__)
251
+ # define WSP_GGML_UNREACHABLE() __builtin_unreachable()
252
+ #elif defined(_MSC_VER)
253
+ # define WSP_GGML_UNREACHABLE() __assume(0)
254
+ #else
255
+ # define WSP_GGML_UNREACHABLE() ((void) 0)
256
+ #endif
257
+
258
+ #ifdef __cplusplus
259
+ # define WSP_GGML_NORETURN [[noreturn]]
260
+ #elif defined(_MSC_VER)
261
+ # define WSP_GGML_NORETURN __declspec(noreturn)
262
+ #else
263
+ # define WSP_GGML_NORETURN _Noreturn
264
+ #endif
265
+
266
+ #define WSP_GGML_ABORT(...) wsp_ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
267
+ #define WSP_GGML_ASSERT(x) if (!(x)) WSP_GGML_ABORT("WSP_GGML_ASSERT(%s) failed", #x)
250
268
 
251
269
  // used to copy the number of elements and stride in bytes of tensors into local variables.
252
270
  // main purpose is to reduce code duplication and improve readability.
@@ -272,74 +290,139 @@
272
290
  const type prefix##3 = (pointer)->array[3]; \
273
291
  WSP_GGML_UNUSED(prefix##3);
274
292
 
293
+ #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
294
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
295
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
296
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
297
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
298
+
299
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
300
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
301
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
302
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
303
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
304
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
305
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
306
+
307
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS01 \
308
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
310
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
311
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
312
+
275
313
  #ifdef __cplusplus
276
314
  extern "C" {
277
315
  #endif
278
316
 
279
- #if defined(__ARM_NEON) && defined(__CUDACC__)
280
- typedef half wsp_ggml_fp16_t;
281
- #elif defined(__ARM_NEON)
282
- typedef __fp16 wsp_ggml_fp16_t;
283
- #else
284
- typedef uint16_t wsp_ggml_fp16_t;
285
- #endif
317
+ WSP_GGML_NORETURN WSP_GGML_ATTRIBUTE_FORMAT(3, 4)
318
+ WSP_GGML_API void wsp_ggml_abort(const char * file, int line, const char * fmt, ...);
319
+
320
+ enum wsp_ggml_status {
321
+ WSP_GGML_STATUS_ALLOC_FAILED = -2,
322
+ WSP_GGML_STATUS_FAILED = -1,
323
+ WSP_GGML_STATUS_SUCCESS = 0,
324
+ WSP_GGML_STATUS_ABORTED = 1,
325
+ };
286
326
 
287
- // convert FP16 <-> FP32
288
- WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x);
289
- WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x);
327
+ // get wsp_ggml_status name string
328
+ WSP_GGML_API const char * wsp_ggml_status_to_string(enum wsp_ggml_status status);
290
329
 
291
- WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, int n);
292
- WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, int n);
330
+ // ieee 754-2008 half-precision float16
331
+ // todo: make this not an integral type
332
+ typedef uint16_t wsp_ggml_fp16_t;
333
+ WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t);
334
+ WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float);
335
+ WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t *, float *, int64_t);
336
+ WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float *, wsp_ggml_fp16_t *, int64_t);
337
+
338
+ // google brain half-precision bfloat16
339
+ typedef struct { uint16_t bits; } wsp_ggml_bf16_t;
340
+ WSP_GGML_API wsp_ggml_bf16_t wsp_ggml_fp32_to_bf16(float);
341
+ WSP_GGML_API float wsp_ggml_bf16_to_fp32(wsp_ggml_bf16_t); // consider just doing << 16
342
+ WSP_GGML_API void wsp_ggml_bf16_to_fp32_row(const wsp_ggml_bf16_t *, float *, int64_t);
343
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row_ref(const float *, wsp_ggml_bf16_t *, int64_t);
344
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row(const float *, wsp_ggml_bf16_t *, int64_t);
293
345
 
294
346
  struct wsp_ggml_object;
295
347
  struct wsp_ggml_context;
348
+ struct wsp_ggml_cgraph;
296
349
 
350
+ // NOTE: always add types at the end of the enum to keep backward compatibility
297
351
  enum wsp_ggml_type {
298
- WSP_GGML_TYPE_F32 = 0,
299
- WSP_GGML_TYPE_F16 = 1,
300
- WSP_GGML_TYPE_Q4_0 = 2,
301
- WSP_GGML_TYPE_Q4_1 = 3,
352
+ WSP_GGML_TYPE_F32 = 0,
353
+ WSP_GGML_TYPE_F16 = 1,
354
+ WSP_GGML_TYPE_Q4_0 = 2,
355
+ WSP_GGML_TYPE_Q4_1 = 3,
302
356
  // WSP_GGML_TYPE_Q4_2 = 4, support has been removed
303
- // WSP_GGML_TYPE_Q4_3 (5) support has been removed
304
- WSP_GGML_TYPE_Q5_0 = 6,
305
- WSP_GGML_TYPE_Q5_1 = 7,
306
- WSP_GGML_TYPE_Q8_0 = 8,
307
- WSP_GGML_TYPE_Q8_1 = 9,
308
- // k-quantizations
309
- WSP_GGML_TYPE_Q2_K = 10,
310
- WSP_GGML_TYPE_Q3_K = 11,
311
- WSP_GGML_TYPE_Q4_K = 12,
312
- WSP_GGML_TYPE_Q5_K = 13,
313
- WSP_GGML_TYPE_Q6_K = 14,
314
- WSP_GGML_TYPE_Q8_K = 15,
315
- WSP_GGML_TYPE_I8,
316
- WSP_GGML_TYPE_I16,
317
- WSP_GGML_TYPE_I32,
318
- WSP_GGML_TYPE_COUNT,
357
+ // WSP_GGML_TYPE_Q4_3 = 5, support has been removed
358
+ WSP_GGML_TYPE_Q5_0 = 6,
359
+ WSP_GGML_TYPE_Q5_1 = 7,
360
+ WSP_GGML_TYPE_Q8_0 = 8,
361
+ WSP_GGML_TYPE_Q8_1 = 9,
362
+ WSP_GGML_TYPE_Q2_K = 10,
363
+ WSP_GGML_TYPE_Q3_K = 11,
364
+ WSP_GGML_TYPE_Q4_K = 12,
365
+ WSP_GGML_TYPE_Q5_K = 13,
366
+ WSP_GGML_TYPE_Q6_K = 14,
367
+ WSP_GGML_TYPE_Q8_K = 15,
368
+ WSP_GGML_TYPE_IQ2_XXS = 16,
369
+ WSP_GGML_TYPE_IQ2_XS = 17,
370
+ WSP_GGML_TYPE_IQ3_XXS = 18,
371
+ WSP_GGML_TYPE_IQ1_S = 19,
372
+ WSP_GGML_TYPE_IQ4_NL = 20,
373
+ WSP_GGML_TYPE_IQ3_S = 21,
374
+ WSP_GGML_TYPE_IQ2_S = 22,
375
+ WSP_GGML_TYPE_IQ4_XS = 23,
376
+ WSP_GGML_TYPE_I8 = 24,
377
+ WSP_GGML_TYPE_I16 = 25,
378
+ WSP_GGML_TYPE_I32 = 26,
379
+ WSP_GGML_TYPE_I64 = 27,
380
+ WSP_GGML_TYPE_F64 = 28,
381
+ WSP_GGML_TYPE_IQ1_M = 29,
382
+ WSP_GGML_TYPE_BF16 = 30,
383
+ // WSP_GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
384
+ // WSP_GGML_TYPE_Q4_0_4_8 = 32,
385
+ // WSP_GGML_TYPE_Q4_0_8_8 = 33,
386
+ WSP_GGML_TYPE_TQ1_0 = 34,
387
+ WSP_GGML_TYPE_TQ2_0 = 35,
388
+ // WSP_GGML_TYPE_IQ4_NL_4_4 = 36,
389
+ // WSP_GGML_TYPE_IQ4_NL_4_8 = 37,
390
+ // WSP_GGML_TYPE_IQ4_NL_8_8 = 38,
391
+ WSP_GGML_TYPE_COUNT = 39,
319
392
  };
320
393
 
321
- enum wsp_ggml_backend {
322
- WSP_GGML_BACKEND_CPU = 0,
323
- WSP_GGML_BACKEND_GPU = 10,
324
- WSP_GGML_BACKEND_GPU_SPLIT = 20,
394
+ // precision
395
+ enum wsp_ggml_prec {
396
+ WSP_GGML_PREC_DEFAULT,
397
+ WSP_GGML_PREC_F32,
325
398
  };
326
399
 
327
400
  // model file types
328
401
  enum wsp_ggml_ftype {
329
- WSP_GGML_FTYPE_UNKNOWN = -1,
330
- WSP_GGML_FTYPE_ALL_F32 = 0,
331
- WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
332
- WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
333
- WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
402
+ WSP_GGML_FTYPE_UNKNOWN = -1,
403
+ WSP_GGML_FTYPE_ALL_F32 = 0,
404
+ WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
405
+ WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
406
+ WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
334
407
  WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
335
- WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
336
- WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
337
- WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
338
- WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
339
- WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
340
- WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
341
- WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
342
- WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
408
+ WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
409
+ WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
410
+ WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
411
+ WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
412
+ WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
413
+ WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
414
+ WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
415
+ WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
416
+ WSP_GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
417
+ WSP_GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
418
+ WSP_GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
419
+ WSP_GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
420
+ WSP_GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
421
+ WSP_GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
422
+ WSP_GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
423
+ WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
424
+ WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
425
+ WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
343
426
  };
344
427
 
345
428
  // available tensor operations:
@@ -356,10 +439,13 @@ extern "C" {
356
439
  WSP_GGML_OP_SQR,
357
440
  WSP_GGML_OP_SQRT,
358
441
  WSP_GGML_OP_LOG,
442
+ WSP_GGML_OP_SIN,
443
+ WSP_GGML_OP_COS,
359
444
  WSP_GGML_OP_SUM,
360
445
  WSP_GGML_OP_SUM_ROWS,
361
446
  WSP_GGML_OP_MEAN,
362
447
  WSP_GGML_OP_ARGMAX,
448
+ WSP_GGML_OP_COUNT_EQUAL,
363
449
  WSP_GGML_OP_REPEAT,
364
450
  WSP_GGML_OP_REPEAT_BACK,
365
451
  WSP_GGML_OP_CONCAT,
@@ -370,6 +456,7 @@ extern "C" {
370
456
  WSP_GGML_OP_GROUP_NORM,
371
457
 
372
458
  WSP_GGML_OP_MUL_MAT,
459
+ WSP_GGML_OP_MUL_MAT_ID,
373
460
  WSP_GGML_OP_OUT_PROD,
374
461
 
375
462
  WSP_GGML_OP_SCALE,
@@ -389,23 +476,32 @@ extern "C" {
389
476
  WSP_GGML_OP_SOFT_MAX_BACK,
390
477
  WSP_GGML_OP_ROPE,
391
478
  WSP_GGML_OP_ROPE_BACK,
392
- WSP_GGML_OP_ALIBI,
393
479
  WSP_GGML_OP_CLAMP,
394
- WSP_GGML_OP_CONV_1D,
395
- WSP_GGML_OP_CONV_2D,
480
+ WSP_GGML_OP_CONV_TRANSPOSE_1D,
481
+ WSP_GGML_OP_IM2COL,
482
+ WSP_GGML_OP_IM2COL_BACK,
396
483
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
397
484
  WSP_GGML_OP_POOL_1D,
398
485
  WSP_GGML_OP_POOL_2D,
399
-
486
+ WSP_GGML_OP_POOL_2D_BACK,
400
487
  WSP_GGML_OP_UPSCALE, // nearest interpolate
401
-
402
- WSP_GGML_OP_FLASH_ATTN,
403
- WSP_GGML_OP_FLASH_FF,
488
+ WSP_GGML_OP_PAD,
489
+ WSP_GGML_OP_PAD_REFLECT_1D,
490
+ WSP_GGML_OP_ARANGE,
491
+ WSP_GGML_OP_TIMESTEP_EMBEDDING,
492
+ WSP_GGML_OP_ARGSORT,
493
+ WSP_GGML_OP_LEAKY_RELU,
494
+
495
+ WSP_GGML_OP_FLASH_ATTN_EXT,
404
496
  WSP_GGML_OP_FLASH_ATTN_BACK,
497
+ WSP_GGML_OP_SSM_CONV,
498
+ WSP_GGML_OP_SSM_SCAN,
405
499
  WSP_GGML_OP_WIN_PART,
406
500
  WSP_GGML_OP_WIN_UNPART,
407
501
  WSP_GGML_OP_GET_REL_POS,
408
502
  WSP_GGML_OP_ADD_REL_POS,
503
+ WSP_GGML_OP_RWKV_WKV6,
504
+ WSP_GGML_OP_GATED_LINEAR_ATTN,
409
505
 
410
506
  WSP_GGML_OP_UNARY,
411
507
 
@@ -422,6 +518,7 @@ extern "C" {
422
518
 
423
519
  WSP_GGML_OP_CROSS_ENTROPY_LOSS,
424
520
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
521
+ WSP_GGML_OP_OPT_STEP_ADAMW,
425
522
 
426
523
  WSP_GGML_OP_COUNT,
427
524
  };
@@ -434,41 +531,57 @@ extern "C" {
434
531
  WSP_GGML_UNARY_OP_TANH,
435
532
  WSP_GGML_UNARY_OP_ELU,
436
533
  WSP_GGML_UNARY_OP_RELU,
534
+ WSP_GGML_UNARY_OP_SIGMOID,
437
535
  WSP_GGML_UNARY_OP_GELU,
438
536
  WSP_GGML_UNARY_OP_GELU_QUICK,
439
537
  WSP_GGML_UNARY_OP_SILU,
538
+ WSP_GGML_UNARY_OP_HARDSWISH,
539
+ WSP_GGML_UNARY_OP_HARDSIGMOID,
540
+ WSP_GGML_UNARY_OP_EXP,
541
+
542
+ WSP_GGML_UNARY_OP_COUNT,
440
543
  };
441
544
 
442
545
  enum wsp_ggml_object_type {
443
- WSP_GGML_OBJECT_TENSOR,
444
- WSP_GGML_OBJECT_GRAPH,
445
- WSP_GGML_OBJECT_WORK_BUFFER
546
+ WSP_GGML_OBJECT_TYPE_TENSOR,
547
+ WSP_GGML_OBJECT_TYPE_GRAPH,
548
+ WSP_GGML_OBJECT_TYPE_WORK_BUFFER
446
549
  };
447
550
 
448
- // ggml object
449
- struct wsp_ggml_object {
450
- size_t offs;
451
- size_t size;
452
-
453
- struct wsp_ggml_object * next;
454
-
455
- enum wsp_ggml_object_type type;
551
+ enum wsp_ggml_log_level {
552
+ WSP_GGML_LOG_LEVEL_NONE = 0,
553
+ WSP_GGML_LOG_LEVEL_DEBUG = 1,
554
+ WSP_GGML_LOG_LEVEL_INFO = 2,
555
+ WSP_GGML_LOG_LEVEL_WARN = 3,
556
+ WSP_GGML_LOG_LEVEL_ERROR = 4,
557
+ WSP_GGML_LOG_LEVEL_CONT = 5, // continue previous log
558
+ };
456
559
 
457
- char padding[4];
560
+ // this tensor...
561
+ enum wsp_ggml_tensor_flag {
562
+ WSP_GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
563
+ WSP_GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
564
+ WSP_GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
565
+ WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
458
566
  };
459
567
 
460
- static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object);
568
+ struct wsp_ggml_init_params {
569
+ // memory pool
570
+ size_t mem_size; // bytes
571
+ void * mem_buffer; // if NULL, memory will be allocated internally
572
+ bool no_alloc; // don't allocate memory for the tensor data
573
+ };
461
574
 
462
575
  // n-dimensional tensor
463
576
  struct wsp_ggml_tensor {
464
- enum wsp_ggml_type type;
465
- enum wsp_ggml_backend backend;
577
+ enum wsp_ggml_type type;
578
+
579
+ struct wsp_ggml_backend_buffer * buffer;
466
580
 
467
- int n_dims;
468
581
  int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements
469
582
  size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes:
470
- // nb[0] = sizeof(type)
471
- // nb[1] = nb[0] * ne[0] + padding
583
+ // nb[0] = wsp_ggml_type_size(type)
584
+ // nb[1] = nb[0] * (ne[0] / wsp_ggml_blck_size(type)) + padding
472
585
  // nb[i] = nb[i-1] * ne[i-1]
473
586
 
474
587
  // compute data
@@ -477,16 +590,11 @@ extern "C" {
477
590
  // op params - allocated as int32_t for alignment
478
591
  int32_t op_params[WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t)];
479
592
 
480
- bool is_param;
593
+ int32_t flags;
481
594
 
482
- struct wsp_ggml_tensor * grad;
483
595
  struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
484
596
 
485
- // performance
486
- int perf_runs;
487
- int64_t perf_cycles;
488
- int64_t perf_time_us;
489
-
597
+ // source tensor and offset for views
490
598
  struct wsp_ggml_tensor * view_src;
491
599
  size_t view_offs;
492
600
 
@@ -496,86 +604,26 @@ extern "C" {
496
604
 
497
605
  void * extra; // extra things e.g. for ggml-cuda.cu
498
606
 
499
- char padding[4];
607
+ char padding[8];
500
608
  };
501
609
 
502
610
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
503
611
 
504
- // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
505
- // since https://github.com/ggerganov/ggml/issues/287
506
- struct wsp_ggml_cplan {
507
- size_t work_size; // size of work buffer, calculated by `wsp_ggml_graph_plan()`
508
- uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
509
-
510
- int n_threads;
511
-
512
- // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
513
- int n_tasks[WSP_GGML_MAX_NODES];
514
-
515
- // abort wsp_ggml_graph_compute when true
516
- bool (*abort_callback)(void * data);
517
- void * abort_callback_data;
518
- };
519
-
520
- // next prime after WSP_GGML_MAX_NODES
521
- // #define WSP_GGML_GRAPH_HASHTABLE_SIZE 4099
522
- // next prime after WSP_GGML_MAX_NODES * 2 (nodes + leafs)
523
- #define WSP_GGML_GRAPH_HASHTABLE_SIZE 8273
524
-
525
- // computation graph
526
- struct wsp_ggml_cgraph {
527
- int n_nodes;
528
- int n_leafs;
529
-
530
- struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES];
531
- struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES];
532
- struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES];
612
+ // Abort callback
613
+ // If not NULL, called before ggml computation
614
+ // If it returns true, the computation is aborted
615
+ typedef bool (*wsp_ggml_abort_callback)(void * data);
533
616
 
534
- void * visited_hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE];
535
617
 
536
- // performance
537
- int perf_runs;
538
- int64_t perf_cycles;
539
- int64_t perf_time_us;
540
- };
541
-
542
- static const size_t WSP_GGML_GRAPH_SIZE = sizeof(struct wsp_ggml_cgraph);
543
-
544
- // scratch buffer
545
- struct wsp_ggml_scratch {
546
- size_t offs;
547
- size_t size;
548
- void * data;
549
- };
550
-
551
- struct wsp_ggml_init_params {
552
- // memory pool
553
- size_t mem_size; // bytes
554
- void * mem_buffer; // if NULL, memory will be allocated internally
555
- bool no_alloc; // don't allocate memory for the tensor data
556
- };
557
-
558
-
559
- // compute types
560
-
561
- // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
562
- // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
563
- enum wsp_ggml_task_type {
564
- WSP_GGML_TASK_INIT = 0,
565
- WSP_GGML_TASK_COMPUTE,
566
- WSP_GGML_TASK_FINALIZE,
567
- };
568
-
569
- struct wsp_ggml_compute_params {
570
- enum wsp_ggml_task_type type;
618
+ //
619
+ // GUID
620
+ //
571
621
 
572
- // ith = thread index, nth = number of threads
573
- int ith, nth;
622
+ // GUID types
623
+ typedef uint8_t wsp_ggml_guid[16];
624
+ typedef wsp_ggml_guid * wsp_ggml_guid_t;
574
625
 
575
- // work buffer for all threads
576
- size_t wsize;
577
- void * wdata;
578
- };
626
+ WSP_GGML_API bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b);
579
627
 
580
628
  // misc
581
629
 
@@ -585,26 +633,32 @@ extern "C" {
585
633
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
586
634
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
587
635
 
588
- WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
589
- WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
636
+ // accepts a UTF-8 path, even on Windows
637
+ WSP_GGML_API FILE * wsp_ggml_fopen(const char * fname, const char * mode);
590
638
 
591
639
  WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj);
592
640
  WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx);
593
641
 
594
- WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
595
- WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
596
- WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
597
- WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
598
- WSP_GGML_API size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split);
642
+ WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
643
+ WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
644
+ WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
645
+ WSP_GGML_API size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
599
646
 
600
- WSP_GGML_API int wsp_ggml_blck_size (enum wsp_ggml_type type);
601
- WSP_GGML_API size_t wsp_ggml_type_size (enum wsp_ggml_type type); // size in bytes for all elements in a block
602
- WSP_GGML_API float wsp_ggml_type_sizef(enum wsp_ggml_type type); // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
647
+ WSP_GGML_API int64_t wsp_ggml_blck_size(enum wsp_ggml_type type);
648
+ WSP_GGML_API size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
649
+ WSP_GGML_API size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
650
+
651
+ WSP_GGML_DEPRECATED(
652
+ WSP_GGML_API double wsp_ggml_type_sizef(enum wsp_ggml_type type), // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
653
+ "use wsp_ggml_row_size() instead");
603
654
 
604
655
  WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type);
605
656
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
606
657
  WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
607
658
 
659
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
660
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
661
+
608
662
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
609
663
 
610
664
  WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
@@ -613,22 +667,37 @@ extern "C" {
613
667
  WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype);
614
668
 
615
669
  WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
616
- WSP_GGML_API bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor);
617
670
  WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
671
+ WSP_GGML_API bool wsp_ggml_is_empty (const struct wsp_ggml_tensor * tensor);
672
+ WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
673
+ WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
674
+ WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
675
+ WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
676
+ WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
618
677
 
619
- WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
678
+ WSP_GGML_API bool wsp_ggml_is_contiguous (const struct wsp_ggml_tensor * tensor);
679
+ WSP_GGML_API bool wsp_ggml_is_contiguous_0(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_is_contiguous()
680
+ WSP_GGML_API bool wsp_ggml_is_contiguous_1(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 1
681
+ WSP_GGML_API bool wsp_ggml_is_contiguous_2(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 2
682
+
683
+ WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
684
+ WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
685
+
686
+ WSP_GGML_API bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
620
687
 
621
688
  // use this to compute the memory overhead of a tensor
622
689
  WSP_GGML_API size_t wsp_ggml_tensor_overhead(void);
623
690
 
691
+ WSP_GGML_API bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size_t nbytes);
692
+
624
693
  // main
625
694
 
626
- WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params);
627
- WSP_GGML_API void wsp_ggml_free(struct wsp_ggml_context * ctx);
695
+ WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init (struct wsp_ggml_init_params params);
696
+ WSP_GGML_API void wsp_ggml_reset(struct wsp_ggml_context * ctx);
697
+ WSP_GGML_API void wsp_ggml_free (struct wsp_ggml_context * ctx);
628
698
 
629
699
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
630
700
 
631
- WSP_GGML_API size_t wsp_ggml_set_scratch (struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch);
632
701
  WSP_GGML_API bool wsp_ggml_get_no_alloc(struct wsp_ggml_context * ctx);
633
702
  WSP_GGML_API void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc);
634
703
 
@@ -668,34 +737,35 @@ extern "C" {
668
737
  int64_t ne2,
669
738
  int64_t ne3);
670
739
 
671
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value);
672
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value);
740
+ WSP_GGML_API void * wsp_ggml_new_buffer(struct wsp_ggml_context * ctx, size_t nbytes);
673
741
 
674
742
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
675
743
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
676
744
 
745
+ // Context tensor enumeration and lookup
746
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(const struct wsp_ggml_context * ctx);
747
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
677
748
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
678
749
 
679
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
680
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value);
681
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value);
682
-
683
- WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i);
684
- WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value);
750
+ // Converts a flat index into coordinates
751
+ WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
685
752
 
686
- WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i);
687
- WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value);
753
+ WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
688
754
 
689
755
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
690
756
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
691
757
 
692
- WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
693
-
694
758
  WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor);
695
759
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name);
696
760
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
697
761
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name( struct wsp_ggml_tensor * tensor, const char * fmt, ...);
698
762
 
763
+ // Tensor flags
764
+ WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
765
+ WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
766
+ WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
767
+ WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
768
+
699
769
  //
700
770
  // operations on tensors with backpropagation
701
771
  //
@@ -719,6 +789,12 @@ extern "C" {
719
789
  struct wsp_ggml_tensor * a,
720
790
  struct wsp_ggml_tensor * b);
721
791
 
792
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_cast(
793
+ struct wsp_ggml_context * ctx,
794
+ struct wsp_ggml_tensor * a,
795
+ struct wsp_ggml_tensor * b,
796
+ enum wsp_ggml_type type);
797
+
722
798
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1(
723
799
  struct wsp_ggml_context * ctx,
724
800
  struct wsp_ggml_tensor * a,
@@ -729,6 +805,9 @@ extern "C" {
729
805
  struct wsp_ggml_tensor * a,
730
806
  struct wsp_ggml_tensor * b);
731
807
 
808
+ // dst = a
809
+ // view(dst, nb1, nb2, nb3, offset) += b
810
+ // return dst
732
811
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_acc(
733
812
  struct wsp_ggml_context * ctx,
734
813
  struct wsp_ggml_tensor * a,
@@ -801,6 +880,22 @@ extern "C" {
801
880
  struct wsp_ggml_context * ctx,
802
881
  struct wsp_ggml_tensor * a);
803
882
 
883
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin(
884
+ struct wsp_ggml_context * ctx,
885
+ struct wsp_ggml_tensor * a);
886
+
887
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin_inplace(
888
+ struct wsp_ggml_context * ctx,
889
+ struct wsp_ggml_tensor * a);
890
+
891
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos(
892
+ struct wsp_ggml_context * ctx,
893
+ struct wsp_ggml_tensor * a);
894
+
895
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos_inplace(
896
+ struct wsp_ggml_context * ctx,
897
+ struct wsp_ggml_tensor * a);
898
+
804
899
  // return scalar
805
900
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum(
806
901
  struct wsp_ggml_context * ctx,
@@ -821,6 +916,12 @@ extern "C" {
821
916
  struct wsp_ggml_context * ctx,
822
917
  struct wsp_ggml_tensor * a);
823
918
 
919
+ // count number of equal elements in a and b
920
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_count_equal(
921
+ struct wsp_ggml_context * ctx,
922
+ struct wsp_ggml_tensor * a,
923
+ struct wsp_ggml_tensor * b);
924
+
824
925
  // if a is the same shape as b, and a is not parameter, return a
825
926
  // otherwise, return a new tensor: repeat(a) to fit in b
826
927
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat(
@@ -828,17 +929,19 @@ extern "C" {
828
929
  struct wsp_ggml_tensor * a,
829
930
  struct wsp_ggml_tensor * b);
830
931
 
932
+ // sums repetitions in a into shape of b
831
933
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back(
832
934
  struct wsp_ggml_context * ctx,
833
935
  struct wsp_ggml_tensor * a,
834
936
  struct wsp_ggml_tensor * b);
835
937
 
836
- // concat a and b on dim 2
938
+ // concat a and b along dim
837
939
  // used in stable-diffusion
838
940
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_concat(
839
941
  struct wsp_ggml_context * ctx,
840
942
  struct wsp_ggml_tensor * a,
841
- struct wsp_ggml_tensor * b);
943
+ struct wsp_ggml_tensor * b,
944
+ int dim);
842
945
 
843
946
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs(
844
947
  struct wsp_ggml_context * ctx,
@@ -892,11 +995,22 @@ extern "C" {
892
995
  struct wsp_ggml_context * ctx,
893
996
  struct wsp_ggml_tensor * a);
894
997
 
998
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky_relu(
999
+ struct wsp_ggml_context * ctx,
1000
+ struct wsp_ggml_tensor * a, float negative_slope, bool inplace);
1001
+
895
1002
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
896
1003
  struct wsp_ggml_context * ctx,
897
1004
  struct wsp_ggml_tensor * a);
898
1005
 
899
- // TODO: double-check this computation is correct
1006
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid(
1007
+ struct wsp_ggml_context * ctx,
1008
+ struct wsp_ggml_tensor * a);
1009
+
1010
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid_inplace(
1011
+ struct wsp_ggml_context * ctx,
1012
+ struct wsp_ggml_tensor * a);
1013
+
900
1014
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu(
901
1015
  struct wsp_ggml_context * ctx,
902
1016
  struct wsp_ggml_tensor * a);
@@ -928,6 +1042,24 @@ extern "C" {
928
1042
  struct wsp_ggml_tensor * a,
929
1043
  struct wsp_ggml_tensor * b);
930
1044
 
1045
+ // hardswish(x) = x * relu6(x + 3) / 6
1046
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardswish(
1047
+ struct wsp_ggml_context * ctx,
1048
+ struct wsp_ggml_tensor * a);
1049
+
1050
+ // hardsigmoid(x) = relu6(x + 3) / 6
1051
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardsigmoid(
1052
+ struct wsp_ggml_context * ctx,
1053
+ struct wsp_ggml_tensor * a);
1054
+
1055
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp(
1056
+ struct wsp_ggml_context * ctx,
1057
+ struct wsp_ggml_tensor * a);
1058
+
1059
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
1060
+ struct wsp_ggml_context * ctx,
1061
+ struct wsp_ggml_tensor * a);
1062
+
931
1063
  // normalize along rows
932
1064
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
933
1065
  struct wsp_ggml_context * ctx,
@@ -951,16 +1083,17 @@ extern "C" {
951
1083
 
952
1084
  // group normalize along ne0*ne1*n_groups
953
1085
  // used in stable-diffusion
954
- // TODO: eps is hardcoded to 1e-6 for now
955
1086
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm(
956
1087
  struct wsp_ggml_context * ctx,
957
1088
  struct wsp_ggml_tensor * a,
958
- int n_groups);
1089
+ int n_groups,
1090
+ float eps);
959
1091
 
960
1092
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm_inplace(
961
1093
  struct wsp_ggml_context * ctx,
962
1094
  struct wsp_ggml_tensor * a,
963
- int n_groups);
1095
+ int n_groups,
1096
+ float eps);
964
1097
 
965
1098
  // a - x
966
1099
  // b - dy
@@ -970,14 +1103,27 @@ extern "C" {
970
1103
  struct wsp_ggml_tensor * b,
971
1104
  float eps);
972
1105
 
973
- // A: n columns, m rows
974
- // B: n columns, p rows (i.e. we transpose it internally)
975
- // result is m columns, p rows
1106
+ // A: k columns, n rows => [ne03, ne02, n, k]
1107
+ // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
1108
+ // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
976
1109
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat(
977
1110
  struct wsp_ggml_context * ctx,
978
1111
  struct wsp_ggml_tensor * a,
979
1112
  struct wsp_ggml_tensor * b);
980
1113
 
1114
+ // change the precision of a matrix multiplication
1115
+ // set to WSP_GGML_PREC_F32 for higher precision (useful for phi-2)
1116
+ WSP_GGML_API void wsp_ggml_mul_mat_set_prec(
1117
+ struct wsp_ggml_tensor * a,
1118
+ enum wsp_ggml_prec prec);
1119
+
1120
+ // indirect matrix multiplication
1121
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1122
+ struct wsp_ggml_context * ctx,
1123
+ struct wsp_ggml_tensor * as,
1124
+ struct wsp_ggml_tensor * b,
1125
+ struct wsp_ggml_tensor * ids);
1126
+
981
1127
  // A: m columns, n rows,
982
1128
  // B: p columns, n rows,
983
1129
  // result is m columns, p rows
@@ -993,13 +1139,13 @@ extern "C" {
993
1139
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale(
994
1140
  struct wsp_ggml_context * ctx,
995
1141
  struct wsp_ggml_tensor * a,
996
- struct wsp_ggml_tensor * b);
1142
+ float s);
997
1143
 
998
1144
  // in-place, returns view(a)
999
1145
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_inplace(
1000
1146
  struct wsp_ggml_context * ctx,
1001
1147
  struct wsp_ggml_tensor * a,
1002
- struct wsp_ggml_tensor * b);
1148
+ float s);
1003
1149
 
1004
1150
  // b -> view(a,offset,nb1,nb2,3), return modified a
1005
1151
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set(
@@ -1009,7 +1155,7 @@ extern "C" {
1009
1155
  size_t nb1,
1010
1156
  size_t nb2,
1011
1157
  size_t nb3,
1012
- size_t offset);
1158
+ size_t offset); // in bytes
1013
1159
 
1014
1160
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1015
1161
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_inplace(
@@ -1019,19 +1165,19 @@ extern "C" {
1019
1165
  size_t nb1,
1020
1166
  size_t nb2,
1021
1167
  size_t nb3,
1022
- size_t offset);
1168
+ size_t offset); // in bytes
1023
1169
 
1024
1170
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d(
1025
1171
  struct wsp_ggml_context * ctx,
1026
1172
  struct wsp_ggml_tensor * a,
1027
1173
  struct wsp_ggml_tensor * b,
1028
- size_t offset);
1174
+ size_t offset); // in bytes
1029
1175
 
1030
1176
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace(
1031
1177
  struct wsp_ggml_context * ctx,
1032
1178
  struct wsp_ggml_tensor * a,
1033
1179
  struct wsp_ggml_tensor * b,
1034
- size_t offset);
1180
+ size_t offset); // in bytes
1035
1181
 
1036
1182
  // b -> view(a,offset,nb1,nb2,3), return modified a
1037
1183
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d(
@@ -1039,7 +1185,7 @@ extern "C" {
1039
1185
  struct wsp_ggml_tensor * a,
1040
1186
  struct wsp_ggml_tensor * b,
1041
1187
  size_t nb1,
1042
- size_t offset);
1188
+ size_t offset); // in bytes
1043
1189
 
1044
1190
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1045
1191
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
@@ -1047,8 +1193,7 @@ extern "C" {
1047
1193
  struct wsp_ggml_tensor * a,
1048
1194
  struct wsp_ggml_tensor * b,
1049
1195
  size_t nb1,
1050
- size_t offset);
1051
-
1196
+ size_t offset); // in bytes
1052
1197
 
1053
1198
  // a -> b, return view(b)
1054
1199
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
@@ -1056,21 +1201,42 @@ extern "C" {
1056
1201
  struct wsp_ggml_tensor * a,
1057
1202
  struct wsp_ggml_tensor * b);
1058
1203
 
1059
- // a -> b, in-place, return view(b)
1060
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy_inplace(
1204
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cast(
1061
1205
  struct wsp_ggml_context * ctx,
1062
1206
  struct wsp_ggml_tensor * a,
1063
- struct wsp_ggml_tensor * b);
1207
+ enum wsp_ggml_type type);
1064
1208
 
1065
1209
  // make contiguous
1066
1210
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont(
1067
1211
  struct wsp_ggml_context * ctx,
1068
1212
  struct wsp_ggml_tensor * a);
1069
1213
 
1070
- // make contiguous, in-place
1071
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_inplace(
1214
+ // make contiguous, with new shape
1215
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d(
1072
1216
  struct wsp_ggml_context * ctx,
1073
- struct wsp_ggml_tensor * a);
1217
+ struct wsp_ggml_tensor * a,
1218
+ int64_t ne0);
1219
+
1220
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_2d(
1221
+ struct wsp_ggml_context * ctx,
1222
+ struct wsp_ggml_tensor * a,
1223
+ int64_t ne0,
1224
+ int64_t ne1);
1225
+
1226
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_3d(
1227
+ struct wsp_ggml_context * ctx,
1228
+ struct wsp_ggml_tensor * a,
1229
+ int64_t ne0,
1230
+ int64_t ne1,
1231
+ int64_t ne2);
1232
+
1233
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_4d(
1234
+ struct wsp_ggml_context * ctx,
1235
+ struct wsp_ggml_tensor * a,
1236
+ int64_t ne0,
1237
+ int64_t ne1,
1238
+ int64_t ne2,
1239
+ int64_t ne3);
1074
1240
 
1075
1241
  // return view(a), b specifies the new shape
1076
1242
  // TODO: when we start computing gradient, make a copy instead of view
@@ -1159,16 +1325,17 @@ extern "C" {
1159
1325
  struct wsp_ggml_context * ctx,
1160
1326
  struct wsp_ggml_tensor * a);
1161
1327
 
1328
+ // supports 3D: a->ne[2] == b->ne[1]
1162
1329
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1163
1330
  struct wsp_ggml_context * ctx,
1164
- struct wsp_ggml_tensor * a,
1165
- struct wsp_ggml_tensor * b);
1331
+ struct wsp_ggml_tensor * a, // data
1332
+ struct wsp_ggml_tensor * b); // row indices
1166
1333
 
1167
1334
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
1168
1335
  struct wsp_ggml_context * ctx,
1169
- struct wsp_ggml_tensor * a,
1170
- struct wsp_ggml_tensor * b,
1171
- struct wsp_ggml_tensor * c);
1336
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_get_rows result
1337
+ struct wsp_ggml_tensor * b, // row indices
1338
+ struct wsp_ggml_tensor * c); // data for wsp_ggml_get_rows, only used for its shape
1172
1339
 
1173
1340
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag(
1174
1341
  struct wsp_ggml_context * ctx,
@@ -1207,105 +1374,208 @@ extern "C" {
1207
1374
  struct wsp_ggml_context * ctx,
1208
1375
  struct wsp_ggml_tensor * a);
1209
1376
 
1210
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1377
+ // fused soft_max(a*scale + mask*(ALiBi slope))
1378
+ // mask is optional
1379
+ // max_bias = 0.0f for no ALiBi
1380
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1211
1381
  struct wsp_ggml_context * ctx,
1212
1382
  struct wsp_ggml_tensor * a,
1213
- struct wsp_ggml_tensor * b);
1383
+ struct wsp_ggml_tensor * mask,
1384
+ float scale,
1385
+ float max_bias);
1386
+
1387
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back(
1388
+ struct wsp_ggml_context * ctx,
1389
+ struct wsp_ggml_tensor * a,
1390
+ struct wsp_ggml_tensor * b,
1391
+ float scale,
1392
+ float max_bias);
1214
1393
 
1215
1394
  // in-place, returns view(a)
1216
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace(
1395
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext_back_inplace(
1217
1396
  struct wsp_ggml_context * ctx,
1218
1397
  struct wsp_ggml_tensor * a,
1219
- struct wsp_ggml_tensor * b);
1398
+ struct wsp_ggml_tensor * b,
1399
+ float scale,
1400
+ float max_bias);
1220
1401
 
1221
1402
  // rotary position embedding
1222
- // if mode & 1 == 1, skip n_past elements
1223
- // if mode & 2 == 1, GPT-NeoX style
1224
- // if mode & 4 == 1, ChatGLM style
1225
- // TODO: avoid creating a new tensor every time
1403
+ // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1404
+ // if (mode & WSP_GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1405
+ //
1406
+ // b is an int32 vector with size a->ne[2], it contains the positions
1226
1407
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
1227
1408
  struct wsp_ggml_context * ctx,
1228
1409
  struct wsp_ggml_tensor * a,
1229
- int n_past,
1410
+ struct wsp_ggml_tensor * b,
1230
1411
  int n_dims,
1231
- int mode,
1232
- int n_ctx);
1412
+ int mode);
1233
1413
 
1234
1414
  // in-place, returns view(a)
1235
1415
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
1236
1416
  struct wsp_ggml_context * ctx,
1237
1417
  struct wsp_ggml_tensor * a,
1238
- int n_past,
1418
+ struct wsp_ggml_tensor * b,
1239
1419
  int n_dims,
1240
- int mode,
1241
- int n_ctx);
1420
+ int mode);
1242
1421
 
1243
1422
  // custom RoPE
1244
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1423
+ // c is freq factors (e.g. phi3-128k), (optional)
1424
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext(
1245
1425
  struct wsp_ggml_context * ctx,
1246
1426
  struct wsp_ggml_tensor * a,
1247
- int n_past,
1427
+ struct wsp_ggml_tensor * b,
1428
+ struct wsp_ggml_tensor * c,
1248
1429
  int n_dims,
1249
1430
  int mode,
1250
- int n_ctx,
1431
+ int n_ctx_orig,
1251
1432
  float freq_base,
1252
- float freq_scale);
1433
+ float freq_scale,
1434
+ float ext_factor,
1435
+ float attn_factor,
1436
+ float beta_fast,
1437
+ float beta_slow);
1438
+
1439
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi(
1440
+ struct wsp_ggml_context * ctx,
1441
+ struct wsp_ggml_tensor * a,
1442
+ struct wsp_ggml_tensor * b,
1443
+ struct wsp_ggml_tensor * c,
1444
+ int n_dims,
1445
+ int sections[4],
1446
+ int mode,
1447
+ int n_ctx_orig,
1448
+ float freq_base,
1449
+ float freq_scale,
1450
+ float ext_factor,
1451
+ float attn_factor,
1452
+ float beta_fast,
1453
+ float beta_slow);
1253
1454
 
1254
1455
  // in-place, returns view(a)
1255
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1456
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
1457
+ struct wsp_ggml_context * ctx,
1458
+ struct wsp_ggml_tensor * a,
1459
+ struct wsp_ggml_tensor * b,
1460
+ struct wsp_ggml_tensor * c,
1461
+ int n_dims,
1462
+ int mode,
1463
+ int n_ctx_orig,
1464
+ float freq_base,
1465
+ float freq_scale,
1466
+ float ext_factor,
1467
+ float attn_factor,
1468
+ float beta_fast,
1469
+ float beta_slow);
1470
+
1471
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1256
1472
  struct wsp_ggml_context * ctx,
1257
1473
  struct wsp_ggml_tensor * a,
1258
- int n_past,
1474
+ struct wsp_ggml_tensor * b,
1259
1475
  int n_dims,
1260
1476
  int mode,
1261
- int n_ctx,
1477
+ int n_ctx_orig,
1262
1478
  float freq_base,
1263
- float freq_scale);
1479
+ float freq_scale,
1480
+ float ext_factor,
1481
+ float attn_factor,
1482
+ float beta_fast,
1483
+ float beta_slow),
1484
+ "use wsp_ggml_rope_ext instead");
1264
1485
 
1265
- // xPos RoPE, in-place, returns view(a)
1266
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1486
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1267
1487
  struct wsp_ggml_context * ctx,
1268
1488
  struct wsp_ggml_tensor * a,
1269
- int n_past,
1489
+ struct wsp_ggml_tensor * b,
1270
1490
  int n_dims,
1271
- float base,
1272
- bool down);
1491
+ int mode,
1492
+ int n_ctx_orig,
1493
+ float freq_base,
1494
+ float freq_scale,
1495
+ float ext_factor,
1496
+ float attn_factor,
1497
+ float beta_fast,
1498
+ float beta_slow),
1499
+ "use wsp_ggml_rope_ext_inplace instead");
1500
+
1501
+ // compute correction dims for YaRN RoPE scaling
1502
+ WSP_GGML_API void wsp_ggml_rope_yarn_corr_dims(
1503
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1273
1504
 
1274
1505
  // rotary position embedding backward, i.e compute dx from dy
1275
1506
  // a - dy
1276
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1507
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_back(
1277
1508
  struct wsp_ggml_context * ctx,
1278
- struct wsp_ggml_tensor * a,
1279
- int n_past,
1509
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_rope result
1510
+ struct wsp_ggml_tensor * b, // positions
1511
+ struct wsp_ggml_tensor * c, // freq factors
1280
1512
  int n_dims,
1281
1513
  int mode,
1282
- int n_ctx,
1514
+ int n_ctx_orig,
1283
1515
  float freq_base,
1284
1516
  float freq_scale,
1285
- float xpos_base,
1286
- bool xpos_down);
1517
+ float ext_factor,
1518
+ float attn_factor,
1519
+ float beta_fast,
1520
+ float beta_slow);
1287
1521
 
1288
- // alibi position embedding
1289
- // in-place, returns view(a)
1290
- struct wsp_ggml_tensor * wsp_ggml_alibi(
1522
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_multi_back(
1291
1523
  struct wsp_ggml_context * ctx,
1292
1524
  struct wsp_ggml_tensor * a,
1293
- int n_past,
1294
- int n_head,
1295
- float bias_max);
1525
+ struct wsp_ggml_tensor * b,
1526
+ struct wsp_ggml_tensor * c,
1527
+ int n_dims,
1528
+ int sections[4],
1529
+ int mode,
1530
+ int n_ctx_orig,
1531
+ float freq_base,
1532
+ float freq_scale,
1533
+ float ext_factor,
1534
+ float attn_factor,
1535
+ float beta_fast,
1536
+ float beta_slow);
1537
+
1296
1538
 
1297
1539
  // clamp
1298
1540
  // in-place, returns view(a)
1299
- struct wsp_ggml_tensor * wsp_ggml_clamp(
1541
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_clamp(
1300
1542
  struct wsp_ggml_context * ctx,
1301
1543
  struct wsp_ggml_tensor * a,
1302
1544
  float min,
1303
1545
  float max);
1304
1546
 
1547
+ // im2col
1548
+ // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1549
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1550
+ struct wsp_ggml_context * ctx,
1551
+ struct wsp_ggml_tensor * a, // convolution kernel
1552
+ struct wsp_ggml_tensor * b, // data
1553
+ int s0, // stride dimension 0
1554
+ int s1, // stride dimension 1
1555
+ int p0, // padding dimension 0
1556
+ int p1, // padding dimension 1
1557
+ int d0, // dilation dimension 0
1558
+ int d1, // dilation dimension 1
1559
+ bool is_2D,
1560
+ enum wsp_ggml_type dst_type);
1561
+
1562
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col_back(
1563
+ struct wsp_ggml_context * ctx,
1564
+ struct wsp_ggml_tensor * a, // convolution kernel
1565
+ struct wsp_ggml_tensor * b, // gradient of im2col output
1566
+ int64_t * ne, // shape of im2col input
1567
+ int s0, // stride dimension 0
1568
+ int s1, // stride dimension 1
1569
+ int p0, // padding dimension 0
1570
+ int p1, // padding dimension 1
1571
+ int d0, // dilation dimension 0
1572
+ int d1, // dilation dimension 1
1573
+ bool is_2D);
1574
+
1305
1575
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1306
1576
  struct wsp_ggml_context * ctx,
1307
- struct wsp_ggml_tensor * a,
1308
- struct wsp_ggml_tensor * b,
1577
+ struct wsp_ggml_tensor * a, // convolution kernel
1578
+ struct wsp_ggml_tensor * b, // data
1309
1579
  int s0, // stride
1310
1580
  int p0, // padding
1311
1581
  int d0); // dilation
@@ -1314,22 +1584,46 @@ extern "C" {
1314
1584
  // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1315
1585
  WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1316
1586
  struct wsp_ggml_context * ctx,
1317
- struct wsp_ggml_tensor * a,
1318
- struct wsp_ggml_tensor * b,
1319
- int s,
1320
- int d);
1587
+ struct wsp_ggml_tensor * a, // convolution kernel
1588
+ struct wsp_ggml_tensor * b, // data
1589
+ int s, // stride
1590
+ int d); // dilation
1321
1591
 
1322
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1592
+ // depthwise
1593
+ // TODO: this is very likely wrong for some cases! - needs more testing
1594
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw(
1323
1595
  struct wsp_ggml_context * ctx,
1324
- struct wsp_ggml_tensor * a,
1325
- struct wsp_ggml_tensor * b,
1326
- int s0,
1327
- int s1,
1328
- int p0,
1329
- int p1,
1330
- int d0,
1331
- int d1);
1596
+ struct wsp_ggml_tensor * a, // convolution kernel
1597
+ struct wsp_ggml_tensor * b, // data
1598
+ int s0, // stride
1599
+ int p0, // padding
1600
+ int d0); // dilation
1601
+
1602
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d_dw_ph(
1603
+ struct wsp_ggml_context * ctx,
1604
+ struct wsp_ggml_tensor * a, // convolution kernel
1605
+ struct wsp_ggml_tensor * b, // data
1606
+ int s0, // stride
1607
+ int d0); // dilation
1608
+
1609
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1610
+ struct wsp_ggml_context * ctx,
1611
+ struct wsp_ggml_tensor * a, // convolution kernel
1612
+ struct wsp_ggml_tensor * b, // data
1613
+ int s0, // stride
1614
+ int p0, // padding
1615
+ int d0); // dilation
1332
1616
 
1617
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1618
+ struct wsp_ggml_context * ctx,
1619
+ struct wsp_ggml_tensor * a, // convolution kernel
1620
+ struct wsp_ggml_tensor * b, // data
1621
+ int s0, // stride dimension 0
1622
+ int s1, // stride dimension 1
1623
+ int p0, // padding dimension 0
1624
+ int p1, // padding dimension 1
1625
+ int d0, // dilation dimension 0
1626
+ int d1); // dilation dimension 1
1333
1627
 
1334
1628
  // kernel size is a->ne[0] x a->ne[1]
1335
1629
  // stride is equal to kernel size
@@ -1357,6 +1651,18 @@ extern "C" {
1357
1651
  struct wsp_ggml_tensor * a,
1358
1652
  struct wsp_ggml_tensor * b);
1359
1653
 
1654
+ // depthwise
1655
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d_dw(
1656
+ struct wsp_ggml_context * ctx,
1657
+ struct wsp_ggml_tensor * a, // convolution kernel
1658
+ struct wsp_ggml_tensor * b, // data
1659
+ int s0, // stride dimension 0
1660
+ int s1, // stride dimension 1
1661
+ int p0, // padding dimension 0
1662
+ int p1, // padding dimension 1
1663
+ int d0, // dilation dimension 0
1664
+ int d1); // dilation dimension 1
1665
+
1360
1666
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_2d_p0(
1361
1667
  struct wsp_ggml_context * ctx,
1362
1668
  struct wsp_ggml_tensor * a,
@@ -1377,6 +1683,8 @@ extern "C" {
1377
1683
  int s0, // stride
1378
1684
  int p0); // padding
1379
1685
 
1686
+ // the result will have 2*p0 padding for the first dimension
1687
+ // and 2*p1 padding for the second dimension
1380
1688
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d(
1381
1689
  struct wsp_ggml_context * ctx,
1382
1690
  struct wsp_ggml_tensor * a,
@@ -1385,23 +1693,113 @@ extern "C" {
1385
1693
  int k1,
1386
1694
  int s0,
1387
1695
  int s1,
1388
- int p0,
1389
- int p1);
1696
+ float p0,
1697
+ float p1);
1698
+
1699
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
1700
+ struct wsp_ggml_context * ctx,
1701
+ struct wsp_ggml_tensor * a,
1702
+ struct wsp_ggml_tensor * af, // "a"/input used in forward pass
1703
+ enum wsp_ggml_op_pool op,
1704
+ int k0,
1705
+ int k1,
1706
+ int s0,
1707
+ int s1,
1708
+ float p0,
1709
+ float p1);
1390
1710
 
1391
1711
  // nearest interpolate
1712
+ // multiplies ne0 and ne1 by scale factor
1392
1713
  // used in stable-diffusion
1393
1714
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1394
1715
  struct wsp_ggml_context * ctx,
1395
1716
  struct wsp_ggml_tensor * a,
1396
1717
  int scale_factor);
1397
1718
 
1398
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1719
+ // nearest interpolate
1720
+ // nearest interpolate to specified dimensions
1721
+ // used in tortoise.cpp
1722
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1723
+ struct wsp_ggml_context * ctx,
1724
+ struct wsp_ggml_tensor * a,
1725
+ int ne0,
1726
+ int ne1,
1727
+ int ne2,
1728
+ int ne3);
1729
+
1730
+ // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1731
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
1732
+ struct wsp_ggml_context * ctx,
1733
+ struct wsp_ggml_tensor * a,
1734
+ int p0,
1735
+ int p1,
1736
+ int p2,
1737
+ int p3);
1738
+
1739
+ // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
1740
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad_reflect_1d(
1741
+ struct wsp_ggml_context * ctx,
1742
+ struct wsp_ggml_tensor * a,
1743
+ int p0,
1744
+ int p1);
1745
+
1746
+ // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1747
+ // timesteps: [N,]
1748
+ // return: [N, dim]
1749
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
1750
+ struct wsp_ggml_context * ctx,
1751
+ struct wsp_ggml_tensor * timesteps,
1752
+ int dim,
1753
+ int max_period);
1754
+
1755
+ // sort rows
1756
+ enum wsp_ggml_sort_order {
1757
+ WSP_GGML_SORT_ORDER_ASC,
1758
+ WSP_GGML_SORT_ORDER_DESC,
1759
+ };
1760
+
1761
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
1762
+ struct wsp_ggml_context * ctx,
1763
+ struct wsp_ggml_tensor * a,
1764
+ enum wsp_ggml_sort_order order);
1765
+
1766
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_arange(
1767
+ struct wsp_ggml_context * ctx,
1768
+ float start,
1769
+ float stop,
1770
+ float step);
1771
+
1772
+ // top k elements per row
1773
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1774
+ struct wsp_ggml_context * ctx,
1775
+ struct wsp_ggml_tensor * a,
1776
+ int k);
1777
+
1778
+ #define WSP_GGML_KQ_MASK_PAD 64
1779
+
1780
+ // q: [n_embd, n_batch, n_head, 1]
1781
+ // k: [n_embd, n_kv, n_head_kv, 1]
1782
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1783
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1784
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1785
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
1399
1786
  struct wsp_ggml_context * ctx,
1400
1787
  struct wsp_ggml_tensor * q,
1401
1788
  struct wsp_ggml_tensor * k,
1402
1789
  struct wsp_ggml_tensor * v,
1403
- bool masked);
1790
+ struct wsp_ggml_tensor * mask,
1791
+ float scale,
1792
+ float max_bias,
1793
+ float logit_softcap);
1794
+
1795
+ WSP_GGML_API void wsp_ggml_flash_attn_ext_set_prec(
1796
+ struct wsp_ggml_tensor * a,
1797
+ enum wsp_ggml_prec prec);
1798
+
1799
+ WSP_GGML_API enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
1800
+ const struct wsp_ggml_tensor * a);
1404
1801
 
1802
+ // TODO: needs to be adapted to wsp_ggml_flash_attn_ext
1405
1803
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
1406
1804
  struct wsp_ggml_context * ctx,
1407
1805
  struct wsp_ggml_tensor * q,
@@ -1410,13 +1808,19 @@ extern "C" {
1410
1808
  struct wsp_ggml_tensor * d,
1411
1809
  bool masked);
1412
1810
 
1413
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_ff(
1811
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
1414
1812
  struct wsp_ggml_context * ctx,
1415
- struct wsp_ggml_tensor * a,
1416
- struct wsp_ggml_tensor * b0,
1417
- struct wsp_ggml_tensor * b1,
1418
- struct wsp_ggml_tensor * c0,
1419
- struct wsp_ggml_tensor * c1);
1813
+ struct wsp_ggml_tensor * sx,
1814
+ struct wsp_ggml_tensor * c);
1815
+
1816
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
1817
+ struct wsp_ggml_context * ctx,
1818
+ struct wsp_ggml_tensor * s,
1819
+ struct wsp_ggml_tensor * x,
1820
+ struct wsp_ggml_tensor * dt,
1821
+ struct wsp_ggml_tensor * A,
1822
+ struct wsp_ggml_tensor * B,
1823
+ struct wsp_ggml_tensor * C);
1420
1824
 
1421
1825
  // partition into non-overlapping windows with padding if needed
1422
1826
  // example:
@@ -1456,7 +1860,6 @@ extern "C" {
1456
1860
  int kh);
1457
1861
 
1458
1862
  // used in sam
1459
-
1460
1863
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos(
1461
1864
  struct wsp_ggml_context * ctx,
1462
1865
  struct wsp_ggml_tensor * a,
@@ -1469,6 +1872,24 @@ extern "C" {
1469
1872
  struct wsp_ggml_tensor * pw,
1470
1873
  struct wsp_ggml_tensor * ph);
1471
1874
 
1875
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv6(
1876
+ struct wsp_ggml_context * ctx,
1877
+ struct wsp_ggml_tensor * k,
1878
+ struct wsp_ggml_tensor * v,
1879
+ struct wsp_ggml_tensor * r,
1880
+ struct wsp_ggml_tensor * tf,
1881
+ struct wsp_ggml_tensor * td,
1882
+ struct wsp_ggml_tensor * state);
1883
+
1884
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gated_linear_attn(
1885
+ struct wsp_ggml_context * ctx,
1886
+ struct wsp_ggml_tensor * k,
1887
+ struct wsp_ggml_tensor * v,
1888
+ struct wsp_ggml_tensor * q,
1889
+ struct wsp_ggml_tensor * g,
1890
+ struct wsp_ggml_tensor * state,
1891
+ float scale);
1892
+
1472
1893
  // custom operators
1473
1894
 
1474
1895
  typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1552,7 +1973,8 @@ extern "C" {
1552
1973
  typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
1553
1974
  typedef void (*wsp_ggml_custom3_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, const struct wsp_ggml_tensor * c, int ith, int nth, void * userdata);
1554
1975
 
1555
- #define WSP_GGML_N_TASKS_MAX -1
1976
+ #define WSP_GGML_N_TASKS_MAX (-1)
1977
+ // n_tasks == WSP_GGML_N_TASKS_MAX means to use max number of tasks
1556
1978
 
1557
1979
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1(
1558
1980
  struct wsp_ggml_context * ctx,
@@ -1605,50 +2027,62 @@ extern "C" {
1605
2027
  // loss function
1606
2028
 
1607
2029
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss(
1608
- struct wsp_ggml_context * ctx,
1609
- struct wsp_ggml_tensor * a,
1610
- struct wsp_ggml_tensor * b);
2030
+ struct wsp_ggml_context * ctx,
2031
+ struct wsp_ggml_tensor * a, // logits
2032
+ struct wsp_ggml_tensor * b); // labels
1611
2033
 
1612
2034
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back(
1613
- struct wsp_ggml_context * ctx,
1614
- struct wsp_ggml_tensor * a,
1615
- struct wsp_ggml_tensor * b,
1616
- struct wsp_ggml_tensor * c);
2035
+ struct wsp_ggml_context * ctx,
2036
+ struct wsp_ggml_tensor * a, // logits
2037
+ struct wsp_ggml_tensor * b, // labels
2038
+ struct wsp_ggml_tensor * c); // gradients of cross_entropy_loss result
2039
+
2040
+ // AdamW optimizer step
2041
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
2042
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
2043
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
2044
+ struct wsp_ggml_context * ctx,
2045
+ struct wsp_ggml_tensor * a,
2046
+ struct wsp_ggml_tensor * grad,
2047
+ struct wsp_ggml_tensor * m,
2048
+ struct wsp_ggml_tensor * v,
2049
+ struct wsp_ggml_tensor * adamw_params); // parameters such a the learning rate
1617
2050
 
1618
2051
  //
1619
2052
  // automatic differentiation
1620
2053
  //
1621
2054
 
1622
- WSP_GGML_API void wsp_ggml_set_param(
1623
- struct wsp_ggml_context * ctx,
1624
- struct wsp_ggml_tensor * tensor);
2055
+ WSP_GGML_API void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
2056
+ WSP_GGML_API void wsp_ggml_build_backward_expand(
2057
+ struct wsp_ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
2058
+ struct wsp_ggml_context * ctx_compute, // context for gradient computation
2059
+ struct wsp_ggml_cgraph * cgraph,
2060
+ bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
1625
2061
 
2062
+ // graph allocation in a context
2063
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
2064
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads);
2065
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
2066
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
2067
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2068
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
1626
2069
 
1627
- WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1628
- WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
2070
+ WSP_GGML_API int wsp_ggml_graph_size (struct wsp_ggml_cgraph * cgraph);
2071
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_node (struct wsp_ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2072
+ WSP_GGML_API struct wsp_ggml_tensor ** wsp_ggml_graph_nodes (struct wsp_ggml_cgraph * cgraph);
2073
+ WSP_GGML_API int wsp_ggml_graph_n_nodes(struct wsp_ggml_cgraph * cgraph);
1629
2074
 
1630
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor);
1631
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep);
2075
+ WSP_GGML_API void wsp_ggml_graph_add_node(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1632
2076
 
1633
- // graph allocation in a context
1634
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx);
1635
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
1636
2077
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
2078
+ WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1637
2079
 
1638
- // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1639
- // when plan.work_size > 0, caller must allocate memory for plan.work_data
1640
- WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1641
- WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1642
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph);
1643
-
1644
- // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1645
- // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1646
- WSP_GGML_API void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
2080
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor (const struct wsp_ggml_cgraph * cgraph, const char * name);
2081
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad (const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
2082
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad_acc(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
1647
2083
 
1648
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
1649
-
1650
- WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1651
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
2084
+ WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
2085
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
1652
2086
 
1653
2087
  // print info and performance information for the graph
1654
2088
  WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph);
@@ -1656,359 +2090,103 @@ extern "C" {
1656
2090
  // dump the graph into a file using the dot format
1657
2091
  WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename);
1658
2092
 
1659
- //
1660
- // optimization
1661
- //
1662
-
1663
- // optimization methods
1664
- enum wsp_ggml_opt_type {
1665
- WSP_GGML_OPT_ADAM,
1666
- WSP_GGML_OPT_LBFGS,
1667
- };
1668
-
1669
- // linesearch methods
1670
- enum wsp_ggml_linesearch {
1671
- WSP_GGML_LINESEARCH_DEFAULT = 1,
1672
-
1673
- WSP_GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
1674
- WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
1675
- WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
1676
- };
1677
-
1678
- // optimization return values
1679
- enum wsp_ggml_opt_result {
1680
- WSP_GGML_OPT_OK = 0,
1681
- WSP_GGML_OPT_DID_NOT_CONVERGE,
1682
- WSP_GGML_OPT_NO_CONTEXT,
1683
- WSP_GGML_OPT_INVALID_WOLFE,
1684
- WSP_GGML_OPT_FAIL,
1685
-
1686
- WSP_GGML_LINESEARCH_FAIL = -128,
1687
- WSP_GGML_LINESEARCH_MINIMUM_STEP,
1688
- WSP_GGML_LINESEARCH_MAXIMUM_STEP,
1689
- WSP_GGML_LINESEARCH_MAXIMUM_ITERATIONS,
1690
- WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
1691
- };
1692
-
1693
- typedef void (*wsp_ggml_opt_callback)(void * data, float * sched);
1694
-
1695
- // optimization parameters
1696
- //
1697
- // see ggml.c (wsp_ggml_opt_default_params) for default values
1698
- //
1699
- struct wsp_ggml_opt_params {
1700
- enum wsp_ggml_opt_type type;
1701
-
1702
- int n_threads;
1703
-
1704
- // delta-based convergence test
1705
- //
1706
- // if past == 0 - disabled
1707
- // if past > 0:
1708
- // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
1709
- //
1710
- int past;
1711
- float delta;
1712
-
1713
- // maximum number of iterations without improvement
1714
- //
1715
- // if 0 - disabled
1716
- // if > 0:
1717
- // assume convergence if no cost improvement in this number of iterations
1718
- //
1719
- int max_no_improvement;
1720
-
1721
- bool print_forward_graph;
1722
- bool print_backward_graph;
1723
-
1724
- // ADAM parameters
1725
- struct {
1726
- int n_iter;
1727
-
1728
- float sched; // schedule multiplier (fixed, decay or warmup)
1729
- float decay; // weight decay for AdamW, use 0.0f to disable
1730
- int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
1731
- float alpha; // learning rate
1732
- float beta1;
1733
- float beta2;
1734
- float eps; // epsilon for numerical stability
1735
- float eps_f; // epsilon for convergence test
1736
- float eps_g; // epsilon for convergence test
1737
- float gclip; // gradient clipping
1738
- } adam;
1739
-
1740
- // LBFGS parameters
1741
- struct {
1742
- int m; // number of corrections to approximate the inv. Hessian
1743
- int n_iter;
1744
- int max_linesearch;
1745
-
1746
- float eps; // convergence tolerance
1747
- float ftol; // line search tolerance
1748
- float wolfe;
1749
- float min_step;
1750
- float max_step;
1751
-
1752
- enum wsp_ggml_linesearch linesearch;
1753
- } lbfgs;
1754
- };
1755
-
1756
- struct wsp_ggml_opt_context {
1757
- struct wsp_ggml_context * ctx;
1758
- struct wsp_ggml_opt_params params;
1759
-
1760
- int iter;
1761
- int64_t nx; // number of parameter elements
1762
-
1763
- bool just_initialized;
1764
-
1765
- float loss_before;
1766
- float loss_after;
1767
-
1768
- struct {
1769
- struct wsp_ggml_tensor * m; // first moment
1770
- struct wsp_ggml_tensor * v; // second moment
1771
- struct wsp_ggml_tensor * pf; // past function values
1772
- float fx_best;
1773
- float fx_prev;
1774
- int n_no_improvement;
1775
- } adam;
1776
-
1777
- struct {
1778
- struct wsp_ggml_tensor * x; // current parameters
1779
- struct wsp_ggml_tensor * xp; // previous parameters
1780
- struct wsp_ggml_tensor * g; // current gradient
1781
- struct wsp_ggml_tensor * gp; // previous gradient
1782
- struct wsp_ggml_tensor * d; // search direction
1783
- struct wsp_ggml_tensor * pf; // past function values
1784
- struct wsp_ggml_tensor * lmal; // the L-BFGS memory alpha
1785
- struct wsp_ggml_tensor * lmys; // the L-BFGS memory ys
1786
- struct wsp_ggml_tensor * lms; // the L-BFGS memory s
1787
- struct wsp_ggml_tensor * lmy; // the L-BFGS memory y
1788
- float fx_best;
1789
- float step;
1790
- int j;
1791
- int k;
1792
- int end;
1793
- int n_no_improvement;
1794
- } lbfgs;
1795
- };
1796
-
1797
- WSP_GGML_API struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type);
1798
-
1799
- // optimize the function defined by the tensor f
1800
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt(
1801
- struct wsp_ggml_context * ctx,
1802
- struct wsp_ggml_opt_params params,
1803
- struct wsp_ggml_tensor * f);
1804
-
1805
- // initialize optimizer context
1806
- WSP_GGML_API void wsp_ggml_opt_init(
1807
- struct wsp_ggml_context * ctx,
1808
- struct wsp_ggml_opt_context * opt,
1809
- struct wsp_ggml_opt_params params,
1810
- int64_t nx);
2093
+ // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
2094
+ typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1811
2095
 
1812
- // continue optimizing the function defined by the tensor f
1813
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume(
1814
- struct wsp_ggml_context * ctx,
1815
- struct wsp_ggml_opt_context * opt,
1816
- struct wsp_ggml_tensor * f);
2096
+ // Set callback for all future logging events.
2097
+ // If this is not called, or NULL is supplied, everything is output on stderr.
2098
+ WSP_GGML_API void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data);
1817
2099
 
1818
- // continue optimizing the function defined by the tensor f
1819
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
1820
- struct wsp_ggml_context * ctx,
1821
- struct wsp_ggml_opt_context * opt,
1822
- struct wsp_ggml_tensor * f,
1823
- struct wsp_ggml_cgraph * gf,
1824
- struct wsp_ggml_cgraph * gb,
1825
- wsp_ggml_opt_callback callback,
1826
- void * callback_data);
2100
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
1827
2101
 
1828
2102
  //
1829
2103
  // quantization
1830
2104
  //
1831
2105
 
1832
- WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1833
- WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1834
- WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1835
- WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1836
- WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1837
-
1838
- WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1839
-
2106
+ // - wsp_ggml_wsp_quantize_init can be called multiple times with the same type
2107
+ // it will only initialize the quantization tables for the first call or after wsp_ggml_wsp_quantize_free
2108
+ // automatically called by wsp_ggml_wsp_quantize_chunk for convenience
1840
2109
  //
1841
- // gguf
2110
+ // - wsp_ggml_wsp_quantize_free will free any memory allocated by wsp_ggml_wsp_quantize_init
2111
+ // call this at the end of the program to avoid memory leaks
1842
2112
  //
1843
-
1844
- enum gguf_type {
1845
- GGUF_TYPE_UINT8 = 0,
1846
- GGUF_TYPE_INT8 = 1,
1847
- GGUF_TYPE_UINT16 = 2,
1848
- GGUF_TYPE_INT16 = 3,
1849
- GGUF_TYPE_UINT32 = 4,
1850
- GGUF_TYPE_INT32 = 5,
1851
- GGUF_TYPE_FLOAT32 = 6,
1852
- GGUF_TYPE_BOOL = 7,
1853
- GGUF_TYPE_STRING = 8,
1854
- GGUF_TYPE_ARRAY = 9,
1855
- GGUF_TYPE_UINT64 = 10,
1856
- GGUF_TYPE_INT64 = 11,
1857
- GGUF_TYPE_FLOAT64 = 12,
1858
- GGUF_TYPE_COUNT, // marks the end of the enum
2113
+ // note: these are thread-safe
2114
+ //
2115
+ WSP_GGML_API void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type);
2116
+ WSP_GGML_API void wsp_ggml_wsp_quantize_free(void);
2117
+
2118
+ // some quantization type cannot be used without an importance matrix
2119
+ WSP_GGML_API bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type);
2120
+
2121
+ // calls wsp_ggml_wsp_quantize_init internally (i.e. can allocate memory)
2122
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(
2123
+ enum wsp_ggml_type type,
2124
+ const float * src,
2125
+ void * dst,
2126
+ int64_t start,
2127
+ int64_t nrows,
2128
+ int64_t n_per_row,
2129
+ const float * imatrix);
2130
+
2131
+ #ifdef __cplusplus
2132
+ // restrict not standard in C++
2133
+ # if defined(__GNUC__)
2134
+ # define WSP_GGML_RESTRICT __restrict__
2135
+ # elif defined(__clang__)
2136
+ # define WSP_GGML_RESTRICT __restrict
2137
+ # elif defined(_MSC_VER)
2138
+ # define WSP_GGML_RESTRICT __restrict
2139
+ # else
2140
+ # define WSP_GGML_RESTRICT
2141
+ # endif
2142
+ #else
2143
+ # define WSP_GGML_RESTRICT restrict
2144
+ #endif
2145
+ typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
2146
+ typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t k);
2147
+
2148
+ struct wsp_ggml_type_traits {
2149
+ const char * type_name;
2150
+ int64_t blck_size;
2151
+ int64_t blck_size_interleave; // interleave elements in blocks
2152
+ size_t type_size;
2153
+ bool is_quantized;
2154
+ wsp_ggml_to_float_t to_float;
2155
+ wsp_ggml_from_float_t from_float_ref;
1859
2156
  };
1860
2157
 
1861
- struct gguf_context;
2158
+ WSP_GGML_API const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type);
1862
2159
 
1863
- struct gguf_init_params {
1864
- bool no_alloc;
2160
+ // ggml threadpool
2161
+ // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend
2162
+ // the goal should be to create an API that other backends can use move everything to the ggml base
1865
2163
 
1866
- // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
1867
- struct wsp_ggml_context ** ctx;
2164
+ // scheduling priorities
2165
+ enum wsp_ggml_sched_priority {
2166
+ WSP_GGML_SCHED_PRIO_NORMAL,
2167
+ WSP_GGML_SCHED_PRIO_MEDIUM,
2168
+ WSP_GGML_SCHED_PRIO_HIGH,
2169
+ WSP_GGML_SCHED_PRIO_REALTIME
1868
2170
  };
1869
2171
 
1870
- WSP_GGML_API struct gguf_context * gguf_init_empty(void);
1871
- WSP_GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
1872
- //WSP_GGML_API struct gguf_context * gguf_init_from_buffer(..);
1873
-
1874
- WSP_GGML_API void gguf_free(struct gguf_context * ctx);
1875
-
1876
- WSP_GGML_API const char * gguf_type_name(enum gguf_type type);
1877
-
1878
- WSP_GGML_API int gguf_get_version (const struct gguf_context * ctx);
1879
- WSP_GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
1880
- WSP_GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
1881
- WSP_GGML_API void * gguf_get_data (const struct gguf_context * ctx);
1882
-
1883
- WSP_GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
1884
- WSP_GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
1885
- WSP_GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
1886
-
1887
- WSP_GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
1888
- WSP_GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
1889
-
1890
- // results are undefined if the wrong type is used for the key
1891
- WSP_GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
1892
- WSP_GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
1893
- WSP_GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
1894
- WSP_GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
1895
- WSP_GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
1896
- WSP_GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
1897
- WSP_GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
1898
- WSP_GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
1899
- WSP_GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
1900
- WSP_GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
1901
- WSP_GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
1902
- WSP_GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
1903
- WSP_GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
1904
- WSP_GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
1905
- WSP_GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
1906
-
1907
- WSP_GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
1908
- WSP_GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
1909
- WSP_GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
1910
- WSP_GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
1911
-
1912
- // overrides existing values or adds a new one
1913
- WSP_GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
1914
- WSP_GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
1915
- WSP_GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
1916
- WSP_GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
1917
- WSP_GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
1918
- WSP_GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
1919
- WSP_GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
1920
- WSP_GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
1921
- WSP_GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
1922
- WSP_GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
1923
- WSP_GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
1924
- WSP_GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
1925
- WSP_GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
1926
- WSP_GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
1927
-
1928
- // set or add KV pairs from another context
1929
- WSP_GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
1930
-
1931
- // manage tensor info
1932
- WSP_GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
1933
- WSP_GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum wsp_ggml_type type);
1934
- WSP_GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
1935
-
1936
- // writing gguf files can be done in 2 ways:
1937
- //
1938
- // - write the entire gguf_context to a binary file in a single pass:
1939
- //
1940
- // gguf_write_to_file(ctx, fname);
1941
- //
1942
- // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
1943
- //
1944
- // FILE * f = fopen(fname, "wb");
1945
- // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
1946
- // fwrite(f, ...);
1947
- // void * data = gguf_meta_get_meta_data(ctx);
1948
- // fseek(f, 0, SEEK_SET);
1949
- // fwrite(f, data, gguf_get_meta_size(ctx));
1950
- // free(data);
1951
- // fclose(f);
1952
- //
1953
-
1954
- // write the entire context to a binary file
1955
- WSP_GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
1956
-
1957
- // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
1958
- WSP_GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
1959
- WSP_GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
1960
-
1961
- //
1962
- // system info
1963
- //
2172
+ // threadpool params
2173
+ // Use wsp_ggml_threadpool_params_default() or wsp_ggml_threadpool_params_init() to populate the defaults
2174
+ struct wsp_ggml_threadpool_params {
2175
+ bool cpumask[WSP_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
2176
+ int n_threads; // number of threads
2177
+ enum wsp_ggml_sched_priority prio; // thread priority
2178
+ uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
2179
+ bool strict_cpu; // strict cpu placement
2180
+ bool paused; // start in paused state
2181
+ };
1964
2182
 
1965
- WSP_GGML_API int wsp_ggml_cpu_has_avx (void);
1966
- WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void);
1967
- WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void);
1968
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void);
1969
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void);
1970
- WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
1971
- WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
1972
- WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
1973
- WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
1974
- WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
1975
- WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
1976
- WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
1977
- WSP_GGML_API int wsp_ggml_cpu_has_blas (void);
1978
- WSP_GGML_API int wsp_ggml_cpu_has_cublas (void);
1979
- WSP_GGML_API int wsp_ggml_cpu_has_clblast (void);
1980
- WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void);
1981
- WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void);
1982
- WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void);
1983
- WSP_GGML_API int wsp_ggml_cpu_has_vsx (void);
2183
+ struct wsp_ggml_threadpool; // forward declaration, see ggml.c
1984
2184
 
1985
- //
1986
- // Internal types and functions exposed for tests and benchmarks
1987
- //
2185
+ typedef struct wsp_ggml_threadpool * wsp_ggml_threadpool_t;
1988
2186
 
1989
- #ifdef __cplusplus
1990
- // restrict not standard in C++
1991
- #define WSP_GGML_RESTRICT
1992
- #else
1993
- #define WSP_GGML_RESTRICT restrict
1994
- #endif
1995
- typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
1996
- typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
1997
- typedef void (*wsp_ggml_vec_dot_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
1998
-
1999
- typedef struct {
2000
- const char * type_name;
2001
- int blck_size;
2002
- size_t type_size;
2003
- bool is_quantized;
2004
- wsp_ggml_to_float_t to_float;
2005
- wsp_ggml_from_float_t from_float;
2006
- wsp_ggml_from_float_t from_float_reference;
2007
- wsp_ggml_vec_dot_t vec_dot;
2008
- enum wsp_ggml_type vec_dot_type;
2009
- } wsp_ggml_type_traits_t;
2010
-
2011
- wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2187
+ WSP_GGML_API struct wsp_ggml_threadpool_params wsp_ggml_threadpool_params_default(int n_threads);
2188
+ WSP_GGML_API void wsp_ggml_threadpool_params_init (struct wsp_ggml_threadpool_params * p, int n_threads);
2189
+ WSP_GGML_API bool wsp_ggml_threadpool_params_match (const struct wsp_ggml_threadpool_params * p0, const struct wsp_ggml_threadpool_params * p1);
2012
2190
 
2013
2191
  #ifdef __cplusplus
2014
2192
  }