whisper.rn 0.4.0-rc.1 → 0.4.0-rc.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +14 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +27 -92
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +86 -40
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +85 -131
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +226 -109
  9. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  10. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +10 -0
  11. package/cpp/README.md +1 -1
  12. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  13. package/cpp/coreml/whisper-encoder.h +4 -0
  14. package/cpp/coreml/whisper-encoder.mm +5 -3
  15. package/cpp/ggml-aarch64.c +129 -0
  16. package/cpp/ggml-aarch64.h +19 -0
  17. package/cpp/ggml-alloc.c +805 -400
  18. package/cpp/ggml-alloc.h +60 -10
  19. package/cpp/ggml-backend-impl.h +216 -0
  20. package/cpp/ggml-backend-reg.cpp +204 -0
  21. package/cpp/ggml-backend.cpp +1996 -0
  22. package/cpp/ggml-backend.cpp.rej +12 -0
  23. package/cpp/ggml-backend.h +336 -0
  24. package/cpp/ggml-common.h +1853 -0
  25. package/cpp/ggml-cpp.h +38 -0
  26. package/cpp/ggml-cpu-aarch64.c +3560 -0
  27. package/cpp/ggml-cpu-aarch64.h +30 -0
  28. package/cpp/ggml-cpu-impl.h +371 -0
  29. package/cpp/ggml-cpu-quants.c +10822 -0
  30. package/cpp/ggml-cpu-quants.h +63 -0
  31. package/cpp/ggml-cpu.c +13970 -0
  32. package/cpp/ggml-cpu.cpp +663 -0
  33. package/cpp/ggml-cpu.h +177 -0
  34. package/cpp/ggml-impl.h +551 -0
  35. package/cpp/ggml-metal-impl.h +249 -0
  36. package/cpp/ggml-metal.h +24 -43
  37. package/cpp/ggml-metal.m +4190 -1075
  38. package/cpp/ggml-quants.c +5247 -0
  39. package/cpp/ggml-quants.h +100 -0
  40. package/cpp/ggml-threading.cpp +12 -0
  41. package/cpp/ggml-threading.h +12 -0
  42. package/cpp/ggml-whisper.metallib +0 -0
  43. package/cpp/ggml.c +5474 -18763
  44. package/cpp/ggml.h +833 -628
  45. package/cpp/rn-audioutils.cpp +68 -0
  46. package/cpp/rn-audioutils.h +14 -0
  47. package/cpp/rn-whisper-log.h +11 -0
  48. package/cpp/rn-whisper.cpp +221 -52
  49. package/cpp/rn-whisper.h +50 -15
  50. package/cpp/whisper.cpp +2872 -1371
  51. package/cpp/whisper.h +170 -41
  52. package/ios/RNWhisper.mm +139 -46
  53. package/ios/RNWhisperAudioUtils.h +1 -2
  54. package/ios/RNWhisperAudioUtils.m +18 -67
  55. package/ios/RNWhisperContext.h +11 -8
  56. package/ios/RNWhisperContext.mm +195 -150
  57. package/jest/mock.js +15 -2
  58. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  59. package/lib/commonjs/index.js +76 -28
  60. package/lib/commonjs/index.js.map +1 -1
  61. package/lib/commonjs/version.json +1 -1
  62. package/lib/module/NativeRNWhisper.js.map +1 -1
  63. package/lib/module/index.js +76 -28
  64. package/lib/module/index.js.map +1 -1
  65. package/lib/module/version.json +1 -1
  66. package/lib/typescript/NativeRNWhisper.d.ts +13 -4
  67. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  68. package/lib/typescript/index.d.ts +37 -5
  69. package/lib/typescript/index.d.ts.map +1 -1
  70. package/package.json +9 -7
  71. package/src/NativeRNWhisper.ts +20 -4
  72. package/src/index.ts +98 -42
  73. package/src/version.json +1 -1
  74. package/whisper-rn.podspec +11 -18
  75. package/cpp/ggml-metal.metal +0 -2353
package/cpp/ggml.h CHANGED
@@ -58,7 +58,8 @@
58
58
  // {
59
59
  // ...
60
60
  //
61
- // struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(f);
61
+ // struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx);
62
+ // wsp_ggml_build_forward_expand(gf, f);
62
63
  //
63
64
  // // set the input variable and parameter values
64
65
  // wsp_ggml_set_f32(x, 2.0f);
@@ -175,15 +176,15 @@
175
176
  #ifdef WSP_GGML_SHARED
176
177
  # if defined(_WIN32) && !defined(__MINGW32__)
177
178
  # ifdef WSP_GGML_BUILD
178
- # define WSP_GGML_API __declspec(dllexport)
179
+ # define WSP_GGML_API __declspec(dllexport) extern
179
180
  # else
180
- # define WSP_GGML_API __declspec(dllimport)
181
+ # define WSP_GGML_API __declspec(dllimport) extern
181
182
  # endif
182
183
  # else
183
- # define WSP_GGML_API __attribute__ ((visibility ("default")))
184
+ # define WSP_GGML_API __attribute__ ((visibility ("default"))) extern
184
185
  # endif
185
186
  #else
186
- # define WSP_GGML_API
187
+ # define WSP_GGML_API extern
187
188
  #endif
188
189
 
189
190
  // TODO: support for clang
@@ -203,24 +204,29 @@
203
204
  # define WSP_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
204
205
  #endif
205
206
 
206
- #include <stdint.h>
207
- #include <stddef.h>
208
207
  #include <stdbool.h>
208
+ #include <stddef.h>
209
+ #include <stdint.h>
210
+ #include <stdio.h>
209
211
 
210
212
  #define WSP_GGML_FILE_MAGIC 0x67676d6c // "ggml"
211
- #define WSP_GGML_FILE_VERSION 1
213
+ #define WSP_GGML_FILE_VERSION 2
212
214
 
213
215
  #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes
214
216
  #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this
215
217
 
216
- #define WSP_GGML_MAX_DIMS 4
217
- #define WSP_GGML_MAX_NODES 4096
218
- #define WSP_GGML_MAX_PARAMS 256
219
- #define WSP_GGML_MAX_CONTEXTS 64
220
- #define WSP_GGML_MAX_SRC 6
221
- #define WSP_GGML_MAX_NAME 64
222
- #define WSP_GGML_MAX_OP_PARAMS 32
223
- #define WSP_GGML_DEFAULT_N_THREADS 4
218
+ #define WSP_GGML_MAX_DIMS 4
219
+ #define WSP_GGML_MAX_PARAMS 2048
220
+ #define WSP_GGML_MAX_SRC 10
221
+ #define WSP_GGML_MAX_N_THREADS 512
222
+ #define WSP_GGML_MAX_OP_PARAMS 64
223
+
224
+ #ifndef WSP_GGML_MAX_NAME
225
+ # define WSP_GGML_MAX_NAME 64
226
+ #endif
227
+
228
+ #define WSP_GGML_DEFAULT_N_THREADS 4
229
+ #define WSP_GGML_DEFAULT_GRAPH_SIZE 2048
224
230
 
225
231
  #if UINTPTR_MAX == 0xFFFFFFFF
226
232
  #define WSP_GGML_MEM_ALIGN 4
@@ -231,22 +237,38 @@
231
237
  #define WSP_GGML_EXIT_SUCCESS 0
232
238
  #define WSP_GGML_EXIT_ABORTED 1
233
239
 
234
- #define GGUF_MAGIC 0x46554747 // "GGUF"
235
- #define GGUF_VERSION 2
240
+ #define WSP_GGML_ROPE_TYPE_NEOX 2
241
+
242
+ #define WSP_GGUF_MAGIC "GGUF"
236
243
 
237
- #define GGUF_DEFAULT_ALIGNMENT 32
244
+ #define WSP_GGUF_VERSION 3
245
+
246
+ #define WSP_GGUF_DEFAULT_ALIGNMENT 32
238
247
 
239
248
  #define WSP_GGML_UNUSED(x) (void)(x)
240
249
 
241
250
  #define WSP_GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
242
251
 
243
- #define WSP_GGML_ASSERT(x) \
244
- do { \
245
- if (!(x)) { \
246
- fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
247
- abort(); \
248
- } \
249
- } while (0)
252
+ #ifndef NDEBUG
253
+ # define WSP_GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
254
+ #elif defined(__GNUC__)
255
+ # define WSP_GGML_UNREACHABLE() __builtin_unreachable()
256
+ #elif defined(_MSC_VER)
257
+ # define WSP_GGML_UNREACHABLE() __assume(0)
258
+ #else
259
+ # define WSP_GGML_UNREACHABLE() ((void) 0)
260
+ #endif
261
+
262
+ #ifdef __cplusplus
263
+ # define WSP_GGML_NORETURN [[noreturn]]
264
+ #elif defined(_MSC_VER)
265
+ # define WSP_GGML_NORETURN __declspec(noreturn)
266
+ #else
267
+ # define WSP_GGML_NORETURN _Noreturn
268
+ #endif
269
+
270
+ #define WSP_GGML_ABORT(...) wsp_ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
271
+ #define WSP_GGML_ASSERT(x) if (!(x)) WSP_GGML_ABORT("WSP_GGML_ASSERT(%s) failed", #x)
250
272
 
251
273
  // used to copy the number of elements and stride in bytes of tensors into local variables.
252
274
  // main purpose is to reduce code duplication and improve readability.
@@ -272,74 +294,145 @@
272
294
  const type prefix##3 = (pointer)->array[3]; \
273
295
  WSP_GGML_UNUSED(prefix##3);
274
296
 
297
+ #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \
298
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
299
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
300
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
301
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
302
+
303
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \
304
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
305
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
306
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
307
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
308
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
309
+ WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
310
+
311
+ #define WSP_GGML_TENSOR_BINARY_OP_LOCALS01 \
312
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
313
+ WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
314
+ WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
315
+ WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
316
+
275
317
  #ifdef __cplusplus
276
318
  extern "C" {
277
319
  #endif
278
320
 
279
- #if defined(__ARM_NEON) && defined(__CUDACC__)
280
- typedef half wsp_ggml_fp16_t;
281
- #elif defined(__ARM_NEON)
282
- typedef __fp16 wsp_ggml_fp16_t;
283
- #else
284
- typedef uint16_t wsp_ggml_fp16_t;
285
- #endif
321
+ WSP_GGML_NORETURN WSP_GGML_ATTRIBUTE_FORMAT(3, 4)
322
+ WSP_GGML_API void wsp_ggml_abort(const char * file, int line, const char * fmt, ...);
286
323
 
287
- // convert FP16 <-> FP32
288
- WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t x);
289
- WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float x);
324
+ enum wsp_ggml_status {
325
+ WSP_GGML_STATUS_ALLOC_FAILED = -2,
326
+ WSP_GGML_STATUS_FAILED = -1,
327
+ WSP_GGML_STATUS_SUCCESS = 0,
328
+ WSP_GGML_STATUS_ABORTED = 1,
329
+ };
290
330
 
291
- WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t * x, float * y, int n);
292
- WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float * x, wsp_ggml_fp16_t * y, int n);
331
+ // get wsp_ggml_status name string
332
+ WSP_GGML_API const char * wsp_ggml_status_to_string(enum wsp_ggml_status status);
333
+
334
+ // ieee 754-2008 half-precision float16
335
+ // todo: make this not an integral type
336
+ typedef uint16_t wsp_ggml_fp16_t;
337
+ WSP_GGML_API float wsp_ggml_fp16_to_fp32(wsp_ggml_fp16_t);
338
+ WSP_GGML_API wsp_ggml_fp16_t wsp_ggml_fp32_to_fp16(float);
339
+ WSP_GGML_API void wsp_ggml_fp16_to_fp32_row(const wsp_ggml_fp16_t *, float *, int64_t);
340
+ WSP_GGML_API void wsp_ggml_fp32_to_fp16_row(const float *, wsp_ggml_fp16_t *, int64_t);
341
+
342
+ // google brain half-precision bfloat16
343
+ typedef struct { uint16_t bits; } wsp_ggml_bf16_t;
344
+ WSP_GGML_API wsp_ggml_bf16_t wsp_ggml_fp32_to_bf16(float);
345
+ WSP_GGML_API float wsp_ggml_bf16_to_fp32(wsp_ggml_bf16_t); // consider just doing << 16
346
+ WSP_GGML_API void wsp_ggml_bf16_to_fp32_row(const wsp_ggml_bf16_t *, float *, int64_t);
347
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row_ref(const float *, wsp_ggml_bf16_t *, int64_t);
348
+ WSP_GGML_API void wsp_ggml_fp32_to_bf16_row(const float *, wsp_ggml_bf16_t *, int64_t);
293
349
 
294
350
  struct wsp_ggml_object;
295
351
  struct wsp_ggml_context;
352
+ struct wsp_ggml_cgraph;
296
353
 
354
+ // NOTE: always add types at the end of the enum to keep backward compatibility
297
355
  enum wsp_ggml_type {
298
- WSP_GGML_TYPE_F32 = 0,
299
- WSP_GGML_TYPE_F16 = 1,
300
- WSP_GGML_TYPE_Q4_0 = 2,
301
- WSP_GGML_TYPE_Q4_1 = 3,
356
+ WSP_GGML_TYPE_F32 = 0,
357
+ WSP_GGML_TYPE_F16 = 1,
358
+ WSP_GGML_TYPE_Q4_0 = 2,
359
+ WSP_GGML_TYPE_Q4_1 = 3,
302
360
  // WSP_GGML_TYPE_Q4_2 = 4, support has been removed
303
- // WSP_GGML_TYPE_Q4_3 (5) support has been removed
304
- WSP_GGML_TYPE_Q5_0 = 6,
305
- WSP_GGML_TYPE_Q5_1 = 7,
306
- WSP_GGML_TYPE_Q8_0 = 8,
307
- WSP_GGML_TYPE_Q8_1 = 9,
308
- // k-quantizations
309
- WSP_GGML_TYPE_Q2_K = 10,
310
- WSP_GGML_TYPE_Q3_K = 11,
311
- WSP_GGML_TYPE_Q4_K = 12,
312
- WSP_GGML_TYPE_Q5_K = 13,
313
- WSP_GGML_TYPE_Q6_K = 14,
314
- WSP_GGML_TYPE_Q8_K = 15,
315
- WSP_GGML_TYPE_I8,
316
- WSP_GGML_TYPE_I16,
317
- WSP_GGML_TYPE_I32,
361
+ // WSP_GGML_TYPE_Q4_3 = 5, support has been removed
362
+ WSP_GGML_TYPE_Q5_0 = 6,
363
+ WSP_GGML_TYPE_Q5_1 = 7,
364
+ WSP_GGML_TYPE_Q8_0 = 8,
365
+ WSP_GGML_TYPE_Q8_1 = 9,
366
+ WSP_GGML_TYPE_Q2_K = 10,
367
+ WSP_GGML_TYPE_Q3_K = 11,
368
+ WSP_GGML_TYPE_Q4_K = 12,
369
+ WSP_GGML_TYPE_Q5_K = 13,
370
+ WSP_GGML_TYPE_Q6_K = 14,
371
+ WSP_GGML_TYPE_Q8_K = 15,
372
+ WSP_GGML_TYPE_IQ2_XXS = 16,
373
+ WSP_GGML_TYPE_IQ2_XS = 17,
374
+ WSP_GGML_TYPE_IQ3_XXS = 18,
375
+ WSP_GGML_TYPE_IQ1_S = 19,
376
+ WSP_GGML_TYPE_IQ4_NL = 20,
377
+ WSP_GGML_TYPE_IQ3_S = 21,
378
+ WSP_GGML_TYPE_IQ2_S = 22,
379
+ WSP_GGML_TYPE_IQ4_XS = 23,
380
+ WSP_GGML_TYPE_I8 = 24,
381
+ WSP_GGML_TYPE_I16 = 25,
382
+ WSP_GGML_TYPE_I32 = 26,
383
+ WSP_GGML_TYPE_I64 = 27,
384
+ WSP_GGML_TYPE_F64 = 28,
385
+ WSP_GGML_TYPE_IQ1_M = 29,
386
+ WSP_GGML_TYPE_BF16 = 30,
387
+ WSP_GGML_TYPE_Q4_0_4_4 = 31,
388
+ WSP_GGML_TYPE_Q4_0_4_8 = 32,
389
+ WSP_GGML_TYPE_Q4_0_8_8 = 33,
390
+ WSP_GGML_TYPE_TQ1_0 = 34,
391
+ WSP_GGML_TYPE_TQ2_0 = 35,
318
392
  WSP_GGML_TYPE_COUNT,
319
393
  };
320
394
 
321
- enum wsp_ggml_backend {
322
- WSP_GGML_BACKEND_CPU = 0,
323
- WSP_GGML_BACKEND_GPU = 10,
324
- WSP_GGML_BACKEND_GPU_SPLIT = 20,
395
+ // precision
396
+ enum wsp_ggml_prec {
397
+ WSP_GGML_PREC_DEFAULT,
398
+ WSP_GGML_PREC_F32,
399
+ };
400
+
401
+ enum wsp_ggml_backend_type {
402
+ WSP_GGML_BACKEND_TYPE_CPU = 0,
403
+ WSP_GGML_BACKEND_TYPE_GPU = 10,
404
+ WSP_GGML_BACKEND_TYPE_GPU_SPLIT = 20,
325
405
  };
326
406
 
327
407
  // model file types
328
408
  enum wsp_ggml_ftype {
329
- WSP_GGML_FTYPE_UNKNOWN = -1,
330
- WSP_GGML_FTYPE_ALL_F32 = 0,
331
- WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
332
- WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
333
- WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
409
+ WSP_GGML_FTYPE_UNKNOWN = -1,
410
+ WSP_GGML_FTYPE_ALL_F32 = 0,
411
+ WSP_GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
412
+ WSP_GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
413
+ WSP_GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
334
414
  WSP_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
335
- WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
336
- WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
337
- WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
338
- WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
339
- WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
340
- WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
341
- WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
342
- WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
415
+ WSP_GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
416
+ WSP_GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
417
+ WSP_GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
418
+ WSP_GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
419
+ WSP_GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
420
+ WSP_GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
421
+ WSP_GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
422
+ WSP_GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
423
+ WSP_GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
424
+ WSP_GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
425
+ WSP_GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
426
+ WSP_GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
427
+ WSP_GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
428
+ WSP_GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
429
+ WSP_GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
430
+ WSP_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431
+ WSP_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432
+ WSP_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
433
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
434
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
435
+ WSP_GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
343
436
  };
344
437
 
345
438
  // available tensor operations:
@@ -356,10 +449,13 @@ extern "C" {
356
449
  WSP_GGML_OP_SQR,
357
450
  WSP_GGML_OP_SQRT,
358
451
  WSP_GGML_OP_LOG,
452
+ WSP_GGML_OP_SIN,
453
+ WSP_GGML_OP_COS,
359
454
  WSP_GGML_OP_SUM,
360
455
  WSP_GGML_OP_SUM_ROWS,
361
456
  WSP_GGML_OP_MEAN,
362
457
  WSP_GGML_OP_ARGMAX,
458
+ WSP_GGML_OP_COUNT_EQUAL,
363
459
  WSP_GGML_OP_REPEAT,
364
460
  WSP_GGML_OP_REPEAT_BACK,
365
461
  WSP_GGML_OP_CONCAT,
@@ -370,6 +466,7 @@ extern "C" {
370
466
  WSP_GGML_OP_GROUP_NORM,
371
467
 
372
468
  WSP_GGML_OP_MUL_MAT,
469
+ WSP_GGML_OP_MUL_MAT_ID,
373
470
  WSP_GGML_OP_OUT_PROD,
374
471
 
375
472
  WSP_GGML_OP_SCALE,
@@ -389,23 +486,30 @@ extern "C" {
389
486
  WSP_GGML_OP_SOFT_MAX_BACK,
390
487
  WSP_GGML_OP_ROPE,
391
488
  WSP_GGML_OP_ROPE_BACK,
392
- WSP_GGML_OP_ALIBI,
393
489
  WSP_GGML_OP_CLAMP,
394
- WSP_GGML_OP_CONV_1D,
395
- WSP_GGML_OP_CONV_2D,
490
+ WSP_GGML_OP_CONV_TRANSPOSE_1D,
491
+ WSP_GGML_OP_IM2COL,
492
+ WSP_GGML_OP_IM2COL_BACK,
396
493
  WSP_GGML_OP_CONV_TRANSPOSE_2D,
397
494
  WSP_GGML_OP_POOL_1D,
398
495
  WSP_GGML_OP_POOL_2D,
399
-
496
+ WSP_GGML_OP_POOL_2D_BACK,
400
497
  WSP_GGML_OP_UPSCALE, // nearest interpolate
498
+ WSP_GGML_OP_PAD,
499
+ WSP_GGML_OP_ARANGE,
500
+ WSP_GGML_OP_TIMESTEP_EMBEDDING,
501
+ WSP_GGML_OP_ARGSORT,
502
+ WSP_GGML_OP_LEAKY_RELU,
401
503
 
402
- WSP_GGML_OP_FLASH_ATTN,
403
- WSP_GGML_OP_FLASH_FF,
504
+ WSP_GGML_OP_FLASH_ATTN_EXT,
404
505
  WSP_GGML_OP_FLASH_ATTN_BACK,
506
+ WSP_GGML_OP_SSM_CONV,
507
+ WSP_GGML_OP_SSM_SCAN,
405
508
  WSP_GGML_OP_WIN_PART,
406
509
  WSP_GGML_OP_WIN_UNPART,
407
510
  WSP_GGML_OP_GET_REL_POS,
408
511
  WSP_GGML_OP_ADD_REL_POS,
512
+ WSP_GGML_OP_RWKV_WKV6,
409
513
 
410
514
  WSP_GGML_OP_UNARY,
411
515
 
@@ -422,6 +526,7 @@ extern "C" {
422
526
 
423
527
  WSP_GGML_OP_CROSS_ENTROPY_LOSS,
424
528
  WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK,
529
+ WSP_GGML_OP_OPT_STEP_ADAMW,
425
530
 
426
531
  WSP_GGML_OP_COUNT,
427
532
  };
@@ -434,41 +539,59 @@ extern "C" {
434
539
  WSP_GGML_UNARY_OP_TANH,
435
540
  WSP_GGML_UNARY_OP_ELU,
436
541
  WSP_GGML_UNARY_OP_RELU,
542
+ WSP_GGML_UNARY_OP_SIGMOID,
437
543
  WSP_GGML_UNARY_OP_GELU,
438
544
  WSP_GGML_UNARY_OP_GELU_QUICK,
439
545
  WSP_GGML_UNARY_OP_SILU,
546
+ WSP_GGML_UNARY_OP_HARDSWISH,
547
+ WSP_GGML_UNARY_OP_HARDSIGMOID,
548
+ WSP_GGML_UNARY_OP_EXP,
549
+
550
+ WSP_GGML_UNARY_OP_COUNT,
440
551
  };
441
552
 
442
553
  enum wsp_ggml_object_type {
443
- WSP_GGML_OBJECT_TENSOR,
444
- WSP_GGML_OBJECT_GRAPH,
445
- WSP_GGML_OBJECT_WORK_BUFFER
554
+ WSP_GGML_OBJECT_TYPE_TENSOR,
555
+ WSP_GGML_OBJECT_TYPE_GRAPH,
556
+ WSP_GGML_OBJECT_TYPE_WORK_BUFFER
446
557
  };
447
558
 
448
- // ggml object
449
- struct wsp_ggml_object {
450
- size_t offs;
451
- size_t size;
452
-
453
- struct wsp_ggml_object * next;
454
-
455
- enum wsp_ggml_object_type type;
559
+ enum wsp_ggml_log_level {
560
+ WSP_GGML_LOG_LEVEL_NONE = 0,
561
+ WSP_GGML_LOG_LEVEL_DEBUG = 1,
562
+ WSP_GGML_LOG_LEVEL_INFO = 2,
563
+ WSP_GGML_LOG_LEVEL_WARN = 3,
564
+ WSP_GGML_LOG_LEVEL_ERROR = 4,
565
+ WSP_GGML_LOG_LEVEL_CONT = 5, // continue previous log
566
+ };
456
567
 
457
- char padding[4];
568
+ // this tensor...
569
+ enum wsp_ggml_tensor_flag {
570
+ WSP_GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
571
+ WSP_GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
572
+ WSP_GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
573
+ WSP_GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
458
574
  };
459
575
 
460
- static const size_t WSP_GGML_OBJECT_SIZE = sizeof(struct wsp_ggml_object);
576
+ struct wsp_ggml_init_params {
577
+ // memory pool
578
+ size_t mem_size; // bytes
579
+ void * mem_buffer; // if NULL, memory will be allocated internally
580
+ bool no_alloc; // don't allocate memory for the tensor data
581
+ };
461
582
 
462
583
  // n-dimensional tensor
463
584
  struct wsp_ggml_tensor {
464
- enum wsp_ggml_type type;
465
- enum wsp_ggml_backend backend;
585
+ enum wsp_ggml_type type;
586
+
587
+ WSP_GGML_DEPRECATED(enum wsp_ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
588
+
589
+ struct wsp_ggml_backend_buffer * buffer;
466
590
 
467
- int n_dims;
468
591
  int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements
469
592
  size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes:
470
- // nb[0] = sizeof(type)
471
- // nb[1] = nb[0] * ne[0] + padding
593
+ // nb[0] = wsp_ggml_type_size(type)
594
+ // nb[1] = nb[0] * (ne[0] / wsp_ggml_blck_size(type)) + padding
472
595
  // nb[i] = nb[i-1] * ne[i-1]
473
596
 
474
597
  // compute data
@@ -477,16 +600,11 @@ extern "C" {
477
600
  // op params - allocated as int32_t for alignment
478
601
  int32_t op_params[WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t)];
479
602
 
480
- bool is_param;
603
+ int32_t flags;
481
604
 
482
- struct wsp_ggml_tensor * grad;
483
605
  struct wsp_ggml_tensor * src[WSP_GGML_MAX_SRC];
484
606
 
485
- // performance
486
- int perf_runs;
487
- int64_t perf_cycles;
488
- int64_t perf_time_us;
489
-
607
+ // source tensor and offset for views
490
608
  struct wsp_ggml_tensor * view_src;
491
609
  size_t view_offs;
492
610
 
@@ -496,86 +614,26 @@ extern "C" {
496
614
 
497
615
  void * extra; // extra things e.g. for ggml-cuda.cu
498
616
 
499
- char padding[4];
617
+ char padding[8];
500
618
  };
501
619
 
502
620
  static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor);
503
621
 
504
- // the compute plan that needs to be prepared for wsp_ggml_graph_compute()
505
- // since https://github.com/ggerganov/ggml/issues/287
506
- struct wsp_ggml_cplan {
507
- size_t work_size; // size of work buffer, calculated by `wsp_ggml_graph_plan()`
508
- uint8_t * work_data; // work buffer, to be allocated by caller before calling to `wsp_ggml_graph_compute()`
509
-
510
- int n_threads;
511
-
512
- // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
513
- int n_tasks[WSP_GGML_MAX_NODES];
514
-
515
- // abort wsp_ggml_graph_compute when true
516
- bool (*abort_callback)(void * data);
517
- void * abort_callback_data;
518
- };
519
-
520
- // next prime after WSP_GGML_MAX_NODES
521
- // #define WSP_GGML_GRAPH_HASHTABLE_SIZE 4099
522
- // next prime after WSP_GGML_MAX_NODES * 2 (nodes + leafs)
523
- #define WSP_GGML_GRAPH_HASHTABLE_SIZE 8273
524
-
525
- // computation graph
526
- struct wsp_ggml_cgraph {
527
- int n_nodes;
528
- int n_leafs;
622
+ // Abort callback
623
+ // If not NULL, called before ggml computation
624
+ // If it returns true, the computation is aborted
625
+ typedef bool (*wsp_ggml_abort_callback)(void * data);
529
626
 
530
- struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES];
531
- struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES];
532
- struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES];
533
627
 
534
- void * visited_hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE];
535
-
536
- // performance
537
- int perf_runs;
538
- int64_t perf_cycles;
539
- int64_t perf_time_us;
540
- };
541
-
542
- static const size_t WSP_GGML_GRAPH_SIZE = sizeof(struct wsp_ggml_cgraph);
543
-
544
- // scratch buffer
545
- struct wsp_ggml_scratch {
546
- size_t offs;
547
- size_t size;
548
- void * data;
549
- };
550
-
551
- struct wsp_ggml_init_params {
552
- // memory pool
553
- size_t mem_size; // bytes
554
- void * mem_buffer; // if NULL, memory will be allocated internally
555
- bool no_alloc; // don't allocate memory for the tensor data
556
- };
557
-
558
-
559
- // compute types
560
-
561
- // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
562
- // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
563
- enum wsp_ggml_task_type {
564
- WSP_GGML_TASK_INIT = 0,
565
- WSP_GGML_TASK_COMPUTE,
566
- WSP_GGML_TASK_FINALIZE,
567
- };
568
-
569
- struct wsp_ggml_compute_params {
570
- enum wsp_ggml_task_type type;
628
+ //
629
+ // GUID
630
+ //
571
631
 
572
- // ith = thread index, nth = number of threads
573
- int ith, nth;
632
+ // GUID types
633
+ typedef uint8_t wsp_ggml_guid[16];
634
+ typedef wsp_ggml_guid * wsp_ggml_guid_t;
574
635
 
575
- // work buffer for all threads
576
- size_t wsize;
577
- void * wdata;
578
- };
636
+ WSP_GGML_API bool wsp_ggml_guid_matches(wsp_ggml_guid_t guid_a, wsp_ggml_guid_t guid_b);
579
637
 
580
638
  // misc
581
639
 
@@ -585,26 +643,32 @@ extern "C" {
585
643
  WSP_GGML_API int64_t wsp_ggml_cycles(void);
586
644
  WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void);
587
645
 
588
- WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems
589
- WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node
646
+ // accepts a UTF-8 path, even on Windows
647
+ WSP_GGML_API FILE * wsp_ggml_fopen(const char * fname, const char * mode);
590
648
 
591
649
  WSP_GGML_API void wsp_ggml_print_object (const struct wsp_ggml_object * obj);
592
650
  WSP_GGML_API void wsp_ggml_print_objects(const struct wsp_ggml_context * ctx);
593
651
 
594
- WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
595
- WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
596
- WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
597
- WSP_GGML_API size_t wsp_ggml_nbytes_pad (const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
598
- WSP_GGML_API size_t wsp_ggml_nbytes_split(const struct wsp_ggml_tensor * tensor, int nrows_split);
652
+ WSP_GGML_API int64_t wsp_ggml_nelements (const struct wsp_ggml_tensor * tensor);
653
+ WSP_GGML_API int64_t wsp_ggml_nrows (const struct wsp_ggml_tensor * tensor);
654
+ WSP_GGML_API size_t wsp_ggml_nbytes (const struct wsp_ggml_tensor * tensor);
655
+ WSP_GGML_API size_t wsp_ggml_nbytes_pad(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_nbytes() but padded to WSP_GGML_MEM_ALIGN
656
+
657
+ WSP_GGML_API int64_t wsp_ggml_blck_size(enum wsp_ggml_type type);
658
+ WSP_GGML_API size_t wsp_ggml_type_size(enum wsp_ggml_type type); // size in bytes for all elements in a block
659
+ WSP_GGML_API size_t wsp_ggml_row_size (enum wsp_ggml_type type, int64_t ne); // size in bytes for all elements in a row
599
660
 
600
- WSP_GGML_API int wsp_ggml_blck_size (enum wsp_ggml_type type);
601
- WSP_GGML_API size_t wsp_ggml_type_size (enum wsp_ggml_type type); // size in bytes for all elements in a block
602
- WSP_GGML_API float wsp_ggml_type_sizef(enum wsp_ggml_type type); // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
661
+ WSP_GGML_DEPRECATED(
662
+ WSP_GGML_API double wsp_ggml_type_sizef(enum wsp_ggml_type type), // wsp_ggml_type_size()/wsp_ggml_blck_size() as float
663
+ "use wsp_ggml_row_size() instead");
603
664
 
604
665
  WSP_GGML_API const char * wsp_ggml_type_name(enum wsp_ggml_type type);
605
666
  WSP_GGML_API const char * wsp_ggml_op_name (enum wsp_ggml_op op);
606
667
  WSP_GGML_API const char * wsp_ggml_op_symbol(enum wsp_ggml_op op);
607
668
 
669
+ WSP_GGML_API const char * wsp_ggml_unary_op_name(enum wsp_ggml_unary_op op);
670
+ WSP_GGML_API const char * wsp_ggml_op_desc(const struct wsp_ggml_tensor * t); // unary or op name
671
+
608
672
  WSP_GGML_API size_t wsp_ggml_element_size(const struct wsp_ggml_tensor * tensor);
609
673
 
610
674
  WSP_GGML_API bool wsp_ggml_is_quantized(enum wsp_ggml_type type);
@@ -613,22 +677,37 @@ extern "C" {
613
677
  WSP_GGML_API enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype);
614
678
 
615
679
  WSP_GGML_API bool wsp_ggml_is_transposed(const struct wsp_ggml_tensor * tensor);
616
- WSP_GGML_API bool wsp_ggml_is_contiguous(const struct wsp_ggml_tensor * tensor);
617
680
  WSP_GGML_API bool wsp_ggml_is_permuted (const struct wsp_ggml_tensor * tensor);
681
+ WSP_GGML_API bool wsp_ggml_is_empty (const struct wsp_ggml_tensor * tensor);
682
+ WSP_GGML_API bool wsp_ggml_is_scalar (const struct wsp_ggml_tensor * tensor);
683
+ WSP_GGML_API bool wsp_ggml_is_vector (const struct wsp_ggml_tensor * tensor);
684
+ WSP_GGML_API bool wsp_ggml_is_matrix (const struct wsp_ggml_tensor * tensor);
685
+ WSP_GGML_API bool wsp_ggml_is_3d (const struct wsp_ggml_tensor * tensor);
686
+ WSP_GGML_API int wsp_ggml_n_dims (const struct wsp_ggml_tensor * tensor); // returns 1 for scalars
687
+
688
+ WSP_GGML_API bool wsp_ggml_is_contiguous (const struct wsp_ggml_tensor * tensor);
689
+ WSP_GGML_API bool wsp_ggml_is_contiguous_0(const struct wsp_ggml_tensor * tensor); // same as wsp_ggml_is_contiguous()
690
+ WSP_GGML_API bool wsp_ggml_is_contiguous_1(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 1
691
+ WSP_GGML_API bool wsp_ggml_is_contiguous_2(const struct wsp_ggml_tensor * tensor); // contiguous for dims >= 2
692
+
693
+ WSP_GGML_API bool wsp_ggml_are_same_shape (const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
694
+ WSP_GGML_API bool wsp_ggml_are_same_stride(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
618
695
 
619
- WSP_GGML_API bool wsp_ggml_are_same_shape(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
696
+ WSP_GGML_API bool wsp_ggml_can_repeat(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1);
620
697
 
621
698
  // use this to compute the memory overhead of a tensor
622
699
  WSP_GGML_API size_t wsp_ggml_tensor_overhead(void);
623
700
 
701
+ WSP_GGML_API bool wsp_ggml_validate_row_data(enum wsp_ggml_type type, const void * data, size_t nbytes);
702
+
624
703
  // main
625
704
 
626
- WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params);
627
- WSP_GGML_API void wsp_ggml_free(struct wsp_ggml_context * ctx);
705
+ WSP_GGML_API struct wsp_ggml_context * wsp_ggml_init (struct wsp_ggml_init_params params);
706
+ WSP_GGML_API void wsp_ggml_reset(struct wsp_ggml_context * ctx);
707
+ WSP_GGML_API void wsp_ggml_free (struct wsp_ggml_context * ctx);
628
708
 
629
709
  WSP_GGML_API size_t wsp_ggml_used_mem(const struct wsp_ggml_context * ctx);
630
710
 
631
- WSP_GGML_API size_t wsp_ggml_set_scratch (struct wsp_ggml_context * ctx, struct wsp_ggml_scratch scratch);
632
711
  WSP_GGML_API bool wsp_ggml_get_no_alloc(struct wsp_ggml_context * ctx);
633
712
  WSP_GGML_API void wsp_ggml_set_no_alloc(struct wsp_ggml_context * ctx, bool no_alloc);
634
713
 
@@ -668,34 +747,35 @@ extern "C" {
668
747
  int64_t ne2,
669
748
  int64_t ne3);
670
749
 
671
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_i32(struct wsp_ggml_context * ctx, int32_t value);
672
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_new_f32(struct wsp_ggml_context * ctx, float value);
750
+ WSP_GGML_API void * wsp_ggml_new_buffer(struct wsp_ggml_context * ctx, size_t nbytes);
673
751
 
674
752
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src);
675
753
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src);
676
754
 
755
+ // Context tensor enumeration and lookup
756
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(const struct wsp_ggml_context * ctx);
757
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (const struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
677
758
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name);
678
759
 
679
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
680
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value);
681
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value);
682
-
683
- WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i);
684
- WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value);
760
+ // Converts a flat index into coordinates
761
+ WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
685
762
 
686
- WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i);
687
- WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value);
763
+ WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
688
764
 
689
765
  WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor);
690
766
  WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor);
691
767
 
692
- WSP_GGML_API enum wsp_ggml_unary_op wsp_ggml_get_unary_op(const struct wsp_ggml_tensor * tensor);
693
-
694
768
  WSP_GGML_API const char * wsp_ggml_get_name (const struct wsp_ggml_tensor * tensor);
695
769
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_name ( struct wsp_ggml_tensor * tensor, const char * name);
696
770
  WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
697
771
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_format_name( struct wsp_ggml_tensor * tensor, const char * fmt, ...);
698
772
 
773
+ // Tensor flags
774
+ WSP_GGML_API void wsp_ggml_set_input(struct wsp_ggml_tensor * tensor);
775
+ WSP_GGML_API void wsp_ggml_set_output(struct wsp_ggml_tensor * tensor);
776
+ WSP_GGML_API void wsp_ggml_set_param(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
777
+ WSP_GGML_API void wsp_ggml_set_loss(struct wsp_ggml_tensor * tensor);
778
+
699
779
  //
700
780
  // operations on tensors with backpropagation
701
781
  //
@@ -719,6 +799,12 @@ extern "C" {
719
799
  struct wsp_ggml_tensor * a,
720
800
  struct wsp_ggml_tensor * b);
721
801
 
802
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_cast(
803
+ struct wsp_ggml_context * ctx,
804
+ struct wsp_ggml_tensor * a,
805
+ struct wsp_ggml_tensor * b,
806
+ enum wsp_ggml_type type);
807
+
722
808
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1(
723
809
  struct wsp_ggml_context * ctx,
724
810
  struct wsp_ggml_tensor * a,
@@ -729,6 +815,9 @@ extern "C" {
729
815
  struct wsp_ggml_tensor * a,
730
816
  struct wsp_ggml_tensor * b);
731
817
 
818
+ // dst = a
819
+ // view(dst, nb1, nb2, nb3, offset) += b
820
+ // return dst
732
821
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_acc(
733
822
  struct wsp_ggml_context * ctx,
734
823
  struct wsp_ggml_tensor * a,
@@ -801,6 +890,22 @@ extern "C" {
801
890
  struct wsp_ggml_context * ctx,
802
891
  struct wsp_ggml_tensor * a);
803
892
 
893
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin(
894
+ struct wsp_ggml_context * ctx,
895
+ struct wsp_ggml_tensor * a);
896
+
897
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sin_inplace(
898
+ struct wsp_ggml_context * ctx,
899
+ struct wsp_ggml_tensor * a);
900
+
901
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos(
902
+ struct wsp_ggml_context * ctx,
903
+ struct wsp_ggml_tensor * a);
904
+
905
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cos_inplace(
906
+ struct wsp_ggml_context * ctx,
907
+ struct wsp_ggml_tensor * a);
908
+
804
909
  // return scalar
805
910
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sum(
806
911
  struct wsp_ggml_context * ctx,
@@ -821,6 +926,12 @@ extern "C" {
821
926
  struct wsp_ggml_context * ctx,
822
927
  struct wsp_ggml_tensor * a);
823
928
 
929
+ // count number of equal elements in a and b
930
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_count_equal(
931
+ struct wsp_ggml_context * ctx,
932
+ struct wsp_ggml_tensor * a,
933
+ struct wsp_ggml_tensor * b);
934
+
824
935
  // if a is the same shape as b, and a is not parameter, return a
825
936
  // otherwise, return a new tensor: repeat(a) to fit in b
826
937
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat(
@@ -828,17 +939,19 @@ extern "C" {
828
939
  struct wsp_ggml_tensor * a,
829
940
  struct wsp_ggml_tensor * b);
830
941
 
942
+ // sums repetitions in a into shape of b
831
943
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back(
832
944
  struct wsp_ggml_context * ctx,
833
945
  struct wsp_ggml_tensor * a,
834
946
  struct wsp_ggml_tensor * b);
835
947
 
836
- // concat a and b on dim 2
948
+ // concat a and b along dim
837
949
  // used in stable-diffusion
838
950
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_concat(
839
951
  struct wsp_ggml_context * ctx,
840
952
  struct wsp_ggml_tensor * a,
841
- struct wsp_ggml_tensor * b);
953
+ struct wsp_ggml_tensor * b,
954
+ int dim);
842
955
 
843
956
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_abs(
844
957
  struct wsp_ggml_context * ctx,
@@ -892,11 +1005,22 @@ extern "C" {
892
1005
  struct wsp_ggml_context * ctx,
893
1006
  struct wsp_ggml_tensor * a);
894
1007
 
1008
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky_relu(
1009
+ struct wsp_ggml_context * ctx,
1010
+ struct wsp_ggml_tensor * a, float negative_slope, bool inplace);
1011
+
895
1012
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
896
1013
  struct wsp_ggml_context * ctx,
897
1014
  struct wsp_ggml_tensor * a);
898
1015
 
899
- // TODO: double-check this computation is correct
1016
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid(
1017
+ struct wsp_ggml_context * ctx,
1018
+ struct wsp_ggml_tensor * a);
1019
+
1020
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_sigmoid_inplace(
1021
+ struct wsp_ggml_context * ctx,
1022
+ struct wsp_ggml_tensor * a);
1023
+
900
1024
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_gelu(
901
1025
  struct wsp_ggml_context * ctx,
902
1026
  struct wsp_ggml_tensor * a);
@@ -928,6 +1052,24 @@ extern "C" {
928
1052
  struct wsp_ggml_tensor * a,
929
1053
  struct wsp_ggml_tensor * b);
930
1054
 
1055
+ // hardswish(x) = x * relu6(x + 3) / 6
1056
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardswish(
1057
+ struct wsp_ggml_context * ctx,
1058
+ struct wsp_ggml_tensor * a);
1059
+
1060
+ // hardsigmoid(x) = relu6(x + 3) / 6
1061
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_hardsigmoid(
1062
+ struct wsp_ggml_context * ctx,
1063
+ struct wsp_ggml_tensor * a);
1064
+
1065
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp(
1066
+ struct wsp_ggml_context * ctx,
1067
+ struct wsp_ggml_tensor * a);
1068
+
1069
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_exp_inplace(
1070
+ struct wsp_ggml_context * ctx,
1071
+ struct wsp_ggml_tensor * a);
1072
+
931
1073
  // normalize along rows
932
1074
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_norm(
933
1075
  struct wsp_ggml_context * ctx,
@@ -951,16 +1093,17 @@ extern "C" {
951
1093
 
952
1094
  // group normalize along ne0*ne1*n_groups
953
1095
  // used in stable-diffusion
954
- // TODO: eps is hardcoded to 1e-6 for now
955
1096
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm(
956
1097
  struct wsp_ggml_context * ctx,
957
1098
  struct wsp_ggml_tensor * a,
958
- int n_groups);
1099
+ int n_groups,
1100
+ float eps);
959
1101
 
960
1102
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_group_norm_inplace(
961
1103
  struct wsp_ggml_context * ctx,
962
1104
  struct wsp_ggml_tensor * a,
963
- int n_groups);
1105
+ int n_groups,
1106
+ float eps);
964
1107
 
965
1108
  // a - x
966
1109
  // b - dy
@@ -970,14 +1113,27 @@ extern "C" {
970
1113
  struct wsp_ggml_tensor * b,
971
1114
  float eps);
972
1115
 
973
- // A: n columns, m rows
974
- // B: n columns, p rows (i.e. we transpose it internally)
975
- // result is m columns, p rows
1116
+ // A: k columns, n rows => [ne03, ne02, n, k]
1117
+ // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
1118
+ // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
976
1119
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat(
977
1120
  struct wsp_ggml_context * ctx,
978
1121
  struct wsp_ggml_tensor * a,
979
1122
  struct wsp_ggml_tensor * b);
980
1123
 
1124
+ // change the precision of a matrix multiplication
1125
+ // set to WSP_GGML_PREC_F32 for higher precision (useful for phi-2)
1126
+ WSP_GGML_API void wsp_ggml_mul_mat_set_prec(
1127
+ struct wsp_ggml_tensor * a,
1128
+ enum wsp_ggml_prec prec);
1129
+
1130
+ // indirect matrix multiplication
1131
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
1132
+ struct wsp_ggml_context * ctx,
1133
+ struct wsp_ggml_tensor * as,
1134
+ struct wsp_ggml_tensor * b,
1135
+ struct wsp_ggml_tensor * ids);
1136
+
981
1137
  // A: m columns, n rows,
982
1138
  // B: p columns, n rows,
983
1139
  // result is m columns, p rows
@@ -993,13 +1149,13 @@ extern "C" {
993
1149
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale(
994
1150
  struct wsp_ggml_context * ctx,
995
1151
  struct wsp_ggml_tensor * a,
996
- struct wsp_ggml_tensor * b);
1152
+ float s);
997
1153
 
998
1154
  // in-place, returns view(a)
999
1155
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_scale_inplace(
1000
1156
  struct wsp_ggml_context * ctx,
1001
1157
  struct wsp_ggml_tensor * a,
1002
- struct wsp_ggml_tensor * b);
1158
+ float s);
1003
1159
 
1004
1160
  // b -> view(a,offset,nb1,nb2,3), return modified a
1005
1161
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set(
@@ -1009,7 +1165,7 @@ extern "C" {
1009
1165
  size_t nb1,
1010
1166
  size_t nb2,
1011
1167
  size_t nb3,
1012
- size_t offset);
1168
+ size_t offset); // in bytes
1013
1169
 
1014
1170
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1015
1171
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_inplace(
@@ -1019,19 +1175,19 @@ extern "C" {
1019
1175
  size_t nb1,
1020
1176
  size_t nb2,
1021
1177
  size_t nb3,
1022
- size_t offset);
1178
+ size_t offset); // in bytes
1023
1179
 
1024
1180
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d(
1025
1181
  struct wsp_ggml_context * ctx,
1026
1182
  struct wsp_ggml_tensor * a,
1027
1183
  struct wsp_ggml_tensor * b,
1028
- size_t offset);
1184
+ size_t offset); // in bytes
1029
1185
 
1030
1186
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_1d_inplace(
1031
1187
  struct wsp_ggml_context * ctx,
1032
1188
  struct wsp_ggml_tensor * a,
1033
1189
  struct wsp_ggml_tensor * b,
1034
- size_t offset);
1190
+ size_t offset); // in bytes
1035
1191
 
1036
1192
  // b -> view(a,offset,nb1,nb2,3), return modified a
1037
1193
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d(
@@ -1039,7 +1195,7 @@ extern "C" {
1039
1195
  struct wsp_ggml_tensor * a,
1040
1196
  struct wsp_ggml_tensor * b,
1041
1197
  size_t nb1,
1042
- size_t offset);
1198
+ size_t offset); // in bytes
1043
1199
 
1044
1200
  // b -> view(a,offset,nb1,nb2,3), return view(a)
1045
1201
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace(
@@ -1047,8 +1203,7 @@ extern "C" {
1047
1203
  struct wsp_ggml_tensor * a,
1048
1204
  struct wsp_ggml_tensor * b,
1049
1205
  size_t nb1,
1050
- size_t offset);
1051
-
1206
+ size_t offset); // in bytes
1052
1207
 
1053
1208
  // a -> b, return view(b)
1054
1209
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy(
@@ -1056,21 +1211,42 @@ extern "C" {
1056
1211
  struct wsp_ggml_tensor * a,
1057
1212
  struct wsp_ggml_tensor * b);
1058
1213
 
1059
- // a -> b, in-place, return view(b)
1060
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy_inplace(
1214
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cast(
1061
1215
  struct wsp_ggml_context * ctx,
1062
1216
  struct wsp_ggml_tensor * a,
1063
- struct wsp_ggml_tensor * b);
1217
+ enum wsp_ggml_type type);
1064
1218
 
1065
1219
  // make contiguous
1066
1220
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont(
1067
1221
  struct wsp_ggml_context * ctx,
1068
1222
  struct wsp_ggml_tensor * a);
1069
1223
 
1070
- // make contiguous, in-place
1071
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_inplace(
1224
+ // make contiguous, with new shape
1225
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d(
1072
1226
  struct wsp_ggml_context * ctx,
1073
- struct wsp_ggml_tensor * a);
1227
+ struct wsp_ggml_tensor * a,
1228
+ int64_t ne0);
1229
+
1230
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_2d(
1231
+ struct wsp_ggml_context * ctx,
1232
+ struct wsp_ggml_tensor * a,
1233
+ int64_t ne0,
1234
+ int64_t ne1);
1235
+
1236
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_3d(
1237
+ struct wsp_ggml_context * ctx,
1238
+ struct wsp_ggml_tensor * a,
1239
+ int64_t ne0,
1240
+ int64_t ne1,
1241
+ int64_t ne2);
1242
+
1243
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_4d(
1244
+ struct wsp_ggml_context * ctx,
1245
+ struct wsp_ggml_tensor * a,
1246
+ int64_t ne0,
1247
+ int64_t ne1,
1248
+ int64_t ne2,
1249
+ int64_t ne3);
1074
1250
 
1075
1251
  // return view(a), b specifies the new shape
1076
1252
  // TODO: when we start computing gradient, make a copy instead of view
@@ -1159,16 +1335,17 @@ extern "C" {
1159
1335
  struct wsp_ggml_context * ctx,
1160
1336
  struct wsp_ggml_tensor * a);
1161
1337
 
1338
+ // supports 3D: a->ne[2] == b->ne[1]
1162
1339
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows(
1163
1340
  struct wsp_ggml_context * ctx,
1164
- struct wsp_ggml_tensor * a,
1165
- struct wsp_ggml_tensor * b);
1341
+ struct wsp_ggml_tensor * a, // data
1342
+ struct wsp_ggml_tensor * b); // row indices
1166
1343
 
1167
1344
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_rows_back(
1168
1345
  struct wsp_ggml_context * ctx,
1169
- struct wsp_ggml_tensor * a,
1170
- struct wsp_ggml_tensor * b,
1171
- struct wsp_ggml_tensor * c);
1346
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_get_rows result
1347
+ struct wsp_ggml_tensor * b, // row indices
1348
+ struct wsp_ggml_tensor * c); // data for wsp_ggml_get_rows, only used for its shape
1172
1349
 
1173
1350
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_diag(
1174
1351
  struct wsp_ggml_context * ctx,
@@ -1207,6 +1384,16 @@ extern "C" {
1207
1384
  struct wsp_ggml_context * ctx,
1208
1385
  struct wsp_ggml_tensor * a);
1209
1386
 
1387
+ // fused soft_max(a*scale + mask*(ALiBi slope))
1388
+ // mask is optional
1389
+ // max_bias = 0.0f for no ALiBi
1390
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_ext(
1391
+ struct wsp_ggml_context * ctx,
1392
+ struct wsp_ggml_tensor * a,
1393
+ struct wsp_ggml_tensor * mask,
1394
+ float scale,
1395
+ float max_bias);
1396
+
1210
1397
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_soft_max_back(
1211
1398
  struct wsp_ggml_context * ctx,
1212
1399
  struct wsp_ggml_tensor * a,
@@ -1219,93 +1406,160 @@ extern "C" {
1219
1406
  struct wsp_ggml_tensor * b);
1220
1407
 
1221
1408
  // rotary position embedding
1222
- // if mode & 1 == 1, skip n_past elements
1223
- // if mode & 2 == 1, GPT-NeoX style
1224
- // if mode & 4 == 1, ChatGLM style
1225
- // TODO: avoid creating a new tensor every time
1409
+ // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
1410
+ // if (mode & WSP_GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
1411
+ //
1412
+ // b is an int32 vector with size a->ne[2], it contains the positions
1226
1413
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope(
1227
1414
  struct wsp_ggml_context * ctx,
1228
1415
  struct wsp_ggml_tensor * a,
1229
- int n_past,
1416
+ struct wsp_ggml_tensor * b,
1230
1417
  int n_dims,
1231
- int mode,
1232
- int n_ctx);
1418
+ int mode);
1233
1419
 
1234
1420
  // in-place, returns view(a)
1235
1421
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace(
1236
1422
  struct wsp_ggml_context * ctx,
1237
1423
  struct wsp_ggml_tensor * a,
1238
- int n_past,
1424
+ struct wsp_ggml_tensor * b,
1239
1425
  int n_dims,
1240
- int mode,
1241
- int n_ctx);
1426
+ int mode);
1242
1427
 
1243
1428
  // custom RoPE
1244
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1429
+ // c is freq factors (e.g. phi3-128k), (optional)
1430
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext(
1245
1431
  struct wsp_ggml_context * ctx,
1246
1432
  struct wsp_ggml_tensor * a,
1247
- int n_past,
1433
+ struct wsp_ggml_tensor * b,
1434
+ struct wsp_ggml_tensor * c,
1248
1435
  int n_dims,
1249
1436
  int mode,
1250
- int n_ctx,
1437
+ int n_ctx_orig,
1251
1438
  float freq_base,
1252
- float freq_scale);
1439
+ float freq_scale,
1440
+ float ext_factor,
1441
+ float attn_factor,
1442
+ float beta_fast,
1443
+ float beta_slow);
1253
1444
 
1254
1445
  // in-place, returns view(a)
1255
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1446
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_ext_inplace(
1256
1447
  struct wsp_ggml_context * ctx,
1257
1448
  struct wsp_ggml_tensor * a,
1258
- int n_past,
1449
+ struct wsp_ggml_tensor * b,
1450
+ struct wsp_ggml_tensor * c,
1259
1451
  int n_dims,
1260
1452
  int mode,
1261
- int n_ctx,
1453
+ int n_ctx_orig,
1262
1454
  float freq_base,
1263
- float freq_scale);
1455
+ float freq_scale,
1456
+ float ext_factor,
1457
+ float attn_factor,
1458
+ float beta_fast,
1459
+ float beta_slow);
1264
1460
 
1265
- // xPos RoPE, in-place, returns view(a)
1266
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace(
1461
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom(
1267
1462
  struct wsp_ggml_context * ctx,
1268
1463
  struct wsp_ggml_tensor * a,
1269
- int n_past,
1464
+ struct wsp_ggml_tensor * b,
1270
1465
  int n_dims,
1271
- float base,
1272
- bool down);
1466
+ int mode,
1467
+ int n_ctx_orig,
1468
+ float freq_base,
1469
+ float freq_scale,
1470
+ float ext_factor,
1471
+ float attn_factor,
1472
+ float beta_fast,
1473
+ float beta_slow),
1474
+ "use wsp_ggml_rope_ext instead");
1273
1475
 
1274
- // rotary position embedding backward, i.e compute dx from dy
1275
- // a - dy
1276
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1476
+ WSP_GGML_DEPRECATED(WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace(
1277
1477
  struct wsp_ggml_context * ctx,
1278
1478
  struct wsp_ggml_tensor * a,
1279
- int n_past,
1479
+ struct wsp_ggml_tensor * b,
1280
1480
  int n_dims,
1281
1481
  int mode,
1282
- int n_ctx,
1482
+ int n_ctx_orig,
1283
1483
  float freq_base,
1284
1484
  float freq_scale,
1285
- float xpos_base,
1286
- bool xpos_down);
1485
+ float ext_factor,
1486
+ float attn_factor,
1487
+ float beta_fast,
1488
+ float beta_slow),
1489
+ "use wsp_ggml_rope_ext_inplace instead");
1287
1490
 
1288
- // alibi position embedding
1289
- // in-place, returns view(a)
1290
- struct wsp_ggml_tensor * wsp_ggml_alibi(
1491
+ // compute correction dims for YaRN RoPE scaling
1492
+ WSP_GGML_API void wsp_ggml_rope_yarn_corr_dims(
1493
+ int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1494
+
1495
+ // rotary position embedding backward, i.e compute dx from dy
1496
+ // a - dy
1497
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back(
1291
1498
  struct wsp_ggml_context * ctx,
1292
- struct wsp_ggml_tensor * a,
1293
- int n_past,
1294
- int n_head,
1295
- float bias_max);
1499
+ struct wsp_ggml_tensor * a, // gradients of wsp_ggml_rope result
1500
+ struct wsp_ggml_tensor * b, // positions
1501
+ struct wsp_ggml_tensor * c, // freq factors
1502
+ int n_dims,
1503
+ int mode,
1504
+ int n_ctx_orig,
1505
+ float freq_base,
1506
+ float freq_scale,
1507
+ float ext_factor,
1508
+ float attn_factor,
1509
+ float beta_fast,
1510
+ float beta_slow);
1296
1511
 
1297
1512
  // clamp
1298
1513
  // in-place, returns view(a)
1299
- struct wsp_ggml_tensor * wsp_ggml_clamp(
1514
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_clamp(
1300
1515
  struct wsp_ggml_context * ctx,
1301
1516
  struct wsp_ggml_tensor * a,
1302
1517
  float min,
1303
1518
  float max);
1304
1519
 
1520
+ // im2col
1521
+ // converts data into a format that effectively results in a convolution when combined with matrix multiplication
1522
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col(
1523
+ struct wsp_ggml_context * ctx,
1524
+ struct wsp_ggml_tensor * a, // convolution kernel
1525
+ struct wsp_ggml_tensor * b, // data
1526
+ int s0, // stride dimension 0
1527
+ int s1, // stride dimension 1
1528
+ int p0, // padding dimension 0
1529
+ int p1, // padding dimension 1
1530
+ int d0, // dilation dimension 0
1531
+ int d1, // dilation dimension 1
1532
+ bool is_2D,
1533
+ enum wsp_ggml_type dst_type);
1534
+
1535
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_im2col_back(
1536
+ struct wsp_ggml_context * ctx,
1537
+ struct wsp_ggml_tensor * a, // convolution kernel
1538
+ struct wsp_ggml_tensor * b, // gradient of im2col output
1539
+ int64_t * ne, // shape of im2col input
1540
+ int s0, // stride dimension 0
1541
+ int s1, // stride dimension 1
1542
+ int p0, // padding dimension 0
1543
+ int p1, // padding dimension 1
1544
+ int d0, // dilation dimension 0
1545
+ int d1, // dilation dimension 1
1546
+ bool is_2D);
1547
+
1548
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_depthwise_2d(
1549
+ struct wsp_ggml_context * ctx,
1550
+ struct wsp_ggml_tensor * a, // convolution kernel
1551
+ struct wsp_ggml_tensor * b, // data
1552
+ int s0, // stride dimension 0
1553
+ int s1, // stride dimension 1
1554
+ int p0, // padding dimension 0
1555
+ int p1, // padding dimension 1
1556
+ int d0, // dilation dimension 0
1557
+ int d1); // dilation dimension 1
1558
+
1305
1559
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d(
1306
1560
  struct wsp_ggml_context * ctx,
1307
- struct wsp_ggml_tensor * a,
1308
- struct wsp_ggml_tensor * b,
1561
+ struct wsp_ggml_tensor * a, // convolution kernel
1562
+ struct wsp_ggml_tensor * b, // data
1309
1563
  int s0, // stride
1310
1564
  int p0, // padding
1311
1565
  int d0); // dilation
@@ -1314,21 +1568,29 @@ extern "C" {
1314
1568
  // alias for wsp_ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1315
1569
  WSP_GGML_API struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph(
1316
1570
  struct wsp_ggml_context * ctx,
1317
- struct wsp_ggml_tensor * a,
1318
- struct wsp_ggml_tensor * b,
1319
- int s,
1320
- int d);
1571
+ struct wsp_ggml_tensor * a, // convolution kernel
1572
+ struct wsp_ggml_tensor * b, // data
1573
+ int s, // stride
1574
+ int d); // dilation
1575
+
1576
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d(
1577
+ struct wsp_ggml_context * ctx,
1578
+ struct wsp_ggml_tensor * a, // convolution kernel
1579
+ struct wsp_ggml_tensor * b, // data
1580
+ int s0, // stride
1581
+ int p0, // padding
1582
+ int d0); // dilation
1321
1583
 
1322
1584
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d(
1323
1585
  struct wsp_ggml_context * ctx,
1324
- struct wsp_ggml_tensor * a,
1325
- struct wsp_ggml_tensor * b,
1326
- int s0,
1327
- int s1,
1328
- int p0,
1329
- int p1,
1330
- int d0,
1331
- int d1);
1586
+ struct wsp_ggml_tensor * a, // convolution kernel
1587
+ struct wsp_ggml_tensor * b, // data
1588
+ int s0, // stride dimension 0
1589
+ int s1, // stride dimension 1
1590
+ int p0, // padding dimension 0
1591
+ int p1, // padding dimension 1
1592
+ int d0, // dilation dimension 0
1593
+ int d1); // dilation dimension 1
1332
1594
 
1333
1595
 
1334
1596
  // kernel size is a->ne[0] x a->ne[1]
@@ -1377,6 +1639,8 @@ extern "C" {
1377
1639
  int s0, // stride
1378
1640
  int p0); // padding
1379
1641
 
1642
+ // the result will have 2*p0 padding for the first dimension
1643
+ // and 2*p1 padding for the second dimension
1380
1644
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d(
1381
1645
  struct wsp_ggml_context * ctx,
1382
1646
  struct wsp_ggml_tensor * a,
@@ -1385,23 +1649,106 @@ extern "C" {
1385
1649
  int k1,
1386
1650
  int s0,
1387
1651
  int s1,
1388
- int p0,
1389
- int p1);
1652
+ float p0,
1653
+ float p1);
1654
+
1655
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d_back(
1656
+ struct wsp_ggml_context * ctx,
1657
+ struct wsp_ggml_tensor * a,
1658
+ struct wsp_ggml_tensor * af, // "a"/input used in forward pass
1659
+ enum wsp_ggml_op_pool op,
1660
+ int k0,
1661
+ int k1,
1662
+ int s0,
1663
+ int s1,
1664
+ float p0,
1665
+ float p1);
1390
1666
 
1391
1667
  // nearest interpolate
1668
+ // multiplies ne0 and ne1 by scale factor
1392
1669
  // used in stable-diffusion
1393
1670
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale(
1394
1671
  struct wsp_ggml_context * ctx,
1395
1672
  struct wsp_ggml_tensor * a,
1396
1673
  int scale_factor);
1397
1674
 
1398
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn(
1675
+ // nearest interpolate
1676
+ // nearest interpolate to specified dimensions
1677
+ // used in tortoise.cpp
1678
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_upscale_ext(
1679
+ struct wsp_ggml_context * ctx,
1680
+ struct wsp_ggml_tensor * a,
1681
+ int ne0,
1682
+ int ne1,
1683
+ int ne2,
1684
+ int ne3);
1685
+
1686
+ // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
1687
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pad(
1688
+ struct wsp_ggml_context * ctx,
1689
+ struct wsp_ggml_tensor * a,
1690
+ int p0,
1691
+ int p1,
1692
+ int p2,
1693
+ int p3);
1694
+
1695
+ // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
1696
+ // timesteps: [N,]
1697
+ // return: [N, dim]
1698
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_timestep_embedding(
1699
+ struct wsp_ggml_context * ctx,
1700
+ struct wsp_ggml_tensor * timesteps,
1701
+ int dim,
1702
+ int max_period);
1703
+
1704
+ // sort rows
1705
+ enum wsp_ggml_sort_order {
1706
+ WSP_GGML_SORT_ORDER_ASC,
1707
+ WSP_GGML_SORT_ORDER_DESC,
1708
+ };
1709
+
1710
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_argsort(
1711
+ struct wsp_ggml_context * ctx,
1712
+ struct wsp_ggml_tensor * a,
1713
+ enum wsp_ggml_sort_order order);
1714
+
1715
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_arange(
1716
+ struct wsp_ggml_context * ctx,
1717
+ float start,
1718
+ float stop,
1719
+ float step);
1720
+
1721
+ // top k elements per row
1722
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_top_k(
1723
+ struct wsp_ggml_context * ctx,
1724
+ struct wsp_ggml_tensor * a,
1725
+ int k);
1726
+
1727
+ #define WSP_GGML_KQ_MASK_PAD 32
1728
+
1729
+ // q: [n_embd, n_batch, n_head, 1]
1730
+ // k: [n_embd, n_kv, n_head_kv, 1]
1731
+ // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
1732
+ // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = WSP_GGML_PAD(n_batch, WSP_GGML_KQ_MASK_PAD) !!
1733
+ // res: [n_embd, n_head, n_batch, 1] !! permuted !!
1734
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_ext(
1399
1735
  struct wsp_ggml_context * ctx,
1400
1736
  struct wsp_ggml_tensor * q,
1401
1737
  struct wsp_ggml_tensor * k,
1402
1738
  struct wsp_ggml_tensor * v,
1403
- bool masked);
1739
+ struct wsp_ggml_tensor * mask,
1740
+ float scale,
1741
+ float max_bias,
1742
+ float logit_softcap);
1743
+
1744
+ WSP_GGML_API void wsp_ggml_flash_attn_ext_set_prec(
1745
+ struct wsp_ggml_tensor * a,
1746
+ enum wsp_ggml_prec prec);
1747
+
1748
+ WSP_GGML_API enum wsp_ggml_prec wsp_ggml_flash_attn_ext_get_prec(
1749
+ const struct wsp_ggml_tensor * a);
1404
1750
 
1751
+ // TODO: needs to be adapted to wsp_ggml_flash_attn_ext
1405
1752
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_attn_back(
1406
1753
  struct wsp_ggml_context * ctx,
1407
1754
  struct wsp_ggml_tensor * q,
@@ -1410,13 +1757,19 @@ extern "C" {
1410
1757
  struct wsp_ggml_tensor * d,
1411
1758
  bool masked);
1412
1759
 
1413
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_flash_ff(
1760
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_conv(
1414
1761
  struct wsp_ggml_context * ctx,
1415
- struct wsp_ggml_tensor * a,
1416
- struct wsp_ggml_tensor * b0,
1417
- struct wsp_ggml_tensor * b1,
1418
- struct wsp_ggml_tensor * c0,
1419
- struct wsp_ggml_tensor * c1);
1762
+ struct wsp_ggml_tensor * sx,
1763
+ struct wsp_ggml_tensor * c);
1764
+
1765
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_ssm_scan(
1766
+ struct wsp_ggml_context * ctx,
1767
+ struct wsp_ggml_tensor * s,
1768
+ struct wsp_ggml_tensor * x,
1769
+ struct wsp_ggml_tensor * dt,
1770
+ struct wsp_ggml_tensor * A,
1771
+ struct wsp_ggml_tensor * B,
1772
+ struct wsp_ggml_tensor * C);
1420
1773
 
1421
1774
  // partition into non-overlapping windows with padding if needed
1422
1775
  // example:
@@ -1456,7 +1809,6 @@ extern "C" {
1456
1809
  int kh);
1457
1810
 
1458
1811
  // used in sam
1459
-
1460
1812
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_rel_pos(
1461
1813
  struct wsp_ggml_context * ctx,
1462
1814
  struct wsp_ggml_tensor * a,
@@ -1469,6 +1821,15 @@ extern "C" {
1469
1821
  struct wsp_ggml_tensor * pw,
1470
1822
  struct wsp_ggml_tensor * ph);
1471
1823
 
1824
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rwkv_wkv6(
1825
+ struct wsp_ggml_context * ctx,
1826
+ struct wsp_ggml_tensor * k,
1827
+ struct wsp_ggml_tensor * v,
1828
+ struct wsp_ggml_tensor * r,
1829
+ struct wsp_ggml_tensor * tf,
1830
+ struct wsp_ggml_tensor * td,
1831
+ struct wsp_ggml_tensor * state);
1832
+
1472
1833
  // custom operators
1473
1834
 
1474
1835
  typedef void (*wsp_ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1552,7 +1913,8 @@ extern "C" {
1552
1913
  typedef void (*wsp_ggml_custom2_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, int ith, int nth, void * userdata);
1553
1914
  typedef void (*wsp_ggml_custom3_op_t)(struct wsp_ggml_tensor * dst , const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b, const struct wsp_ggml_tensor * c, int ith, int nth, void * userdata);
1554
1915
 
1555
- #define WSP_GGML_N_TASKS_MAX -1
1916
+ #define WSP_GGML_N_TASKS_MAX (-1)
1917
+ // n_tasks == WSP_GGML_N_TASKS_MAX means to use max number of tasks
1556
1918
 
1557
1919
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_map_custom1(
1558
1920
  struct wsp_ggml_context * ctx,
@@ -1605,50 +1967,62 @@ extern "C" {
1605
1967
  // loss function
1606
1968
 
1607
1969
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss(
1608
- struct wsp_ggml_context * ctx,
1609
- struct wsp_ggml_tensor * a,
1610
- struct wsp_ggml_tensor * b);
1970
+ struct wsp_ggml_context * ctx,
1971
+ struct wsp_ggml_tensor * a, // logits
1972
+ struct wsp_ggml_tensor * b); // labels
1611
1973
 
1612
1974
  WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss_back(
1613
- struct wsp_ggml_context * ctx,
1614
- struct wsp_ggml_tensor * a,
1615
- struct wsp_ggml_tensor * b,
1616
- struct wsp_ggml_tensor * c);
1975
+ struct wsp_ggml_context * ctx,
1976
+ struct wsp_ggml_tensor * a, // logits
1977
+ struct wsp_ggml_tensor * b, // labels
1978
+ struct wsp_ggml_tensor * c); // gradients of cross_entropy_loss result
1979
+
1980
+ // AdamW optimizer step
1981
+ // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
1982
+ // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
1983
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_opt_step_adamw(
1984
+ struct wsp_ggml_context * ctx,
1985
+ struct wsp_ggml_tensor * a,
1986
+ struct wsp_ggml_tensor * grad,
1987
+ struct wsp_ggml_tensor * m,
1988
+ struct wsp_ggml_tensor * v,
1989
+ struct wsp_ggml_tensor * adamw_params); // parameters such a the learning rate
1617
1990
 
1618
1991
  //
1619
1992
  // automatic differentiation
1620
1993
  //
1621
1994
 
1622
- WSP_GGML_API void wsp_ggml_set_param(
1623
- struct wsp_ggml_context * ctx,
1624
- struct wsp_ggml_tensor * tensor);
1995
+ WSP_GGML_API void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1996
+ WSP_GGML_API void wsp_ggml_build_backward_expand(
1997
+ struct wsp_ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
1998
+ struct wsp_ggml_context * ctx_compute, // context for gradient computation
1999
+ struct wsp_ggml_cgraph * cgraph,
2000
+ bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
1625
2001
 
2002
+ // graph allocation in a context
2003
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false
2004
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads);
2005
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph);
2006
+ WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst);
2007
+ WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
2008
+ WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph);
1626
2009
 
1627
- WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1628
- WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep);
2010
+ WSP_GGML_API int wsp_ggml_graph_size (struct wsp_ggml_cgraph * cgraph);
2011
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_node (struct wsp_ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
2012
+ WSP_GGML_API struct wsp_ggml_tensor ** wsp_ggml_graph_nodes (struct wsp_ggml_cgraph * cgraph);
2013
+ WSP_GGML_API int wsp_ggml_graph_n_nodes(struct wsp_ggml_cgraph * cgraph);
1629
2014
 
1630
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor);
1631
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep);
2015
+ WSP_GGML_API void wsp_ggml_graph_add_node(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor);
1632
2016
 
1633
- // graph allocation in a context
1634
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx);
1635
- WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor);
1636
2017
  WSP_GGML_API size_t wsp_ggml_graph_overhead(void);
2018
+ WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads);
1637
2019
 
1638
- // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute()
1639
- // when plan.work_size > 0, caller must allocate memory for plan.work_data
1640
- WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/);
1641
- WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan);
1642
- WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph);
1643
-
1644
- // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context
1645
- // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1646
- WSP_GGML_API void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads);
1647
-
1648
- WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name);
2020
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor (const struct wsp_ggml_cgraph * cgraph, const char * name);
2021
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad (const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
2022
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_grad_acc(const struct wsp_ggml_cgraph * cgraph, const struct wsp_ggml_tensor * node);
1649
2023
 
1650
- WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
1651
- WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
2024
+ WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname);
2025
+ WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval);
1652
2026
 
1653
2027
  // print info and performance information for the graph
1654
2028
  WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph);
@@ -1656,335 +2030,169 @@ extern "C" {
1656
2030
  // dump the graph into a file using the dot format
1657
2031
  WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename);
1658
2032
 
1659
- //
1660
- // optimization
1661
- //
1662
-
1663
- // optimization methods
1664
- enum wsp_ggml_opt_type {
1665
- WSP_GGML_OPT_ADAM,
1666
- WSP_GGML_OPT_LBFGS,
1667
- };
2033
+ // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
2034
+ typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data);
1668
2035
 
1669
- // linesearch methods
1670
- enum wsp_ggml_linesearch {
1671
- WSP_GGML_LINESEARCH_DEFAULT = 1,
2036
+ // Set callback for all future logging events.
2037
+ // If this is not called, or NULL is supplied, everything is output on stderr.
2038
+ WSP_GGML_API void wsp_ggml_log_set(wsp_ggml_log_callback log_callback, void * user_data);
1672
2039
 
1673
- WSP_GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
1674
- WSP_GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
1675
- WSP_GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
1676
- };
1677
-
1678
- // optimization return values
1679
- enum wsp_ggml_opt_result {
1680
- WSP_GGML_OPT_OK = 0,
1681
- WSP_GGML_OPT_DID_NOT_CONVERGE,
1682
- WSP_GGML_OPT_NO_CONTEXT,
1683
- WSP_GGML_OPT_INVALID_WOLFE,
1684
- WSP_GGML_OPT_FAIL,
1685
-
1686
- WSP_GGML_LINESEARCH_FAIL = -128,
1687
- WSP_GGML_LINESEARCH_MINIMUM_STEP,
1688
- WSP_GGML_LINESEARCH_MAXIMUM_STEP,
1689
- WSP_GGML_LINESEARCH_MAXIMUM_ITERATIONS,
1690
- WSP_GGML_LINESEARCH_INVALID_PARAMETERS,
1691
- };
1692
-
1693
- typedef void (*wsp_ggml_opt_callback)(void * data, float * sched);
2040
+ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor);
1694
2041
 
1695
- // optimization parameters
1696
2042
  //
1697
- // see ggml.c (wsp_ggml_opt_default_params) for default values
2043
+ // quantization
1698
2044
  //
1699
- struct wsp_ggml_opt_params {
1700
- enum wsp_ggml_opt_type type;
1701
-
1702
- int n_threads;
1703
-
1704
- // delta-based convergence test
1705
- //
1706
- // if past == 0 - disabled
1707
- // if past > 0:
1708
- // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
1709
- //
1710
- int past;
1711
- float delta;
1712
-
1713
- // maximum number of iterations without improvement
1714
- //
1715
- // if 0 - disabled
1716
- // if > 0:
1717
- // assume convergence if no cost improvement in this number of iterations
1718
- //
1719
- int max_no_improvement;
1720
-
1721
- bool print_forward_graph;
1722
- bool print_backward_graph;
1723
-
1724
- // ADAM parameters
1725
- struct {
1726
- int n_iter;
1727
-
1728
- float sched; // schedule multiplier (fixed, decay or warmup)
1729
- float decay; // weight decay for AdamW, use 0.0f to disable
1730
- int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
1731
- float alpha; // learning rate
1732
- float beta1;
1733
- float beta2;
1734
- float eps; // epsilon for numerical stability
1735
- float eps_f; // epsilon for convergence test
1736
- float eps_g; // epsilon for convergence test
1737
- float gclip; // gradient clipping
1738
- } adam;
1739
-
1740
- // LBFGS parameters
1741
- struct {
1742
- int m; // number of corrections to approximate the inv. Hessian
1743
- int n_iter;
1744
- int max_linesearch;
1745
-
1746
- float eps; // convergence tolerance
1747
- float ftol; // line search tolerance
1748
- float wolfe;
1749
- float min_step;
1750
- float max_step;
1751
-
1752
- enum wsp_ggml_linesearch linesearch;
1753
- } lbfgs;
1754
- };
1755
-
1756
- struct wsp_ggml_opt_context {
1757
- struct wsp_ggml_context * ctx;
1758
- struct wsp_ggml_opt_params params;
1759
-
1760
- int iter;
1761
- int64_t nx; // number of parameter elements
1762
-
1763
- bool just_initialized;
1764
-
1765
- float loss_before;
1766
- float loss_after;
1767
-
1768
- struct {
1769
- struct wsp_ggml_tensor * m; // first moment
1770
- struct wsp_ggml_tensor * v; // second moment
1771
- struct wsp_ggml_tensor * pf; // past function values
1772
- float fx_best;
1773
- float fx_prev;
1774
- int n_no_improvement;
1775
- } adam;
1776
-
1777
- struct {
1778
- struct wsp_ggml_tensor * x; // current parameters
1779
- struct wsp_ggml_tensor * xp; // previous parameters
1780
- struct wsp_ggml_tensor * g; // current gradient
1781
- struct wsp_ggml_tensor * gp; // previous gradient
1782
- struct wsp_ggml_tensor * d; // search direction
1783
- struct wsp_ggml_tensor * pf; // past function values
1784
- struct wsp_ggml_tensor * lmal; // the L-BFGS memory alpha
1785
- struct wsp_ggml_tensor * lmys; // the L-BFGS memory ys
1786
- struct wsp_ggml_tensor * lms; // the L-BFGS memory s
1787
- struct wsp_ggml_tensor * lmy; // the L-BFGS memory y
1788
- float fx_best;
1789
- float step;
1790
- int j;
1791
- int k;
1792
- int end;
1793
- int n_no_improvement;
1794
- } lbfgs;
1795
- };
1796
-
1797
- WSP_GGML_API struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type);
1798
-
1799
- // optimize the function defined by the tensor f
1800
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt(
1801
- struct wsp_ggml_context * ctx,
1802
- struct wsp_ggml_opt_params params,
1803
- struct wsp_ggml_tensor * f);
1804
-
1805
- // initialize optimizer context
1806
- WSP_GGML_API void wsp_ggml_opt_init(
1807
- struct wsp_ggml_context * ctx,
1808
- struct wsp_ggml_opt_context * opt,
1809
- struct wsp_ggml_opt_params params,
1810
- int64_t nx);
1811
-
1812
- // continue optimizing the function defined by the tensor f
1813
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume(
1814
- struct wsp_ggml_context * ctx,
1815
- struct wsp_ggml_opt_context * opt,
1816
- struct wsp_ggml_tensor * f);
1817
-
1818
- // continue optimizing the function defined by the tensor f
1819
- WSP_GGML_API enum wsp_ggml_opt_result wsp_ggml_opt_resume_g(
1820
- struct wsp_ggml_context * ctx,
1821
- struct wsp_ggml_opt_context * opt,
1822
- struct wsp_ggml_tensor * f,
1823
- struct wsp_ggml_cgraph * gf,
1824
- struct wsp_ggml_cgraph * gb,
1825
- wsp_ggml_opt_callback callback,
1826
- void * callback_data);
1827
2045
 
2046
+ // - wsp_ggml_wsp_quantize_init can be called multiple times with the same type
2047
+ // it will only initialize the quantization tables for the first call or after wsp_ggml_wsp_quantize_free
2048
+ // automatically called by wsp_ggml_wsp_quantize_chunk for convenience
1828
2049
  //
1829
- // quantization
2050
+ // - wsp_ggml_wsp_quantize_free will free any memory allocated by wsp_ggml_wsp_quantize_init
2051
+ // call this at the end of the program to avoid memory leaks
1830
2052
  //
1831
-
1832
- WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
1833
- WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
1834
- WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
1835
- WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
1836
- WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
1837
-
1838
- WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
2053
+ // note: these are thread-safe
2054
+ //
2055
+ WSP_GGML_API void wsp_ggml_wsp_quantize_init(enum wsp_ggml_type type);
2056
+ WSP_GGML_API void wsp_ggml_wsp_quantize_free(void);
2057
+
2058
+ // some quantization type cannot be used without an importance matrix
2059
+ WSP_GGML_API bool wsp_ggml_wsp_quantize_requires_imatrix(enum wsp_ggml_type type);
2060
+
2061
+ // calls wsp_ggml_wsp_quantize_init internally (i.e. can allocate memory)
2062
+ WSP_GGML_API size_t wsp_ggml_wsp_quantize_chunk(
2063
+ enum wsp_ggml_type type,
2064
+ const float * src,
2065
+ void * dst,
2066
+ int64_t start,
2067
+ int64_t nrows,
2068
+ int64_t n_per_row,
2069
+ const float * imatrix);
1839
2070
 
1840
2071
  //
1841
2072
  // gguf
1842
2073
  //
1843
2074
 
1844
- enum gguf_type {
1845
- GGUF_TYPE_UINT8 = 0,
1846
- GGUF_TYPE_INT8 = 1,
1847
- GGUF_TYPE_UINT16 = 2,
1848
- GGUF_TYPE_INT16 = 3,
1849
- GGUF_TYPE_UINT32 = 4,
1850
- GGUF_TYPE_INT32 = 5,
1851
- GGUF_TYPE_FLOAT32 = 6,
1852
- GGUF_TYPE_BOOL = 7,
1853
- GGUF_TYPE_STRING = 8,
1854
- GGUF_TYPE_ARRAY = 9,
1855
- GGUF_TYPE_UINT64 = 10,
1856
- GGUF_TYPE_INT64 = 11,
1857
- GGUF_TYPE_FLOAT64 = 12,
1858
- GGUF_TYPE_COUNT, // marks the end of the enum
2075
+ enum wsp_gguf_type {
2076
+ WSP_GGUF_TYPE_UINT8 = 0,
2077
+ WSP_GGUF_TYPE_INT8 = 1,
2078
+ WSP_GGUF_TYPE_UINT16 = 2,
2079
+ WSP_GGUF_TYPE_INT16 = 3,
2080
+ WSP_GGUF_TYPE_UINT32 = 4,
2081
+ WSP_GGUF_TYPE_INT32 = 5,
2082
+ WSP_GGUF_TYPE_FLOAT32 = 6,
2083
+ WSP_GGUF_TYPE_BOOL = 7,
2084
+ WSP_GGUF_TYPE_STRING = 8,
2085
+ WSP_GGUF_TYPE_ARRAY = 9,
2086
+ WSP_GGUF_TYPE_UINT64 = 10,
2087
+ WSP_GGUF_TYPE_INT64 = 11,
2088
+ WSP_GGUF_TYPE_FLOAT64 = 12,
2089
+ WSP_GGUF_TYPE_COUNT, // marks the end of the enum
1859
2090
  };
1860
2091
 
1861
- struct gguf_context;
2092
+ struct wsp_gguf_context;
1862
2093
 
1863
- struct gguf_init_params {
2094
+ struct wsp_gguf_init_params {
1864
2095
  bool no_alloc;
1865
2096
 
1866
2097
  // if not NULL, create a wsp_ggml_context and allocate the tensor data in it
1867
2098
  struct wsp_ggml_context ** ctx;
1868
2099
  };
1869
2100
 
1870
- WSP_GGML_API struct gguf_context * gguf_init_empty(void);
1871
- WSP_GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
1872
- //WSP_GGML_API struct gguf_context * gguf_init_from_buffer(..);
1873
-
1874
- WSP_GGML_API void gguf_free(struct gguf_context * ctx);
1875
-
1876
- WSP_GGML_API const char * gguf_type_name(enum gguf_type type);
1877
-
1878
- WSP_GGML_API int gguf_get_version (const struct gguf_context * ctx);
1879
- WSP_GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
1880
- WSP_GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
1881
- WSP_GGML_API void * gguf_get_data (const struct gguf_context * ctx);
1882
-
1883
- WSP_GGML_API int gguf_get_n_kv(const struct gguf_context * ctx);
1884
- WSP_GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key);
1885
- WSP_GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i);
1886
-
1887
- WSP_GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i);
1888
- WSP_GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i);
1889
-
1890
- // results are undefined if the wrong type is used for the key
1891
- WSP_GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i);
1892
- WSP_GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i);
1893
- WSP_GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i);
1894
- WSP_GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i);
1895
- WSP_GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i);
1896
- WSP_GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i);
1897
- WSP_GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i);
1898
- WSP_GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i);
1899
- WSP_GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i);
1900
- WSP_GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i);
1901
- WSP_GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i);
1902
- WSP_GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i);
1903
- WSP_GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i);
1904
- WSP_GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i);
1905
- WSP_GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
1906
-
1907
- WSP_GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
1908
- WSP_GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
1909
- WSP_GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
1910
- WSP_GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
2101
+ WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_empty(void);
2102
+ WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp_gguf_init_params params);
2103
+ //WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_buffer(..);
2104
+
2105
+ WSP_GGML_API void wsp_gguf_free(struct wsp_gguf_context * ctx);
2106
+
2107
+ WSP_GGML_API const char * wsp_gguf_type_name(enum wsp_gguf_type type);
2108
+
2109
+ WSP_GGML_API int wsp_gguf_get_version (const struct wsp_gguf_context * ctx);
2110
+ WSP_GGML_API size_t wsp_gguf_get_alignment (const struct wsp_gguf_context * ctx);
2111
+ WSP_GGML_API size_t wsp_gguf_get_data_offset(const struct wsp_gguf_context * ctx);
2112
+ WSP_GGML_API void * wsp_gguf_get_data (const struct wsp_gguf_context * ctx);
2113
+
2114
+ WSP_GGML_API int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx);
2115
+ WSP_GGML_API int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key);
2116
+ WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int key_id);
2117
+
2118
+ WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int key_id);
2119
+ WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id);
2120
+
2121
+ // will abort if the wrong type is used for the key
2122
+ WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int key_id);
2123
+ WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int key_id);
2124
+ WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int key_id);
2125
+ WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int key_id);
2126
+ WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int key_id);
2127
+ WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int key_id);
2128
+ WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int key_id);
2129
+ WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int key_id);
2130
+ WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int key_id);
2131
+ WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id);
2132
+ WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id);
2133
+ WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id);
2134
+ WSP_GGML_API const void * wsp_gguf_get_val_data(const struct wsp_gguf_context * ctx, int key_id);
2135
+ WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id);
2136
+ WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id);
2137
+ WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i);
2138
+
2139
+ WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx);
2140
+ WSP_GGML_API int wsp_gguf_find_tensor (const struct wsp_gguf_context * ctx, const char * name);
2141
+ WSP_GGML_API size_t wsp_gguf_get_tensor_offset(const struct wsp_gguf_context * ctx, int i);
2142
+ WSP_GGML_API char * wsp_gguf_get_tensor_name (const struct wsp_gguf_context * ctx, int i);
2143
+ WSP_GGML_API enum wsp_ggml_type wsp_gguf_get_tensor_type (const struct wsp_gguf_context * ctx, int i);
2144
+
2145
+ // removes key if it exists
2146
+ WSP_GGML_API void wsp_gguf_remove_key(struct wsp_gguf_context * ctx, const char * key);
1911
2147
 
1912
2148
  // overrides existing values or adds a new one
1913
- WSP_GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
1914
- WSP_GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
1915
- WSP_GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
1916
- WSP_GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
1917
- WSP_GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
1918
- WSP_GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
1919
- WSP_GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
1920
- WSP_GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
1921
- WSP_GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
1922
- WSP_GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
1923
- WSP_GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
1924
- WSP_GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
1925
- WSP_GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
1926
- WSP_GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
2149
+ WSP_GGML_API void wsp_gguf_set_val_u8 (struct wsp_gguf_context * ctx, const char * key, uint8_t val);
2150
+ WSP_GGML_API void wsp_gguf_set_val_i8 (struct wsp_gguf_context * ctx, const char * key, int8_t val);
2151
+ WSP_GGML_API void wsp_gguf_set_val_u16 (struct wsp_gguf_context * ctx, const char * key, uint16_t val);
2152
+ WSP_GGML_API void wsp_gguf_set_val_i16 (struct wsp_gguf_context * ctx, const char * key, int16_t val);
2153
+ WSP_GGML_API void wsp_gguf_set_val_u32 (struct wsp_gguf_context * ctx, const char * key, uint32_t val);
2154
+ WSP_GGML_API void wsp_gguf_set_val_i32 (struct wsp_gguf_context * ctx, const char * key, int32_t val);
2155
+ WSP_GGML_API void wsp_gguf_set_val_f32 (struct wsp_gguf_context * ctx, const char * key, float val);
2156
+ WSP_GGML_API void wsp_gguf_set_val_u64 (struct wsp_gguf_context * ctx, const char * key, uint64_t val);
2157
+ WSP_GGML_API void wsp_gguf_set_val_i64 (struct wsp_gguf_context * ctx, const char * key, int64_t val);
2158
+ WSP_GGML_API void wsp_gguf_set_val_f64 (struct wsp_gguf_context * ctx, const char * key, double val);
2159
+ WSP_GGML_API void wsp_gguf_set_val_bool(struct wsp_gguf_context * ctx, const char * key, bool val);
2160
+ WSP_GGML_API void wsp_gguf_set_val_str (struct wsp_gguf_context * ctx, const char * key, const char * val);
2161
+ WSP_GGML_API void wsp_gguf_set_arr_data(struct wsp_gguf_context * ctx, const char * key, enum wsp_gguf_type type, const void * data, int n);
2162
+ WSP_GGML_API void wsp_gguf_set_arr_str (struct wsp_gguf_context * ctx, const char * key, const char ** data, int n);
1927
2163
 
1928
2164
  // set or add KV pairs from another context
1929
- WSP_GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
2165
+ WSP_GGML_API void wsp_gguf_set_kv(struct wsp_gguf_context * ctx, struct wsp_gguf_context * src);
1930
2166
 
1931
2167
  // manage tensor info
1932
- WSP_GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
1933
- WSP_GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum wsp_ggml_type type);
1934
- WSP_GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
2168
+ WSP_GGML_API void wsp_gguf_add_tensor(struct wsp_gguf_context * ctx, const struct wsp_ggml_tensor * tensor);
2169
+ WSP_GGML_API void wsp_gguf_set_tensor_type(struct wsp_gguf_context * ctx, const char * name, enum wsp_ggml_type type);
2170
+ WSP_GGML_API void wsp_gguf_set_tensor_data(struct wsp_gguf_context * ctx, const char * name, const void * data, size_t size);
1935
2171
 
1936
2172
  // writing gguf files can be done in 2 ways:
1937
2173
  //
1938
- // - write the entire gguf_context to a binary file in a single pass:
2174
+ // - write the entire wsp_gguf_context to a binary file in a single pass:
1939
2175
  //
1940
- // gguf_write_to_file(ctx, fname);
2176
+ // wsp_gguf_write_to_file(ctx, fname);
1941
2177
  //
1942
2178
  // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
1943
2179
  //
1944
2180
  // FILE * f = fopen(fname, "wb");
1945
- // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
2181
+ // fseek(f, wsp_gguf_get_meta_size(ctx), SEEK_SET);
1946
2182
  // fwrite(f, ...);
1947
- // void * data = gguf_meta_get_meta_data(ctx);
2183
+ // void * data = wsp_gguf_meta_get_meta_data(ctx);
1948
2184
  // fseek(f, 0, SEEK_SET);
1949
- // fwrite(f, data, gguf_get_meta_size(ctx));
2185
+ // fwrite(f, data, wsp_gguf_get_meta_size(ctx));
1950
2186
  // free(data);
1951
2187
  // fclose(f);
1952
2188
  //
1953
2189
 
1954
2190
  // write the entire context to a binary file
1955
- WSP_GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
2191
+ WSP_GGML_API void wsp_gguf_write_to_file(const struct wsp_gguf_context * ctx, const char * fname, bool only_meta);
1956
2192
 
1957
2193
  // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
1958
- WSP_GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
1959
- WSP_GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
1960
-
1961
- //
1962
- // system info
1963
- //
1964
-
1965
- WSP_GGML_API int wsp_ggml_cpu_has_avx (void);
1966
- WSP_GGML_API int wsp_ggml_cpu_has_avx2 (void);
1967
- WSP_GGML_API int wsp_ggml_cpu_has_avx512 (void);
1968
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vbmi(void);
1969
- WSP_GGML_API int wsp_ggml_cpu_has_avx512_vnni(void);
1970
- WSP_GGML_API int wsp_ggml_cpu_has_fma (void);
1971
- WSP_GGML_API int wsp_ggml_cpu_has_neon (void);
1972
- WSP_GGML_API int wsp_ggml_cpu_has_arm_fma (void);
1973
- WSP_GGML_API int wsp_ggml_cpu_has_metal (void);
1974
- WSP_GGML_API int wsp_ggml_cpu_has_f16c (void);
1975
- WSP_GGML_API int wsp_ggml_cpu_has_fp16_va (void);
1976
- WSP_GGML_API int wsp_ggml_cpu_has_wasm_simd (void);
1977
- WSP_GGML_API int wsp_ggml_cpu_has_blas (void);
1978
- WSP_GGML_API int wsp_ggml_cpu_has_cublas (void);
1979
- WSP_GGML_API int wsp_ggml_cpu_has_clblast (void);
1980
- WSP_GGML_API int wsp_ggml_cpu_has_gpublas (void);
1981
- WSP_GGML_API int wsp_ggml_cpu_has_sse3 (void);
1982
- WSP_GGML_API int wsp_ggml_cpu_has_ssse3 (void);
1983
- WSP_GGML_API int wsp_ggml_cpu_has_vsx (void);
1984
-
1985
- //
1986
- // Internal types and functions exposed for tests and benchmarks
1987
- //
2194
+ WSP_GGML_API size_t wsp_gguf_get_meta_size(const struct wsp_gguf_context * ctx);
2195
+ WSP_GGML_API void wsp_gguf_get_meta_data(const struct wsp_gguf_context * ctx, void * data);
1988
2196
 
1989
2197
  #ifdef __cplusplus
1990
2198
  // restrict not standard in C++
@@ -1992,23 +2200,20 @@ extern "C" {
1992
2200
  #else
1993
2201
  #define WSP_GGML_RESTRICT restrict
1994
2202
  #endif
1995
- typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int k);
1996
- typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int k);
1997
- typedef void (*wsp_ggml_vec_dot_t) (const int n, float * WSP_GGML_RESTRICT s, const void * WSP_GGML_RESTRICT x, const void * WSP_GGML_RESTRICT y);
1998
-
1999
- typedef struct {
2000
- const char * type_name;
2001
- int blck_size;
2002
- size_t type_size;
2003
- bool is_quantized;
2004
- wsp_ggml_to_float_t to_float;
2005
- wsp_ggml_from_float_t from_float;
2006
- wsp_ggml_from_float_t from_float_reference;
2007
- wsp_ggml_vec_dot_t vec_dot;
2008
- enum wsp_ggml_type vec_dot_type;
2009
- } wsp_ggml_type_traits_t;
2010
-
2011
- wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type);
2203
+ typedef void (*wsp_ggml_to_float_t) (const void * WSP_GGML_RESTRICT x, float * WSP_GGML_RESTRICT y, int64_t k);
2204
+ typedef void (*wsp_ggml_from_float_t)(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT y, int64_t k);
2205
+
2206
+ struct wsp_ggml_type_traits {
2207
+ const char * type_name;
2208
+ int64_t blck_size;
2209
+ int64_t blck_size_interleave; // interleave elements in blocks
2210
+ size_t type_size;
2211
+ bool is_quantized;
2212
+ wsp_ggml_to_float_t to_float;
2213
+ wsp_ggml_from_float_t from_float_ref;
2214
+ };
2215
+
2216
+ WSP_GGML_API const struct wsp_ggml_type_traits * wsp_ggml_get_type_traits(enum wsp_ggml_type type);
2012
2217
 
2013
2218
  #ifdef __cplusplus
2014
2219
  }